Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,11 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
@Since("2.0.0")
def setSeed(value: Long): this.type = set(seed, value)

/** @group setParam */
@Since("2.1.0")
def setGroupCol(value: String): this.type = set(groupKFoldCol, value)
setDefault(groupKFoldCol -> "")

@Since("2.0.0")
override def fit(dataset: Dataset[_]): CrossValidatorModel = {
val schema = dataset.schema
Expand All @@ -101,7 +106,16 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
val epm = $(estimatorParamMaps)
val numModels = epm.length
val metrics = new Array[Double](epm.length)
val splits = MLUtils.kFold(dataset.toDF.rdd, $(numFolds), $(seed))

val splits = if ($(groupKFoldCol).nonEmpty) {
val groupKFoldColIdx = schema.fieldNames.indexOf($(groupKFoldCol))
val pairValue = dataset.toDF.rdd.map(row => (row(groupKFoldColIdx), row))
val labeledSplits = MLUtils.groupKFold(pairValue, $(numFolds))
labeledSplits.map { case (training, validation) => (training.values, validation.values) }
} else {
MLUtils.kFold(dataset.toDF.rdd, $(numFolds), $(seed))
}

splits.zipWithIndex.foreach { case ((training, validation), splitIndex) =>
val trainingDataset = sparkSession.createDataFrame(training, schema).cache()
val validationDataset = sparkSession.createDataFrame(validation, schema).cache()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,17 @@ private[ml] trait ValidatorParams extends HasSeed with Params {
/** @group getParam */
def getEstimatorParamMaps: Array[ParamMap] = $(estimatorParamMaps)

/**
* param for groupKFold column name
* default: empty
* @group param
*/
val groupKFoldCol: Param[String] = new Param[String](this, "groupKFoldCol",
"groupKFold column name")

/** @group getParam */
def getGroupCol: String = $(groupKFoldCol)

/**
* param for the evaluator used to select hyper-parameters that maximize the validated metric
*
Expand Down
25 changes: 25 additions & 0 deletions mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.mllib.util

import scala.annotation.varargs
import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag

import org.apache.spark.SparkContext
Expand Down Expand Up @@ -227,6 +228,30 @@ object MLUtils extends Logging {
}.toArray
}

/**
* Version of [[groupKFold()]] taking a PairRDD with group labels as keys.
*/
def groupKFold[K: ClassTag, V: ClassTag](rdd: RDD[(K, V)], numFolds: Int):
Array[(RDD[(K, V)], RDD[(K, V)])] = {
val groupArray = rdd.mapValues(_ => 1L)
.reduceByKey(_ + _)
.sortBy(_._2, ascending = false)
.collect()
require(groupArray.length >= numFolds, s"there cannot be more folds than groups.")
val samplePerFold = new Array[Long](numFolds)
val group2fold = Array.fill(numFolds)(new ArrayBuffer[K])
for ((k, v) <- groupArray) {
val index = samplePerFold.zipWithIndex.min._2
samplePerFold(index) += v
group2fold(index) += k
}
(0 until numFolds).map { fold =>
val validation = rdd.filter(sample => group2fold(fold).contains(sample._1))
val training = rdd.filter(sample => !group2fold(fold).contains(sample._1))
(training, validation)
}.toArray
}

/**
* Returns a new vector with `1.0` (bias) appended to the input vector.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,45 @@ class MLUtilsSuite extends SparkFunSuite with MLlibTestSparkContext {
}
}

test("groupKFold") {
val data = sc.parallelize(
Seq((1, 'a'), (1, 'b'), (2, 'c'), (2, 'd'), (2, 'e'), (3, 'f')))
val collectedData = data.collect().sorted
val twoFoldedRdd = groupKFold(data, 2)
assert(twoFoldedRdd(0)._1.collect().sorted.sameElements(twoFoldedRdd(1)._2.collect().sorted))
assert(twoFoldedRdd(0)._2.collect().sorted.sameElements(twoFoldedRdd(1)._1.collect().sorted))

withClue("Invalid setting of numFolds was not caught !") {
intercept[IllegalArgumentException] {
val testRdd = groupKFold(data, 10)
}
}

for (folds <- 2 to 3) {
val foldedRdds = groupKFold(data, folds)
assert(foldedRdds.length === folds)
foldedRdds.foreach { case (training, validation) =>
val result = validation.union(training).collect().sorted
val validationSize = validation.collect().size.toFloat
assert(validationSize > 0, "empty validation data")
assert(training.collect().size > 0, "empty training data")
assert(result === collectedData,
"Each training+validation set combined should contain all of the data.")
val validationLength = validation.map(_._1).collect().distinct.length
val traingingLength = training.map(_._1).collect().distinct.length
assert(validationLength + traingingLength === validation.map(_._1)
.collect()
.distinct
.union(training.map(_._1).collect().distinct)
.length,
"same group should not appear in both training set and validation set")
}
// group K fold should only have each element in the validation set exactly once
assert(foldedRdds.map(_._2).reduce((x, y) => x.union(y)).collect().sorted ===
data.collect().sorted)
}
}

test("loadVectors") {
val vectors = sc.parallelize(Seq(
Vectors.dense(1.0, 2.0),
Expand Down