spark源码分析之随机森林(Random Forest)(一) spark源码分析之随机森林(Random Forest)(二) spark源码分析之随机森林(Random Forest)(四) spark源码分析之随机森林(Random Forest)(五)
树中的每个节点是一个Node结构
class Node @Since("1.2.0") ( @Since("1.0.0") val id: Int, @Since("1.0.0") var predict: Predict, @Since("1.2.0") var impurity: Double, @Since("1.0.0") var isLeaf: Boolean, @Since("1.0.0") var split: Option[Split], @Since("1.0.0") var leftNode: Option[Node], @Since("1.0.0") var rightNode: Option[Node], @Since("1.0.0") var stats: Option[InformationGainStats])emptyNode,只初始化nodeIndex,其他都是默认值
def emptyNode(nodeIndex: Int): Node = new Node(nodeIndex, new Predict(Double.MinValue), -1.0, false, None, None, None, None)根据node的id,计算孩子节点的id
* Return the index of the left child of this node. */ def leftChildIndex(nodeIndex: Int): Int = nodeIndex << 1 /** * Return the index of the right child of this node. */ def rightChildIndex(nodeIndex: Int): Int = (nodeIndex << 1) + 1左孩子节点就是当前id * 2,右孩子是id * 2+1。
Entropy是个Object,里面最重要的是calculate函数
/** * :: DeveloperApi :: * information calculation for multiclass classification * @param counts Array[Double] with counts for each label * @param totalCount sum of counts for all labels * @return information value, or 0 if totalCount = 0 */ @Since("1.1.0") @DeveloperApi override def calculate(counts: Array[Double], totalCount: Double): Double = { if (totalCount == 0) { return 0 } val numClasses = counts.length var impurity = 0.0 var classIndex = 0 while (classIndex < numClasses) { val classCount = counts(classIndex) if (classCount != 0) { val freq = classCount / totalCount impurity -= freq * log2(freq) } classIndex += 1 } impurity }熵的计算公式
H=E[−logpi]=−∑i=1n−pilogpi 因此这里的入参count是各class的出现的次数,先计算出现概率,然后取log累加。只有一个成员变量class的个数,关键是update函数
/** * Update stats for one (node, feature, bin) with the given label. * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous. * @param offset Start index of stats for this (node, feature, bin). */ def update(allStats: Array[Double], offset: Int, label: Double, instanceWeight: Double): Unit = { if (label >= statsSize) { throw new IllegalArgumentException(s"EntropyAggregator given label $label" + s" but requires label < numClasses (= $statsSize).") } if (label < 0) { throw new IllegalArgumentException(s"EntropyAggregator given label $label" + s"but requires label is non-negative.") } allStats(offset + label.toInt) += instanceWeight }offset是特征值偏移,加上label就是该class在allStats里的位置,累加出现的次数
/** * Get an [[ImpurityCalculator]] for a (node, feature, bin). * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous. * @param offset Start index of stats for this (node, feature, bin). */ def getCalculator(allStats: Array[Double], offset: Int): EntropyCalculator = { new EntropyCalculator(allStats.view(offset, offset + statsSize).toArray) }截取allStats中属于该特征的split的部分数组,长度是statSize,也就是class数
结合上面的函数可以看到,计算entropy的路径是调用Entropy的getCalculator函数,里面截取allStats中属于该split的部分,然后实际调用Entropy的calculate函数计算熵。 这里还重载了prob函数,主要是返回label的概率,例如0的统计有3个,1的统计7个,则label 0的概率就是0.3.
这里啰嗦下node分裂时需要怎样统计,这与DTStatsAggregator的设计是相关的。以使用信息熵为例,node分裂时,迭代每个特征的每个split,这个split会把样本集分成两部分,要计算entropy,需要分别统计左/右部分class的分布情况,然后计算概率,进而计算entropy,因此aggregator中statsSize等于numberclasses,同时allStats里记录了所有的统计值,实际这个统计值就是class的分布情况
class DTStatsAggregator( val metadata: DecisionTreeMetadata, featureSubset: Option[Array[Int]]) extends Serializable { /** * [[ImpurityAggregator]] instance specifying the impurity type. */ val impurityAggregator: ImpurityAggregator = metadata.impurity match { case Gini => new GiniAggregator(metadata.numClasses) case Entropy => new EntropyAggregator(metadata.numClasses) case Variance => new VarianceAggregator() case _ => throw new IllegalArgumentException(s"Bad impurity parameter: ${metadata.impurity}") } /** * Number of elements (Double values) used for the sufficient statistics of each bin. */ private val statsSize: Int = impurityAggregator.statsSize /** * Number of bins for each feature. This is indexed by the feature index. */ private val numBins: Array[Int] = { if (featureSubset.isDefined) { featureSubset.get.map(metadata.numBins(_)) } else { metadata.numBins } } /** * Offset for each feature for calculating indices into the [[allStats]] array. */ private val featureOffsets: Array[Int] = { numBins.scanLeft(0)((total, nBins) => total + statsSize * nBins) } /** * Total number of elements stored in this aggregator */ private val allStatsSize: Int = featureOffsets.last /** * Flat array of elements. * Index for start of stats for a (feature, bin) is: * index = featureOffsets(featureIndex) + binIndex * statsSize * Note: For unordered features, * the left child stats have binIndex in [0, numBins(featureIndex) / 2)) * and the right child stats in [numBins(featureIndex) / 2), numBins(featureIndex)) */ private val allStats: Array[Double] = new Array[Double](allStatsSize)每个node有一个DTStatsAggregator,构造函数接受2个参数,metadata和node使用的特征子集。其他的类成员 - impurityAggregator:目前支持Gini,Entropy和Variance,后面我们以Entropy为例,其他类似 - statsSize:每个bin需要的统计数,分类时等于numClasses,因为于每个class都需要单独统计;回归等于3,分别存着特征值个数,特征值sum,特征值平方和,为计算variance - numBins:node所用特征对应的numBins数组元素 - featureOffsets:计算特征在allStats中的index,与每个特征的bin个数和statsSize有关,例如我们有3个特征,其bins分别为3,2,2,statsSize为2,则第一个特征需要的bin的个数是3 * 2=6,2 * 2=4,2 * 2=4,则featureOffsets为0,6,10,14,是从左到右的累计值 - allStatsSize:需要的桶的个数 - allStats:存储统计值的桶 f0,f1,f2是3个特征,f0有3个特征值(其实是binIndex)0/1/2,f1有2个0/1,f2有2个0/1,每个特征值都有statsSize个状态桶,因此共14个,个数allStatsSize=14, 比如我们想在f1的v1的c1的index,就是从featureOffsets中取得f1的特征偏移量featureOffsets(1)=6,v1的binIndex相当于是1,statsSize是2,其label是1,则桶的index=6+1*2+1=9,恰好是图中f1v1的c1的桶的index
我们对其中的关键函数进行说明
/** * Update the stats for a given (feature, bin) for ordered features, using the given label. */ def update(featureIndex: Int, binIndex: Int, label: Double, instanceWeight: Double): Unit = { //第一部分是特征偏移 //binIndex相当于特征内特征值的偏移,每个特征有statsSize个桶,因此两者相加就是这个特征值对应的桶 //例如Entropy的update函数,里面再加上label.toInt就是这个label的桶 //从这里特征偏移的计算可以看出ordered特征其特征值最好是连续的,中间无间断,并且必须从0开始 //当然如果有间断,这里相当于浪费部分空间 val i = featureOffsets(featureIndex) + binIndex * statsSize impurityAggregator.update(allStats, i, label, instanceWeight) } /** * Get an [[ImpurityCalculator]] for a given (node, feature, bin). * @param featureOffset For ordered features, this is a pre-computed (node, feature) offset * from [[getFeatureOffset]]. * For unordered features, this is a pre-computed * (node, feature, left/right child) offset from * [[getLeftRightFeatureOffsets]]. */ def getImpurityCalculator(featureOffset: Int, binIndex: Int): ImpurityCalculator = { //偏移的计算同上,不过这里特征偏移是入参给出的,不需要再计算 impurityAggregator.getCalculator(allStats, featureOffset + binIndex * statsSize) }构造了numTrees个Node,赋默认值emptyNode,这些node将作为每棵树的root node,参与后面的训练。将这些node与treeIndex封装加入到队列nodeQueue中,后面会将所有待split的node都加入到这个队列中,依次split,直到所有node触发截止条件,也就是后面的while循环中队列为空了。
这部分逻辑在selectNodesToSplit中,主要是从nodeQueue中取出本轮需要分裂的node,并计算node的参数。
/** * Pull nodes off of the queue, and collect a group of nodes to be split on this iteration. * This tracks the memory usage for aggregates and stops adding nodes when too much memory * will be needed; this allows an adaptive number of nodes since different nodes may require * different amounts of memory (if featureSubsetStrategy is not "all"). * * @param nodeQueue Queue of nodes to split. * @param maxMemoryUsage Bound on size of aggregate statistics. * @return (nodesForGroup, treeToNodeToIndexInfo). * nodesForGroup holds the nodes to split: treeIndex --> nodes in tree. * * treeToNodeToIndexInfo holds indices selected features for each node: * treeIndex --> (global) node index --> (node index in group, feature indices). * The (global) node index is the index in the tree; the node index in group is the * index in [0, numNodesInGroup) of the node in this group. * The feature indices are None if not subsampling features. */ private[tree] def selectNodesToSplit( nodeQueue: mutable.Queue[(Int, Node)], maxMemoryUsage: Long, metadata: DecisionTreeMetadata, rng: scala.util.Random): (Map[Int, Array[Node]], Map[Int, Map[Int, NodeIndexInfo]]) = { // Collect some nodes to split: // nodesForGroup(treeIndex) = nodes to split val mutableNodesForGroup = new mutable.HashMap[Int, mutable.ArrayBuffer[Node]]() val mutableTreeToNodeToIndexInfo = new mutable.HashMap[Int, mutable.HashMap[Int, NodeIndexInfo]]() var memUsage: Long = 0L var numNodesInGroup = 0 while (nodeQueue.nonEmpty && memUsage < maxMemoryUsage) { val (treeIndex, node) = nodeQueue.head //用蓄水池抽样(之前的文章有介绍)对node使用的特征集抽样 // Choose subset of features for node (if subsampling). val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) { Some(SamplingUtils.reservoirSampleAndCount(Range(0, metadata.numFeatures).iterator, metadata.numFeaturesPerNode, rng.nextLong)._1) } else { None } // Check if enough memory remains to add this node to the group. val nodeMemUsage = RandomForest.aggregateSizeForNode(metadata, featureSubset) * 8L if (memUsage + nodeMemUsage <= maxMemoryUsage) { nodeQueue.dequeue() mutableNodesForGroup.getOrElseUpdate(treeIndex, new mutable.ArrayBuffer[Node]()) += node mutableTreeToNodeToIndexInfo .getOrElseUpdate(treeIndex, new mutable.HashMap[Int, NodeIndexInfo]())(node.id) = new NodeIndexInfo(numNodesInGroup, featureSubset) } numNodesInGroup += 1 memUsage += nodeMemUsage } // Convert mutable maps to immutable ones. val nodesForGroup: Map[Int, Array[Node]] = mutableNodesForGroup.mapValues(_.toArray).toMap val treeToNodeToIndexInfo = mutableTreeToNodeToIndexInfo.mapValues(_.toMap).toMap (nodesForGroup, treeToNodeToIndexInfo) }代码比较简单明确,受限于内存,将本次能够处理的node从nodeQueue中取出,放入nodesForGroup和treeToNodeToIndexInfo中。 是否对特征集进行抽样的条件是metadata的 numFeatures是否等于numFeaturesPerNode,这两个参数是metadata的入参,在buildMetadata时,根据featureSubsetStrateg确定,参见前文。 nodesForGroup是Map[Int, Array[Node]],其key是treeIndex,value是Node数组,其中放着该tree本次要分裂的node。 treeToNodeToIndexInfo的类型是Map[Int, Map[Int, NodeIndexInfo]],key为treeIndex,value中Map的key是node.id,这个id来自Node初始化时的第一个参数,第一轮时node的id都是1。其value为NodeIndexInfo结构,
class NodeIndexInfo( val nodeIndexInGroup: Int, val featureSubset: Option[Array[Int]])第一个成员是此node在本次node选择的while循环中的index,称为groupIndex,第二个成员是特征子集。