From 9d2a64be0676718259fedd8d9090717fb2432457 Mon Sep 17 00:00:00 2001 From: sethah Date: Fri, 4 Nov 2016 13:40:58 -0700 Subject: [PATCH 1/3] add tests for copy summary --- .../apache/spark/ml/clustering/BisectingKMeans.scala | 5 +++-- .../apache/spark/ml/clustering/GaussianMixture.scala | 12 ++++++++++-- .../org/apache/spark/ml/clustering/KMeans.scala | 5 +++-- .../ml/regression/GeneralizedLinearRegression.scala | 6 ++++-- .../ml/classification/LogisticRegressionSuite.scala | 6 +++++- .../spark/ml/clustering/BisectingKMeansSuite.scala | 9 ++++++++- .../spark/ml/clustering/GaussianMixtureSuite.scala | 12 ++++++++++-- .../org/apache/spark/ml/clustering/KMeansSuite.scala | 9 ++++++++- .../GeneralizedLinearRegressionSuite.scala | 5 ++++- .../spark/ml/regression/LinearRegressionSuite.scala | 6 +++++- 10 files changed, 60 insertions(+), 15 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala index 2718dd93dcb5..f8a606d60b2a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala @@ -94,8 +94,9 @@ class BisectingKMeansModel private[ml] ( @Since("2.0.0") override def copy(extra: ParamMap): BisectingKMeansModel = { - val copied = new BisectingKMeansModel(uid, parentModel) - copyValues(copied, extra) + val copied = copyValues(new BisectingKMeansModel(uid, parentModel), extra) + if (trainingSummary.isDefined) copied.setSummary(trainingSummary.get) + copied.setParent(this.parent) } @Since("2.0.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala index 8fac63fefbb5..417f78dc4128 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -89,8 +89,9 @@ class GaussianMixtureModel private[ml] ( @Since("2.0.0") override def copy(extra: ParamMap): GaussianMixtureModel = { - val copied = new GaussianMixtureModel(uid, weights, gaussians) - copyValues(copied, extra).setParent(this.parent) + val copied = copyValues(new GaussianMixtureModel(uid, weights, gaussians), extra) + if (trainingSummary.isDefined) copied.setSummary(trainingSummary.get) + copied.setParent(this.parent) } @Since("2.0.0") @@ -169,6 +170,13 @@ class GaussianMixtureModel private[ml] ( throw new RuntimeException( s"No training summary available for the ${this.getClass.getSimpleName}") } + +// @Since("2.1.0") +// override def copy(extra: ParamMap): GaussianMixtureModel = { +// val newModel = copyValues(new GaussianMixtureModel(uid, weights, gaussians), extra) +// if (trainingSummary.isDefined) newModel.setSummary(trainingSummary.get) +// newModel.setParent(parent) +// } } @Since("2.0.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 85bb8c93b3fa..a0d481b294ac 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -108,8 +108,9 @@ class KMeansModel private[ml] ( @Since("1.5.0") override def copy(extra: ParamMap): KMeansModel = { - val copied = new KMeansModel(uid, parentModel) - copyValues(copied, extra) + val copied = copyValues(new KMeansModel(uid, parentModel), extra) + if (trainingSummary.isDefined) copied.setSummary(trainingSummary.get) + copied.setParent(this.parent) } /** @group setParam */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index 8656ecf609ea..1938e8ecc513 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -776,8 +776,10 @@ class GeneralizedLinearRegressionModel private[ml] ( @Since("2.0.0") override def copy(extra: ParamMap): GeneralizedLinearRegressionModel = { - copyValues(new GeneralizedLinearRegressionModel(uid, coefficients, intercept), extra) - .setParent(parent) + val copied = copyValues(new GeneralizedLinearRegressionModel(uid, coefficients, intercept), + extra) + if (trainingSummary.isDefined) copied.setSummary(trainingSummary.get) + copied.setParent(parent) } /** diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 8771fd2e9d2b..fa5cf2cec5ee 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -27,7 +27,7 @@ import org.apache.spark.ml.attribute.NominalAttribute import org.apache.spark.ml.classification.LogisticRegressionSuite._ import org.apache.spark.ml.feature.{Instance, LabeledPoint} import org.apache.spark.ml.linalg.{DenseMatrix, Matrices, SparseMatrix, SparseVector, Vector, Vectors} -import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.param.{ParamMap, ParamsSuite} import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext @@ -141,6 +141,10 @@ class LogisticRegressionSuite assert(model.getProbabilityCol === "probability") assert(model.intercept !== 0.0) assert(model.hasParent) + + assert(model.hasSummary) + val copiedModel = model.copy(ParamMap.empty) + assert(copiedModel.hasSummary) } test("empty probabilityCol") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala index f2368a9f8dad..e3e8774fffed 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala @@ -18,7 +18,8 @@ package org.apache.spark.ml.clustering import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.util.{MLTestingUtils, DefaultReadWriteTest} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Dataset @@ -41,6 +42,12 @@ class BisectingKMeansSuite assert(bkm.getPredictionCol === "prediction") assert(bkm.getMaxIter === 20) assert(bkm.getMinDivisibleClusterSize === 1.0) + + val model = bkm.setMaxIter(1).fit(dataset) + assert(model.hasSummary) + val copiedModel = model.copy(ParamMap.empty) + assert(copiedModel.hasSummary) + MLTestingUtils.checkCopy(model) } test("setter/getter") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala index 003fa6abf659..4933fb74f2cc 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala @@ -18,8 +18,10 @@ package org.apache.spark.ml.clustering import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.util.DefaultReadWriteTest -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.ml.linalg.Vectors +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.util.{MLTestingUtils, DefaultReadWriteTest} +import org.apache.spark.mllib.util.{MLUtils, MLlibTestSparkContext} import org.apache.spark.sql.Dataset @@ -43,6 +45,12 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext assert(gm.getPredictionCol === "prediction") assert(gm.getMaxIter === 100) assert(gm.getTol === 0.01) + + val model = gm.setMaxIter(1).fit(dataset) + assert(model.hasSummary) + val copiedModel = model.copy(ParamMap.empty) + assert(copiedModel.hasSummary) + MLTestingUtils.checkCopy(model) } test("set parameters") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index ca392653557c..fb11ebe3184a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -19,7 +19,8 @@ package org.apache.spark.ml.clustering import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} -import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.util.{MLTestingUtils, DefaultReadWriteTest} import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} @@ -47,6 +48,12 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR assert(kmeans.getInitMode === MLlibKMeans.K_MEANS_PARALLEL) assert(kmeans.getInitSteps === 2) assert(kmeans.getTol === 1e-4) + + val model = kmeans.setMaxIter(1).fit(dataset) + assert(model.hasSummary) + val copiedModel = model.copy(ParamMap.empty) + assert(copiedModel.hasSummary) + MLTestingUtils.checkCopy(model) } test("set parameters") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala index ac1ef5feb95b..111bc974642d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.ml.classification.LogisticRegressionSuite._ import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.{BLAS, DenseVector, Vector, Vectors} -import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.param.{ParamMap, ParamsSuite} import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.random._ @@ -183,6 +183,9 @@ class GeneralizedLinearRegressionSuite // copied model must have the same parent. MLTestingUtils.checkCopy(model) + assert(model.hasSummary) + val copiedModel = model.copy(ParamMap.empty) + assert(copiedModel.hasSummary) assert(model.getFeaturesCol === "features") assert(model.getPredictionCol === "prediction") diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index c0e8afbf5e34..b3ded8784c6b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.{DenseVector, Vector, Vectors} -import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.param.{ParamMap, ParamsSuite} import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} @@ -140,9 +140,13 @@ class LinearRegressionSuite assert(lir.getStandardization) assert(lir.getSolver == "auto") val model = lir.fit(datasetWithDenseFeature) + assert(model.hasSummary) // copied model must have the same parent. MLTestingUtils.checkCopy(model) + val copiedModel = model.copy(ParamMap.empty) + assert(copiedModel.hasSummary) + model.transform(datasetWithDenseFeature) .select("label", "prediction") From 6bb2871be0d866252ec90a802a6e3043a503990a Mon Sep 17 00:00:00 2001 From: sethah Date: Fri, 4 Nov 2016 13:53:32 -0700 Subject: [PATCH 2/3] cleanups --- .../apache/spark/ml/clustering/GaussianMixture.scala | 7 ------- .../ml/classification/LogisticRegressionSuite.scala | 5 ++--- .../spark/ml/clustering/BisectingKMeansSuite.scala | 7 ++++--- .../spark/ml/clustering/GaussianMixtureSuite.scala | 10 +++++----- .../org/apache/spark/ml/clustering/KMeansSuite.scala | 7 ++++--- .../spark/ml/regression/LinearRegressionSuite.scala | 3 +-- 6 files changed, 16 insertions(+), 23 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala index 417f78dc4128..a0bd66e731a1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -170,13 +170,6 @@ class GaussianMixtureModel private[ml] ( throw new RuntimeException( s"No training summary available for the ${this.getClass.getSimpleName}") } - -// @Since("2.1.0") -// override def copy(extra: ParamMap): GaussianMixtureModel = { -// val newModel = copyValues(new GaussianMixtureModel(uid, weights, gaussians), extra) -// if (trainingSummary.isDefined) newModel.setSummary(trainingSummary.get) -// newModel.setParent(parent) -// } } @Since("2.0.0") diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index fa5cf2cec5ee..2877285eb4d5 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -142,6 +142,8 @@ class LogisticRegressionSuite assert(model.intercept !== 0.0) assert(model.hasParent) + // copied model must have the same parent. + MLTestingUtils.checkCopy(model) assert(model.hasSummary) val copiedModel = model.copy(ParamMap.empty) assert(copiedModel.hasSummary) @@ -255,9 +257,6 @@ class LogisticRegressionSuite mlr.setFitIntercept(false) val mlrModel = mlr.fit(smallMultinomialDataset) assert(mlrModel.interceptVector === Vectors.sparse(3, Seq())) - - // copied model must have the same parent. - MLTestingUtils.checkCopy(model) } test("logistic regression with setters") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala index e3e8774fffed..49797d938d75 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.ml.clustering import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamMap -import org.apache.spark.ml.util.{MLTestingUtils, DefaultReadWriteTest} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Dataset @@ -42,12 +42,13 @@ class BisectingKMeansSuite assert(bkm.getPredictionCol === "prediction") assert(bkm.getMaxIter === 20) assert(bkm.getMinDivisibleClusterSize === 1.0) - val model = bkm.setMaxIter(1).fit(dataset) + + // copied model must have the same parent + MLTestingUtils.checkCopy(model) assert(model.hasSummary) val copiedModel = model.copy(ParamMap.empty) assert(copiedModel.hasSummary) - MLTestingUtils.checkCopy(model) } test("setter/getter") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala index 4933fb74f2cc..7165b63ed3b9 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala @@ -18,10 +18,9 @@ package org.apache.spark.ml.clustering import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.param.ParamMap -import org.apache.spark.ml.util.{MLTestingUtils, DefaultReadWriteTest} -import org.apache.spark.mllib.util.{MLUtils, MLlibTestSparkContext} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Dataset @@ -45,12 +44,13 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext assert(gm.getPredictionCol === "prediction") assert(gm.getMaxIter === 100) assert(gm.getTol === 0.01) - val model = gm.setMaxIter(1).fit(dataset) + + // copied model must have the same parent + MLTestingUtils.checkCopy(model) assert(model.hasSummary) val copiedModel = model.copy(ParamMap.empty) assert(copiedModel.hasSummary) - MLTestingUtils.checkCopy(model) } test("set parameters") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index fb11ebe3184a..73972557d263 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.clustering import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamMap -import org.apache.spark.ml.util.{MLTestingUtils, DefaultReadWriteTest} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} @@ -48,12 +48,13 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR assert(kmeans.getInitMode === MLlibKMeans.K_MEANS_PARALLEL) assert(kmeans.getInitSteps === 2) assert(kmeans.getTol === 1e-4) - val model = kmeans.setMaxIter(1).fit(dataset) + + // copied model must have the same parent + MLTestingUtils.checkCopy(model) assert(model.hasSummary) val copiedModel = model.copy(ParamMap.empty) assert(copiedModel.hasSummary) - MLTestingUtils.checkCopy(model) } test("set parameters") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index b3ded8784c6b..df97d0b2ae7a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -140,14 +140,13 @@ class LinearRegressionSuite assert(lir.getStandardization) assert(lir.getSolver == "auto") val model = lir.fit(datasetWithDenseFeature) - assert(model.hasSummary) // copied model must have the same parent. MLTestingUtils.checkCopy(model) + assert(model.hasSummary) val copiedModel = model.copy(ParamMap.empty) assert(copiedModel.hasSummary) - model.transform(datasetWithDenseFeature) .select("label", "prediction") .collect() From c4da8115cc5fff0722b9649217d343cfe8cec9e9 Mon Sep 17 00:00:00 2001 From: sethah Date: Fri, 4 Nov 2016 14:10:36 -0700 Subject: [PATCH 3/3] set parent for train validation split --- .../org/apache/spark/ml/tuning/TrainValidationSplit.scala | 2 +- .../spark/ml/tuning/TrainValidationSplitSuite.scala | 8 ++++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index 0fdba1cb8814..5d1a39f7c16d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -221,7 +221,7 @@ class TrainValidationSplitModel private[ml] ( uid, bestModel.copy(extra).asInstanceOf[Model[_]], validationMetrics.clone()) - copyValues(copied, extra) + copyValues(copied, extra).setParent(parent) } @Since("2.0.0") diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala index 87100ae2e342..4463a9b6e543 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala @@ -22,11 +22,11 @@ import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel} import org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInput import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator} -import org.apache.spark.ml.linalg.{DenseMatrix, Vectors} +import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared.HasInputCol import org.apache.spark.ml.regression.LinearRegression -import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} import org.apache.spark.sql.Dataset import org.apache.spark.sql.types.StructType @@ -78,6 +78,10 @@ class TrainValidationSplitSuite .setTrainRatio(0.5) .setSeed(42L) val cvModel = cv.fit(dataset) + + // copied model must have the same paren. + MLTestingUtils.checkCopy(cvModel) + val parent = cvModel.bestModel.parent.asInstanceOf[LinearRegression] assert(parent.getRegParam === 0.001) assert(parent.getMaxIter === 10)