diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala index 6709bd79bc820..df088255f1bae 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala @@ -18,6 +18,7 @@ package org.apache.spark.mllib.regression import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.broadcast.Broadcast import org.apache.spark.mllib.feature.StandardScaler import org.apache.spark.{Logging, SparkException} import org.apache.spark.rdd.RDD @@ -39,6 +40,8 @@ import org.apache.spark.storage.StorageLevel abstract class GeneralizedLinearModel(val weights: Vector, val intercept: Double) extends Serializable { + private var bcWeights: Option[Broadcast[Vector]] = None + /** * Predict the result given a data point and the weights learned. * @@ -57,11 +60,17 @@ abstract class GeneralizedLinearModel(val weights: Vector, val intercept: Double def predict(testData: RDD[Vector]): RDD[Double] = { // A small optimization to avoid serializing the entire model. Only the weightsMatrix // and intercept is needed. - val localWeights = weights - val bcWeights = testData.context.broadcast(localWeights) + bcWeights match { + case None => { + val localWeights = weights + bcWeights = Some(testData.context.broadcast(localWeights)) + } + case _ => + } + val localBcWeights = bcWeights val localIntercept = intercept testData.mapPartitions { iter => - val w = bcWeights.value + val w = localBcWeights.get.value iter.map(v => predictPoint(v, w, localIntercept)) } }