def run(data: RDD[LabeledPoint]): NaiveBayesModel = {
val requireNonnegativeValues: Vector => Unit = (v: Vector) => {
val values = v match {
case sv: SparseVector => sv.values
case dv: DenseVector => dv.values
}
if (!values.forall(_ >= 0.0)) {
throw new SparkException(s"Naive Bayes requires nonnegative feature values but found $v.")
}
}
val requireZeroOneBernoulliValues: Vector => Unit = (v: Vector) => {
val values = v match {
case sv: SparseVector => sv.values
case dv: DenseVector => dv.values
}
if (!values.forall(v => v == 0.0 || v == 1.0)) {
throw new SparkException(
s"Bernoulli naive Bayes requires 0 or 1 feature values but found $v.")
}
}
// Aggregates term frequencies per label.
// TODO: Calling combineByKey and collect creates two stages, we can implement something
// TODO: similar to reduceByKeyLocally to save one stage.
val aggregated = data.map(p => (p.label, p.features)).combineByKey[(Long, DenseVector)](
createCombiner = (v: Vector) => {
if (modelType == Bernoulli) {
requireZeroOneBernoulliValues(v)
} else {
requireNonnegativeValues(v)
}
(1L, v.copy.toDense)
},
mergeValue = (c: (Long, DenseVector), v: Vector) => {
requireNonnegativeValues(v)
BLAS.axpy(1.0, v, c._2)
(c._1 + 1L, c._2)
},
mergeCombiners = (c1: (Long, DenseVector), c2: (Long, DenseVector)) => {
BLAS.axpy(1.0, c2._2, c1._2)
(c1._1 + c2._1, c1._2)
}
).collect().sortBy(_._1)
val numLabels = aggregated.length
var numDocuments = 0L
aggregated.foreach { case (_, (n, _)) =>
numDocuments += n
}
val numFeatures = aggregated.head match { case (_, (_, v)) => v.size }
val labels = new Array[Double](numLabels)
val pi = new Array[Double](numLabels)
val theta = Array.fill(numLabels)(new Array[Double](numFeatures))
val piLogDenom = math.log(numDocuments + numLabels * lambda)
var i = 0
aggregated.foreach { case (label, (n, sumTermFreqs)) =>
labels(i) = label
pi(i) = math.log(n + lambda) - piLogDenom
val thetaLogDenom = modelType match {
case Multinomial => math.log(sumTermFreqs.values.sum + numFeatures * lambda)
case Bernoulli => math.log(n + 2.0 * lambda)
case _ =>
// This should never happen.
throw new UnknownError(s"Invalid modelType: $modelType.")
}
var j = 0
while (j < numFeatures) {
theta(i)(j) = math.log(sumTermFreqs(j) + lambda) - thetaLogDenom
j += 1
}
i += 1
}
new NaiveBayesModel(labels, pi, theta, modelType)
}
}