决策树实现

xiaoxiao2021-02-28  72

#-*- coding=utf8 -*- import numpy as np from math import log def createTree(dataSet,labels): classList=[s[-1] for s in dataSet] if classList.count(classList[0])==len(classList): return classList[0] #叶节点 elif len(dataSet[0])==1: return majorityLabel(dataSet) #叶节点 bestFeat=chooseBestFeature(dataSet) # print labels # print bestFeat bestfeatlabel=labels[bestFeat] myTree={bestfeatlabel:{}} del(labels[bestFeat]) featvalues=[example[bestFeat] for example in dataSet] uniqueValue=set(featvalues) for value in uniqueValue: subLabels = labels[:] #这里必须写成subLabels = labels[:]而不能是subLabels = labels myTree[bestfeatlabel][value]=createTree(splitDataSet(dataSet,bestFeat,value),subLabels) return myTree def splitDataSet(dataSet,feat,value): subDateSet=[] for line in dataSet: if line[feat]==value: lineSubData=line[:feat] lineSubData.extend(line[feat+1:]) subDateSet.append(lineSubData) return subDateSet def majorityLabel(classList): counts={} unique=set(classList) for i in unique: counts[i]=classList.count(i) sortedlabel = sorted(counts.iteritems(), key=lambda asd: asd[1], reverse=False) return sortedlabel[0][0] def calcEntropy(dataSet): classList=[example[-1] for example in dataSet] numberData = len(classList) uniqueLabel=set(classList) counts=[] for i in uniqueLabel: counts.append(classList.count(i)) entropy=0.0 for i in counts: prop=i*1.0/numberData entropy-=prop*log(prop,2) # print entropy return entropy def chooseBestFeature(dataSet): numFeatures=len(dataSet[0])-1 baseEntropy=calcEntropy(dataSet) maxEntropy=0.0;bestFeature=-1 for i in range(numFeatures): featureValues=[example[i] for example in dataSet] uniqueValue=set(featureValues) newEntropy=0.0 for value in uniqueValue: subDataSet=splitDataSet(dataSet,i,value) prob=len(subDataSet)/float(len(dataSet)) newEntropy+=prob*calcEntropy(subDataSet) infoGain=baseEntropy-newEntropy if infoGain>maxEntropy: maxEntropy=infoGain bestFeature=i return bestFeature def storeTree(inputTree, filename): import pickle fw = open(filename, 'w') pickle.dump(inputTree, fw) fw.close() def grabTree(filename): import pickle fr = open(filename) return pickle.load(fr) dataSet = [[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']] labels = ['no surfacing','flippers'] # # storeTree(createTree(dataSet,labels),'mytree.txt') fr=open('Ch03\lenses.txt') lenses=[inst.strip().split('\t') for inst in fr.readlines()] lenseLabel=['age','pres','asti','tear','ui','iu'] lensetree=createTree(lenses,lenseLabel) # grabTree('mytree.txt') print lensetree
转载请注明原文地址: https://www.6miu.com/read-69413.html

最新回复(0)