diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index 3073a2a61ce83..1bd86fdb28f49 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -21,6 +21,7 @@ import com.github.fommil.netlib.BLAS.{getInstance => blas} import org.apache.spark.Logging import org.apache.spark.annotation.Experimental +import org.apache.spark.broadcast.Broadcast import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.{Param, ParamMap} import org.apache.spark.ml.regression.DecisionTreeRegressionModel @@ -175,14 +176,20 @@ final class GBTClassificationModel( require(_trees.length == _treeWeights.length, "GBTClassificationModel given trees, treeWeights" + s" of non-matching lengths (${_trees.length}, ${_treeWeights.length}, respectively).") + private var bcastModel: Option[Broadcast[GBTClassificationModel]] = None + override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]] override def treeWeights: Array[Double] = _treeWeights override protected def transformImpl(dataset: DataFrame): DataFrame = { - val bcastModel = dataset.sqlContext.sparkContext.broadcast(this) + bcastModel match { + case None => bcastModel = Some(dataset.sqlContext.sparkContext.broadcast(this)) + case _ => + } + val lclBcastModel = bcastModel val predictUDF = udf { (features: Any) => - bcastModel.value.predict(features.asInstanceOf[Vector]) + lclBcastModel.get.value.predict(features.asInstanceOf[Vector]) } dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index 11a6d72468333..3e3d8378bfe7c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -18,6 +18,7 @@ package org.apache.spark.ml.classification import org.apache.spark.annotation.Experimental +import org.apache.spark.broadcast.Broadcast import org.apache.spark.ml.tree.impl.RandomForest import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.tree.{DecisionTreeModel, RandomForestParams, TreeClassifierParams, TreeEnsembleModel} @@ -132,6 +133,8 @@ final class RandomForestClassificationModel private[ml] ( require(numTrees > 0, "RandomForestClassificationModel requires at least 1 tree.") + private var bcastModel: Option[Broadcast[RandomForestClassificationModel]] = None + /** * Construct a random forest classification model, with all trees weighted equally. * @param trees Component trees @@ -150,9 +153,13 @@ final class RandomForestClassificationModel private[ml] ( override def treeWeights: Array[Double] = _treeWeights override protected def transformImpl(dataset: DataFrame): DataFrame = { - val bcastModel = dataset.sqlContext.sparkContext.broadcast(this) + bcastModel match { + case None => bcastModel = Some(dataset.sqlContext.sparkContext.broadcast(this)) + case _ => + } + val lclBcastModel = bcastModel val predictUDF = udf { (features: Any) => - bcastModel.value.predict(features.asInstanceOf[Vector]) + lclBcastModel.get.value.predict(features.asInstanceOf[Vector]) } dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index b66e61f37dd5e..948dc9b39519c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -21,6 +21,7 @@ import com.github.fommil.netlib.BLAS.{getInstance => blas} import org.apache.spark.Logging import org.apache.spark.annotation.Experimental +import org.apache.spark.broadcast.Broadcast import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.{Param, ParamMap} import org.apache.spark.ml.tree.{DecisionTreeModel, GBTParams, TreeEnsembleModel, TreeRegressorParams} @@ -165,14 +166,20 @@ final class GBTRegressionModel( require(_trees.length == _treeWeights.length, "GBTRegressionModel given trees, treeWeights of" + s" non-matching lengths (${_trees.length}, ${_treeWeights.length}, respectively).") + private var bcastModel: Option[Broadcast[GBTRegressionModel]] = None + override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]] override def treeWeights: Array[Double] = _treeWeights override protected def transformImpl(dataset: DataFrame): DataFrame = { - val bcastModel = dataset.sqlContext.sparkContext.broadcast(this) + bcastModel match { + case None => bcastModel = Some(dataset.sqlContext.sparkContext.broadcast(this)) + case _ => + } + val lclBcastModel = bcastModel val predictUDF = udf { (features: Any) => - bcastModel.value.predict(features.asInstanceOf[Vector]) + lclBcastModel.get.value.predict(features.asInstanceOf[Vector]) } dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index 2f36da371f577..fb6235526db7d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -18,6 +18,7 @@ package org.apache.spark.ml.regression import org.apache.spark.annotation.Experimental +import org.apache.spark.broadcast.Broadcast import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.tree.{DecisionTreeModel, RandomForestParams, TreeEnsembleModel, TreeRegressorParams} @@ -121,6 +122,8 @@ final class RandomForestRegressionModel private[ml] ( require(numTrees > 0, "RandomForestRegressionModel requires at least 1 tree.") + private var bcastModel: Option[Broadcast[RandomForestRegressionModel]] = None + /** * Construct a random forest regression model, with all trees weighted equally. * @param trees Component trees @@ -136,9 +139,13 @@ final class RandomForestRegressionModel private[ml] ( override def treeWeights: Array[Double] = _treeWeights override protected def transformImpl(dataset: DataFrame): DataFrame = { - val bcastModel = dataset.sqlContext.sparkContext.broadcast(this) + bcastModel match { + case None => bcastModel = Some(dataset.sqlContext.sparkContext.broadcast(this)) + case _ => + } + val lclBcastModel = bcastModel val predictUDF = udf { (features: Any) => - bcastModel.value.predict(features.asInstanceOf[Vector]) + lclBcastModel.get.value.predict(features.asInstanceOf[Vector]) } dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) }