spark mllib源码分析之L-BFGS(二)

xiaoxiao2021-02-27  180

相关文章 spark源码分析之L-BFGS(一) 线搜索 spark正则化 spark mllib源码分析之OWLQN 其他源码分析文章 spark源码分析之DecisionTree与GBDT spark源码分析之随机森林(Random Forest)

4.4. optimize

我们的optimizer使用的是LBFGS,其optimize函数

override def optimize(data: RDD[(Double, Vector)], initialWeights: Vector): Vector = { val (weights, _) = LBFGS.runLBFGS( data, gradient, //LogisticGradient updater, //SquaredL2Updater numCorrections, //default 10 convergenceTol, //default 1E-6 maxNumIterations, //default 100 regParam, //0.0 initialWeights) weights }

其默认参数都封装在mllib的LBFGS中,实际的训练过程在object LBFGS的runLBFGS函数中

4.4.1. 训练使用的数据结构

4.4.1.1. 损失函数

首先将loss和gradient的计算封装成CostFun类,方便在LBFGS迭代过程中计算

/** * CostFun implements Breeze's DiffFunction[T], which returns the loss and gradient * at a particular point (weights). It's used in Breeze's convex optimization routines. */ private class CostFun( data: RDD[(Double, Vector)], gradient: Gradient, updater: Updater, regParam: Double, numExamples: Long) extends DiffFunction[BDV[Double]] { override def calculate(weights: BDV[Double]): (Double, BDV[Double]) = { // Have a local copy to avoid the serialization of CostFun object which is not serializable. val w = Vectors.fromBreeze(weights) val n = w.size val bcW = data.context.broadcast(w) val localGradient = gradient val (gradientSum, lossSum) = data.treeAggregate((Vectors.zeros(n), 0.0))( //executor ops,计算每个partition上的grad和loss,具体参见treeAggregate的用法 seqOp = (c, v) => (c, v) match { case ((grad, loss), (label, features)) => val l = localGradient.compute( features, label, bcW.value, grad) (grad, loss + l) }, //driver ops,计算所有分区返回结果 combOp = (c1, c2) => (c1, c2) match { case ((grad1, loss1), (grad2, loss2)) => axpy(1.0, grad2, grad1) (grad1, loss1 + loss2) }) /** * regVal is sum of weight squares if it's L2 updater; * for other updater, the same logic is followed. */ //计算loss val regVal = updater.compute(w, Vectors.zeros(n), 0, 1, regParam)._2 val loss = lossSum / numExamples + regVal /** * It will return the gradient part of regularization using updater. * * Given the input parameters, the updater basically does the following, * * w' = w - thisIterStepSize * (gradient + regGradient(w)) * Note that regGradient is function of w * * If we set gradient = 0, thisIterStepSize = 1, then * * regGradient(w) = w - w' * * TODO: We need to clean it up by separating the logic of regularization out * from updater to regularizer. */ // The following gradientTotal is actually the regularization part of gradient. // Will add the gradientSum computed from the data with weights in the next step. //计算gradient val gradientTotal = w.copy axpy(-1.0, updater.compute(w, Vectors.zeros(n), 1, 1, regParam)._1, gradientTotal) // gradientTotal = gradientSum / numExamples + gradientTotal axpy(1.0 / numExamples, gradientSum, gradientTotal) (loss, gradientTotal.asBreeze.asInstanceOf[BDV[Double]]) } }

4.4.1.2. State

对迭代过程中的参数进行简单封装,放在State中

/** * Tracks the information about the optimizer, including the current point, its value, gradient, and then any history. * Also includes information for checking convergence. * @param x the current point being considered * @param value f(x) * @param grad f.gradientAt(x) * @param adjustedValue f(x) + r(x), where r is any regularization added to the objective. For LBFGS, this is f(x). * @param adjustedGradient f'(x) + r'(x), where r is any regularization added to the objective. For LBFGS, this is f'(x). * @param iter what iteration number we are on. * @param initialAdjVal f(x_0) + r(x_0), used for checking convergence * @param history any information needed by the optimizer to do updates. * @param fVals the sequence of the last minImprovementWindow values, used for checking if the "value" isn't improving * @param numImprovementFailures the number of times in a row the objective hasn't improved, mostly for SGD * @param searchFailed did the line search fail? */ case class State(x: T, value: Double, grad: T, adjustedValue: Double, adjustedGradient: T, iter: Int, initialAdjVal: Double, history: History, fVals: IndexedSeq[Double] = Vector(Double.PositiveInfinity), numImprovementFailures: Int = 0, searchFailed: Boolean = false)

这里的x就是weight,value对应loss,grad对应梯度,history是海森矩阵。

4.4.1.3. ApproximateInverseHessian

默认m=10,使用近10次近似计算,建议3到7;memStep和memGradDelta都是空

case class ApproximateInverseHessian[T](m: Int, private[LBFGS] val memStep: IndexedSeq[T] = IndexedSeq.empty, private[LBFGS] val memGradDelta: IndexedSeq[T] = IndexedSeq.empty) (implicit space: MutableInnerProductModule[T, Double])

L-BFGS计算迭代方向主要的实现是定义了*算子,之前介绍过,但当时理解错了,这里重新介绍下

def *(grad: T) = { //计算D0 val diag = if(historyLength > 0) { val prevStep = memStep.head val prevGradStep = memGradDelta.head val sy = prevStep dot prevGradStep val yy = prevGradStep dot prevGradStep if(sy < 0 || sy.isNaN) throw new NaNHistory sy/yy } else { 1.0 } val dir = space.copy(grad) val as = new Array[Double](m) val rho = new Array[Double](m) for(i <- 0 until historyLength) { rho(i) = (memStep(i) dot memGradDelta(i)) as(i) = (memStep(i) dot dir)/rho(i) if(as(i).isNaN) { throw new NaNHistory } axpy(-as(i), memGradDelta(i), dir) } dir *= diag for(i <- (historyLength - 1) to 0 by (-1)) { val beta = (memGradDelta(i) dot dir)/rho(i) axpy(as(i) - beta, memStep(i), dir) } dir *= -1.0 dir } }

这里memStep对应si,memGradDelta对应yi,diag是每轮的初始值,算法介绍中有其计算方式,as是alpha,算法中的rho与这里是倒数关系。dir是要返回结果变量,在后向循环中可以认为是q,在前向循环中是r。注意到这里第一轮for训练是从0到historyLength,第二轮是从historyLength到0,与算法的次序正好相反,这是因为在memStep和memGradDelta中,最新的值是存在最前面的(insert,s(k), s(k-1), …, s(0),在update函数中可以看到),在算法中最新的值是往后放的(append, s(0), s(1), …, s(k))。算法第一轮for的使用次序应该是从sk到s0,对应到这里就应该是s0到sk,因此index的次序是反的。 矩阵的更新

def updated(step: T, gradDelta: T) = { val memStep = (step +: this.memStep) take m val memGradDelta = (gradDelta +: this.memGradDelta) take m new ApproximateInverseHessian(m, memStep,memGradDelta) }

可以看到是插入之后取前m个

4.4.2. 训练

4.4.2.1. adjustFunction

入参是CostFun,这里返回CachedDiffFunction

/** Calculates both the value and the gradient at a point */ def calculate(x:T):(Double,T) = { var ld = lastData if (ld == null || x != ld._1) { val newData = obj.calculate(x) ld = (copy(x), newData._1, newData._2) lastData = ld } val (_, v, g) = ld v -> g }

其实就是记住上次的结果,如果本次x与上次相同,就可以直接返回结果

4.4.2.2. initialState

初始化State

protected def initialState(f: DF, init: T) = { //x是初始化的weight val x = init //LBFGS.ApproximateInverseHessian val history = initialHistory(f,init) //用初始weight调用损失函数计算loss和gradient val (value, grad) = calculateObjective(f, x, history) //adjust函数这里直接返回了loss和gradient了 val (adjValue,adjGrad) = adjust(x,grad,value) //返回第一次计算的状态 State(x,value,grad,adjValue,adjGrad,0,adjValue,history) }

4.4.2.3. iterations

4.4.2.3.1. chooseDescentDirection

计算下降方向,实际调用的是ApproximateInverseHessian的*算子,前面有介绍

protected def chooseDescentDirection(state: State, fn: DiffFunction[T]):T = { state.history * state.grad }
4.4.2.3.2. determineStepSize

使用线搜索方法确定最优步长,之前的文章有介绍

protected def determineStepSize(state: State, f: DiffFunction[T], dir: T) = { val x = state.x val grad = state.grad //偏函数,先把x和dir放进去,后面调用的时候只需要传入alpha,就可以计算f(x+d*alpha) val ff = LineSearch.functionFromSearchDirection(f, x, dir) //使用强Wolfe线搜索,在之前的文章有介绍 val search = new StrongWolfeLineSearch(maxZoomIter = 10, maxLineSearchIter = 10) // TODO: Need good default values here. val alpha = search.minimize(ff, if(state.iter == 0.0) 1.0/norm(dir) else 1.0) if(alpha * norm(grad) < 1E-10) throw new StepSizeUnderflow alpha }
4.4.2.3.3. 调整

根据优化方向和步长计算weight

val x = takeStep(state,dir,stepSize) protected def takeStep(state: State, dir: T, stepSize: Double) = state.x + dir * stepSize

根据新得到的weight,用损失函数计算loss及梯度

val (value,grad) = calculateObjective(adjustedFun, x, state.history) protected def calculateObjective(f: DF, x: T, history: History): (Double, T) = { f.calculate(x) }

adjust函数直接返回新得到的loss和gradient,adjValue等于上面的loss,adjGrad等于gradient

val (adjValue,adjGrad) = adjust(x,grad,value) def adjust(newX: T, newGrad: T, newVal: Double):(Double,T) = (newVal,newGrad)

根据上面计算得到的loss和gradient,与上一轮比较,计算相对的improvement

val oneOffImprovement = (state.adjustedValue - adjValue)/ (state.adjustedValue.abs max adjValue.abs max 1E-6 * state.initialAdjVal.abs)
4.4.2.3.4. 更新

更新海森矩阵

protected def updateHistory(newX: T, newGrad: T, newVal: Double, f: DiffFunction[T], oldState: State): History = { //(si,yi) oldState.history.updated(newX - oldState.x, newGrad :- oldState.grad) }

在state中记录最近的loss,次数由minImprovementWindow决定

val newAverage = updateFValWindow(state, adjValue) protected def updateFValWindow(oldState: State, newAdjVal: Double):IndexedSeq[Double] = { val interm = oldState.fVals :+ newAdjVal if(interm.length > minImprovementWindow) interm.drop(1) else interm }

构造新的state

var s = State(x,value,grad,adjValue,adjGrad,state.iter + 1, state.initialAdjVal, history, newAverage, 0)

x是新一轮计算的weight,value是新的loss,grad是新梯度,adjValue在这里一直是等于value,adjGrad也是一直等于grad;算法的迭代次数加1,这里相当于算法中的k;state.initialAdjVal一直是0(来自initialState);history是新计算得到海森矩阵;newAverage记录了最近minImprovementWindow(默认为0)次的loss;numImprovementFailures此处设为0 计算本轮迭代是否有改善

val improvementFailure = (state.fVals.length >= minImprovementWindow && state.fVals.nonEmpty && state.fVals.last > state.fVals.head * (1-improvementTol)) if(improvementFailure) s = s.copy(fVals = IndexedSeq.empty, numImprovementFailures = state.numImprovementFailures + 1)

改善失败的条件是至少迭代了minImprovementWindow轮,并且本轮相对初轮提升小于improvementTol。如果失败,将之前记录的loss清空,state.numImprovementFailures加1,上面的新state的构造时,numImprovementFailures是设为0的,因此这里应该是连续改善失败的累加,一旦成功improve,会被清0

4.4.2.3.4. 截止

这里是无限代跌的,直到有异常抛出,第一次异常会重置海森矩阵,第二次异常才会退出

} catch { case x: FirstOrderException if !failedOnce => failedOnce = true logger.error("Failure! Resetting history: " + x) state.copy(history = initialHistory(adjustedFun, state.x)) case x: FirstOrderException => logger.error("Failure again! Giving up and returning. Maybe the objective is just poorly behaved?") state.copy(searchFailed = true) }

根据截止状态,返回截止原因

def convergedReason:Option[ConvergenceReason] = { if (iter >= maxIter && maxIter >= 0) Some(FirstOrderMinimizer.MaxIterations) else if (!fVals.isEmpty && (adjustedValue - fVals.max).abs <= tolerance * initialAdjVal) Some(FirstOrderMinimizer.FunctionValuesConverged) else if (numImprovementFailures >= numberOfImprovementFailures) Some(FirstOrderMinimizer.ObjectiveNotImproving) else if (norm(adjustedGradient) <= math.max(tolerance * adjustedValue.abs, 1E-8)) Some(FirstOrderMinimizer.GradientConverged) else if (searchFailed) Some(FirstOrderMinimizer.SearchFailed) else None }

4.4.2. 结果返回

判断是否真正收敛,返回weight和loss

var state = states.next() while (states.hasNext) { lossHistory += state.value state = states.next() } lossHistory += state.value //收敛原因为FunctionValuesConverged或GradientConverged if (!state.actuallyConverged) { logWarning("LBFGS training finished but the result " + s"is not converged because: ${state.convergedReason.get.reason}") } val weights = Vectors.fromBreeze(state.x) val lossHistoryArray = lossHistory.result() logInfo("LBFGS.runLBFGS finished. Last 10 losses %s".format( lossHistoryArray.takeRight(10).mkString(", "))) (weights, lossHistoryArray)

4.5. 收尾

根据是否有截距,获取截距和真正的weight;如果输入进行了scaling,训练特征值与weight计算 wifi/stdi ,训练时是对特征scaling,返回时(预测时),特征值不需要变换的情况下,相当于 (wi/stdi)fi

转载请注明原文地址: https://www.6miu.com/read-17172.html

最新回复(0)