diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 6ea52ef7f025..0e48e5b8ee6e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -91,6 +91,11 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) @Since("2.0.0") def setSeed(value: Long): this.type = set(seed, value) + /** @group setParam */ + @Since("2.1.0") + def setGroupCol(value: String): this.type = set(groupKFoldCol, value) + setDefault(groupKFoldCol -> "") + @Since("2.0.0") override def fit(dataset: Dataset[_]): CrossValidatorModel = { val schema = dataset.schema @@ -101,7 +106,16 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) val epm = $(estimatorParamMaps) val numModels = epm.length val metrics = new Array[Double](epm.length) - val splits = MLUtils.kFold(dataset.toDF.rdd, $(numFolds), $(seed)) + + val splits = if ($(groupKFoldCol).nonEmpty) { + val groupKFoldColIdx = schema.fieldNames.indexOf($(groupKFoldCol)) + val pairValue = dataset.toDF.rdd.map(row => (row(groupKFoldColIdx), row)) + val labeledSplits = MLUtils.groupKFold(pairValue, $(numFolds)) + labeledSplits.map { case (training, validation) => (training.values, validation.values) } + } else { + MLUtils.kFold(dataset.toDF.rdd, $(numFolds), $(seed)) + } + splits.zipWithIndex.foreach { case ((training, validation), splitIndex) => val trainingDataset = sparkSession.createDataFrame(training, schema).cache() val validationDataset = sparkSession.createDataFrame(validation, schema).cache() diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala index 26fd73814d70..9f759f2bf06d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala @@ -56,6 +56,17 @@ private[ml] trait ValidatorParams extends HasSeed with Params { /** @group getParam */ def getEstimatorParamMaps: Array[ParamMap] = $(estimatorParamMaps) + /** + * param for groupKFold column name + * default: empty + * @group param + */ + val groupKFoldCol: Param[String] = new Param[String](this, "groupKFoldCol", + "groupKFold column name") + + /** @group getParam */ + def getGroupCol: String = $(groupKFoldCol) + /** * param for the evaluator used to select hyper-parameters that maximize the validated metric * diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index e96c2bc6edfc..bd178d438109 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -18,6 +18,7 @@ package org.apache.spark.mllib.util import scala.annotation.varargs +import scala.collection.mutable.ArrayBuffer import scala.reflect.ClassTag import org.apache.spark.SparkContext @@ -227,6 +228,30 @@ object MLUtils extends Logging { }.toArray } + /** + * Version of [[groupKFold()]] taking a PairRDD with group labels as keys. + */ + def groupKFold[K: ClassTag, V: ClassTag](rdd: RDD[(K, V)], numFolds: Int): + Array[(RDD[(K, V)], RDD[(K, V)])] = { + val groupArray = rdd.mapValues(_ => 1L) + .reduceByKey(_ + _) + .sortBy(_._2, ascending = false) + .collect() + require(groupArray.length >= numFolds, s"there cannot be more folds than groups.") + val samplePerFold = new Array[Long](numFolds) + val group2fold = Array.fill(numFolds)(new ArrayBuffer[K]) + for ((k, v) <- groupArray) { + val index = samplePerFold.zipWithIndex.min._2 + samplePerFold(index) += v + group2fold(index) += k + } + (0 until numFolds).map { fold => + val validation = rdd.filter(sample => group2fold(fold).contains(sample._1)) + val training = rdd.filter(sample => !group2fold(fold).contains(sample._1)) + (training, validation) + }.toArray + } + /** * Returns a new vector with `1.0` (bias) appended to the input vector. */ diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala index 6aa93c907600..2f498cbb898f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala @@ -210,6 +210,45 @@ class MLUtilsSuite extends SparkFunSuite with MLlibTestSparkContext { } } + test("groupKFold") { + val data = sc.parallelize( + Seq((1, 'a'), (1, 'b'), (2, 'c'), (2, 'd'), (2, 'e'), (3, 'f'))) + val collectedData = data.collect().sorted + val twoFoldedRdd = groupKFold(data, 2) + assert(twoFoldedRdd(0)._1.collect().sorted.sameElements(twoFoldedRdd(1)._2.collect().sorted)) + assert(twoFoldedRdd(0)._2.collect().sorted.sameElements(twoFoldedRdd(1)._1.collect().sorted)) + + withClue("Invalid setting of numFolds was not caught !") { + intercept[IllegalArgumentException] { + val testRdd = groupKFold(data, 10) + } + } + + for (folds <- 2 to 3) { + val foldedRdds = groupKFold(data, folds) + assert(foldedRdds.length === folds) + foldedRdds.foreach { case (training, validation) => + val result = validation.union(training).collect().sorted + val validationSize = validation.collect().size.toFloat + assert(validationSize > 0, "empty validation data") + assert(training.collect().size > 0, "empty training data") + assert(result === collectedData, + "Each training+validation set combined should contain all of the data.") + val validationLength = validation.map(_._1).collect().distinct.length + val traingingLength = training.map(_._1).collect().distinct.length + assert(validationLength + traingingLength === validation.map(_._1) + .collect() + .distinct + .union(training.map(_._1).collect().distinct) + .length, + "same group should not appear in both training set and validation set") + } + // group K fold should only have each element in the validation set exactly once + assert(foldedRdds.map(_._2).reduce((x, y) => x.union(y)).collect().sorted === + data.collect().sorted) + } + } + test("loadVectors") { val vectors = sc.parallelize(Seq( Vectors.dense(1.0, 2.0),