Matplotlib绘制树形图

xiaoxiao2021-02-27  332

Matplotlib绘制树形图

树信息存储为”字典“对象,

例如{‘no surfacing’: {0: ‘no’, 1: {‘flippers’: {0: ‘no’, 1: ‘yes’}}}} 绘制图形为:

#绘制树形图 import matplotlib.pyplot as plt decision_node = dict(boxstyle="sawtooth",fc="0.8") leaf_node = dict(boxstyle="round4",fc="0.8") arrow_args = dict(arrowstyle="<-") #获取树的叶子结点个数(确定图的宽度) def get_leaf_num(tree): leaf_num = 0 first_key = list(tree.keys())[0] next_dict = tree[first_key] for key in next_dict.keys(): if type(next_dict[key]).__name__=="dict": leaf_num +=get_leaf_num(next_dict[key]) else: leaf_num +=1 return leaf_num #获取数的深度(确定图的高度) def get_tree_depth(tree): depth = 0 first_key = list(tree.keys())[0] next_dict = tree[first_key] for key in next_dict.keys(): if type(next_dict[key]).__name__ == "dict": thisdepth = 1+ get_tree_depth(next_dict[key]) else: thisdepth = 1 if thisdepth>depth: depth = thisdepth return depth def plotNode(nodeTxt, centerPt, parentPt, nodeType): createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', xytext=centerPt, textcoords='axes fraction', va="center", ha="center", bbox=nodeType, arrowprops=arrow_args) #在父子节点间填充文本信息 def plotMidText(cntrPt, parentPt, txtString): xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0] yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1] createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30) def plotTree(myTree, parentPt, nodeTxt): numLeafs = get_leaf_num(myTree) depth = get_tree_depth(myTree) firstStr = list(myTree.keys())[0] cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff) plotMidText(cntrPt, parentPt, nodeTxt) plotNode(firstStr, cntrPt, parentPt, decision_node) secondDict = myTree[firstStr] plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD for key in secondDict.keys(): if type(secondDict[ key]).__name__ == 'dict': plotTree(secondDict[key], cntrPt, str(key)) else: plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leaf_node) plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key)) plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD def createPlot(inTree): fig = plt.figure(1, facecolor='white') fig.clf() axprops = dict(xticks=[], yticks=[]) createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) plotTree.totalW = float(get_leaf_num(inTree)) plotTree.totalD = float(get_tree_depth(inTree)) plotTree.xOff = -0.5 / plotTree.totalW plotTree.yOff = 1.0 plotTree(inTree, (0.5, 1.0), '') plt.show()
转载请注明原文地址: https://www.6miu.com/read-4698.html

最新回复(0)