机器学习中,决策树是一个预测模型;他代表的是对象属性与对象值之间的一种映射关系。树中每个节点表示某个对象,而每个分叉路径则代表的某个可能的属性值,而每个叶结点则对应从根节点到该叶节点所经历的路径所表示的对象的值。决策树仅有单一输出,若欲有复数输出,可以建立独立的决策树以处理不同输出。 数据挖掘中决策树是一种经常要用到的技术,可以用于分析数据,同样也可以用来作预测。
一决策树与ID3概述 决策树2ID3算法ID3算法与决策树的流程 二Python算法实现 构造函数createDataSet1 计算信息熵2 按照最大信息增益划分数据集3 创建决策树构造函数createTree4 决策树运用于分类5 决策树的存储 三使用Matplotlib绘制决策树四实例使用决策树预测隐形眼镜类型决策树,其结构和树非常相似,因此得其名决策树。决策树具有树形的结构,其中每个内部节点表示一个属性上的测试,每个分支代表一个测试输出,每个叶节点代表一种类别。
例如:
按照豆腐脑的冷热、甜咸和是否含有大蒜构建决策树,对其属性的测试,在最终的叶节点决定该豆腐脑吃还是不吃。
分类树(决策树)是一种十分常用的将决策树应用于分类的机器学习方法。他是一种监管学习,所谓监管学习就是给定一堆样本,每个样本都有一组属性(特征)和一个类别(分类信息/目标),这些类别是事先确定的,那么通过学习得到一个分类器,这个分类器能够对新出现的对象给出正确的分类。 其原理在于,每个决策树都表述了一种树型结构,它由它的分支来对该类型的对象依靠属性进行分类。每个决策树可以依靠对源数据库的分割进行数据测试。这个过程可以递归式的对树进行修剪。 当不能再进行分割或一个单独的类可以被应用于某一分支时,递归过程就完成了。
机器学习中,决策树是一个预测模型;他代表的是对象属性与对象值之间的一种映射关系。树中每个节点表示某个对象,而每个分叉路径则代表的某个可能的属性值,而每个叶结点则对应从根节点到该叶节点所经历的路径所表示的对象的值。决策树仅有单一输出,若欲有复数输出,可以建立独立的决策树以处理不同输出。数据挖掘中决策树是一种经常要用到的技术,可以用于分析数据,同样也可以用来作预测。从数据产生决策树的机器学习技术叫做决策树学习, 通俗说就是决策树。
目前常用的决策树算法有ID3算法、改进的C4.5算法和CART算法。
决策树的特点 1.多层次的决策树形式易于理解;2.只适用于标称型数据,对连续性数据处理得不好;ID3算法最早是由罗斯昆(J. Ross Quinlan)于1975年在悉尼大学提出的一种分类预测算法,算法以信息论为基础,其核心是“信息熵”。ID3算法通过计算每个属性的信息增益,认为信息增益高的是好属性,每次划分选取信息增益最高的属性为划分标准,重复这个过程,直至生成一个能完美分类训练样例的决策树。
信息熵(Entropy):
H(x)=−∑i=1np(xi)log2p(xi)=∑i=1np(xi)log21p(xi) ,其中 p(xi) 是选择i的概率。 熵越高,表示混合的数据越多。信息增益(Information Gain):
IG=H−∑t∈Tp(t)H(t) T是划分之后的分支集合,p(t)是该分支集合在原本的父集合中出现的概率,H(t)是该子集合的信息熵。(1)数据准备:需要对数值型数据进行离散化 (2)ID3算法构建决策树:
如果数据集类别完全相同,则停止划分否则,继续划分决策树: 计算信息熵和信息增益来选择最好的数据集划分方法;划分数据集创建分支节点:对每个分支进行判定是否类别相同,如果相同停止划分,不同按照上述方法进行划分。创建 trees.py文件,在其中创建构建决策树的函数。 首先构建一组测试数据:
序号不浮出水面是否可以生存是否有脚蹼是否属于鱼类1是是是2是是是3是否否4否是否5否是否在Python控制台测试构造函数
#测试下构造的数据 import trees myDat,labels = trees.createDataSet() myDat Out[4]: [[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']] labels Out[5]: ['no surfacing', 'flippers']利用构造的数据测试calcShannonEnt:
#Python console In [6]: trees.calcShannonEnt(myDat) ...: Out[6]: 0.9709505944546686在控制台中测试这两个函数:
#测试按照特征划分数据集的函数 In [8]: from imp import reload In [9]: reload(trees) Out[9]: <module 'trees' from 'G:\\Workspaces\\MachineLearning\\trees.py'> In [10]: myDat,labels=trees.createDataSet() ...: In [11]: trees.splitDataSet(myDat,0,0) ...: Out[11]: [[1, 'no'], [1, 'no']] In [12]: trees.splitDataSet(myDat,0,1) ...: Out[12]: [[1, 'yes'], [1, 'yes'], [0, 'no']] #测试chooseBestFeatureToSplit函数 In [13]: reload(trees) ...: Out[13]: <module 'trees' from 'G:\\Workspaces\\MachineLearning\\trees.py'> In [14]: trees.chooseBestFeatureToSplit(myDat) ...: Out[14]: 0以之前构造的测试数据为例,对决策树构造函数进行测试,在Python控制台进行输入:
#决策树构造函数测试 In [15]: reload(trees) ...: Out[15]: <module 'trees' from 'G:\\Workspaces\\MachineLearning\\trees.py'> In [16]: myTree=trees.createTree(myDat,labels) ...: In [17]: myTree Out[17]: {'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}可以看到,最后生成的决策树myTree是一个多层嵌套的字典。
运用决策树进行分类,首先构建一个决策树分类函数:
#输入三个变量(决策树,属性特征标签,测试的数据) def classify(inputTree,featLables,testVec): firstStr=list(inputTree.keys())[0] #获取树的第一个特征属性 secondDict=inputTree[firstStr] #树的分支,子集合Dict featIndex=featLables.index(firstStr) #获取决策树第一层在featLables中的位置 for key in secondDict.keys(): if testVec[featIndex]==key: if type(secondDict[key]).__name__=='dict': classLabel=classify(secondDict[key],featLables,testVec) else:classLabel=secondDict[key] return classLabel对决策树分类函数进行测试:
In [29]: reload(trees) ...: Out[29]: <module 'trees' from 'G:\\Workspaces\\MachineLearning\\trees.py'> In [30]: myDat,labels=trees.createDataSet() ...: In [31]: labels ...: Out[31]: ['no surfacing', 'flippers'] In [32]: myTree=treeplotter.retrieveTree(0) ...: In [33]: myTree ...: Out[33]: {'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}} In [34]: trees.classify(myTree,labels,[1,0]) ...: Out[34]: 'no' In [35]: trees.classify(myTree,labels,[1,1]) ...: Out[35]: 'yes'如果每次都需要训练样本集来构建决策树,费时费力,特别是数据很大的时候,每次重新构建决策树浪费时间。因此可以将已经创建的决策树(如字典形式)保存在硬盘上,需要使用的时候直接读取就好。 (1)存储函数
def storeTree(inputTree,filename): import pickle fw=open(filename,'wb') #pickle默认方式是二进制,需要制定'wb' pickle.dump(inputTree,fw) fw.close()(2)读取函数
def grabTree(filename): import pickle fr=open(filename,'rb')#需要制定'rb',以byte形式读取 return pickle.load(fr)对这两个函数进行测试(Python console):
In [36]: myTree Out[36]: {'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}} In [37]: trees.storeTree(myTree,'classifierStorage.txt') In [38]: trees.grabTree('classifierStorage.txt') Out[38]: {'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}在工作目录下存在一个名为’classifierStorage.txt’的txt文档,该文档 保存了myTree的决策树信息,需要使用的时候直接调出使用。
对绘制决策树图的函数进行测试(控制台):
In [26]: reload(treeplotter) ...: Out[26]: <module 'treeplotter' from 'G:\\Workspaces\\MachineLearning\\treeplotter.py'> In [27]: myTree=treeplotter.retrieveTree(0) ...: In [28]: treeplotter.createPlot(myTree) ...:得到决策树图:
隐形眼镜的数据集包含了患者的四个属性age,prescript,stigmatic,tearRate,利用这些数据构建决策树,并通过Matplotlib绘制出决策树的树状图。 附lenses.txt数据:
young myope no reduced no lenses young myope no normal soft young myope yes reduced no lenses young myope yes normal hard young hyper no reduced no lenses young hyper no normal soft young hyper yes reduced no lenses young hyper yes normal hard pre myope no reduced no lenses pre myope no normal soft pre myope yes reduced no lenses pre myope yes normal hard pre hyper no reduced no lenses pre hyper no normal soft pre hyper yes reduced no lenses pre hyper yes normal no lenses presbyopic myope no reduced no lenses presbyopic myope no normal no lenses presbyopic myope yes reduced no lenses presbyopic myope yes normal hard presbyopic hyper no reduced no lenses presbyopic hyper no normal soft presbyopic hyper yes reduced no lenses presbyopic hyper yes normal no lenses In [40]: fr=open('machinelearninginaction/Ch03/lenses.txt') In [41]: lenses=[inst.strip().split('\t') for inst in fr.readlines()] In [42]: lensesLabels=['age','prescript','astigmatic','tearRate'] In [43]: lensesTree=trees.createTree(lenses,lensesLabels) In [44]: lensesTree Out[44]: {'tearRate': {'normal': {'astigmatic': {'no': {'age': {'pre': 'soft', 'presbyopic': {'prescript': {'hyper': 'soft', 'myope': 'no lenses'}}, 'young': 'soft'}}, 'yes': {'prescript': {'hyper': {'age': {'pre': 'no lenses', 'presbyopic': 'no lenses', 'young': 'hard'}}, 'myope': 'hard'}}}}, 'reduced': 'no lenses'}} In [45]: treeplotter.createPlot(lensesTree)得到图: