贝叶斯分类算法

xiaoxiao2021-02-28  110

def run(data: RDD[LabeledPoint]): NaiveBayesModel = { val requireNonnegativeValues: Vector => Unit = (v: Vector) => { val values = v match { case sv: SparseVector => sv.values case dv: DenseVector => dv.values } if (!values.forall(_ >= 0.0)) { throw new SparkException(s"Naive Bayes requires nonnegative feature values but found $v.") } } val requireZeroOneBernoulliValues: Vector => Unit = (v: Vector) => { val values = v match { case sv: SparseVector => sv.values case dv: DenseVector => dv.values } if (!values.forall(v => v == 0.0 || v == 1.0)) { throw new SparkException( s"Bernoulli naive Bayes requires 0 or 1 feature values but found $v.") } } // Aggregates term frequencies per label. // TODO: Calling combineByKey and collect creates two stages, we can implement something // TODO: similar to reduceByKeyLocally to save one stage. val aggregated = data.map(p => (p.label, p.features)).combineByKey[(Long, DenseVector)]( createCombiner = (v: Vector) => { if (modelType == Bernoulli) { requireZeroOneBernoulliValues(v) } else { requireNonnegativeValues(v) } (1L, v.copy.toDense) }, mergeValue = (c: (Long, DenseVector), v: Vector) => { requireNonnegativeValues(v) BLAS.axpy(1.0, v, c._2) (c._1 + 1L, c._2) }, mergeCombiners = (c1: (Long, DenseVector), c2: (Long, DenseVector)) => { BLAS.axpy(1.0, c2._2, c1._2) (c1._1 + c2._1, c1._2) } ).collect().sortBy(_._1) val numLabels = aggregated.length var numDocuments = 0L aggregated.foreach { case (_, (n, _)) => numDocuments += n } val numFeatures = aggregated.head match { case (_, (_, v)) => v.size } val labels = new Array[Double](numLabels) val pi = new Array[Double](numLabels) val theta = Array.fill(numLabels)(new Array[Double](numFeatures)) val piLogDenom = math.log(numDocuments + numLabels * lambda) var i = 0 aggregated.foreach { case (label, (n, sumTermFreqs)) => labels(i) = label pi(i) = math.log(n + lambda) - piLogDenom val thetaLogDenom = modelType match { case Multinomial => math.log(sumTermFreqs.values.sum + numFeatures * lambda) case Bernoulli => math.log(n + 2.0 * lambda) case _ => // This should never happen. throw new UnknownError(s"Invalid modelType: $modelType.") } var j = 0 while (j < numFeatures) { theta(i)(j) = math.log(sumTermFreqs(j) + lambda) - thetaLogDenom j += 1 } i += 1 } new NaiveBayesModel(labels, pi, theta, modelType) } }
转载请注明原文地址: https://www.6miu.com/read-52052.html

最新回复(0)