diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index e60a14f976a5..6c00b569aa06 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -141,8 +141,10 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) Some(Array.fill($(numFolds))(Array.fill[Model[_]](epm.length)(null))) } else None + val inputRDD = dataset.toDF.rdd + inputRDD.persist() // Compute metrics for each model over each split - val splits = MLUtils.kFold(dataset.toDF.rdd, $(numFolds), $(seed)) + val splits = MLUtils.kFold(inputRDD, $(numFolds), $(seed)) val metrics = splits.zipWithIndex.map { case ((training, validation), splitIndex) => val trainingDataset = sparkSession.createDataFrame(training, schema).cache() val validationDataset = sparkSession.createDataFrame(validation, schema).cache()