import numpy
as np
from math
import log
def createTree(dataSet,labels):
classList=[s[-
1]
for s
in dataSet]
if classList.count(classList[
0])==len(classList):
return classList[
0]
elif len(dataSet[
0])==
1:
return majorityLabel(dataSet)
bestFeat=chooseBestFeature(dataSet)
bestfeatlabel=labels[bestFeat]
myTree={bestfeatlabel:{}}
del(labels[bestFeat])
featvalues=[example[bestFeat]
for example
in dataSet]
uniqueValue=set(featvalues)
for value
in uniqueValue:
subLabels = labels[:]
myTree[bestfeatlabel][value]=createTree(splitDataSet(dataSet,bestFeat,value),subLabels)
return myTree
def splitDataSet(dataSet,feat,value):
subDateSet=[]
for line
in dataSet:
if line[feat]==value:
lineSubData=line[:feat]
lineSubData.extend(line[feat+
1:])
subDateSet.append(lineSubData)
return subDateSet
def majorityLabel(classList):
counts={}
unique=set(classList)
for i
in unique:
counts[i]=classList.count(i)
sortedlabel = sorted(counts.iteritems(), key=
lambda asd: asd[
1], reverse=
False)
return sortedlabel[
0][
0]
def calcEntropy(dataSet):
classList=[example[-
1]
for example
in dataSet]
numberData = len(classList)
uniqueLabel=set(classList)
counts=[]
for i
in uniqueLabel:
counts.append(classList.count(i))
entropy=
0.0
for i
in counts:
prop=i*
1.0/numberData
entropy-=prop*log(prop,
2)
return entropy
def chooseBestFeature(dataSet):
numFeatures=len(dataSet[
0])-
1
baseEntropy=calcEntropy(dataSet)
maxEntropy=
0.0;bestFeature=-
1
for i
in range(numFeatures):
featureValues=[example[i]
for example
in dataSet]
uniqueValue=set(featureValues)
newEntropy=
0.0
for value
in uniqueValue:
subDataSet=splitDataSet(dataSet,i,value)
prob=len(subDataSet)/float(len(dataSet))
newEntropy+=prob*calcEntropy(subDataSet)
infoGain=baseEntropy-newEntropy
if infoGain>maxEntropy:
maxEntropy=infoGain
bestFeature=i
return bestFeature
def storeTree(inputTree, filename):
import pickle
fw = open(filename,
'w')
pickle.dump(inputTree, fw)
fw.close()
def grabTree(filename):
import pickle
fr = open(filename)
return pickle.load(fr)
dataSet = [[
1,
1,
'yes'],
[
1,
1,
'yes'],
[
1,
0,
'no'],
[
0,
1,
'no'],
[
0,
1,
'no']]
labels = [
'no surfacing',
'flippers']
fr=open(
'Ch03\lenses.txt')
lenses=[inst.strip().split(
'\t')
for inst
in fr.readlines()]
lenseLabel=[
'age',
'pres',
'asti',
'tear',
'ui',
'iu']
lensetree=createTree(lenses,lenseLabel)
print lensetree