diff --git a/mllib/src/main/scala/org/apache/spark/ml/Model.scala b/mllib/src/main/scala/org/apache/spark/ml/Model.scala index 7fd515369b19b..70e7495ac616c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Model.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Model.scala @@ -32,7 +32,7 @@ abstract class Model[M <: Model[M]] extends Transformer { * The parent estimator that produced this model. * Note: For ensembles' component Models, this value can be null. */ - var parent: Estimator[M] = _ + @transient var parent: Estimator[M] = _ /** * Sets the parent of this model (Java API). @@ -42,6 +42,9 @@ abstract class Model[M <: Model[M]] extends Transformer { this.asInstanceOf[M] } + /** Indicates whether this [[Model]] has a corresponding parent. */ + def hasParent: Boolean = parent != null + override def copy(extra: ParamMap): M = { // The default implementation of Params.copy doesn't work for models. throw new NotImplementedError(s"${this.getClass} doesn't implement copy(extra: ParamMap)") diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 43765241a20b6..97f9749cb4a9a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -83,6 +83,7 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext { assert(model.getRawPredictionCol === "rawPrediction") assert(model.getProbabilityCol === "probability") assert(model.intercept !== 0.0) + assert(model.hasParent) } test("logistic regression doesn't fit intercept when fitIntercept is off") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index 08f86fa45bc1d..cdbbacab8e0e3 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -162,5 +162,7 @@ private object RandomForestClassifierSuite { val oldModelAsNew = RandomForestClassificationModel.fromOld( oldModel, newModel.parent.asInstanceOf[RandomForestClassifier], categoricalFeatures) TreeTests.checkEqual(oldModelAsNew, newModel) + assert(newModel.hasParent) + assert(!newModel.trees.head.asInstanceOf[DecisionTreeClassificationModel].hasParent) } }