机器学习实战 决策树

xiaoxiao2021-02-28  54

决策树的构造

信息增益

计算给定数据集的香农熵

from math import log def calcShannonEnt(dataSet): numEntries = len(dataSet) labelCounts = {} for featVec in dataSet: currentLabel = featVec[-1] if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0 labelCounts[currentLabel] += 1 shannonEnt = 0.0 for key in labelCounts: prob = float(labelCounts[key]) / numEntries shannonEnt -= prob * log(prob,2) return shannonEnt def createDataSet(): dataSet = [[1,1,'yes'], [1,1,'yes'], [1,0,'no'], [0,1,'no'], [0,1,'no']] labels = ['no surfacing','flippers'] return dataSet, labels if __name__ == '__main__': myDat,labels = createDataSet() print(myDat) print(calcShannonEnt(myDat)) myDat[0][-1] = 'maybe' print(myDat) print(calcShannonEnt(myDat))

划分数据集

按照给定的特征划分数据集

def splitDataSet(dataSet, axis, value): retDataSet = [] for featVec in dataSet: if featVec[axis] == value: reducedFeatVec = featVec[:axis] reducedFeatVec.extend(featVec[axis+1:]) retDataSet.append(reducedFeatVec) return retDataSet

选择最好的数据集划分方式

def chooseBestFeatureToSplit(dataSet): numFeatures = len(dataSet[0]) - 1 baseEntropy = calcShannonEnt(dataSet) #初始的香农熵 bestInfoGain = 0.0#最好的信息增益 bestFeature = -1#最好特征划分的索引值 for i in range(numFeatures): featList = [example[i] for example in dataSet] #获取第i个特征 uniqueVals = set(featList) #去重 newEntropy = 0.0 #划分后的熵 for value in uniqueVals: subDataSet = splitDataSet(dataSet, i, value) prob = len(subDataSet)/float(len(dataSet)) newEntropy += prob*calcShannonEnt(subDataSet) infoGain = baseEntropy - newEntropy #信息增益是熵的减少 if infoGain > bestInfoGain: #记录best... bestInfoGain = infoGain bestFeature = i return bestFeature

递归构建决策树

from math import log import operator def calcShannonEnt(dataSet): numEntries = len(dataSet) labelCounts = {} for featVec in dataSet: currentLabel = featVec[-1] if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0 labelCounts[currentLabel] += 1 shannonEnt = 0.0 for key in labelCounts: prob = float(labelCounts[key]) / numEntries shannonEnt -= prob * log(prob,2) return shannonEnt def createDataSet(): dataSet = [[1,1,'yes'], [1,1,'yes'], [1,0,'no'], [0,1,'no'], [0,1,'no']] labels = ['no surfacing','flippers'] return dataSet, labels def splitDataSet(dataSet, axis, value): retDataSet = [] for featVec in dataSet: if featVec[axis] == value: reducedFeatVec = featVec[:axis] reducedFeatVec.extend(featVec[axis+1:]) retDataSet.append(reducedFeatVec) return retDataSet def chooseBestFeatureToSplit(dataSet): numFeatures = len(dataSet[0]) - 1 baseEntropy = calcShannonEnt(dataSet) #初始的香农熵 bestInfoGain = 0.0#最好的信息增益 bestFeature = -1#最好特征划分的索引值 for i in range(numFeatures): featList = [example[i] for example in dataSet] #获取第i个特征 uniqueVals = set(featList) #去重 newEntropy = 0.0 #划分后的熵 for value in uniqueVals: subDataSet = splitDataSet(dataSet, i, value) prob = len(subDataSet)/float(len(dataSet)) newEntropy += prob*calcShannonEnt(subDataSet) infoGain = baseEntropy - newEntropy #信息增益是熵的减少 if infoGain > bestInfoGain: #记录best... bestInfoGain = infoGain bestFeature = i return bestFeature def majorityCnt(classList): classCount={} for vote in classList: if vote not in classCount.keys(): classCount[vote] = 0 classCount[vote] += 1 sortedClassCount = sorted(classCount.items(),key=operator.itemgetter(1),reverse=True)#根据值排序 return sortedClassCount[0][0]#返回出现次数最多的类别 def createTree(dataSet,labels): classList = [example[-1] for example in dataSet] if classList.count(classList[0]) == len(classList):#类别完全相同时停止划分 return classList[0] if len(dataSet[0]) == 1:#遍历完所有特征返回出现次数最多的类别 return majorityCnt(classList) bestFeat = chooseBestFeatureToSplit(dataSet) bestFeatLabel = labels[bestFeat] myTree = {bestFeatLabel:{}} del(labels[bestFeat]) featValues = [example[bestFeat] for example in dataSet] uniqueVals = set(featValues) for value in uniqueVals: subLabels = labels[:] myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value),subLabels) return myTree if __name__ == '__main__': myDat,labels = createDataSet() print(myDat) print(chooseBestFeatureToSplit(myDat)) myTree = createTree(myDat, labels) print(myTree)
转载请注明原文地址: https://www.6miu.com/read-2629187.html

最新回复(0)