From b14fbab7487a8464ba2a53bb9804e00fd14d3785 Mon Sep 17 00:00:00 2001 From: VinceShieh Date: Mon, 10 Oct 2016 10:33:09 +0800 Subject: [PATCH 1/4] [SPARK-17219][ML] enchance NaN value handling in Bucketizer This PR is an enhancement of PR with commit ID:57dc326bd00cf0a49da971e9c573c48ae28acaa2. We provided user when dealing NaN value in the dataset with 3 options, to either reserve an extra bucket for NaN values, or remove the NaN values, or report an error, by setting "keep", "skip", or "error"(default) to handleInvalid. '''Before: val bucketizer: Bucketizer = new Bucketizer() .setInputCol("feature") .setOutputCol("result") .setSplits(splits) '''After: val bucketizer: Bucketizer = new Bucketizer() .setInputCol("feature") .setOutputCol("result") .setSplits(splits) .setHandleNaN("skip") Signed-off-by: VinceShieh --- docs/ml-features.md | 8 ++- .../apache/spark/ml/feature/Bucketizer.scala | 63 ++++++++++++++++--- .../ml/feature/QuantileDiscretizer.scala | 43 +++++++++++-- .../spark/ml/feature/BucketizerSuite.scala | 8 +-- .../ml/feature/QuantileDiscretizerSuite.scala | 26 +++++--- python/pyspark/ml/feature.py | 8 ++- .../apache/spark/sql/DataFrameStatSuite.scala | 4 ++ 7 files changed, 127 insertions(+), 33 deletions(-) diff --git a/docs/ml-features.md b/docs/ml-features.md index a7f710fa52e6..948d8f29a193 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -1104,9 +1104,11 @@ for more details on the API. `QuantileDiscretizer` takes a column with continuous features and outputs a column with binned categorical features. The number of bins is set by the `numBuckets` parameter. It is possible that the number of buckets used will be less than this value, for example, if there are too few -distinct values of the input to create enough distinct quantiles. Note also that NaN values are -handled specially and placed into their own bucket. For example, if 4 buckets are used, then -non-NaN data will be put into buckets[0-3], but NaNs will be counted in a special bucket[4]. +distinct values of the input to create enough distinct quantiles. Note also that QuantileDiscretizer +will raise an error when it finds NaN value in the dataset, but user can also choose to either +keep or remove NaN values within the dataset by setting handleInvalid. If user chooses to keep +NaN values, they will be handled specially and placed into their own bucket, for example, if 4 buckets +are used, then non-NaN data will be put into buckets[0-3], but NaNs will be counted in a special bucket[4]. The bin ranges are chosen using an approximate algorithm (see the documentation for [approxQuantile](api/scala/index.html#org.apache.spark.sql.DataFrameStatFunctions) for a detailed description). The precision of the approximation can be controlled with the diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index ec0ea05f9e1b..d04ab5cd7572 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -27,6 +27,7 @@ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.util._ import org.apache.spark.sql._ +import org.apache.spark.sql.expressions.UserDefinedFunction import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DoubleType, StructField, StructType} @@ -73,15 +74,52 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String @Since("1.4.0") def setOutputCol(value: String): this.type = set(outputCol, value) + /** + * Param for how to handle invalid entries. Options are skip (which will filter out rows with + * invalid values), or error (which will throw an error), or keep (which will keep the invalid + * values in certain way). Default behaviour is to report an error for invalid entries. + * + * @group param + */ + @Since("2.1.0") + val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle" + + "invalid entries. Options are skip (which will filter out rows with invalid values), or" + + "error (which will throw an error), or keep (which will keep the invalid values" + + " in certain way). Default behaviour is to report an error for invalid entries.", + ParamValidators.inArray(Array("skip", "error", "keep"))) + + /** @group getParam */ + @Since("2.1.0") + def gethandleInvalid: Option[Boolean] = $(handleInvalid) match { + case "keep" => Some(true) + case "skip" => Some(false) + case _ => None + } + + /** @group setParam */ + @Since("2.1.0") + def sethandleInvalid(value: String): this.type = set(handleInvalid, value) + setDefault(handleInvalid, "error") + @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema) - val bucketizer = udf { feature: Double => - Bucketizer.binarySearchForBuckets($(splits), feature) + val keepInvalid = gethandleInvalid.isDefined && gethandleInvalid.get + + val bucketizer: UserDefinedFunction = udf { (feature: Double) => + Bucketizer.binarySearchForBuckets($(splits), feature, keepInvalid) } - val newCol = bucketizer(dataset($(inputCol))) - val newField = prepOutputField(dataset.schema) - dataset.withColumn($(outputCol), newCol, newField.metadata) + val filteredDataset = { + if (!keepInvalid) { + // "skip" NaN option is set, will filter out NaN values in the dataset + dataset.na.drop.toDF() + } else { + dataset.toDF() + } + } + val newCol = bucketizer(filteredDataset($(inputCol))) + val newField = prepOutputField(filteredDataset.schema) + filteredDataset.withColumn($(outputCol), newCol, newField.metadata) } private def prepOutputField(schema: StructType): StructField = { @@ -126,10 +164,21 @@ object Bucketizer extends DefaultParamsReadable[Bucketizer] { /** * Binary searching in several buckets to place each data point. + * @param splits array of split points + * @param feature data point + * @param keepInvalid NaN flag. + * Set "true" to make an extra bucket for NaN values; + * Set "false" to report an error for NaN values + * @return bucket for each data point * @throws SparkException if a feature is < splits.head or > splits.last */ - private[feature] def binarySearchForBuckets(splits: Array[Double], feature: Double): Double = { - if (feature.isNaN) { + + private[feature] def binarySearchForBuckets( + splits: Array[Double], + feature: Double, + keepInvalid: Boolean): Double = { + if (feature.isNaN && keepInvalid) { + // NaN data point found plus "keep" NaN option is set splits.length - 1 } else if (feature == splits.last) { splits.length - 2 diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala index 05e034d90f6a..31bf18f3741c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala @@ -66,11 +66,13 @@ private[feature] trait QuantileDiscretizerBase extends Params /** * `QuantileDiscretizer` takes a column with continuous features and outputs a column with binned * categorical features. The number of bins can be set using the `numBuckets` parameter. It is - * possible that the number of buckets used will be less than this value, for example, if there - * are too few distinct values of the input to create enough distinct quantiles. Note also that - * NaN values are handled specially and placed into their own bucket. For example, if 4 buckets - * are used, then non-NaN data will be put into buckets(0-3), but NaNs will be counted in a special - * bucket(4). + * possible that the number of buckets used will be less than this value, for example, if there are + * too few distinct values of the input to create enough distinct quantiles. Note also that + * QuantileDiscretizer will raise an error when it finds NaN value in the dataset, but user can + * also choose to either keep or remove NaN values within the dataset by setting handleInvalid. + * If user chooses to keep NaN values, they will be handled specially and placed into their own + * bucket, for example, if 4 buckets are used, then non-NaN data will be put into buckets[0-3], + * but NaNs will be counted in a special bucket[4]. * The bin ranges are chosen using an approximate algorithm (see the documentation for * [[org.apache.spark.sql.DataFrameStatFunctions.approxQuantile approxQuantile]] * for a detailed description). The precision of the approximation can be controlled with the @@ -100,6 +102,33 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui @Since("1.6.0") def setOutputCol(value: String): this.type = set(outputCol, value) + /** + * Param for how to handle invalid entries. Options are skip (which will filter out rows with + * invalid values), or error (which will throw an error), or keep (which will keep the invalid + * values in certain way). Default behaviour is to report an error for invalid entries. + * + * @group param + */ + @Since("2.1.0") + val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle" + + "invalid entries. Options are skip (which will filter out rows with invalid values), or" + + "error (which will throw an error), or keep (which will keep the invalid values" + + " in certain way). Default behaviour is to report an error for invalid entries.", + ParamValidators.inArray(Array("skip", "error", "keep"))) + + /** @group getParam */ + @Since("2.1.0") + def gethandleInvalid: Option[Boolean] = $(handleInvalid) match { + case "keep" => Some(true) + case "skip" => Some(false) + case _ => None + } + + /** @group setParam */ + @Since("2.1.0") + def sethandleInvalid(value: String): this.type = set(handleInvalid, value) + setDefault(handleInvalid, "error") + @Since("1.6.0") override def transformSchema(schema: StructType): StructType = { SchemaUtils.checkNumericType(schema, $(inputCol)) @@ -124,7 +153,9 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui log.warn(s"Some quantiles were identical. Bucketing to ${distinctSplits.length - 1}" + s" buckets as a result.") } - val bucketizer = new Bucketizer(uid).setSplits(distinctSplits.sorted) + val bucketizer = new Bucketizer(uid) + .setSplits(distinctSplits.sorted) + .sethandleInvalid($(handleInvalid)) copyValues(bucketizer.setParent(this)) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala index 87cdceb26738..5066238d06ce 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala @@ -98,6 +98,7 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa .setInputCol("feature") .setOutputCol("result") .setSplits(splits) + .sethandleInvalid("keep") bucketizer.transform(dataFrame).select("result", "expected").collect().foreach { case Row(x: Double, y: Double) => @@ -111,8 +112,6 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa withClue("Invalid NaN split was not caught as an invalid split!") { intercept[IllegalArgumentException] { val bucketizer: Bucketizer = new Bucketizer() - .setInputCol("feature") - .setOutputCol("result") .setSplits(splits) } } @@ -138,7 +137,8 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa val data = Array.fill(100)(Random.nextDouble()) val splits: Array[Double] = Double.NegativeInfinity +: Array.fill(10)(Random.nextDouble()).sorted :+ Double.PositiveInfinity - val bsResult = Vectors.dense(data.map(x => Bucketizer.binarySearchForBuckets(splits, x))) + val bsResult = Vectors.dense(data.map(x => + Bucketizer.binarySearchForBuckets(splits, x, false))) val lsResult = Vectors.dense(data.map(x => BucketizerSuite.linearSearchForBuckets(splits, x))) assert(bsResult ~== lsResult absTol 1e-5) } @@ -169,7 +169,7 @@ private object BucketizerSuite extends SparkFunSuite { /** Check all values in splits, plus values between all splits. */ def checkBinarySearch(splits: Array[Double]): Unit = { def testFeature(feature: Double, expectedBucket: Double): Unit = { - assert(Bucketizer.binarySearchForBuckets(splits, feature) === expectedBucket, + assert(Bucketizer.binarySearchForBuckets(splits, feature, false) === expectedBucket, s"Expected feature value $feature to be in bucket $expectedBucket with splits:" + s" ${splits.mkString(", ")}") } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala index 6822594044a5..7464bde5b3e5 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql._ import org.apache.spark.sql.functions.udf class QuantileDiscretizerSuite @@ -76,20 +76,26 @@ class QuantileDiscretizerSuite import spark.implicits._ val numBuckets = 3 - val df = sc.parallelize(Array(1.0, 1.0, 1.0, Double.NaN)) - .map(Tuple1.apply).toDF("input") + val validData = Array(-0.9, -0.5, -0.3, 0.0, 0.2, 0.5, 0.9, Double.NaN, Double.NaN, Double.NaN) + val expectedKeep = Array(0.0, 0.0, 1.0, 1.0, 2.0, 2.0, 2.0, 3.0, 3.0, 3.0) + val expectedSkip = Array(0.0, 0.0, 1.0, 1.0, 2.0, 2.0, 2.0) + val discretizer = new QuantileDiscretizer() .setInputCol("input") .setOutputCol("result") .setNumBuckets(numBuckets) - // Reserve extra one bucket for NaN - val expectedNumBuckets = discretizer.fit(df).getSplits.length - 1 - val result = discretizer.fit(df).transform(df) - val observedNumBuckets = result.select("result").distinct.count - assert(observedNumBuckets == expectedNumBuckets, - s"Observed number of buckets are not correct." + - s" Expected $expectedNumBuckets but found $observedNumBuckets") + List(("keep", expectedKeep), ("skip", expectedSkip)).foreach{ + case(u, v) => + discretizer.sethandleInvalid(u) + val dataFrame: DataFrame = validData.zip(v).toSeq.toDF("input", "expected") + val result = discretizer.fit(dataFrame).transform(dataFrame) + result.select("result", "expected").collect().foreach { + case Row(x: Double, y: Double) => + assert(x === y, + s"The feature value is not correct after bucketing. Expected $y but found $x") + } + } } test("Test transform method on unseen data") { diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 64b21caa616e..469c96377276 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -1157,9 +1157,11 @@ class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadab categorical features. The number of bins can be set using the :py:attr:`numBuckets` parameter. It is possible that the number of buckets used will be less than this value, for example, if there are too few distinct values of the input to create enough distinct quantiles. Note also - that NaN values are handled specially and placed into their own bucket. For example, if 4 - buckets are used, then non-NaN data will be put into buckets(0-3), but NaNs will be counted in - a special bucket(4). + that QuantileDiscretizer will raise an error when it finds NaN value in the dataset, but user + can also choose to either keep or remove NaN values within the dataset by setting + handleInvalid. If user chooses to keep NaN values, they will be handled specially and placed + into their own bucket, for example, if 4 buckets are used, then non-NaN data will be put into + buckets[0-3], but NaNs will be counted in a special bucket[4]. The bin ranges are chosen using an approximate algorithm (see the documentation for :py:meth:`~.DataFrameStatFunctions.approxQuantile` for a detailed description). The precision of the approximation can be controlled with the diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 73026c749db4..726773ed9365 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -150,6 +150,10 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { assert(math.abs(d1 - 2 * q1 * n) < error_double) assert(math.abs(d2 - 2 * q2 * n) < error_double) } + // test approxQuantile on NaN values + val dfNaN = Array(Double.NaN, 1.0, Double.NaN, Double.NaN).toSeq.toDF("input") + val resNaN = dfNaN.stat.approxQuantile("input", Array(q1, q2), epsilons(0)) + assert(resNaN.count(_.isNaN) == 0) } test("crosstab") { From 5274d4a3703193a59607635f80eb9e3ebe61552c Mon Sep 17 00:00:00 2001 From: VinceShieh Date: Tue, 25 Oct 2016 14:13:56 +0800 Subject: [PATCH 2/4] [SPARK-17219][ML] enchance NaN value handling in Bucketizer This PR is an enhancement of PR with commit ID:57dc326bd00cf0a49da971e9c573c48ae28acaa2. We provided user when dealing NaN value in the dataset with 3 options, to either reserve an extra bucket for NaN values, or remove the NaN values, or report an error, by setting "keep", "skip", or "error"(default) to handleInvalid. '''Before: val bucketizer: Bucketizer = new Bucketizer() .setInputCol("feature") .setOutputCol("result") .setSplits(splits) '''After: val bucketizer: Bucketizer = new Bucketizer() .setInputCol("feature") .setOutputCol("result") .setSplits(splits) .setHandleNaN("skip") Signed-off-by: VinceShieh --- .../apache/spark/ml/feature/Bucketizer.scala | 28 ++++++++----------- .../ml/feature/QuantileDiscretizer.scala | 8 ++---- 2 files changed, 14 insertions(+), 22 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index d04ab5cd7572..e568dc0556a5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -77,8 +77,8 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String /** * Param for how to handle invalid entries. Options are skip (which will filter out rows with * invalid values), or error (which will throw an error), or keep (which will keep the invalid - * values in certain way). Default behaviour is to report an error for invalid entries. - * + * values in certain way). + * Default: "error" * @group param */ @Since("2.1.0") @@ -90,11 +90,7 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String /** @group getParam */ @Since("2.1.0") - def gethandleInvalid: Option[Boolean] = $(handleInvalid) match { - case "keep" => Some(true) - case "skip" => Some(false) - case _ => None - } + def gethandleInvalid: String = $(handleInvalid) /** @group setParam */ @Since("2.1.0") @@ -104,19 +100,19 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema) - val keepInvalid = gethandleInvalid.isDefined && gethandleInvalid.get - - val bucketizer: UserDefinedFunction = udf { (feature: Double) => - Bucketizer.binarySearchForBuckets($(splits), feature, keepInvalid) - } - val filteredDataset = { - if (!keepInvalid) { + val (filteredDataset, keepInvalid) = { + if ("skip" == gethandleInvalid) { // "skip" NaN option is set, will filter out NaN values in the dataset - dataset.na.drop.toDF() + (dataset.na.drop.toDF(), false) } else { - dataset.toDF() + (dataset.toDF(), "keep" == gethandleInvalid) } } + + val bucketizer: UserDefinedFunction = udf { (feature: Double) => + Bucketizer.binarySearchForBuckets($(splits), feature, keepInvalid) + } + val newCol = bucketizer(filteredDataset($(inputCol))) val newField = prepOutputField(filteredDataset.schema) filteredDataset.withColumn($(outputCol), newCol, newField.metadata) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala index 31bf18f3741c..5a90abba242f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala @@ -106,7 +106,7 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui * Param for how to handle invalid entries. Options are skip (which will filter out rows with * invalid values), or error (which will throw an error), or keep (which will keep the invalid * values in certain way). Default behaviour is to report an error for invalid entries. - * + * Default: "error" * @group param */ @Since("2.1.0") @@ -118,11 +118,7 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui /** @group getParam */ @Since("2.1.0") - def gethandleInvalid: Option[Boolean] = $(handleInvalid) match { - case "keep" => Some(true) - case "skip" => Some(false) - case _ => None - } + def gethandleInvalid: String = $(handleInvalid) /** @group setParam */ @Since("2.1.0") From 2f98d31118413e61e1aa0431da402c41aa1ca5a6 Mon Sep 17 00:00:00 2001 From: VinceShieh Date: Wed, 26 Oct 2016 11:12:26 +0800 Subject: [PATCH 3/4] revert changes in feature.py Signed-off-by: VinceShieh --- python/pyspark/ml/feature.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 469c96377276..ee86207e7744 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -1155,13 +1155,6 @@ class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadab `QuantileDiscretizer` takes a column with continuous features and outputs a column with binned categorical features. The number of bins can be set using the :py:attr:`numBuckets` parameter. - It is possible that the number of buckets used will be less than this value, for example, if - there are too few distinct values of the input to create enough distinct quantiles. Note also - that QuantileDiscretizer will raise an error when it finds NaN value in the dataset, but user - can also choose to either keep or remove NaN values within the dataset by setting - handleInvalid. If user chooses to keep NaN values, they will be handled specially and placed - into their own bucket, for example, if 4 buckets are used, then non-NaN data will be put into - buckets[0-3], but NaNs will be counted in a special bucket[4]. The bin ranges are chosen using an approximate algorithm (see the documentation for :py:meth:`~.DataFrameStatFunctions.approxQuantile` for a detailed description). The precision of the approximation can be controlled with the From 2644235f111bbbf43fd1f30d24d318735553e034 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 26 Oct 2016 13:01:27 -0700 Subject: [PATCH 4/4] Cleanups: docs cleanups, slightly improved unit test coverage, fixed naming of set/get for handleInvalid --- docs/ml-features.md | 13 ++-- .../apache/spark/ml/feature/Bucketizer.scala | 44 +++++++++----- .../ml/feature/QuantileDiscretizer.scala | 60 ++++++++++--------- .../spark/ml/feature/BucketizerSuite.scala | 20 +++++-- .../ml/feature/QuantileDiscretizerSuite.scala | 11 +++- .../apache/spark/sql/DataFrameStatSuite.scala | 6 +- 6 files changed, 97 insertions(+), 57 deletions(-) diff --git a/docs/ml-features.md b/docs/ml-features.md index 948d8f29a193..64c6a160239c 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -1103,13 +1103,16 @@ for more details on the API. `QuantileDiscretizer` takes a column with continuous features and outputs a column with binned categorical features. The number of bins is set by the `numBuckets` parameter. It is possible -that the number of buckets used will be less than this value, for example, if there are too few -distinct values of the input to create enough distinct quantiles. Note also that QuantileDiscretizer -will raise an error when it finds NaN value in the dataset, but user can also choose to either -keep or remove NaN values within the dataset by setting handleInvalid. If user chooses to keep +that the number of buckets used will be smaller than this value, for example, if there are too few +distinct values of the input to create enough distinct quantiles. + +NaN values: Note also that QuantileDiscretizer +will raise an error when it finds NaN values in the dataset, but the user can also choose to either +keep or remove NaN values within the dataset by setting `handleInvalid`. If the user chooses to keep NaN values, they will be handled specially and placed into their own bucket, for example, if 4 buckets are used, then non-NaN data will be put into buckets[0-3], but NaNs will be counted in a special bucket[4]. -The bin ranges are chosen using an approximate algorithm (see the documentation for + +Algorithm: The bin ranges are chosen using an approximate algorithm (see the documentation for [approxQuantile](api/scala/index.html#org.apache.spark.sql.DataFrameStatFunctions) for a detailed description). The precision of the approximation can be controlled with the `relativeError` parameter. When set to zero, exact quantiles are calculated diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index e568dc0556a5..1143f0f565eb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -47,6 +47,9 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String * also includes y. Splits should be of length >= 3 and strictly increasing. * Values at -inf, inf must be explicitly provided to cover all Double values; * otherwise, values outside the splits specified will be treated as errors. + * + * See also [[handleInvalid]], which can optionally create an additional bucket for NaN values. + * * @group param */ @Since("1.4.0") @@ -75,37 +78,36 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String def setOutputCol(value: String): this.type = set(outputCol, value) /** - * Param for how to handle invalid entries. Options are skip (which will filter out rows with - * invalid values), or error (which will throw an error), or keep (which will keep the invalid - * values in certain way). + * Param for how to handle invalid entries. Options are skip (filter out rows with + * invalid values), error (throw an error), or keep (keep invalid values in a special additional + * bucket). * Default: "error" * @group param */ @Since("2.1.0") val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle" + - "invalid entries. Options are skip (which will filter out rows with invalid values), or" + - "error (which will throw an error), or keep (which will keep the invalid values" + - " in certain way). Default behaviour is to report an error for invalid entries.", - ParamValidators.inArray(Array("skip", "error", "keep"))) + "invalid entries. Options are skip (filter out rows with invalid values), " + + "error (throw an error), or keep (keep invalid values in a special additional bucket).", + ParamValidators.inArray(Bucketizer.supportedHandleInvalid)) /** @group getParam */ @Since("2.1.0") - def gethandleInvalid: String = $(handleInvalid) + def getHandleInvalid: String = $(handleInvalid) /** @group setParam */ @Since("2.1.0") - def sethandleInvalid(value: String): this.type = set(handleInvalid, value) - setDefault(handleInvalid, "error") + def setHandleInvalid(value: String): this.type = set(handleInvalid, value) + setDefault(handleInvalid, Bucketizer.ERROR_INVALID) @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema) val (filteredDataset, keepInvalid) = { - if ("skip" == gethandleInvalid) { + if (getHandleInvalid == Bucketizer.SKIP_INVALID) { // "skip" NaN option is set, will filter out NaN values in the dataset - (dataset.na.drop.toDF(), false) + (dataset.na.drop().toDF(), false) } else { - (dataset.toDF(), "keep" == gethandleInvalid) + (dataset.toDF(), getHandleInvalid == Bucketizer.KEEP_INVALID) } } @@ -140,6 +142,12 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String @Since("1.6.0") object Bucketizer extends DefaultParamsReadable[Bucketizer] { + private[feature] val SKIP_INVALID: String = "skip" + private[feature] val ERROR_INVALID: String = "error" + private[feature] val KEEP_INVALID: String = "keep" + private[feature] val supportedHandleInvalid: Array[String] = + Array(SKIP_INVALID, ERROR_INVALID, KEEP_INVALID) + /** * We require splits to be of length >= 3 and to be in strictly increasing order. * No NaN split should be accepted. @@ -173,9 +181,13 @@ object Bucketizer extends DefaultParamsReadable[Bucketizer] { splits: Array[Double], feature: Double, keepInvalid: Boolean): Double = { - if (feature.isNaN && keepInvalid) { - // NaN data point found plus "keep" NaN option is set - splits.length - 1 + if (feature.isNaN) { + if (keepInvalid) { + splits.length - 1 + } else { + throw new SparkException("Bucketizer encountered NaN value. To handle or skip NaNs," + + " try setting Bucketizer.handleInvalid.") + } } else if (feature == splits.last) { splits.length - 2 } else { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala index 5a90abba242f..b9e01dde70d8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala @@ -36,6 +36,9 @@ private[feature] trait QuantileDiscretizerBase extends Params /** * Number of buckets (quantiles, or categories) into which data points are grouped. Must * be >= 2. + * + * See also [[handleInvalid]], which can optionally create an additional bucket for NaN values. + * * default: 2 * @group param */ @@ -61,19 +64,41 @@ private[feature] trait QuantileDiscretizerBase extends Params /** @group getParam */ def getRelativeError: Double = getOrDefault(relativeError) + + /** + * Param for how to handle invalid entries. Options are skip (filter out rows with + * invalid values), error (throw an error), or keep (keep invalid values in a special additional + * bucket). + * Default: "error" + * @group param + */ + @Since("2.1.0") + val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle" + + "invalid entries. Options are skip (filter out rows with invalid values), " + + "error (throw an error), or keep (keep invalid values in a special additional bucket).", + ParamValidators.inArray(Bucketizer.supportedHandleInvalid)) + setDefault(handleInvalid, Bucketizer.ERROR_INVALID) + + /** @group getParam */ + @Since("2.1.0") + def getHandleInvalid: String = $(handleInvalid) + } /** * `QuantileDiscretizer` takes a column with continuous features and outputs a column with binned * categorical features. The number of bins can be set using the `numBuckets` parameter. It is - * possible that the number of buckets used will be less than this value, for example, if there are - * too few distinct values of the input to create enough distinct quantiles. Note also that - * QuantileDiscretizer will raise an error when it finds NaN value in the dataset, but user can - * also choose to either keep or remove NaN values within the dataset by setting handleInvalid. - * If user chooses to keep NaN values, they will be handled specially and placed into their own + * possible that the number of buckets used will be smaller than this value, for example, if there + * are too few distinct values of the input to create enough distinct quantiles. + * + * NaN handling: Note also that + * QuantileDiscretizer will raise an error when it finds NaN values in the dataset, but the user can + * also choose to either keep or remove NaN values within the dataset by setting `handleInvalid`. + * If the user chooses to keep NaN values, they will be handled specially and placed into their own * bucket, for example, if 4 buckets are used, then non-NaN data will be put into buckets[0-3], * but NaNs will be counted in a special bucket[4]. - * The bin ranges are chosen using an approximate algorithm (see the documentation for + * + * Algorithm: The bin ranges are chosen using an approximate algorithm (see the documentation for * [[org.apache.spark.sql.DataFrameStatFunctions.approxQuantile approxQuantile]] * for a detailed description). The precision of the approximation can be controlled with the * `relativeError` parameter. The lower and upper bin bounds will be `-Infinity` and `+Infinity`, @@ -102,28 +127,9 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui @Since("1.6.0") def setOutputCol(value: String): this.type = set(outputCol, value) - /** - * Param for how to handle invalid entries. Options are skip (which will filter out rows with - * invalid values), or error (which will throw an error), or keep (which will keep the invalid - * values in certain way). Default behaviour is to report an error for invalid entries. - * Default: "error" - * @group param - */ - @Since("2.1.0") - val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle" + - "invalid entries. Options are skip (which will filter out rows with invalid values), or" + - "error (which will throw an error), or keep (which will keep the invalid values" + - " in certain way). Default behaviour is to report an error for invalid entries.", - ParamValidators.inArray(Array("skip", "error", "keep"))) - - /** @group getParam */ - @Since("2.1.0") - def gethandleInvalid: String = $(handleInvalid) - /** @group setParam */ @Since("2.1.0") - def sethandleInvalid(value: String): this.type = set(handleInvalid, value) - setDefault(handleInvalid, "error") + def setHandleInvalid(value: String): this.type = set(handleInvalid, value) @Since("1.6.0") override def transformSchema(schema: StructType): StructType = { @@ -151,7 +157,7 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui } val bucketizer = new Bucketizer(uid) .setSplits(distinctSplits.sorted) - .sethandleInvalid($(handleInvalid)) + .setHandleInvalid($(handleInvalid)) copyValues(bucketizer.setParent(this)) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala index 5066238d06ce..aac29137d791 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala @@ -98,21 +98,33 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa .setInputCol("feature") .setOutputCol("result") .setSplits(splits) - .sethandleInvalid("keep") + bucketizer.setHandleInvalid("keep") bucketizer.transform(dataFrame).select("result", "expected").collect().foreach { case Row(x: Double, y: Double) => assert(x === y, s"The feature value is not correct after bucketing. Expected $y but found $x") } + + bucketizer.setHandleInvalid("skip") + val skipResults: Array[Double] = bucketizer.transform(dataFrame) + .select("result").as[Double].collect() + assert(skipResults.length === 7) + assert(skipResults.forall(_ !== 4.0)) + + bucketizer.setHandleInvalid("error") + withClue("Bucketizer should throw error when setHandleInvalid=error and given NaN values") { + intercept[SparkException] { + bucketizer.transform(dataFrame).collect() + } + } } test("Bucket continuous features, with NaN splits") { val splits = Array(Double.NegativeInfinity, -0.5, 0.0, 0.5, Double.PositiveInfinity, Double.NaN) - withClue("Invalid NaN split was not caught as an invalid split!") { + withClue("Invalid NaN split was not caught during Bucketizer initialization") { intercept[IllegalArgumentException] { - val bucketizer: Bucketizer = new Bucketizer() - .setSplits(splits) + new Bucketizer().setSplits(splits) } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala index 7464bde5b3e5..f219f775b218 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql._ @@ -85,9 +85,16 @@ class QuantileDiscretizerSuite .setOutputCol("result") .setNumBuckets(numBuckets) + withClue("QuantileDiscretizer with handleInvalid=error should throw exception for NaN values") { + val dataFrame: DataFrame = validData.toSeq.toDF("input") + intercept[SparkException] { + discretizer.fit(dataFrame).transform(dataFrame).collect() + } + } + List(("keep", expectedKeep), ("skip", expectedSkip)).foreach{ case(u, v) => - discretizer.sethandleInvalid(u) + discretizer.setHandleInvalid(u) val dataFrame: DataFrame = validData.zip(v).toSeq.toDF("input", "expected") val result = discretizer.fit(dataFrame).transform(dataFrame) result.select("result", "expected").collect().foreach { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 726773ed9365..1383208874a1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -151,9 +151,9 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { assert(math.abs(d2 - 2 * q2 * n) < error_double) } // test approxQuantile on NaN values - val dfNaN = Array(Double.NaN, 1.0, Double.NaN, Double.NaN).toSeq.toDF("input") - val resNaN = dfNaN.stat.approxQuantile("input", Array(q1, q2), epsilons(0)) - assert(resNaN.count(_.isNaN) == 0) + val dfNaN = Seq(Double.NaN, 1.0, Double.NaN, Double.NaN).toDF("input") + val resNaN = dfNaN.stat.approxQuantile("input", Array(q1, q2), epsilons.head) + assert(resNaN.count(_.isNaN) === 0) } test("crosstab") {