《机器学习实战》第三章 3.2 在Python中使用Matplotlib注解绘制树形图
自己按照树上的代码敲了一遍(Spyder),调试之后可以使用,向大家分享一下,另外,有好多的注解,希望对大家有帮助。 代码源代码
"""
Created on Mon Dec 18 10:53:31 2017
@author: XU LI
"""
import matplotlib.pyplot
as plt
decisionNode = dict(boxstyle=
"sawtooth", fc=
"0.8")
leafNode = dict(boxstyle=
"round4", fc=
"0.8")
arrow_args = dict(arrowstyle=
"<-")
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords=
'axes fraction',
xytext=centerPt, textcoords=
'axes fraction',
va=
"center", ha=
"center", bbox=nodeType, arrowprops=arrow_args )
def createPlot():
fig = plt.figure(
1,facecolor=
'white')
fig.clf()
createPlot.ax1 = plt.subplot(
111,frameon=
False)
plotNode(
'a decision node',(
0.5,
0.1),(
0.1,
0.5),decisionNode)
plotNode(
'a leaf node',(
0.8,
0.1),(
0.3,
0.8),leafNode)
plt.show()
def getNumLeafs(myTree):
numLeafs =
0
firstStr = myTree.keys()[
0]
secondDict = myTree[firstStr]
for key
in secondDict.keys():
if type(secondDict[key]).__name__==
'dict':
numLeafs += getNumLeafs(secondDict[key])
else : numLeafs +=
1
return numLeafs
def getTreeDepth(myTree):
maxDepth =
0
firstStr = myTree.keys()[
0]
secondDict = myTree[firstStr]
for key
in secondDict.keys():
if type(secondDict[key]).__name__==
'dict':
thisDepth =
1 + getTreeDepth(secondDict[key])
else : thisDepth =
1
if thisDepth > maxDepth: maxDepth = thisDepth
return maxDepth
def retrieveTree(i):
listOfTrees = [{
'no surfacing':{
0:
'no',
1:{
'flippers':{
0:
'no',
1:
'yes'}}}},
{
'no surfacing':{
0:
'no',
1:{
'flippers':{
0:{
'head':{
0:
'no',
1:
'yes'}},
1:
'no'}}}}]
return listOfTrees[i]
def plotMidText(cntrpt,parentPt,txtString):
xMid = (parentPt[
0] - cntrpt[
0])/
2.0 + cntrpt[
0]
yMid = (parentPt[
1] - cntrpt[
1])/
2.0 + cntrpt[
1]
createPlot.ax1.text(xMid,yMid,txtString)
def plotTree(myTree,parentPt,nodeTxt):
numLeafs = getNumLeafs(myTree)
depth = getTreeDepth(myTree)
firstStr = myTree.keys()[
0]
cntrpt = (plotTree.x0ff + (
1.0 + float(numLeafs))/
2.0/plotTree.totalW,plotTree.y0ff)
plotMidText(cntrpt,parentPt,nodeTxt)
plotNode(firstStr,cntrpt,parentPt,decisionNode)
secondDict = myTree[firstStr]
plotTree.y0ff = plotTree.y0ff -
1.0/plotTree.totalD
for key
in secondDict.keys():
if type(secondDict[key]).__name__==
'dict':
plotTree(secondDict[key],cntrpt,str(key))
else:
plotTree.x0ff = plotTree.x0ff +
1.0/plotTree.totalW
plotNode(secondDict[key],(plotTree.x0ff,plotTree.y0ff),cntrpt,
leafNode)
plotMidText((plotTree.x0ff,plotTree.y0ff),cntrpt,str(key))
plotTree.y0ff = plotTree.y0ff +
1.0/plotTree.totalD
def createPlot(inTree):
fig = plt.figure(
1,facecolor=
'white')
fig.clf()
axprops = dict(xticks=[],yticks=[])
createPlot.ax1 = plt.subplot(
111,frameon=
False,**axprops)
plotTree.totalW = float(getNumLeafs(inTree))
plotTree.totalD = float(getTreeDepth(inTree))
plotTree.x0ff = -
0.5/plotTree.totalW;
plotTree.y0ff =
1.0;
plotTree(inTree, (
0.5,
1.0),
'')
plt.show()
3.结果分析代码
"""
Created on Sat Oct 28 20:52:14 2017
@author: XU LI
"""
myData,labels = createDataSet()
print myData
print calcShannonEnt(myData)
print'---------------------------------'
print splitDataSet (myData,
0,
1)
print splitDataSet (myData,
0,
0)
print splitDataSet (myData,
1,
1)
print splitDataSet (myData,
1,
0)
print'---------------------------------'
myData,labels = createDataSet()
print myData
print '第', chooseBestFeatureToSplit(myData),
print '个特征是最好的用于划分数据集的特征'
print'---------------------------------'
myData,labels = createDataSet()
myTree = createTree(myData,labels)
print 'myTree=',myTree
print'---------------------------------'
print'---------------------------------'
print treePlotter.retrieveTree(
1)
myTree = treePlotter.retrieveTree(
0)
print myTree
print treePlotter.getNumLeafs(myTree)
print treePlotter.getTreeDepth(myTree)
print'---------------------------------'
myTree = retrieveTree(
0)
createPlot(myTree)
4.结果