Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion mllib/src/main/scala/org/apache/spark/ml/Model.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ abstract class Model[M <: Model[M]] extends Transformer {
* The parent estimator that produced this model.
* Note: For ensembles' component Models, this value can be null.
*/
var parent: Estimator[M] = _
@transient var parent: Estimator[M] = _

/**
* Sets the parent of this model (Java API).
Expand All @@ -42,6 +42,9 @@ abstract class Model[M <: Model[M]] extends Transformer {
this.asInstanceOf[M]
}

/** Indicates whether this [[Model]] has a corresponding parent. */
def hasParent: Boolean = parent != null

override def copy(extra: ParamMap): M = {
// The default implementation of Params.copy doesn't work for models.
throw new NotImplementedError(s"${this.getClass} doesn't implement copy(extra: ParamMap)")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext {
assert(model.getRawPredictionCol === "rawPrediction")
assert(model.getProbabilityCol === "probability")
assert(model.intercept !== 0.0)
assert(model.hasParent)
}

test("logistic regression doesn't fit intercept when fitIntercept is off") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,5 +162,7 @@ private object RandomForestClassifierSuite {
val oldModelAsNew = RandomForestClassificationModel.fromOld(
oldModel, newModel.parent.asInstanceOf[RandomForestClassifier], categoricalFeatures)
TreeTests.checkEqual(oldModelAsNew, newModel)
assert(newModel.hasParent)
assert(!newModel.trees.head.asInstanceOf[DecisionTreeClassificationModel].hasParent)
}
}