From 5715092892218ae9b99ac173adf2c98452b58bc7 Mon Sep 17 00:00:00 2001 From: Enrique Rebollo Date: Fri, 4 Oct 2024 09:16:58 +0200 Subject: [PATCH 1/9] [SPARK-37178][ML] Add Target Encoding to ml.feature --- docs/ml-features.md | 40 ++ .../examples/ml/JavaTargetEncoderExample.java | 90 ++++ .../main/python/ml/target_encoder_example.py | 65 +++ .../examples/ml/TargetEncoderExample.scala | 71 +++ .../spark/ml/feature/TargetEncoder.scala | 409 +++++++++++++++++ .../ml/feature/JavaTargetEncoderSuite.java | 110 +++++ .../spark/ml/feature/TargetEncoderSuite.scala | 410 ++++++++++++++++++ python/docs/source/reference/pyspark.ml.rst | 2 + python/pyspark/ml/feature.py | 301 +++++++++++++ python/pyspark/ml/tests/test_feature.py | 75 ++++ 10 files changed, 1573 insertions(+) create mode 100644 examples/src/main/java/org/apache/spark/examples/ml/JavaTargetEncoderExample.java create mode 100644 examples/src/main/python/ml/target_encoder_example.py create mode 100644 examples/src/main/scala/org/apache/spark/examples/ml/TargetEncoderExample.scala create mode 100644 mllib/src/main/scala/org/apache/spark/ml/feature/TargetEncoder.scala create mode 100644 mllib/src/test/java/org/apache/spark/ml/feature/JavaTargetEncoderSuite.java create mode 100644 mllib/src/test/scala/org/apache/spark/ml/feature/TargetEncoderSuite.scala diff --git a/docs/ml-features.md b/docs/ml-features.md index 3dbb960dea03e..e34137e7e7628 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -855,6 +855,46 @@ for more details on the API. +## TargetEncoder + +Target Encoding maps a column of categorical indices into a numerical feature derived from the target. For string type input data, it is common to encode categorical features using [StringIndexer](ml-features.html#stringindexer) first. + +`TargetEncoder` can transform multiple columns, returning a target-encoded output column for each input column. + +`TargetEncoder` supports the `handleInvalid` parameter to choose how to handle invalid input during transforming data. Available options include 'keep' (any invalid inputs are assigned to an extra categorical index) and 'error' (throw an error). + +`TargetEncoder` supports the `targetType` parameter to choose the label type when fitting data, affecting how statistics are calculated. Available options include 'binary' (bin-counting) and 'continuous' (mean-encoding). + +`TargetEncoder` supports the `smoothing` parameter to tune how in-category stats and overall stats are weighted. + +**Examples** + +
+ +
+ +Refer to the [TargetEncoder Python docs](api/python/reference/api/pyspark.ml.feature.TargetEncoder.html) for more details on the API. + +{% include_example python/ml/target_encoder_example.py %} +
+ +
+ +Refer to the [TargetEncoder Scala docs](api/scala/org/apache/spark/ml/feature/TargetEncoder.html) for more details on the API. + +{% include_example scala/org/apache/spark/examples/ml/TargetEncoderExample.scala %} +
+ +
+ +Refer to the [TargetEncoder Java docs](api/java/org/apache/spark/ml/feature/TargetEncoder.html) +for more details on the API. + +{% include_example java/org/apache/spark/examples/ml/JavaTargetEncoderExample.java %} +
+ +
+ ## VectorIndexer `VectorIndexer` helps index categorical features in datasets of `Vector`s. diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaTargetEncoderExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaTargetEncoderExample.java new file mode 100644 index 0000000000000..da391bd469192 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaTargetEncoderExample.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import org.apache.spark.sql.SparkSession; + +// $example on$ +import org.apache.spark.ml.feature.TargetEncoder; +import org.apache.spark.ml.feature.TargetEncoderModel; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +import java.util.Arrays; +import java.util.List; +// $example off$ + +public class JavaTargetEncoderExample { + public static void main(String[] args) { + SparkSession spark = SparkSession + .builder() + .appName("JavaTargetEncoderExample") + .getOrCreate(); + + // Note: categorical features are usually first encoded with StringIndexer + // $example on$ + List data = Arrays.asList( + RowFactory.create(0.0, 1.0, 0, 10.0), + RowFactory.create(1.0, 0.0, 1, 20.0), + RowFactory.create(2.0, 1.0, 0, 30.0), + RowFactory.create(0.0, 2.0, 1, 40.0), + RowFactory.create(0.0, 1.0, 0, 50.0), + RowFactory.create(2.0, 0.0, 1, 60.0) + ); + + StructType schema = new StructType(new StructField[]{ + new StructField("categoryIndex1", DataTypes.DoubleType, false, Metadata.empty()), + new StructField("categoryIndex2", DataTypes.DoubleType, false, Metadata.empty()), + new StructField("binaryLabel", DataTypes.DoubleType, false, Metadata.empty()), + new StructField("continuousLabel", DataTypes.DoubleType, false, Metadata.empty()) + }); + + Dataset df = spark.createDataFrame(data, schema); + + // binary target + TargetEncoder bin_encoder = new TargetEncoder() + .setInputCols(new String[] {"categoryIndex1", "categoryIndex2"}) + .setOutputCols(new String[] {"categoryIndex1Target", "categoryIndex2Target"}) + .setLabelCol("binaryLabel") + .setTargetType("binary"); + + TargetEncoderModel bin_model = bin_encoder.fit(df); + Dataset bin_encoded = bin_model.transform(df); + bin_encoded.show(); + + // continuous target + TargetEncoder cont_encoder = new TargetEncoder() + .setInputCols(new String[] {"categoryIndex1", "categoryIndex2"}) + .setOutputCols(new String[] {"categoryIndex1Target", "categoryIndex2Target"}) + .setLabelCol("continuousLabel") + .setTargetType("continuous"); + + TargetEncoderModel cont_model = cont_encoder.fit(df); + Dataset cont_encoded = cont_model.transform(df); + cont_encoded.show(); + // $example off$ + + spark.stop(); + } +} + diff --git a/examples/src/main/python/ml/target_encoder_example.py b/examples/src/main/python/ml/target_encoder_example.py new file mode 100644 index 0000000000000..f6c1010de71f3 --- /dev/null +++ b/examples/src/main/python/ml/target_encoder_example.py @@ -0,0 +1,65 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# $example on$ +from pyspark.ml.feature import TargetEncoder + +# $example off$ +from pyspark.sql import SparkSession + +if __name__ == "__main__": + spark = SparkSession.builder.appName("TargetEncoderExample").getOrCreate() + + # Note: categorical features are usually first encoded with StringIndexer + # $example on$ + df = spark.createDataFrame( + [ + (0.0, 1.0, 0, 10.0), + (1.0, 0.0, 1, 20.0), + (2.0, 1.0, 0, 30.0), + (0.0, 2.0, 1, 40.0), + (0.0, 1.0, 0, 50.0), + (2.0, 0.0, 1, 60.0), + ], + ["categoryIndex1", "categoryIndex2", "binaryLabel", "continuousLabel"], + ) + + # binary target + encoder = TargetEncoder( + inputCols=["categoryIndex1", "categoryIndex2"], + outputCols=["categoryIndex1Target", "categoryIndex2Target"], + labelCol="binaryLabel", + targetType="binary" + ) + model = encoder.fit(df) + encoded = model.transform(df) + encoded.show() + + # continuous target + encoder = TargetEncoder( + inputCols=["categoryIndex1", "categoryIndex2"], + outputCols=["categoryIndex1Target", "categoryIndex2Target"], + labelCol="continuousLabel", + targetType="continuous" + ) + + model = encoder.fit(df) + encoded = model.transform(df) + encoded.show() + # $example off$ + + spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/TargetEncoderExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/TargetEncoderExample.scala new file mode 100644 index 0000000000000..a03f903c86d06 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/TargetEncoderExample.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.feature.TargetEncoder +// $example off$ +import org.apache.spark.sql.SparkSession + +object TargetEncoderExample { + def main(args: Array[String]): Unit = { + val spark = SparkSession + .builder() + .appName("TargetEncoderExample") + .getOrCreate() + + // Note: categorical features are usually first encoded with StringIndexer + // $example on$ + val df = spark.createDataFrame(Seq( + (0.0, 1.0, 0, 10.0), + (1.0, 0.0, 1, 20.0), + (2.0, 1.0, 0, 30.0), + (0.0, 2.0, 1, 40.0), + (0.0, 1.0, 0, 50.0), + (2.0, 0.0, 1, 60.0) + )).toDF("categoryIndex1", "categoryIndex2", + "binaryLabel", "continuousLabel") + + // binary target + val bin_encoder = new TargetEncoder() + .setInputCols(Array("categoryIndex1", "categoryIndex2")) + .setOutputCols(Array("categoryIndex1Target", "categoryIndex2Target")) + .setLabelCol("binaryLabel") + .setTargetType("binary"); + + val bin_model = bin_encoder.fit(df) + val bin_encoded = bin_model.transform(df) + bin_encoded.show() + + // continuous target + val cont_encoder = new TargetEncoder() + .setInputCols(Array("categoryIndex1", "categoryIndex2")) + .setOutputCols(Array("categoryIndex1Target", "categoryIndex2Target")) + .setLabelCol("continuousLabel") + .setTargetType("continuous"); + + val cont_model = cont_encoder.fit(df) + val cont_encoded = cont_model.transform(df) + cont_encoded.show() + // $example off$ + + spark.stop() + } +} +// scalastyle:on println diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/TargetEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/TargetEncoder.scala new file mode 100644 index 0000000000000..36b071fd14cae --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/TargetEncoder.scala @@ -0,0 +1,409 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import scala.collection.immutable.ArraySeq + +import org.apache.hadoop.fs.Path + +import org.apache.spark.{SparkException, SparkRuntimeException} +import org.apache.spark.annotation.Since +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.attribute.NominalAttribute +import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util._ +import org.apache.spark.sql.{Column, DataFrame, Dataset, Row} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ + +/** Private trait for params and common methods for TargetEncoder and TargetEncoderModel */ +private[ml] trait TargetEncoderBase extends Params with HasLabelCol + with HasInputCol with HasInputCols with HasOutputCol with HasOutputCols with HasHandleInvalid { + + /** + * Param for how to handle invalid data during transform(). + * Options are 'keep' (invalid data presented as an extra categorical feature) or + * 'error' (throw an error). + * Note that this Param is only used during transform; during fitting, invalid data + * will result in an error. + * Default: "error" + * @group param + */ + @Since("4.0.0") + override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", + "How to handle invalid data during transform(). " + + "Options are 'keep' (invalid data presented as an extra categorical feature) " + + "or error (throw an error). Note that this Param is only used during transform; " + + "during fitting, invalid data will result in an error.", + ParamValidators.inArray(TargetEncoder.supportedHandleInvalids)) + + setDefault(handleInvalid -> TargetEncoder.ERROR_INVALID) + + @Since("4.0.0") + val targetType: Param[String] = new Param[String](this, "targetType", + "How to handle invalid data during transform(). " + + "Options are 'keep' (invalid data presented as an extra categorical feature) " + + "or error (throw an error). Note that this Param is only used during transform; " + + "during fitting, invalid data will result in an error.", + ParamValidators.inArray(TargetEncoder.supportedTargetTypes)) + + setDefault(targetType -> TargetEncoder.TARGET_BINARY) + + final def getTargetType: String = $(targetType) + + @Since("4.0.0") + val smoothing: DoubleParam = new DoubleParam(this, "smoothing", + "lower bound of the output feature range", + ParamValidators.gtEq(0.0)) + + setDefault(smoothing -> 0.0) + + final def getSmoothing: Double = $(smoothing) + + private[feature] lazy val inputFeatures = if (isSet(inputCol)) Array($(inputCol)) + else if (isSet(inputCols)) $(inputCols) + else Array.empty[String] + + private[feature] lazy val outputFeatures = if (isSet(outputCol)) Array($(outputCol)) + else if (isSet(outputCols)) $(outputCols) + else inputFeatures.map{field: String => s"${field}_indexed"} + + private[feature] def validateSchema(schema: StructType, + fitting: Boolean): StructType = { + + require(inputFeatures.length > 0, + s"At least one input column must be specified.") + + require(inputFeatures.length == outputFeatures.length, + s"The number of input columns ${inputFeatures.length} must be the same as the number of " + + s"output columns ${outputFeatures.length}.") + + val features = if (fitting) inputFeatures :+ $(labelCol) + else inputFeatures + + features.foreach { + feature => { + try { + val field = schema(feature) + if (field.dataType != DoubleType) { + throw new SparkException(s"Data type for column ${feature} is ${field.dataType}" + + s", but ${DoubleType.typeName} is required.") + } + } catch { + case e: IllegalArgumentException => + throw new SparkException(s"No column named ${feature} found on dataset.") + } + } + } + schema + } + +} + +/** + * Target Encoding maps a column of categorical indices into a numerical feature derived + * from the target. + * + * When `handleInvalid` is configured to 'keep', previously unseen values of a feature + * are mapped to the dataset overall statistics. + * + * When 'targetType' is configured to 'binary', categories are encoded as the conditional + * probability of the target given that category (bin counting). + * When 'targetType' is configured to 'continuous', categories are encoded as the average + * of the target given that category (mean encoding) + * + * Parameter 'smoothing' controls how in-category stats and overall stats are weighted. + * + * @note When encoding multi-column by using `inputCols` and `outputCols` params, input/output cols + * come in pairs, specified by the order in the arrays, and each pair is treated independently. + * + * @see `StringIndexer` for converting categorical values into category indices + */ +@Since("4.0.0") +class TargetEncoder @Since("4.0.0") (@Since("4.0.0") override val uid: String) + extends Estimator[TargetEncoderModel] with TargetEncoderBase with DefaultParamsWritable { + + @Since("4.0.0") + def this() = this(Identifiable.randomUID("TargetEncoder")) + + /** @group setParam */ + @Since("4.0.0") + def setLabelCol(value: String): this.type = set(labelCol, value) + + /** @group setParam */ + @Since("4.0.0") + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + @Since("4.0.0") + def setOutputCol(value: String): this.type = set(outputCol, value) + + /** @group setParam */ + @Since("4.0.0") + def setInputCols(values: Array[String]): this.type = set(inputCols, values) + + /** @group setParam */ + @Since("4.0.0") + def setOutputCols(values: Array[String]): this.type = set(outputCols, values) + + /** @group setParam */ + @Since("4.0.0") + def setHandleInvalid(value: String): this.type = set(handleInvalid, value) + + /** @group setParam */ + @Since("4.0.0") + def setTargetType(value: String): this.type = set(targetType, value) + + /** @group setParam */ + @Since("4.0.0") + def setSmoothing(value: Double): this.type = set(smoothing, value) + + @Since("4.0.0") + override def transformSchema(schema: StructType): StructType = { + validateSchema(schema, fitting = true) + } + + @Since("4.0.0") + override def fit(dataset: Dataset[_]): TargetEncoderModel = { + validateSchema(dataset.schema, fitting = true) + + val stats = dataset + .select(ArraySeq.unsafeWrapArray( + (inputFeatures :+ $(labelCol)).map(col)): _*) + .rdd + .treeAggregate( + Array.fill(inputFeatures.length) { + Map.empty[Double, (Double, Double)] + })( + (agg, row: Row) => { + val label = row.getDouble(inputFeatures.length) + Range(0, inputFeatures.length).map { + feature => try { + val value = row.getDouble(feature) + if (value < 0.0 || value != value.toInt) throw new SparkException( + s"Values from column ${inputFeatures(feature)} must be indices, but got $value.") + val counter = agg(feature).getOrElse(value, (0.0, 0.0)) + val globalCounter = agg(feature).getOrElse(TargetEncoder.UNSEEN_CATEGORY, (0.0, 0.0)) + $(targetType) match { + case TargetEncoder.TARGET_BINARY => + if (label == 1.0) agg(feature) + + (value -> (1 + counter._1, 1 + counter._2)) + + (TargetEncoder.UNSEEN_CATEGORY -> (1 + globalCounter._1, 1 + globalCounter._2)) + else if (label == 0.0) agg(feature) + + (value -> (1 + counter._1, counter._2)) + + (TargetEncoder.UNSEEN_CATEGORY -> (1 + globalCounter._1, globalCounter._2)) + else throw new SparkException( + s"Values from column ${getLabelCol} must be binary (0,1) but got $label.") + case TargetEncoder.TARGET_CONTINUOUS => agg(feature) + + (value -> (1 + counter._1, + counter._2 + ((label - counter._2) / (1 + counter._1)))) + + (TargetEncoder.UNSEEN_CATEGORY -> (1 + globalCounter._1, + globalCounter._2 + ((label - globalCounter._2) / (1 + globalCounter._1)))) + } + } catch { + case e: SparkRuntimeException => + if (e.getErrorClass == "ROW_VALUE_IS_NULL") { + throw new SparkException(s"Null value found in feature ${inputFeatures(feature)}." + + s" See Imputer estimator for completing missing values.") + } else throw e + } + }.toArray + }, + (agg1, agg2) => Range(0, inputFeatures.length) + .map { + feature => { + val values = agg1(feature).keySet ++ agg2(feature).keySet + values.map(value => + value -> { + val stat1 = agg1(feature).getOrElse(value, (0.0, 0.0)) + val stat2 = agg2(feature).getOrElse(value, (0.0, 0.0)) + $(targetType) match { + case TargetEncoder.TARGET_BINARY => (stat1._1 + stat2._1, stat1._2 + stat2._2) + case TargetEncoder.TARGET_CONTINUOUS => (stat1._1 + stat2._1, + ((stat1._1 * stat1._2) + (stat2._1 * stat2._2)) / (stat1._1 + stat2._1)) + } + }).toMap + } + }.toArray) + + val encodings: Map[String, Map[Double, Double]] = stats.zipWithIndex.map { + case (stat, idx) => + val global = stat.get(TargetEncoder.UNSEEN_CATEGORY).get + inputFeatures(idx) -> stat.map { + case (feature, value) => feature -> { + val weight = value._1 / (value._1 + $(smoothing)) + $(targetType) match { + case TargetEncoder.TARGET_BINARY => + weight * (value._2 / value._1) + (1 - weight) * (global._2 / global._1) + case TargetEncoder.TARGET_CONTINUOUS => + weight * value._2 + (1 - weight) * global._2 + } + } + } + }.toMap + + val model = new TargetEncoderModel(uid, encodings).setParent(this) + copyValues(model) + } + + @Since("4.0.0") + override def copy(extra: ParamMap): TargetEncoder = defaultCopy(extra) +} + +@Since("4.0.0") +object TargetEncoder extends DefaultParamsReadable[TargetEncoder] { + + private[feature] val KEEP_INVALID: String = "keep" + private[feature] val ERROR_INVALID: String = "error" + private[feature] val supportedHandleInvalids: Array[String] = Array(KEEP_INVALID, ERROR_INVALID) + + private[feature] val TARGET_BINARY: String = "binary" + private[feature] val TARGET_CONTINUOUS: String = "continuous" + private[feature] val supportedTargetTypes: Array[String] = Array(TARGET_BINARY, TARGET_CONTINUOUS) + + private[feature] val UNSEEN_CATEGORY: Double = -1 + + @Since("4.0.0") + override def load(path: String): TargetEncoder = super.load(path) +} + +/** + * @param encodings Original number of categories for each feature being encoded. + * The array contains one value for each input column, in order. + */ +@Since("4.0.0") +class TargetEncoderModel private[ml] ( + @Since("4.0.0") override val uid: String, + @Since("4.0.0") val encodings: Map[String, Map[Double, Double]]) + extends Model[TargetEncoderModel] with TargetEncoderBase with MLWritable { + + @Since("4.0.0") + override def transformSchema(schema: StructType): StructType = { + inputFeatures.zip(outputFeatures) + .foldLeft(validateSchema(schema, fitting = false)) { + case (newSchema, fieldName) => + val field = schema(fieldName._1) + newSchema.add(StructField(fieldName._2, field.dataType, field.nullable)) + } + } + + @Since("4.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { + validateSchema(dataset.schema, fitting = false) + + val apply_encodings: Map[Double, Double] => (Column => Column) = + (mappings: Map[Double, Double]) => { + (col: Column) => { + val first :: rest = mappings.toList.sortWith{ + (a, b) => (b._1 == TargetEncoder.UNSEEN_CATEGORY) || + ((a._1 != TargetEncoder.UNSEEN_CATEGORY) && (a._1 < b._1)) + } + rest + .foldLeft(when(col === first._1, first._2))( + (new_col: Column, encoding) => + if (encoding._1 != TargetEncoder.UNSEEN_CATEGORY) { + new_col.when(col === encoding._1, encoding._2) + } else { + new_col.otherwise( + if ($(handleInvalid) == TargetEncoder.KEEP_INVALID) encoding._2 + else raise_error(concat( + lit("Unseen value "), col, + lit(s" in feature ${col.toString}. To handle unseen values, " + + s"set Param handleInvalid to ${TargetEncoder.KEEP_INVALID}.")))) + }) + } + } + + dataset.withColumns( + inputFeatures.zip(outputFeatures).map { + feature => + feature._2 -> (encodings.get(feature._1) match { + case Some(dict: Map[Double, Double]) => + apply_encodings(dict)(col(feature._1)) + .as(feature._2, NominalAttribute.defaultAttr + .withName(feature._2) + .withNumValues(dict.size) + .withValues(dict.values.toSet.toArray.map(_.toString)).toMetadata()) + case None => + throw new SparkException(s"No encodings found for ${feature._1}.") + col(feature._1) + }) + }.toMap) + } + + + @Since("4.0.0") + override def copy(extra: ParamMap): TargetEncoderModel = { + val copied = new TargetEncoderModel(uid, encodings) + copyValues(copied, extra).setParent(parent) + } + + @Since("4.0.0") + override def write: MLWriter = new TargetEncoderModel.TargetEncoderModelWriter(this) + + @Since("4.0.0") + override def toString: String = { + s"TargetEncoderModel: uid=$uid, " + + s" handleInvalid=${$(handleInvalid)}, targetType=${$(targetType)}, " + + s"numInputCols=${inputFeatures.length}, numOutputCols=${outputFeatures.length}, " + + s"smoothing=${$(smoothing)}" + } + +} + +@Since("4.0.0") +object TargetEncoderModel extends MLReadable[TargetEncoderModel] { + + private[TargetEncoderModel] + class TargetEncoderModelWriter(instance: TargetEncoderModel) extends MLWriter { + + private case class Data(encodings: Map[String, Map[Double, Double]]) + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sparkSession) + val data = Data(instance.encodings) + val dataPath = new Path(path, "data").toString + sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath) + } + } + + private class TargetEncoderModelReader extends MLReader[TargetEncoderModel] { + + private val className = classOf[TargetEncoderModel].getName + + override def load(path: String): TargetEncoderModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className) + val dataPath = new Path(path, "data").toString + val data = sparkSession.read.parquet(dataPath) + .select("encodings") + .head() + val encodings = data.getAs[Map[String, Map[Double, Double]]](0) + val model = new TargetEncoderModel(metadata.uid, encodings) + metadata.getAndSetParams(model) + model + } + } + + @Since("4.0.0") + override def read: MLReader[TargetEncoderModel] = new TargetEncoderModelReader + + @Since("4.0.0") + override def load(path: String): TargetEncoderModel = super.load(path) +} + diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaTargetEncoderSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaTargetEncoderSuite.java new file mode 100644 index 0000000000000..998bdb18016d7 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaTargetEncoderSuite.java @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature; + +import org.apache.spark.SharedSparkSession; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +import java.util.Arrays; +import java.util.List; + +import static org.apache.spark.sql.types.DataTypes.*; + +public class JavaTargetEncoderSuite extends SharedSparkSession { + + @Test + public void testTargetEncoderBinary() { + + List data = Arrays.asList( + RowFactory.create(0.0, 3.0, 5.0, 0.0, 1.0/3, 0.0, 1.0/3), + RowFactory.create(1.0, 4.0, 5.0, 1.0, 2.0/3, 1.0, 1.0/3), + RowFactory.create(2.0, 3.0, 5.0, 0.0, 1.0/3, 0.0, 1.0/3), + RowFactory.create(0.0, 4.0, 6.0, 1.0, 1.0/3, 1.0, 2.0/3), + RowFactory.create(1.0, 3.0, 6.0, 0.0, 2.0/3, 0.0, 2.0/3), + RowFactory.create(2.0, 4.0, 6.0, 1.0, 1.0/3, 1.0, 2.0/3), + RowFactory.create(0.0, 3.0, 7.0, 0.0, 1.0/3, 0.0, 0.0), + RowFactory.create(1.0, 4.0, 8.0, 1.0, 2.0/3, 1.0, 1.0), + RowFactory.create(2.0, 3.0, 9.0, 0.0, 1.0/3, 0.0, 0.0)); + StructType schema = createStructType(new StructField[]{ + createStructField("input1", DoubleType, false), + createStructField("input2", DoubleType, false), + createStructField("input3", DoubleType, false), + createStructField("label", DoubleType, false), + createStructField("expected1", DoubleType, false), + createStructField("expected2", DoubleType, false), + createStructField("expected3", DoubleType, false) + }); + Dataset dataset = spark.createDataFrame(data, schema); + + TargetEncoder encoder = new TargetEncoder() + .setInputCols(new String[]{"input1", "input2", "input3"}) + .setOutputCols(new String[]{"output1", "output2", "output3"}) + .setTargetType("binary"); + TargetEncoderModel model = encoder.fit(dataset); + Dataset output = model.transform(dataset); + + Assertions.assertEquals( + output.select("output1", "output2", "output3").collectAsList(), + output.select("expected1", "expected2", "expected3").collectAsList()); + + } + + @Test + public void testTargetEncoderContinuous() { + + List data = Arrays.asList( + RowFactory.create(0.0, 3.0, 5.0, 10.0, 40.0, 50.0, 20.0), + RowFactory.create(1.0, 4.0, 5.0, 20.0, 50.0, 50.0, 20.0), + RowFactory.create(2.0, 3.0, 5.0, 30.0, 60.0, 50.0, 20.0), + RowFactory.create(0.0, 4.0, 6.0, 40.0, 40.0, 50.0, 50.0), + RowFactory.create(1.0, 3.0, 6.0, 50.0, 50.0, 50.0, 50.0), + RowFactory.create(2.0, 4.0, 6.0, 60.0, 60.0, 50.0, 50.0), + RowFactory.create(0.0, 3.0, 7.0, 70.0, 40.0, 50.0, 70.0), + RowFactory.create(1.0, 4.0, 8.0, 80.0, 50.0, 50.0, 80.0), + RowFactory.create(2.0, 3.0, 9.0, 90.0, 60.0, 50.0, 90.0)); + StructType schema = createStructType(new StructField[]{ + createStructField("input1", DoubleType, false), + createStructField("input2", DoubleType, false), + createStructField("input3", DoubleType, false), + createStructField("label", DoubleType, false), + createStructField("expected1", DoubleType, false), + createStructField("expected2", DoubleType, false), + createStructField("expected3", DoubleType, false) + }); + Dataset dataset = spark.createDataFrame(data, schema); + + TargetEncoder encoder = new TargetEncoder() + .setInputCols(new String[]{"input1", "input2", "input3"}) + .setOutputCols(new String[]{"output1", "output2", "output3"}) + .setTargetType("continuous"); + TargetEncoderModel model = encoder.fit(dataset); + Dataset output = model.transform(dataset); + + Assertions.assertEquals( + output.select("output1", "output2", "output3").collectAsList(), + output.select("expected1", "expected2", "expected3").collectAsList()); + + } + +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/TargetEncoderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/TargetEncoderSuite.scala new file mode 100644 index 0000000000000..2b434268db3b6 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/TargetEncoderSuite.scala @@ -0,0 +1,410 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import scala.collection.immutable.HashMap + +import org.apache.spark.{SparkException, SparkRuntimeException} +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} +import org.apache.spark.sql.Row +import org.apache.spark.sql.types._ + +class TargetEncoderSuite extends MLTest with DefaultReadWriteTest { + + import testImplicits._ + + @transient var data: Seq[Row] = _ + @transient var schema: StructType = _ + + override def beforeAll(): Unit = { + super.beforeAll() + + // scalastyle:off + data = Seq( + Row(0.0, 3.0, 5.0, 0.0, 1.0/3, 0.0, 1.0/3, 10.0, 40.0, 50.0, 20.0, 42.5, 50.0, 27.5), + Row(1.0, 4.0, 5.0, 1.0, 2.0/3, 1.0, 1.0/3, 20.0, 50.0, 50.0, 20.0, 50.0, 50.0, 27.5), + Row(2.0, 3.0, 5.0, 0.0, 1.0/3, 0.0, 1.0/3, 30.0, 60.0, 50.0, 20.0, 57.5, 50.0, 27.5), + Row(0.0, 4.0, 6.0, 1.0, 1.0/3, 1.0, 2.0/3, 40.0, 40.0, 50.0, 50.0, 42.5, 50.0, 50.0), + Row(1.0, 3.0, 6.0, 0.0, 2.0/3, 0.0, 2.0/3, 50.0, 50.0, 50.0, 50.0, 50.0, 50.0, 50.0), + Row(2.0, 4.0, 6.0, 1.0, 1.0/3, 1.0, 2.0/3, 60.0, 60.0, 50.0, 50.0, 57.5, 50.0, 50.0), + Row(0.0, 3.0, 7.0, 0.0, 1.0/3, 0.0, 0.0, 70.0, 40.0, 50.0, 70.0, 42.5, 50.0, 60.0), + Row(1.0, 4.0, 8.0, 1.0, 2.0/3, 1.0, 1.0, 80.0, 50.0, 50.0, 80.0, 50.0, 50.0, 65.0), + Row(2.0, 3.0, 9.0, 0.0, 1.0/3, 0.0, 0.0, 90.0, 60.0, 50.0, 90.0, 57.5, 50.0, 70.0)) + // scalastyle:on + + schema = StructType(Array( + StructField("input1", DoubleType), + StructField("input2", DoubleType), + StructField("input3", DoubleType), + StructField("binaryLabel", DoubleType), + StructField("binaryExpected1", DoubleType), + StructField("binaryExpected2", DoubleType), + StructField("binaryExpected3", DoubleType), + StructField("continuousLabel", DoubleType), + StructField("continuousExpected1", DoubleType), + StructField("continuousExpected2", DoubleType), + StructField("continuousExpected3", DoubleType), + StructField("smoothingExpected1", DoubleType), + StructField("smoothingExpected2", DoubleType), + StructField("smoothingExpected3", DoubleType))) + } + + test("params") { + ParamsSuite.checkParams(new TargetEncoder) + } + + test("TargetEncoder - binary target") { + + val df = spark.createDataFrame(sc.parallelize(data), schema) + + val encoder = new TargetEncoder() + .setLabelCol("binaryLabel") + .setTargetType(TargetEncoder.TARGET_BINARY) + .setInputCols(Array("input1", "input2", "input3")) + .setOutputCols(Array("output1", "output2", "output3")) + + val model = encoder.fit(df) + + val expected_encodings = Map( + "input1" -> Map(0.0 -> 1.0/3, 1.0 -> 2.0/3, 2.0 -> 1.0/3, -1.0 -> 4.0/9), + "input2" -> Map(3.0 -> 0.0, 4.0 -> 1.0, -1.0 -> 4.0/9), + "input3" -> + HashMap(5.0 -> 1.0/3, 6.0 -> 2.0/3, 7.0 -> 0.0, 8.0 -> 1.0, 9.0 -> 0.0, -1.0 -> 4.0/9)) + + assert(model.encodings.equals(expected_encodings)) + + testTransformer[(Double, Double, Double, Double, Double, Double)]( + df.select("input1", "input2", "input3", + "binaryExpected1", "binaryExpected2", "binaryExpected3"), + model, + "output1", "binaryExpected1", + "output2", "binaryExpected2", + "output3", "binaryExpected3") { + case Row(output1: Double, expected1: Double, + output2: Double, expected2: Double, + output3: Double, expected3: Double) => + assert(output1 === expected1) + assert(output2 === expected2) + assert(output3 === expected3) + } + + } + + test("TargetEncoder - continuous target") { + + val df = spark + .createDataFrame(sc.parallelize(data), schema) + + val encoder = new TargetEncoder() + .setLabelCol("continuousLabel") + .setTargetType(TargetEncoder.TARGET_CONTINUOUS) + .setInputCols(Array("input1", "input2", "input3")) + .setOutputCols(Array("output1", "output2", "output3")) + + val model = encoder.fit(df) + + val expected_encodings = Map( + "input1" -> Map(0.0 -> 40.0, 1.0 -> 50.0, 2.0 -> 60.0, -1.0 -> 50.0), + "input2" -> Map(3.0 -> 50.0, 4.0 -> 50.0, -1.0 -> 50.0), + "input3" -> + HashMap(5.0 -> 20.0, 6.0 -> 50.0, 7.0 -> 70.0, 8.0 -> 80.0, 9.0 -> 90.0, -1.0 -> 50.0)) + + assert(model.encodings.equals(expected_encodings)) + + testTransformer[(Double, Double, Double, Double, Double, Double)]( + df.select("input1", "input2", "input3", + "continuousExpected1", "continuousExpected2", "continuousExpected3"), + model, + "output1", "continuousExpected1", + "output2", "continuousExpected2", + "output3", "continuousExpected3") { + case Row(output1: Double, expected1: Double, + output2: Double, expected2: Double, + output3: Double, expected3: Double) => + assert(output1 === expected1) + assert(output2 === expected2) + assert(output3 === expected3) + } + + } + + test("TargetEncoder - smoothing") { + + val df = spark + .createDataFrame(sc.parallelize(data), schema) + + val encoder = new TargetEncoder() + .setLabelCol("continuousLabel") + .setTargetType(TargetEncoder.TARGET_CONTINUOUS) + .setInputCols(Array("input1", "input2", "input3")) + .setOutputCols(Array("output1", "output2", "output3")) + .setSmoothing(1) + + val model = encoder.fit(df) + + val expected_encodings = Map( + "input1" -> Map(0.0 -> 42.5, 1.0 -> 50.0, 2.0 -> 57.5, -1.0 -> 50.0), + "input2" -> Map(3.0 -> 50.0, 4.0 -> 50.0, -1.0 -> 50.0), + "input3" -> + HashMap(5.0 -> 27.5, 6.0 -> 50.0, 7.0 -> 60.0, 8.0 -> 65.0, 9.0 -> 70.0, -1.0 -> 50.0)) + + assert(model.encodings.equals(expected_encodings)) + + testTransformer[(Double, Double, Double, Double, Double, Double)]( + df.select("input1", "input2", "input3", + "smoothingExpected1", "smoothingExpected2", "smoothingExpected3"), + model, + "output1", "smoothingExpected1", + "output2", "smoothingExpected2", + "output3", "smoothingExpected3") { + case Row(output1: Double, expected1: Double, + output2: Double, expected2: Double, + output3: Double, expected3: Double) => + assert(output1 === expected1) + assert(output2 === expected2) + assert(output3 === expected3) + } + + } + + test("TargetEncoder - unseen value - keep") { + + val df = spark + .createDataFrame(sc.parallelize(data), schema) + + val encoder = new TargetEncoder() + .setLabelCol("continuousLabel") + .setTargetType(TargetEncoder.TARGET_CONTINUOUS) + .setHandleInvalid(TargetEncoder.KEEP_INVALID) + .setInputCols(Array("input1", "input2", "input3")) + .setOutputCols(Array("output1", "output2", "output3")) + + val model = encoder.fit(df) + + val data_unseen = Row(0.0, 3.0, 10.0, 0.0, 0.0, 0.0, 0.0, 0.0, 40.0, 50.0, 50.0, 0.0, 0.0, 0.0) + + val df_unseen = spark + .createDataFrame(sc.parallelize(data :+ data_unseen), schema) + + testTransformer[(Double, Double, Double, Double, Double, Double)]( + df_unseen.select("input1", "input2", "input3", + "continuousExpected1", "continuousExpected2", "continuousExpected3"), + model, + "output1", "continuousExpected1", + "output2", "continuousExpected2", + "output3", "continuousExpected3") { + case Row(output1: Double, expected1: Double, + output2: Double, expected2: Double, + output3: Double, expected3: Double) => + assert(output1 === expected1) + assert(output2 === expected2) + assert(output3 === expected3) + } + } + + test("TargetEncoder - unseen value - error") { + + val df = spark + .createDataFrame(sc.parallelize(data), schema) + + val encoder = new TargetEncoder() + .setLabelCol("continuousLabel") + .setTargetType(TargetEncoder.TARGET_CONTINUOUS) + .setHandleInvalid(TargetEncoder.ERROR_INVALID) + .setInputCols(Array("input1", "input2", "input3")) + .setOutputCols(Array("output1", "output2", "output3")) + + val model = encoder.fit(df) + + val data_unseen = Row(0.0, 3.0, 10.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 4.0/9, 4.0/9, 4.0/9, 0.0, 0.0, 0.0) + + val df_unseen = spark + .createDataFrame(sc.parallelize(data :+ data_unseen), schema) + + val ex = intercept[SparkRuntimeException] { + val out = model.transform(df_unseen) + out.show(false) + } + + assert(ex.isInstanceOf[SparkRuntimeException]) + assert(ex.getMessage.contains("Unseen value 10.0 in feature input3")) + + } + + test("TargetEncoder - missing feature") { + + val df = spark + .createDataFrame(sc.parallelize(data), schema) + .drop("continuousLabel") + + val encoder = new TargetEncoder() + .setLabelCol("binaryLabel") + .setInputCols(Array("input1", "input2", "input3")) + .setTargetType(TargetEncoder.TARGET_BINARY) + .setOutputCols(Array("output1", "output2", "output3")) + + val ex = intercept[SparkException] { + val model = encoder.fit(df.drop("input3")) + print(model.encodings) + } + + assert(ex.isInstanceOf[SparkException]) + assert(ex.getMessage.contains("No column named input3 found on dataset")) + } + + test("TargetEncoder - wrong data type") { + + val wrong_schema = new StructType( + schema.map{ + field: StructField => if (field.name != "input3") field + else new StructField(field.name, StringType, field.nullable, field.metadata) + }.toArray) + + val df = spark + .createDataFrame(sc.parallelize(data), wrong_schema) + .drop("continuousLabel") + + val encoder = new TargetEncoder() + .setLabelCol("binaryLabel") + .setInputCols(Array("input1", "input2", "input3")) + .setTargetType(TargetEncoder.TARGET_BINARY) + .setOutputCols(Array("output1", "output2", "output3")) + + val ex = intercept[SparkException] { + val model = encoder.fit(df) + print(model.encodings) + } + + assert(ex.isInstanceOf[SparkException]) + assert(ex.getMessage.contains("Data type for column input3 is StringType")) + } + + test("TargetEncoder - null value") { + + val df = spark + .createDataFrame(sc.parallelize(data), schema) + + val encoder = new TargetEncoder() + .setLabelCol("continuousLabel") + .setTargetType(TargetEncoder.TARGET_CONTINUOUS) + .setHandleInvalid(TargetEncoder.ERROR_INVALID) + .setInputCols(Array("input1", "input2", "input3")) + .setOutputCols(Array("output1", "output2", "output3")) + + + val data_null = Row(0.0, 3.0, null.asInstanceOf[Integer], + 0.0, 0.0, 0.0, 0.0, 0.0, 4.0/9, 4.0/9, 4.0/9, 0.0, 0.0, 0.0) + + val df_null = spark + .createDataFrame(sc.parallelize(data :+ data_null), schema) + + val ex = intercept[SparkException] { + val model = encoder.fit(df_null) + print(model.encodings) + } + + assert(ex.isInstanceOf[SparkException]) + assert(ex.getMessage.contains("Null value found in feature input3")) + + } + + test("TargetEncoder - non-indexed categories") { + + val df = spark + .createDataFrame(sc.parallelize(data), schema) + + val encoder = new TargetEncoder() + .setLabelCol("binaryLabel") + .setTargetType(TargetEncoder.TARGET_BINARY) + .setInputCols(Array("input1", "input2", "input3")) + .setOutputCols(Array("output1", "output2", "output3")) + + val data_noindex = Row(0.0, 3.0, 5.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0) + + val df_noindex = spark + .createDataFrame(sc.parallelize(data :+ data_noindex), schema) + + val ex = intercept[SparkException] { + val model = encoder.fit(df_noindex) + print(model.encodings) + } + + assert(ex.isInstanceOf[SparkException]) + assert(ex.getMessage.contains( + "Values from column input3 must be indices, but got 5.1")) + + } + + test("TargetEncoder - non-binary labels") { + + val df = spark + .createDataFrame(sc.parallelize(data), schema) + + val encoder = new TargetEncoder() + .setLabelCol("binaryLabel") + .setTargetType(TargetEncoder.TARGET_BINARY) + .setInputCols(Array("input1", "input2", "input3")) + .setOutputCols(Array("output1", "output2", "output3")) + + val data_non_binary = Row(0.0, 3.0, 5.0, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0) + + val df_non_binary = spark + .createDataFrame(sc.parallelize(data :+ data_non_binary), schema) + + val ex = intercept[SparkException] { + val model = encoder.fit(df_non_binary) + print(model.encodings) + } + + assert(ex.isInstanceOf[SparkException]) + assert(ex.getMessage.contains( + "Values from column binaryLabel must be binary (0,1) but got 0.1")) + + } + + test("TargetEncoder - R/W single-column") { + + val encoder = new TargetEncoder() + .setLabelCol("continuousLabel") + .setTargetType(TargetEncoder.TARGET_CONTINUOUS) + .setInputCol("input1") + .setOutputCol("output1") + .setHandleInvalid(TargetEncoder.ERROR_INVALID) + .setSmoothing(2) + + testDefaultReadWrite(encoder) + + } + + test("TargetEncoder - R/W multi-column") { + + val encoder = new TargetEncoder() + .setLabelCol("binaryLabel") + .setTargetType(TargetEncoder.TARGET_BINARY) + .setInputCols(Array("input1", "input2", "input3")) + .setOutputCols(Array("output1", "output2", "output3")) + .setHandleInvalid(TargetEncoder.KEEP_INVALID) + .setSmoothing(1) + + testDefaultReadWrite(encoder) + + } + +} \ No newline at end of file diff --git a/python/docs/source/reference/pyspark.ml.rst b/python/docs/source/reference/pyspark.ml.rst index 965cbe7eb5a57..f81498d3b5eae 100644 --- a/python/docs/source/reference/pyspark.ml.rst +++ b/python/docs/source/reference/pyspark.ml.rst @@ -104,6 +104,8 @@ Feature StopWordsRemover StringIndexer StringIndexerModel + TargetEncoder + TargetEncoderModel Tokenizer UnivariateFeatureSelector UnivariateFeatureSelectorModel diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 9a392c9dd420f..98fc6dc690880 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -104,6 +104,8 @@ "StopWordsRemover", "StringIndexer", "StringIndexerModel", + "TargetEncoder", + "TargetEncoderModel", "Tokenizer", "UnivariateFeatureSelector", "UnivariateFeatureSelectorModel", @@ -5200,6 +5202,305 @@ def loadDefaultStopWords(language: str) -> List[str]: return list(stopWordsObj.loadDefaultStopWords(language)) +class _TargetEncoderParams( + HasLabelCol, HasInputCol, HasInputCols, HasOutputCol, HasOutputCols, HasHandleInvalid +): + """ + Params for :py:class:`TargetEncoder` and :py:class:`TargetEncoderModel`. + + .. versionadded:: 4.0.0 + """ + + handleInvalid: Param[str] = Param( + Params._dummy(), + "handleInvalid", + "How to handle invalid data during transform(). " + + "Options are 'keep' (invalid data presented as an extra " + + "categorical feature) or error (throw an error).", + typeConverter=TypeConverters.toString, + ) + + targetType: Param[str] = Param( + Params._dummy(), + "targetType", + "whether the label is 'binary' or 'continuous'", + typeConverter=TypeConverters.toString, + ) + + smoothing: Param[float] = Param( + Params._dummy(), + "smoothing", + "value to smooth in-category averages with overall averages.", + typeConverter=TypeConverters.toFloat, + ) + + def __init__(self, *args: Any): + super(_TargetEncoderParams, self).__init__(*args) + self._setDefault(handleInvalid="error", targetType="binary", smoothing=0.0) + + @since("4.0.0") + def getTargetType(self) -> str: + """ + Gets the value of targetType or its default value. + """ + return self.getOrDefault(self.targetType) + + @since("4.0.0") + def getSmoothing(self) -> float: + """ + Gets the value of smoothing or its default value. + """ + return self.getOrDefault(self.smoothing) + + +@inherit_doc +class TargetEncoder( + JavaEstimator["TargetEncoderModel"], + _TargetEncoderParams, + JavaMLReadable["TargetEncoder"], + JavaMLWritable, +): + """ + Target Encoding maps a column of categorical indices into a numerical feature derived + from the target. + + When :py:attr:`handleInvalid` is configured to 'keep', previously unseen values of + a feature are mapped to the dataset overall statistics. + + When :py:attr:'targetType' is configured to 'binary', categories are encoded as the + conditional probability of the target given that category (bin counting). + When :py:attr:'targetType' is configured to 'continuous', categories are encoded as + the average of the target given that category (mean encoding) + + Parameter :py:attr:'smoothing' controls how in-category stats and overall stats are + weighted to build the encodings + + @note When encoding multi-column by using `inputCols` and `outputCols` params, + input/output cols come in pairs, specified by the order in the arrays, and each pair + is treated independently. + + .. versionadded:: 4.0.0 + """ + + _input_kwargs: Dict[str, Any] + + @overload + def __init__( + self, + *, + inputCols: Optional[List[str]] = ..., + outputCols: Optional[List[str]] = ..., + labelCol: str = ..., + handleInvalid: str = ..., + targetType: str = ..., + smoothing: float = ..., + ): + ... + + @overload + def __init__( + self, + *, + labelCol: str = ..., + handleInvalid: str = ..., + targetType: str = ..., + smoothing: float = ..., + inputCol: Optional[str] = ..., + outputCol: Optional[str] = ..., + ): + ... + + @keyword_only + def __init__( + self, + *, + inputCols: Optional[List[str]] = None, + outputCols: Optional[List[str]] = None, + labelCol: str = "label", + handleInvalid: str = "error", + targetType: str = "binary", + smoothing: float = 0.0, + inputCol: Optional[str] = None, + outputCol: Optional[str] = None, + ): + """ + __init__(self, \\*, inputCols=None, outputCols=None, handleInvalid="error", dropLast=True, \ + targetType="binary", smoothing=0.0, inputCol=None, outputCol=None) + """ + super(TargetEncoder, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.TargetEncoder", self.uid) + kwargs = self._input_kwargs + self.setParams(**kwargs) + + @overload + def setParams( + self, + *, + inputCols: Optional[List[str]] = ..., + outputCols: Optional[List[str]] = ..., + labelCol: str = ..., + handleInvalid: str = ..., + targetType: str = ..., + smoothing: float = ..., + ) -> "TargetEncoder": + ... + + @overload + def setParams( + self, + *, + labelCol: str = ..., + handleInvalid: str = ..., + targetType: str = ..., + smoothing: float = ..., + inputCol: Optional[str] = ..., + outputCol: Optional[str] = ..., + ) -> "TargetEncoder": + ... + + @keyword_only + @since("4.0.0") + def setParams( + self, + *, + inputCols: Optional[List[str]] = None, + outputCols: Optional[List[str]] = None, + labelCol: str = "label", + handleInvalid: str = "error", + targetType: str = "binary", + smoothing: float = 0.0, + inputCol: Optional[str] = None, + outputCol: Optional[str] = None, + ) -> "TargetEncoder": + """ + setParams(self, \\*, inputCols=None, outputCols=None, handleInvalid="error", \ + dropLast=True, inputCol=None, outputCol=None) + Sets params for this TargetEncoder. + """ + kwargs = self._input_kwargs + return self._set(**kwargs) + + @since("4.0.0") + def setInputCols(self, value: List[str]) -> "TargetEncoder": + """ + Sets the value of :py:attr:`inputCols`. + """ + return self._set(inputCols=value) + + @since("4.0.0") + def setOutputCols(self, value: List[str]) -> "TargetEncoder": + """ + Sets the value of :py:attr:`outputCols`. + """ + return self._set(outputCols=value) + + @since("4.0.0") + def setInputCol(self, value: str) -> "TargetEncoder": + """ + Sets the value of :py:attr:`inputCol`. + """ + return self._set(inputCol=value) + + @since("4.0.0") + def setOutputCol(self, value: str) -> "TargetEncoder": + """ + Sets the value of :py:attr:`outputCol`. + """ + return self._set(outputCol=value) + + @since("4.0.0") + def setHandleInvalid(self, value: str) -> "TargetEncoder": + """ + Sets the value of :py:attr:`handleInvalid`. + """ + return self._set(handleInvalid=value) + + @since("4.0.0") + def setTargetType(self, value: str) -> "TargetEncoder": + """ + Sets the value of :py:attr:`targetType`. + """ + return self._set(targetType=value) + + @since("4.0.0") + def setSmoothing(self, value: float) -> "TargetEncoder": + """ + Sets the value of :py:attr:`smoothing`. + """ + return self._set(smoothing=value) + + def _create_model(self, java_model: "JavaObject") -> "TargetEncoderModel": + return TargetEncoderModel(java_model) + + +class TargetEncoderModel( + JavaModel, _TargetEncoderParams, JavaMLReadable["TargetEncoderModel"], JavaMLWritable +): + """ + Model fitted by :py:class:`TargetEncoder`. + + .. versionadded:: 4.0.0 + """ + + @since("4.0.0") + def setInputCols(self, value: List[str]) -> "TargetEncoderModel": + """ + Sets the value of :py:attr:`inputCols`. + """ + return self._set(inputCols=value) + + @since("4.0.0") + def setOutputCols(self, value: List[str]) -> "TargetEncoderModel": + """ + Sets the value of :py:attr:`outputCols`. + """ + return self._set(outputCols=value) + + @since("4.0.0") + def setInputCol(self, value: str) -> "TargetEncoderModel": + """ + Sets the value of :py:attr:`inputCol`. + """ + return self._set(inputCol=value) + + @since("4.0.0") + def setOutputCol(self, value: str) -> "TargetEncoderModel": + """ + Sets the value of :py:attr:`outputCol`. + """ + return self._set(outputCol=value) + + @since("4.0.0") + def setHandleInvalid(self, value: str) -> "TargetEncoderModel": + """ + Sets the value of :py:attr:`handleInvalid`. + """ + return self._set(handleInvalid=value) + + @since("4.0.0") + def setTargetType(self, value: str) -> "TargetEncoderModel": + """ + Sets the value of :py:attr:`targetType`. + """ + return self._set(targetType=value) + + @since("4.0.0") + def setSmoothing(self, value: float) -> "TargetEncoderModel": + """ + Sets the value of :py:attr:`smoothing`. + """ + return self._set(smoothing=value) + + @property + @since("4.0.0") + def encodings(self) -> dict[str, dict[float, float]]: + """ + Fitted mappings for each feature to being encoded. + The dictionary contains a dictionary for each input column. + """ + return self._call_java("encodings") + + @inherit_doc class Tokenizer( JavaTransformer, diff --git a/python/pyspark/ml/tests/test_feature.py b/python/pyspark/ml/tests/test_feature.py index 4bf6641723da6..6dbb7f06b33ff 100644 --- a/python/pyspark/ml/tests/test_feature.py +++ b/python/pyspark/ml/tests/test_feature.py @@ -29,6 +29,7 @@ StopWordsRemover, StringIndexer, StringIndexerModel, + TargetEncoder, VectorSizeHint, ) from pyspark.ml.linalg import DenseVector, SparseVector, Vectors @@ -346,6 +347,80 @@ def test_string_indexer_from_labels(self): ) self.assertEqual(len(transformed_list), 5) + def test_target_encoder_binary(self): + df = self.spark.createDataFrame( + [ + (0.0, 3.0, 5.0, 0.0), + (1.0, 4.0, 5.0, 1.0), + (2.0, 3.0, 5.0, 0.0), + (0.0, 4.0, 6.0, 1.0), + (1.0, 3.0, 6.0, 0.0), + (2.0, 4.0, 6.0, 1.0), + (0.0, 3.0, 7.0, 0.0), + (1.0, 4.0, 8.0, 1.0), + (2.0, 3.0, 9.0, 0.0), + ], + ["input1", "input2", "input3", "label"], + ) + encoder = TargetEncoder( + inputCols=["input1", "input2", "input3"], + outputCols=["output", "output2", "output3"], + labelCol="label", + targetType="binary", + ) + model = encoder.fit(df) + te = model.transform(df) + actual = te.drop("label").collect() + expected = [ + Row(input1=0.0, input2=3.0, input3=5.0, output1=1.0 / 3, output2=0.0, output3=1.0 / 3), + Row(input1=1.0, input2=4.0, input3=5.0, output1=2.0 / 3, output2=1.0, output3=1.0 / 3), + Row(input1=2.0, input2=3.0, input3=5.0, output1=1.0 / 3, output2=0.0, output3=1.0 / 3), + Row(input1=0.0, input2=4.0, input3=6.0, output1=1.0 / 3, output2=1.0, output3=2.0 / 3), + Row(input1=1.0, input2=3.0, input3=6.0, output1=2.0 / 3, output2=0.0, output3=2.0 / 3), + Row(input1=2.0, input2=4.0, input3=6.0, output1=1.0 / 3, output2=1.0, output3=2.0 / 3), + Row(input1=0.0, input2=3.0, input3=7.0, output1=1.0 / 3, output2=0.0, output3=0.0), + Row(input1=1.0, input2=4.0, input3=8.0, output1=2.0 / 3, output2=1.0, output3=1.0), + Row(input1=2.0, input2=3.0, input3=9.0, output1=1.0 / 3, output2=0.0, output3=0.0), + ] + self.assertEqual(actual, expected) + + def test_target_encoder_continuous(self): + df = self.spark.createDataFrame( + [ + (0.0, 3.0, 5.0, 10.0), + (1.0, 4.0, 5.0, 20.0), + (2.0, 3.0, 5.0, 30.0), + (0.0, 4.0, 6.0, 40.0), + (1.0, 3.0, 6.0, 50.0), + (2.0, 4.0, 6.0, 60.0), + (0.0, 3.0, 7.0, 70.0), + (1.0, 4.0, 8.0, 80.0), + (2.0, 3.0, 9.0, 90.0), + ], + ["input1", "input2", "input3", "label"], + ) + encoder = TargetEncoder( + inputCols=["input1", "input2", "input3"], + outputCols=["output", "output2", "output3"], + labelCol="label", + targetType="continuous", + ) + model = encoder.fit(df) + te = model.transform(df) + actual = te.drop("label").collect() + expected = [ + Row(input1=0.0, input2=3.0, input3=5.0, output1=40.0, output2=50.0, output3=20.0), + Row(input1=1.0, input2=4.0, input3=5.0, output1=50.0, output2=50.0, output3=20.0), + Row(input1=2.0, input2=3.0, input3=5.0, output1=60.0, output2=50.0, output3=20.0), + Row(input1=0.0, input2=4.0, input3=6.0, output1=40.0, output2=50.0, output3=50.0), + Row(input1=1.0, input2=3.0, input3=6.0, output1=50.0, output2=50.0, output3=50.0), + Row(input1=2.0, input2=4.0, input3=6.0, output1=60.0, output2=50.0, output3=50.0), + Row(input1=0.0, input2=3.0, input3=7.0, output1=40.0, output2=50.0, output3=70.0), + Row(input1=1.0, input2=4.0, input3=8.0, output1=50.0, output2=50.0, output3=80.0), + Row(input1=2.0, input2=3.0, input3=9.0, output1=60.0, output2=50.0, output3=90.0), + ] + self.assertEqual(actual, expected) + def test_vector_size_hint(self): df = self.spark.createDataFrame( [ From 02642646ed2b9211f3ec615da176bf2021ef7aab Mon Sep 17 00:00:00 2001 From: Enrique Rebollo Date: Tue, 8 Oct 2024 23:53:25 +0200 Subject: [PATCH 2/9] [SPARK-37178][ML] handle null category, support all numeric types, improved doc --- docs/ml-features.md | 20 +- .../spark/ml/feature/TargetEncoder.scala | 227 ++++++++++-------- .../ml/feature/JavaTargetEncoderSuite.java | 48 ++-- .../spark/ml/feature/TargetEncoderSuite.scala | 171 +++++++------ python/pyspark/ml/tests/test_feature.py | 76 +++--- 5 files changed, 309 insertions(+), 233 deletions(-) diff --git a/docs/ml-features.md b/docs/ml-features.md index e34137e7e7628..1fb0e03f8f435 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -857,15 +857,27 @@ for more details on the API. ## TargetEncoder -Target Encoding maps a column of categorical indices into a numerical feature derived from the target. For string type input data, it is common to encode categorical features using [StringIndexer](ml-features.html#stringindexer) first. +Target Encoding maps a column of categorical indices into a numerical feature derived from the target. +Leveraging the relationship between categorical features and the target variable, target encoding usually performs better than one-hot encoding (while avoiding the need to add extra columns) -`TargetEncoder` can transform multiple columns, returning a target-encoded output column for each input column. +`TargetEncoder` can transform multiple columns, returning a single target-encoded output column for each input column. +User can specify input and output column names by setting `inputCol` and `outputCol` for single-column use cases, or `inputCols` and `outputCols` for multi-column use cases (both arrays required to have the same size) -`TargetEncoder` supports the `handleInvalid` parameter to choose how to handle invalid input during transforming data. Available options include 'keep' (any invalid inputs are assigned to an extra categorical index) and 'error' (throw an error). +User can specify the target column name by setting `label`. This column contains the ground-truth labels from which encodings will be derived -`TargetEncoder` supports the `targetType` parameter to choose the label type when fitting data, affecting how statistics are calculated. Available options include 'binary' (bin-counting) and 'continuous' (mean-encoding). +`TargetEncoder` supports the `handleInvalid` parameter to choose how to handle invalid input during transforming data. +Available options include 'keep' (any invalid inputs are assigned to an extra categorical index) and 'error' (throw an error). + +`TargetEncoder` supports the `targetType` parameter to choose the label type when fitting data, affecting how statistics are calculated. +Available options include 'binary' and 'continuous' (mean-encoding). +When set to 'binary', encodings will be fitted from target conditional probabilities (a.k.a bin-counting). +When set to 'continuous', encodings will be fitted from according to target mean (a.k.a. mean-encoding). `TargetEncoder` supports the `smoothing` parameter to tune how in-category stats and overall stats are weighted. +When calculating encodings according only to in-class statistics, rarely seen categories are very likely to cause overfitting when used in learning. +Smoothing prevents this behaviour by pondering in-class stats and overall stats according to weight of this class on the whole dataset. + +For string type input data, it is common to encode categorical features using [StringIndexer](ml-features.html#stringindexer) first. **Examples** diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/TargetEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/TargetEncoder.scala index 36b071fd14cae..0c3880cf14b24 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/TargetEncoder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/TargetEncoder.scala @@ -17,11 +17,9 @@ package org.apache.spark.ml.feature -import scala.collection.immutable.ArraySeq - import org.apache.hadoop.fs.Path -import org.apache.spark.{SparkException, SparkRuntimeException} +import org.apache.spark.SparkException import org.apache.spark.annotation.Since import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.attribute.NominalAttribute @@ -48,9 +46,9 @@ private[ml] trait TargetEncoderBase extends Params with HasLabelCol @Since("4.0.0") override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "How to handle invalid data during transform(). " + - "Options are 'keep' (invalid data presented as an extra categorical feature) " + - "or error (throw an error). Note that this Param is only used during transform; " + - "during fitting, invalid data will result in an error.", + "Options are 'keep' (invalid data presented as an extra categorical feature) " + + "or error (throw an error). Note that this Param is only used during transform; " + + "during fitting, invalid data will result in an error.", ParamValidators.inArray(TargetEncoder.supportedHandleInvalids)) setDefault(handleInvalid -> TargetEncoder.ERROR_INVALID) @@ -81,8 +79,8 @@ private[ml] trait TargetEncoderBase extends Params with HasLabelCol else Array.empty[String] private[feature] lazy val outputFeatures = if (isSet(outputCol)) Array($(outputCol)) - else if (isSet(outputCols)) $(outputCols) - else inputFeatures.map{field: String => s"${field}_indexed"} + else if (isSet(outputCols)) $(outputCols) + else inputFeatures.map{field: String => s"${field}_indexed"} private[feature] def validateSchema(schema: StructType, fitting: Boolean): StructType = { @@ -95,15 +93,15 @@ private[ml] trait TargetEncoderBase extends Params with HasLabelCol s"output columns ${outputFeatures.length}.") val features = if (fitting) inputFeatures :+ $(labelCol) - else inputFeatures + else inputFeatures features.foreach { feature => { try { val field = schema(feature) - if (field.dataType != DoubleType) { + if (!field.dataType.isInstanceOf[NumericType]) { throw new SparkException(s"Data type for column ${feature} is ${field.dataType}" + - s", but ${DoubleType.typeName} is required.") + s", but a subclass of ${NumericType} is required.") } } catch { case e: IllegalArgumentException => @@ -137,7 +135,7 @@ private[ml] trait TargetEncoderBase extends Params with HasLabelCol */ @Since("4.0.0") class TargetEncoder @Since("4.0.0") (@Since("4.0.0") override val uid: String) - extends Estimator[TargetEncoderModel] with TargetEncoderBase with DefaultParamsWritable { + extends Estimator[TargetEncoderModel] with TargetEncoderBase with DefaultParamsWritable { @Since("4.0.0") def this() = this(Identifiable.randomUID("TargetEncoder")) @@ -183,80 +181,101 @@ class TargetEncoder @Since("4.0.0") (@Since("4.0.0") override val uid: String) override def fit(dataset: Dataset[_]): TargetEncoderModel = { validateSchema(dataset.schema, fitting = true) + val feature_types = inputFeatures.map{ + feature => dataset.schema(feature).dataType + } + val label_type = dataset.schema($(labelCol)).dataType + val stats = dataset - .select(ArraySeq.unsafeWrapArray( - (inputFeatures :+ $(labelCol)).map(col)): _*) - .rdd - .treeAggregate( + .select((inputFeatures :+ $(labelCol)).map(col).toIndexedSeq: _*) + .rdd.treeAggregate( Array.fill(inputFeatures.length) { - Map.empty[Double, (Double, Double)] + Map.empty[Option[Double], (Double, Double)] })( (agg, row: Row) => { - val label = row.getDouble(inputFeatures.length) - Range(0, inputFeatures.length).map { - feature => try { - val value = row.getDouble(feature) - if (value < 0.0 || value != value.toInt) throw new SparkException( - s"Values from column ${inputFeatures(feature)} must be indices, but got $value.") - val counter = agg(feature).getOrElse(value, (0.0, 0.0)) - val globalCounter = agg(feature).getOrElse(TargetEncoder.UNSEEN_CATEGORY, (0.0, 0.0)) + val label = label_type match { + case ByteType => row.getByte(inputFeatures.length).toDouble + case ShortType => row.getShort(inputFeatures.length).toDouble + case IntegerType => row.getInt(inputFeatures.length).toDouble + case LongType => row.getLong(inputFeatures.length).toDouble + case DoubleType => row.getDouble(inputFeatures.length) + } + inputFeatures.indices.map { + feature => { + val category: Option[Double] = { + if (row.isNullAt(feature)) None // null category + else { + val value: Double = feature_types(feature) match { + case ByteType => row.getByte(feature).toDouble + case ShortType => row.getShort(feature).toDouble + case IntegerType => row.getInt(feature).toDouble + case LongType => row.getLong(feature).toDouble + case DoubleType => row.getDouble(feature) + } + if (value < 0.0 || value != value.toInt) throw new SparkException( + s"Values from column ${inputFeatures(feature)} must be indices, " + + s"but got $value.") + else Some(value) + } + } + val (class_count, class_stat) = agg(feature).getOrElse(category, (0.0, 0.0)) + val (global_count, global_stat) = + agg(feature).getOrElse(TargetEncoder.UNSEEN_CATEGORY, (0.0, 0.0)) $(targetType) match { - case TargetEncoder.TARGET_BINARY => - if (label == 1.0) agg(feature) + - (value -> (1 + counter._1, 1 + counter._2)) + - (TargetEncoder.UNSEEN_CATEGORY -> (1 + globalCounter._1, 1 + globalCounter._2)) - else if (label == 0.0) agg(feature) + - (value -> (1 + counter._1, counter._2)) + - (TargetEncoder.UNSEEN_CATEGORY -> (1 + globalCounter._1, globalCounter._2)) - else throw new SparkException( + case TargetEncoder.TARGET_BINARY => // counting + if (label == 1.0) { + agg(feature) + + (category -> (1 + class_count, 1 + class_stat)) + + (TargetEncoder.UNSEEN_CATEGORY -> (1 + global_count, 1 + global_stat)) + } else if (label == 0.0) { + agg(feature) + + (category -> (1 + class_count, class_stat)) + + (TargetEncoder.UNSEEN_CATEGORY -> (1 + global_count, global_stat)) + } else throw new SparkException( s"Values from column ${getLabelCol} must be binary (0,1) but got $label.") - case TargetEncoder.TARGET_CONTINUOUS => agg(feature) + - (value -> (1 + counter._1, - counter._2 + ((label - counter._2) / (1 + counter._1)))) + - (TargetEncoder.UNSEEN_CATEGORY -> (1 + globalCounter._1, - globalCounter._2 + ((label - globalCounter._2) / (1 + globalCounter._1)))) + case TargetEncoder.TARGET_CONTINUOUS => // incremental mean + agg(feature) + + (category -> (1 + class_count, + class_stat + ((label - class_stat) / (1 + class_count)))) + + (TargetEncoder.UNSEEN_CATEGORY -> (1 + global_count, + global_stat + ((label - global_stat) / (1 + global_count)))) } - } catch { - case e: SparkRuntimeException => - if (e.getErrorClass == "ROW_VALUE_IS_NULL") { - throw new SparkException(s"Null value found in feature ${inputFeatures(feature)}." + - s" See Imputer estimator for completing missing values.") - } else throw e } }.toArray }, - (agg1, agg2) => Range(0, inputFeatures.length) - .map { - feature => { - val values = agg1(feature).keySet ++ agg2(feature).keySet - values.map(value => - value -> { - val stat1 = agg1(feature).getOrElse(value, (0.0, 0.0)) - val stat2 = agg2(feature).getOrElse(value, (0.0, 0.0)) - $(targetType) match { - case TargetEncoder.TARGET_BINARY => (stat1._1 + stat2._1, stat1._2 + stat2._2) - case TargetEncoder.TARGET_CONTINUOUS => (stat1._1 + stat2._1, - ((stat1._1 * stat1._2) + (stat2._1 * stat2._2)) / (stat1._1 + stat2._1)) - } - }).toMap - } - }.toArray) - - val encodings: Map[String, Map[Double, Double]] = stats.zipWithIndex.map { - case (stat, idx) => - val global = stat.get(TargetEncoder.UNSEEN_CATEGORY).get - inputFeatures(idx) -> stat.map { - case (feature, value) => feature -> { - val weight = value._1 / (value._1 + $(smoothing)) - $(targetType) match { - case TargetEncoder.TARGET_BINARY => - weight * (value._2 / value._1) + (1 - weight) * (global._2 / global._1) - case TargetEncoder.TARGET_CONTINUOUS => - weight * value._2 + (1 - weight) * global._2 + (agg1, agg2) => inputFeatures.indices.map { + feature => { + val categories = agg1(feature).keySet ++ agg2(feature).keySet + categories.map(category => + category -> { + val (counter1, stat1) = agg1(feature).getOrElse(category, (0.0, 0.0)) + val (counter2, stat2) = agg2(feature).getOrElse(category, (0.0, 0.0)) + $(targetType) match { + case TargetEncoder.TARGET_BINARY => (counter1 + counter2, stat1 + stat2) + case TargetEncoder.TARGET_CONTINUOUS => (counter1 + counter2, + ((counter1 * stat1) + (counter2 * stat2)) / (counter1 + counter2)) + } + }).toMap + } + }.toArray) + + // encodings: Map[feature, Map[Some(category), encoding]] + val encodings: Map[String, Map[Option[Double], Double]] = + inputFeatures.zip(stats).map { + case (feature, stat) => + val (global_count, global_stat) = stat.get(TargetEncoder.UNSEEN_CATEGORY).get + feature -> stat.map { + case (cat, (class_count, class_stat)) => cat -> { + val weight = class_count / (class_count + $(smoothing)) + $(targetType) match { + case TargetEncoder.TARGET_BINARY => + weight * (class_stat/ class_count) + (1 - weight) * (global_stat / global_count) + case TargetEncoder.TARGET_CONTINUOUS => + weight * class_stat + (1 - weight) * global_stat + } } } - } - }.toMap + }.toMap val model = new TargetEncoderModel(uid, encodings).setParent(this) copyValues(model) @@ -269,15 +288,17 @@ class TargetEncoder @Since("4.0.0") (@Since("4.0.0") override val uid: String) @Since("4.0.0") object TargetEncoder extends DefaultParamsReadable[TargetEncoder] { + // handleInvalid parameter values private[feature] val KEEP_INVALID: String = "keep" private[feature] val ERROR_INVALID: String = "error" private[feature] val supportedHandleInvalids: Array[String] = Array(KEEP_INVALID, ERROR_INVALID) + // targetType parameter values private[feature] val TARGET_BINARY: String = "binary" private[feature] val TARGET_CONTINUOUS: String = "continuous" private[feature] val supportedTargetTypes: Array[String] = Array(TARGET_BINARY, TARGET_CONTINUOUS) - private[feature] val UNSEEN_CATEGORY: Double = -1 + private[feature] val UNSEEN_CATEGORY: Option[Double] = Some(-1) @Since("4.0.0") override def load(path: String): TargetEncoder = super.load(path) @@ -289,8 +310,8 @@ object TargetEncoder extends DefaultParamsReadable[TargetEncoder] { */ @Since("4.0.0") class TargetEncoderModel private[ml] ( - @Since("4.0.0") override val uid: String, - @Since("4.0.0") val encodings: Map[String, Map[Double, Double]]) + @Since("4.0.0") override val uid: String, + @Since("4.0.0") val encodings: Map[String, Map[Option[Double], Double]]) extends Model[TargetEncoderModel] with TargetEncoderBase with MLWritable { @Since("4.0.0") @@ -307,26 +328,38 @@ class TargetEncoderModel private[ml] ( override def transform(dataset: Dataset[_]): DataFrame = { validateSchema(dataset.schema, fitting = false) - val apply_encodings: Map[Double, Double] => (Column => Column) = - (mappings: Map[Double, Double]) => { + // builds a column-to-column function from a map of encodings + val apply_encodings: Map[Option[Double], Double] => (Column => Column) = + (mappings: Map[Option[Double], Double]) => { (col: Column) => { - val first :: rest = mappings.toList.sortWith{ + val nullWhen = when(col.isNull, + mappings.get(None) match { + case Some(code) => lit(code) + case None => if ($(handleInvalid) == TargetEncoder.KEEP_INVALID) { + lit(mappings.get(TargetEncoder.UNSEEN_CATEGORY).get) + } else raise_error(lit( + s"Unseen null value in feature ${col.toString}. To handle unseen values, " + + s"set Param handleInvalid to ${TargetEncoder.KEEP_INVALID}.")) + }) + val ordered_mappings = (mappings - None).toList.sortWith { (a, b) => (b._1 == TargetEncoder.UNSEEN_CATEGORY) || - ((a._1 != TargetEncoder.UNSEEN_CATEGORY) && (a._1 < b._1)) + ((a._1 != TargetEncoder.UNSEEN_CATEGORY) && (a._1.get < b._1.get)) } - rest - .foldLeft(when(col === first._1, first._2))( - (new_col: Column, encoding) => - if (encoding._1 != TargetEncoder.UNSEEN_CATEGORY) { - new_col.when(col === encoding._1, encoding._2) - } else { + ordered_mappings + .foldLeft(nullWhen)( + (new_col: Column, mapping) => { + val (Some(original), encoded) = mapping + if (original != TargetEncoder.UNSEEN_CATEGORY.get) { + new_col.when(col === original, lit(encoded)) + } else { // unseen category new_col.otherwise( - if ($(handleInvalid) == TargetEncoder.KEEP_INVALID) encoding._2 + if ($(handleInvalid) == TargetEncoder.KEEP_INVALID) lit(encoded) else raise_error(concat( lit("Unseen value "), col, lit(s" in feature ${col.toString}. To handle unseen values, " + s"set Param handleInvalid to ${TargetEncoder.KEEP_INVALID}.")))) - }) + } + }) } } @@ -334,17 +367,17 @@ class TargetEncoderModel private[ml] ( inputFeatures.zip(outputFeatures).map { feature => feature._2 -> (encodings.get(feature._1) match { - case Some(dict: Map[Double, Double]) => + case Some(dict) => apply_encodings(dict)(col(feature._1)) - .as(feature._2, NominalAttribute.defaultAttr - .withName(feature._2) - .withNumValues(dict.size) - .withValues(dict.values.toSet.toArray.map(_.toString)).toMetadata()) + .as(feature._2, NominalAttribute.defaultAttr + .withName(feature._2) + .withNumValues(dict.size) + .withValues(dict.values.toSet.toArray.map(_.toString)).toMetadata()) case None => throw new SparkException(s"No encodings found for ${feature._1}.") col(feature._1) }) - }.toMap) + }.toMap) } @@ -373,7 +406,7 @@ object TargetEncoderModel extends MLReadable[TargetEncoderModel] { private[TargetEncoderModel] class TargetEncoderModelWriter(instance: TargetEncoderModel) extends MLWriter { - private case class Data(encodings: Map[String, Map[Double, Double]]) + private case class Data(encodings: Map[String, Map[Option[Double], Double]]) override protected def saveImpl(path: String): Unit = { DefaultParamsWriter.saveMetadata(instance, path, sparkSession) @@ -393,7 +426,7 @@ object TargetEncoderModel extends MLReadable[TargetEncoderModel] { val data = sparkSession.read.parquet(dataPath) .select("encodings") .head() - val encodings = data.getAs[Map[String, Map[Double, Double]]](0) + val encodings = data.getAs[Map[String, Map[Option[Double], Double]]](0) val model = new TargetEncoderModel(metadata.uid, encodings) metadata.getAndSetParams(model) model diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaTargetEncoderSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaTargetEncoderSuite.java index 998bdb18016d7..44e38543c515e 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaTargetEncoderSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaTargetEncoderSuite.java @@ -37,19 +37,19 @@ public class JavaTargetEncoderSuite extends SharedSparkSession { public void testTargetEncoderBinary() { List data = Arrays.asList( - RowFactory.create(0.0, 3.0, 5.0, 0.0, 1.0/3, 0.0, 1.0/3), - RowFactory.create(1.0, 4.0, 5.0, 1.0, 2.0/3, 1.0, 1.0/3), - RowFactory.create(2.0, 3.0, 5.0, 0.0, 1.0/3, 0.0, 1.0/3), - RowFactory.create(0.0, 4.0, 6.0, 1.0, 1.0/3, 1.0, 2.0/3), - RowFactory.create(1.0, 3.0, 6.0, 0.0, 2.0/3, 0.0, 2.0/3), - RowFactory.create(2.0, 4.0, 6.0, 1.0, 1.0/3, 1.0, 2.0/3), - RowFactory.create(0.0, 3.0, 7.0, 0.0, 1.0/3, 0.0, 0.0), - RowFactory.create(1.0, 4.0, 8.0, 1.0, 2.0/3, 1.0, 1.0), - RowFactory.create(2.0, 3.0, 9.0, 0.0, 1.0/3, 0.0, 0.0)); + RowFactory.create((short)0, 3, 5.0, 0.0, 1.0/3, 0.0, 1.0/3), + RowFactory.create((short)1, 4, 5.0, 1.0, 2.0/3, 1.0, 1.0/3), + RowFactory.create((short)2, 3, 5.0, 0.0, 1.0/3, 0.0, 1.0/3), + RowFactory.create((short)0, 4, 6.0, 1.0, 1.0/3, 1.0, 2.0/3), + RowFactory.create((short)1, 3, 6.0, 0.0, 2.0/3, 0.0, 2.0/3), + RowFactory.create((short)2, 4, 6.0, 1.0, 1.0/3, 1.0, 2.0/3), + RowFactory.create((short)0, 3, 7.0, 0.0, 1.0/3, 0.0, 0.0), + RowFactory.create((short)1, 4, 8.0, 1.0, 2.0/3, 1.0, 1.0), + RowFactory.create((short)2, 3, null, 0.0, 1.0/3, 0.0, 0.0)); StructType schema = createStructType(new StructField[]{ - createStructField("input1", DoubleType, false), - createStructField("input2", DoubleType, false), - createStructField("input3", DoubleType, false), + createStructField("input1", ShortType, true), + createStructField("input2", IntegerType, true), + createStructField("input3", DoubleType, true), createStructField("label", DoubleType, false), createStructField("expected1", DoubleType, false), createStructField("expected2", DoubleType, false), @@ -74,19 +74,19 @@ public void testTargetEncoderBinary() { public void testTargetEncoderContinuous() { List data = Arrays.asList( - RowFactory.create(0.0, 3.0, 5.0, 10.0, 40.0, 50.0, 20.0), - RowFactory.create(1.0, 4.0, 5.0, 20.0, 50.0, 50.0, 20.0), - RowFactory.create(2.0, 3.0, 5.0, 30.0, 60.0, 50.0, 20.0), - RowFactory.create(0.0, 4.0, 6.0, 40.0, 40.0, 50.0, 50.0), - RowFactory.create(1.0, 3.0, 6.0, 50.0, 50.0, 50.0, 50.0), - RowFactory.create(2.0, 4.0, 6.0, 60.0, 60.0, 50.0, 50.0), - RowFactory.create(0.0, 3.0, 7.0, 70.0, 40.0, 50.0, 70.0), - RowFactory.create(1.0, 4.0, 8.0, 80.0, 50.0, 50.0, 80.0), - RowFactory.create(2.0, 3.0, 9.0, 90.0, 60.0, 50.0, 90.0)); + RowFactory.create((short)0, 3, 5.0, 10.0, 40.0, 50.0, 20.0), + RowFactory.create((short)1, 4, 5.0, 20.0, 50.0, 50.0, 20.0), + RowFactory.create((short)2, 3, 5.0, 30.0, 60.0, 50.0, 20.0), + RowFactory.create((short)0, 4, 6.0, 40.0, 40.0, 50.0, 50.0), + RowFactory.create((short)1, 3, 6.0, 50.0, 50.0, 50.0, 50.0), + RowFactory.create((short)2, 4, 6.0, 60.0, 60.0, 50.0, 50.0), + RowFactory.create((short)0, 3, 7.0, 70.0, 40.0, 50.0, 70.0), + RowFactory.create((short)1, 4, 8.0, 80.0, 50.0, 50.0, 80.0), + RowFactory.create((short)2, 3, null, 90.0, 60.0, 50.0, 90.0)); StructType schema = createStructType(new StructField[]{ - createStructField("input1", DoubleType, false), - createStructField("input2", DoubleType, false), - createStructField("input3", DoubleType, false), + createStructField("input1", ShortType, true), + createStructField("input2", IntegerType, true), + createStructField("input3", DoubleType, true), createStructField("label", DoubleType, false), createStructField("expected1", DoubleType, false), createStructField("expected2", DoubleType, false), diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/TargetEncoderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/TargetEncoderSuite.scala index 2b434268db3b6..b2cb30eec4661 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/TargetEncoderSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/TargetEncoderSuite.scala @@ -23,6 +23,7 @@ import org.apache.spark.{SparkException, SparkRuntimeException} import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} import org.apache.spark.sql.Row +import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ class TargetEncoderSuite extends MLTest with DefaultReadWriteTest { @@ -37,22 +38,22 @@ class TargetEncoderSuite extends MLTest with DefaultReadWriteTest { // scalastyle:off data = Seq( - Row(0.0, 3.0, 5.0, 0.0, 1.0/3, 0.0, 1.0/3, 10.0, 40.0, 50.0, 20.0, 42.5, 50.0, 27.5), - Row(1.0, 4.0, 5.0, 1.0, 2.0/3, 1.0, 1.0/3, 20.0, 50.0, 50.0, 20.0, 50.0, 50.0, 27.5), - Row(2.0, 3.0, 5.0, 0.0, 1.0/3, 0.0, 1.0/3, 30.0, 60.0, 50.0, 20.0, 57.5, 50.0, 27.5), - Row(0.0, 4.0, 6.0, 1.0, 1.0/3, 1.0, 2.0/3, 40.0, 40.0, 50.0, 50.0, 42.5, 50.0, 50.0), - Row(1.0, 3.0, 6.0, 0.0, 2.0/3, 0.0, 2.0/3, 50.0, 50.0, 50.0, 50.0, 50.0, 50.0, 50.0), - Row(2.0, 4.0, 6.0, 1.0, 1.0/3, 1.0, 2.0/3, 60.0, 60.0, 50.0, 50.0, 57.5, 50.0, 50.0), - Row(0.0, 3.0, 7.0, 0.0, 1.0/3, 0.0, 0.0, 70.0, 40.0, 50.0, 70.0, 42.5, 50.0, 60.0), - Row(1.0, 4.0, 8.0, 1.0, 2.0/3, 1.0, 1.0, 80.0, 50.0, 50.0, 80.0, 50.0, 50.0, 65.0), - Row(2.0, 3.0, 9.0, 0.0, 1.0/3, 0.0, 0.0, 90.0, 60.0, 50.0, 90.0, 57.5, 50.0, 70.0)) + Row(0.toShort, 3, 5.0, 0.toByte, 1.0/3, 0.0, 1.0/3, 10.0, 40.0, 50.0, 20.0, 42.5, 50.0, 27.5), + Row(1.toShort, 4, 5.0, 1.toByte, 2.0/3, 1.0, 1.0/3, 20.0, 50.0, 50.0, 20.0, 50.0, 50.0, 27.5), + Row(2.toShort, 3, 5.0, 0.toByte, 1.0/3, 0.0, 1.0/3, 30.0, 60.0, 50.0, 20.0, 57.5, 50.0, 27.5), + Row(0.toShort, 4, 6.0, 1.toByte, 1.0/3, 1.0, 2.0/3, 40.0, 40.0, 50.0, 50.0, 42.5, 50.0, 50.0), + Row(1.toShort, 3, 6.0, 0.toByte, 2.0/3, 0.0, 2.0/3, 50.0, 50.0, 50.0, 50.0, 50.0, 50.0, 50.0), + Row(2.toShort, 4, 6.0, 1.toByte, 1.0/3, 1.0, 2.0/3, 60.0, 60.0, 50.0, 50.0, 57.5, 50.0, 50.0), + Row(0.toShort, 3, 7.0, 0.toByte, 1.0/3, 0.0, 0.0, 70.0, 40.0, 50.0, 70.0, 42.5, 50.0, 60.0), + Row(1.toShort, 4, 8.0, 1.toByte, 2.0/3, 1.0, 1.0, 80.0, 50.0, 50.0, 80.0, 50.0, 50.0, 65.0), + Row(2.toShort, 3, 9.0, 0.toByte, 1.0/3, 0.0, 0.0, 90.0, 60.0, 50.0, 90.0, 57.5, 50.0, 70.0)) // scalastyle:on schema = StructType(Array( - StructField("input1", DoubleType), - StructField("input2", DoubleType), - StructField("input3", DoubleType), - StructField("binaryLabel", DoubleType), + StructField("input1", ShortType, nullable = true), + StructField("input2", IntegerType, nullable = true), + StructField("input3", DoubleType, nullable = true), + StructField("binaryLabel", ByteType), StructField("binaryExpected1", DoubleType), StructField("binaryExpected2", DoubleType), StructField("binaryExpected3", DoubleType), @@ -82,10 +83,11 @@ class TargetEncoderSuite extends MLTest with DefaultReadWriteTest { val model = encoder.fit(df) val expected_encodings = Map( - "input1" -> Map(0.0 -> 1.0/3, 1.0 -> 2.0/3, 2.0 -> 1.0/3, -1.0 -> 4.0/9), - "input2" -> Map(3.0 -> 0.0, 4.0 -> 1.0, -1.0 -> 4.0/9), - "input3" -> - HashMap(5.0 -> 1.0/3, 6.0 -> 2.0/3, 7.0 -> 0.0, 8.0 -> 1.0, 9.0 -> 0.0, -1.0 -> 4.0/9)) + "input1" -> + Map(Some(0.0) -> 1.0/3, Some(1.0) -> 2.0/3, Some(2.0) -> 1.0/3, Some(-1.0) -> 4.0/9), + "input2" -> Map(Some(3.0) -> 0.0, Some(4.0) -> 1.0, Some(-1.0) -> 4.0/9), + "input3" -> HashMap(Some(5.0) -> 1.0/3, Some(6.0) -> 2.0/3, Some(7.0) -> 0.0, + Some(8.0) -> 1.0, Some(9.0) -> 0.0, Some(-1.0) -> 4.0/9)) assert(model.encodings.equals(expected_encodings)) @@ -96,12 +98,12 @@ class TargetEncoderSuite extends MLTest with DefaultReadWriteTest { "output1", "binaryExpected1", "output2", "binaryExpected2", "output3", "binaryExpected3") { - case Row(output1: Double, expected1: Double, - output2: Double, expected2: Double, - output3: Double, expected3: Double) => - assert(output1 === expected1) - assert(output2 === expected2) - assert(output3 === expected3) + case Row(output1: Double, expected1: Double, + output2: Double, expected2: Double, + output3: Double, expected3: Double) => + assert(output1 === expected1) + assert(output2 === expected2) + assert(output3 === expected3) } } @@ -120,10 +122,10 @@ class TargetEncoderSuite extends MLTest with DefaultReadWriteTest { val model = encoder.fit(df) val expected_encodings = Map( - "input1" -> Map(0.0 -> 40.0, 1.0 -> 50.0, 2.0 -> 60.0, -1.0 -> 50.0), - "input2" -> Map(3.0 -> 50.0, 4.0 -> 50.0, -1.0 -> 50.0), - "input3" -> - HashMap(5.0 -> 20.0, 6.0 -> 50.0, 7.0 -> 70.0, 8.0 -> 80.0, 9.0 -> 90.0, -1.0 -> 50.0)) + "input1" -> Map(Some(0.0) -> 40.0, Some(1.0) -> 50.0, Some(2.0) -> 60.0, Some(-1.0) -> 50.0), + "input2" -> Map(Some(3.0) -> 50.0, Some(4.0) -> 50.0, Some(-1.0) -> 50.0), + "input3" -> HashMap(Some(5.0) -> 20.0, Some(6.0) -> 50.0, Some(7.0) -> 70.0, + Some(8.0) -> 80.0, Some(9.0) -> 90.0, Some(-1.0) -> 50.0)) assert(model.encodings.equals(expected_encodings)) @@ -134,13 +136,13 @@ class TargetEncoderSuite extends MLTest with DefaultReadWriteTest { "output1", "continuousExpected1", "output2", "continuousExpected2", "output3", "continuousExpected3") { - case Row(output1: Double, expected1: Double, - output2: Double, expected2: Double, - output3: Double, expected3: Double) => - assert(output1 === expected1) - assert(output2 === expected2) - assert(output3 === expected3) - } + case Row(output1: Double, expected1: Double, + output2: Double, expected2: Double, + output3: Double, expected3: Double) => + assert(output1 === expected1) + assert(output2 === expected2) + assert(output3 === expected3) + } } @@ -159,10 +161,10 @@ class TargetEncoderSuite extends MLTest with DefaultReadWriteTest { val model = encoder.fit(df) val expected_encodings = Map( - "input1" -> Map(0.0 -> 42.5, 1.0 -> 50.0, 2.0 -> 57.5, -1.0 -> 50.0), - "input2" -> Map(3.0 -> 50.0, 4.0 -> 50.0, -1.0 -> 50.0), - "input3" -> - HashMap(5.0 -> 27.5, 6.0 -> 50.0, 7.0 -> 60.0, 8.0 -> 65.0, 9.0 -> 70.0, -1.0 -> 50.0)) + "input1" -> Map(Some(0.0) -> 42.5, Some(1.0) -> 50.0, Some(2.0) -> 57.5, Some(-1.0) -> 50.0), + "input2" -> Map(Some(3.0) -> 50.0, Some(4.0) -> 50.0, Some(-1.0) -> 50.0), + "input3" -> HashMap(Some(5.0) -> 27.5, Some(6.0) -> 50.0, Some(7.0) -> 60.0, + Some(8.0) -> 65.0, Some(9.0) -> 70.0, Some(-1.0) -> 50.0)) assert(model.encodings.equals(expected_encodings)) @@ -197,7 +199,8 @@ class TargetEncoderSuite extends MLTest with DefaultReadWriteTest { val model = encoder.fit(df) - val data_unseen = Row(0.0, 3.0, 10.0, 0.0, 0.0, 0.0, 0.0, 0.0, 40.0, 50.0, 50.0, 0.0, 0.0, 0.0) + val data_unseen = Row(0.toShort, 3, 10.0, + 0.toByte, 0.0, 0.0, 0.0, 0.0, 40.0, 50.0, 50.0, 0.0, 0.0, 0.0) val df_unseen = spark .createDataFrame(sc.parallelize(data :+ data_unseen), schema) @@ -209,13 +212,13 @@ class TargetEncoderSuite extends MLTest with DefaultReadWriteTest { "output1", "continuousExpected1", "output2", "continuousExpected2", "output3", "continuousExpected3") { - case Row(output1: Double, expected1: Double, - output2: Double, expected2: Double, - output3: Double, expected3: Double) => - assert(output1 === expected1) - assert(output2 === expected2) - assert(output3 === expected3) - } + case Row(output1: Double, expected1: Double, + output2: Double, expected2: Double, + output3: Double, expected3: Double) => + assert(output1 === expected1) + assert(output2 === expected2) + assert(output3 === expected3) + } } test("TargetEncoder - unseen value - error") { @@ -232,8 +235,8 @@ class TargetEncoderSuite extends MLTest with DefaultReadWriteTest { val model = encoder.fit(df) - val data_unseen = Row(0.0, 3.0, 10.0, 0.0, 0.0, 0.0, 0.0, 0.0, - 4.0/9, 4.0/9, 4.0/9, 0.0, 0.0, 0.0) + val data_unseen = Row(0.toShort, 3, 10.0, + 0.toByte, 0.0, 0.0, 0.0, 0.0, 4.0/9, 4.0/9, 4.0/9, 0.0, 0.0, 0.0) val df_unseen = spark .createDataFrame(sc.parallelize(data :+ data_unseen), schema) @@ -252,7 +255,6 @@ class TargetEncoderSuite extends MLTest with DefaultReadWriteTest { val df = spark .createDataFrame(sc.parallelize(data), schema) - .drop("continuousLabel") val encoder = new TargetEncoder() .setLabelCol("binaryLabel") @@ -274,7 +276,7 @@ class TargetEncoderSuite extends MLTest with DefaultReadWriteTest { val wrong_schema = new StructType( schema.map{ field: StructField => if (field.name != "input3") field - else new StructField(field.name, StringType, field.nullable, field.metadata) + else new StructField(field.name, StringType, field.nullable, field.metadata) }.toArray) val df = spark @@ -296,7 +298,40 @@ class TargetEncoderSuite extends MLTest with DefaultReadWriteTest { assert(ex.getMessage.contains("Data type for column input3 is StringType")) } - test("TargetEncoder - null value") { + test("TargetEncoder - seen null category") { + + val data_null = Row(2.toShort, 3, null, + 0.toByte, 1.0/3, 0.0, 0.0, 90.0, 60.0, 50.0, 90.0, 57.5, 50.0, 70.0) + + val df_null = spark + .createDataFrame(sc.parallelize(data.dropRight(1) :+ data_null), schema) + + val encoder = new TargetEncoder() + .setLabelCol("continuousLabel") + .setTargetType(TargetEncoder.TARGET_CONTINUOUS) + .setInputCols(Array("input1", "input2", "input3")) + .setOutputCols(Array("output1", "output2", "output3")) + + val model = encoder.fit(df_null) + + val expected_encodings = Map( + "input1" -> Map(Some(0.0) -> 40.0, Some(1.0) -> 50.0, Some(2.0) -> 60.0, Some(-1.0) -> 50.0), + "input2" -> Map(Some(3.0) -> 50.0, Some(4.0) -> 50.0, Some(-1.0) -> 50.0), + "input3" -> HashMap(Some(5.0) -> 20.0, Some(6.0) -> 50.0, Some(7.0) -> 70.0, + Some(8.0) -> 80.0, None -> 90.0, Some(-1.0) -> 50.0)) + + assert(model.encodings.equals(expected_encodings)) + + val output = model.transform(df_null) + + assert_true( + output("output1") === output("continuousExpected1") && + output("output1") === output("continuousExpected1") && + output("output1") === output("continuousExpected1")) + + } + + test("TargetEncoder - unseen null category") { val df = spark .createDataFrame(sc.parallelize(data), schema) @@ -304,39 +339,37 @@ class TargetEncoderSuite extends MLTest with DefaultReadWriteTest { val encoder = new TargetEncoder() .setLabelCol("continuousLabel") .setTargetType(TargetEncoder.TARGET_CONTINUOUS) - .setHandleInvalid(TargetEncoder.ERROR_INVALID) + .setHandleInvalid(TargetEncoder.KEEP_INVALID) .setInputCols(Array("input1", "input2", "input3")) .setOutputCols(Array("output1", "output2", "output3")) - - val data_null = Row(0.0, 3.0, null.asInstanceOf[Integer], - 0.0, 0.0, 0.0, 0.0, 0.0, 4.0/9, 4.0/9, 4.0/9, 0.0, 0.0, 0.0) + val data_null = Row(null, null, null, + 0.toByte, 1.0/3, 0.0, 0.0, 90.0, 50.0, 50.0, 50.0, 57.5, 50.0, 70.0) val df_null = spark .createDataFrame(sc.parallelize(data :+ data_null), schema) - val ex = intercept[SparkException] { - val model = encoder.fit(df_null) - print(model.encodings) - } + val model = encoder.fit(df) - assert(ex.isInstanceOf[SparkException]) - assert(ex.getMessage.contains("Null value found in feature input3")) + val output = model.transform(df_null) + + assert_true( + output("output1") === output("continuousExpected1") && + output("output1") === output("continuousExpected1") && + output("output1") === output("continuousExpected1")) } test("TargetEncoder - non-indexed categories") { - val df = spark - .createDataFrame(sc.parallelize(data), schema) - val encoder = new TargetEncoder() .setLabelCol("binaryLabel") .setTargetType(TargetEncoder.TARGET_BINARY) .setInputCols(Array("input1", "input2", "input3")) .setOutputCols(Array("output1", "output2", "output3")) - val data_noindex = Row(0.0, 3.0, 5.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0) + val data_noindex = Row( + 0.toShort, 3, 5.1, 0.toByte, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0) val df_noindex = spark .createDataFrame(sc.parallelize(data :+ data_noindex), schema) @@ -354,16 +387,14 @@ class TargetEncoderSuite extends MLTest with DefaultReadWriteTest { test("TargetEncoder - non-binary labels") { - val df = spark - .createDataFrame(sc.parallelize(data), schema) - val encoder = new TargetEncoder() .setLabelCol("binaryLabel") .setTargetType(TargetEncoder.TARGET_BINARY) .setInputCols(Array("input1", "input2", "input3")) .setOutputCols(Array("output1", "output2", "output3")) - val data_non_binary = Row(0.0, 3.0, 5.0, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0) + val data_non_binary = Row( + 0.toShort, 3, 5.0, 2.toByte, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0) val df_non_binary = spark .createDataFrame(sc.parallelize(data :+ data_non_binary), schema) @@ -375,7 +406,7 @@ class TargetEncoderSuite extends MLTest with DefaultReadWriteTest { assert(ex.isInstanceOf[SparkException]) assert(ex.getMessage.contains( - "Values from column binaryLabel must be binary (0,1) but got 0.1")) + "Values from column binaryLabel must be binary (0,1) but got 2.0")) } diff --git a/python/pyspark/ml/tests/test_feature.py b/python/pyspark/ml/tests/test_feature.py index 6dbb7f06b33ff..666ed1c4269e1 100644 --- a/python/pyspark/ml/tests/test_feature.py +++ b/python/pyspark/ml/tests/test_feature.py @@ -350,17 +350,17 @@ def test_string_indexer_from_labels(self): def test_target_encoder_binary(self): df = self.spark.createDataFrame( [ - (0.0, 3.0, 5.0, 0.0), - (1.0, 4.0, 5.0, 1.0), - (2.0, 3.0, 5.0, 0.0), - (0.0, 4.0, 6.0, 1.0), - (1.0, 3.0, 6.0, 0.0), - (2.0, 4.0, 6.0, 1.0), - (0.0, 3.0, 7.0, 0.0), - (1.0, 4.0, 8.0, 1.0), - (2.0, 3.0, 9.0, 0.0), + (0, 3, 5.0, 0.0), + (1, 4, 5.0, 1.0), + (2, 3, 5.0, 0.0), + (0, 4, 6.0, 1.0), + (1, 3, 6.0, 0.0), + (2, 4, 6.0, 1.0), + (0, 3, 7.0, 0.0), + (1, 4, 8.0, 1.0), + (2, 3, 9.0, 0.0), ], - ["input1", "input2", "input3", "label"], + schema="input1 short, input2 int, input3 double, label double", ) encoder = TargetEncoder( inputCols=["input1", "input2", "input3"], @@ -372,32 +372,32 @@ def test_target_encoder_binary(self): te = model.transform(df) actual = te.drop("label").collect() expected = [ - Row(input1=0.0, input2=3.0, input3=5.0, output1=1.0 / 3, output2=0.0, output3=1.0 / 3), - Row(input1=1.0, input2=4.0, input3=5.0, output1=2.0 / 3, output2=1.0, output3=1.0 / 3), - Row(input1=2.0, input2=3.0, input3=5.0, output1=1.0 / 3, output2=0.0, output3=1.0 / 3), - Row(input1=0.0, input2=4.0, input3=6.0, output1=1.0 / 3, output2=1.0, output3=2.0 / 3), - Row(input1=1.0, input2=3.0, input3=6.0, output1=2.0 / 3, output2=0.0, output3=2.0 / 3), - Row(input1=2.0, input2=4.0, input3=6.0, output1=1.0 / 3, output2=1.0, output3=2.0 / 3), - Row(input1=0.0, input2=3.0, input3=7.0, output1=1.0 / 3, output2=0.0, output3=0.0), - Row(input1=1.0, input2=4.0, input3=8.0, output1=2.0 / 3, output2=1.0, output3=1.0), - Row(input1=2.0, input2=3.0, input3=9.0, output1=1.0 / 3, output2=0.0, output3=0.0), + Row(input1=0, input2=3, input3=5.0, output1=1.0 / 3, output2=0.0, output3=1.0 / 3), + Row(input1=1, input2=4, input3=5.0, output1=2.0 / 3, output2=1.0, output3=1.0 / 3), + Row(input1=2, input2=3, input3=5.0, output1=1.0 / 3, output2=0.0, output3=1.0 / 3), + Row(input1=0, input2=4, input3=6.0, output1=1.0 / 3, output2=1.0, output3=2.0 / 3), + Row(input1=1, input2=3, input3=6.0, output1=2.0 / 3, output2=0.0, output3=2.0 / 3), + Row(input1=2, input2=4, input3=6.0, output1=1.0 / 3, output2=1.0, output3=2.0 / 3), + Row(input1=0, input2=3, input3=7.0, output1=1.0 / 3, output2=0.0, output3=0.0), + Row(input1=1, input2=4, input3=8.0, output1=2.0 / 3, output2=1.0, output3=1.0), + Row(input1=2, input2=3, input3=9.0, output1=1.0 / 3, output2=0.0, output3=0.0), ] self.assertEqual(actual, expected) def test_target_encoder_continuous(self): df = self.spark.createDataFrame( [ - (0.0, 3.0, 5.0, 10.0), - (1.0, 4.0, 5.0, 20.0), - (2.0, 3.0, 5.0, 30.0), - (0.0, 4.0, 6.0, 40.0), - (1.0, 3.0, 6.0, 50.0), - (2.0, 4.0, 6.0, 60.0), - (0.0, 3.0, 7.0, 70.0), - (1.0, 4.0, 8.0, 80.0), - (2.0, 3.0, 9.0, 90.0), + (0, 3, 5.0, 10.0), + (1, 4, 5.0, 20.0), + (2, 3, 5.0, 30.0), + (0, 4, 6.0, 40.0), + (1, 3, 6.0, 50.0), + (2, 4, 6.0, 60.0), + (0, 3, 7.0, 70.0), + (1, 4, 8.0, 80.0), + (2, 3, 9.0, 90.0), ], - ["input1", "input2", "input3", "label"], + schema="input1 short, input2 int, input3 double, label double", ) encoder = TargetEncoder( inputCols=["input1", "input2", "input3"], @@ -409,15 +409,15 @@ def test_target_encoder_continuous(self): te = model.transform(df) actual = te.drop("label").collect() expected = [ - Row(input1=0.0, input2=3.0, input3=5.0, output1=40.0, output2=50.0, output3=20.0), - Row(input1=1.0, input2=4.0, input3=5.0, output1=50.0, output2=50.0, output3=20.0), - Row(input1=2.0, input2=3.0, input3=5.0, output1=60.0, output2=50.0, output3=20.0), - Row(input1=0.0, input2=4.0, input3=6.0, output1=40.0, output2=50.0, output3=50.0), - Row(input1=1.0, input2=3.0, input3=6.0, output1=50.0, output2=50.0, output3=50.0), - Row(input1=2.0, input2=4.0, input3=6.0, output1=60.0, output2=50.0, output3=50.0), - Row(input1=0.0, input2=3.0, input3=7.0, output1=40.0, output2=50.0, output3=70.0), - Row(input1=1.0, input2=4.0, input3=8.0, output1=50.0, output2=50.0, output3=80.0), - Row(input1=2.0, input2=3.0, input3=9.0, output1=60.0, output2=50.0, output3=90.0), + Row(input1=0, input2=3, input3=5.0, output1=40.0, output2=50.0, output3=20.0), + Row(input1=1, input2=4, input3=5.0, output1=50.0, output2=50.0, output3=20.0), + Row(input1=2, input2=3, input3=5.0, output1=60.0, output2=50.0, output3=20.0), + Row(input1=0, input2=4, input3=6.0, output1=40.0, output2=50.0, output3=50.0), + Row(input1=1, input2=3, input3=6.0, output1=50.0, output2=50.0, output3=50.0), + Row(input1=2, input2=4, input3=6.0, output1=60.0, output2=50.0, output3=50.0), + Row(input1=0, input2=3, input3=7.0, output1=40.0, output2=50.0, output3=70.0), + Row(input1=1, input2=4, input3=8.0, output1=50.0, output2=50.0, output3=80.0), + Row(input1=2, input2=3, input3=9.0, output1=60.0, output2=50.0, output3=90.0), ] self.assertEqual(actual, expected) From 02219337cb27304638e4026856a153d0113cf761 Mon Sep 17 00:00:00 2001 From: Enrique Rebollo Date: Wed, 9 Oct 2024 19:27:16 +0200 Subject: [PATCH 3/9] [SPARK-37178][ML] ignore null label observations --- .../spark/ml/feature/TargetEncoder.scala | 4 +-- .../spark/ml/feature/TargetEncoderSuite.scala | 28 +++++++++++++++++++ 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/TargetEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/TargetEncoder.scala index 0c3880cf14b24..db40fa052b84f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/TargetEncoder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/TargetEncoder.scala @@ -192,7 +192,7 @@ class TargetEncoder @Since("4.0.0") (@Since("4.0.0") override val uid: String) Array.fill(inputFeatures.length) { Map.empty[Option[Double], (Double, Double)] })( - (agg, row: Row) => { + (agg, row: Row) => if (!row.isNullAt(inputFeatures.length)) { val label = label_type match { case ByteType => row.getByte(inputFeatures.length).toDouble case ShortType => row.getShort(inputFeatures.length).toDouble @@ -242,7 +242,7 @@ class TargetEncoder @Since("4.0.0") (@Since("4.0.0") override val uid: String) } } }.toArray - }, + } else agg, // ignore null-labeled observations (agg1, agg2) => inputFeatures.indices.map { feature => { val categories = agg1(feature).keySet ++ agg2(feature).keySet diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/TargetEncoderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/TargetEncoderSuite.scala index b2cb30eec4661..4d3f4f3f7213b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/TargetEncoderSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/TargetEncoderSuite.scala @@ -385,6 +385,34 @@ class TargetEncoderSuite extends MLTest with DefaultReadWriteTest { } + test("TargetEncoder - null label") { + + val data_nolabel = Row(2.toShort, 3, 5.0, + null, 1.0/3, 0.0, 0.0, null, 60.0, 50.0, 90.0, 57.5, 50.0, 70.0) + + val df_nolabel = spark + .createDataFrame(sc.parallelize(data :+ data_nolabel), schema) + + val encoder = new TargetEncoder() + .setLabelCol("continuousLabel") + .setTargetType(TargetEncoder.TARGET_CONTINUOUS) + .setInputCols(Array("input1", "input2", "input3")) + .setOutputCols(Array("output1", "output2", "output3")) + + val model = encoder.fit(df_nolabel) + + val expected_encodings = Map( + "input1" -> Map(Some(0.0) -> 40.0, Some(1.0) -> 50.0, Some(2.0) -> 60.0, Some(-1.0) -> 50.0), + "input2" -> Map(Some(3.0) -> 50.0, Some(4.0) -> 50.0, Some(-1.0) -> 50.0), + "input3" -> HashMap(Some(5.0) -> 20.0, Some(6.0) -> 50.0, Some(7.0) -> 70.0, + Some(8.0) -> 80.0, Some(9.0) -> 90.0, Some(-1.0) -> 50.0)) + + print(model.encodings) + + assert(model.encodings.equals(expected_encodings)) + + } + test("TargetEncoder - non-binary labels") { val encoder = new TargetEncoder() From 3f1f86ddf62ca5810a1b3e27c3428aaccd47a163 Mon Sep 17 00:00:00 2001 From: Enrique Rebollo Date: Wed, 16 Oct 2024 20:35:33 +0200 Subject: [PATCH 4/9] [SPARK-37178][ML] improve doc & comments --- docs/ml-features.md | 88 +++++++++++++++---- .../spark/ml/feature/TargetEncoder.scala | 9 +- 2 files changed, 80 insertions(+), 17 deletions(-) diff --git a/docs/ml-features.md b/docs/ml-features.md index 1fb0e03f8f435..418e94ad1ea19 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -857,30 +857,88 @@ for more details on the API. ## TargetEncoder -Target Encoding maps a column of categorical indices into a numerical feature derived from the target. -Leveraging the relationship between categorical features and the target variable, target encoding usually performs better than one-hot encoding (while avoiding the need to add extra columns) +[Target Encoding](https://www.researchgate.net/publication/220520258_A_Preprocessing_Scheme_for_High-Cardinality_Categorical_Attributes_in_Classification_and_Prediction_Problems) is a data-preprocessing technique that transforms high-cardinality categorical features into quasi-continuous scalar attributes suited for use in regression-type models. This paradigm maps individual values of an independent feature to a scalar, representing some estimate of the dependent attribute (meaning categorical values that exhibit similar statistics with respect to the target will have a similar representation). -`TargetEncoder` can transform multiple columns, returning a single target-encoded output column for each input column. -User can specify input and output column names by setting `inputCol` and `outputCol` for single-column use cases, or `inputCols` and `outputCols` for multi-column use cases (both arrays required to have the same size) +By leveraging the relationship between categorical features and the target variable, Target Encoding usually performs better than One-Hot and does not require a final binary vector encoding, decreasing the overall dimensionality of the dataset. -User can specify the target column name by setting `label`. This column contains the ground-truth labels from which encodings will be derived +User can specify input and output column names by setting `inputCol` and `outputCol` for single-column use cases, or `inputCols` and `outputCols` for multi-column use cases (both arrays required to have the same size). These columns are expected to contain categorical indices (positive integers), being missing values (null) treated as a separate category. Data type must be any subclass of 'NumericType'. For string type input data, it is common to encode categorical features using [StringIndexer](ml-features.html#stringindexer) first. -`TargetEncoder` supports the `handleInvalid` parameter to choose how to handle invalid input during transforming data. -Available options include 'keep' (any invalid inputs are assigned to an extra categorical index) and 'error' (throw an error). +User can specify the target column name by setting `label`. This column is expected to contain the ground-truth labels from which encodings will be derived. Observations with missing label (null) are not considered when calculating estimates. Data type must be any subclass of 'NumericType'. -`TargetEncoder` supports the `targetType` parameter to choose the label type when fitting data, affecting how statistics are calculated. -Available options include 'binary' and 'continuous' (mean-encoding). -When set to 'binary', encodings will be fitted from target conditional probabilities (a.k.a bin-counting). -When set to 'continuous', encodings will be fitted from according to target mean (a.k.a. mean-encoding). +`TargetEncoder` supports the `handleInvalid` parameter to choose how to handle invalid input, meaning categories not seen at training, when encoding new data. Available options include 'keep' (any invalid inputs are assigned to an extra categorical index) and 'error' (throw an exception). -`TargetEncoder` supports the `smoothing` parameter to tune how in-category stats and overall stats are weighted. -When calculating encodings according only to in-class statistics, rarely seen categories are very likely to cause overfitting when used in learning. -Smoothing prevents this behaviour by pondering in-class stats and overall stats according to weight of this class on the whole dataset. +`TargetEncoder` supports the `targetType` parameter to choose the label type when fitting data, affecting how estimates are calculated. Available options include 'binary' and 'continuous'. -For string type input data, it is common to encode categorical features using [StringIndexer](ml-features.html#stringindexer) first. +When set to 'binary', the target attribute $Y$ is expected to be binary, $Y\in\{ 0,1 \}$. The transformation maps individual values $X_{i}$ to the conditional probability of $Y$ given that $X=X_{i}\;$: $\;\; S_{i}=P(Y\mid X=X_{i})$. This approach is also known as bin-counting. + +When set to 'continuous', the target attribute $Y$ is expected to be continuous, $Y\in\mathbb{Q}$. The transformation maps individual values $X_{i}$ to the average of $Y$ given that $X=X_{i}\;$: $\;\; S_{i}=E[Y\mid X=X_{i}]$. This approach is also known as mean-encoding. + +`TargetEncoder` supports the `smoothing` parameter to tune how in-category stats and overall stats are blended. High-cardinality categorical features are usually unevenly distributed across all possible values of $X$. +Therefore, calculating encodings $S_{i}$ according only to in-class statistics makes this estimates very unreliable, and rarely seen categories will very likely cause overfitting in learning. + +Smoothing prevents this behaviour by weighting in-class estimates with overall estimates according to the relative size of the particular class on the whole dataset. + +$\;\;\; S_{i}=\lambda(n_{i})\, P(Y\mid X=X_{i})+(1-\lambda(n_{i}))\, P(Y)$ for the binary case + +$\;\;\; S_{i}=\lambda(n_{i})\, E[Y\mid X=X_{i}]+(1-\lambda(n_{i}))\, E[Y]$ for the continuous case + +being $\lambda(n_{i})$ a monotonically increasing function on $n_{i}$, bounded between 0 and 1. + +Usually $\lambda(n_{i})$ is implemented as the parametric function $\lambda(n_{i})=\frac{n_{i}}{n_{i}+m}$, where $m$ is the smoothing factor, represented by `smoothing` parameter in `TargetEncoder`. **Examples** +Building on the `TargetEncoder` example, let's assume we have the following +DataFrame with columns `feature` and `target` (binary & continuous): + +~~~~ + feature | target | target + | (bin) | (cont) + --------|--------|-------- + 1 | 0 | 1.3 + 1 | 1 | 2.5 + 1 | 0 | 1.6 + 2 | 1 | 1.8 + 2 | 0 | 2.4 + 3 | 1 | 3.2 +~~~~ + +Applying `TargetEncoder` with 'binary' target type, +`feature` as the input column,`target (bin)` as the label column +and `encoded` as the output column, we are able to fit a model +on the data to learn encodings and transform the data according +to these mappings: + +~~~~ + feature | target | encoded + | (bin) | + --------|--------|-------- + 1 | 0 | 0.333 + 1 | 1 | 0.333 + 1 | 0 | 0.333 + 2 | 1 | 0.5 + 2 | 0 | 0.5 + 3 | 1 | 1.0 +~~~~ + +Applying `TargetEncoder` with 'continuous' target type, +`feature` as the input column,`target (cont)` as the label column +and `encoded` as the output column, we are able to fit a model +on the data to learn encodings and transform the data according +to these mappings: + +~~~~ + feature | target | encoded + | (cont) | + --------|--------|-------- + 1 | 1.3 | 1.8 + 1 | 2.5 | 1.8 + 1 | 1.6 | 1.8 + 2 | 1.8 | 2.1 + 2 | 2.4 | 2.1 + 3 | 3.2 | 3.2 +~~~~ +
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/TargetEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/TargetEncoder.scala index db40fa052b84f..2be3529d00a35 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/TargetEncoder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/TargetEncoder.scala @@ -215,7 +215,7 @@ class TargetEncoder @Since("4.0.0") (@Since("4.0.0") override val uid: String) if (value < 0.0 || value != value.toInt) throw new SparkException( s"Values from column ${inputFeatures(feature)} must be indices, " + s"but got $value.") - else Some(value) + else Some(value) // non-null category } } val (class_count, class_stat) = agg(feature).getOrElse(category, (0.0, 0.0)) @@ -224,16 +224,19 @@ class TargetEncoder @Since("4.0.0") (@Since("4.0.0") override val uid: String) $(targetType) match { case TargetEncoder.TARGET_BINARY => // counting if (label == 1.0) { + // positive => increment both counters for current & unseen categories agg(feature) + (category -> (1 + class_count, 1 + class_stat)) + (TargetEncoder.UNSEEN_CATEGORY -> (1 + global_count, 1 + global_stat)) } else if (label == 0.0) { + // negative => increment only global counter for current & unseen categories agg(feature) + (category -> (1 + class_count, class_stat)) + (TargetEncoder.UNSEEN_CATEGORY -> (1 + global_count, global_stat)) } else throw new SparkException( s"Values from column ${getLabelCol} must be binary (0,1) but got $label.") case TargetEncoder.TARGET_CONTINUOUS => // incremental mean + // increment counter and iterate on mean for current & unseen categories agg(feature) + (category -> (1 + class_count, class_stat + ((label - class_stat) / (1 + class_count)))) + @@ -266,11 +269,13 @@ class TargetEncoder @Since("4.0.0") (@Since("4.0.0") override val uid: String) val (global_count, global_stat) = stat.get(TargetEncoder.UNSEEN_CATEGORY).get feature -> stat.map { case (cat, (class_count, class_stat)) => cat -> { - val weight = class_count / (class_count + $(smoothing)) + val weight = class_count / (class_count + $(smoothing)) // smoothing weight $(targetType) match { case TargetEncoder.TARGET_BINARY => + // calculate conditional probabilities and blend weight * (class_stat/ class_count) + (1 - weight) * (global_stat / global_count) case TargetEncoder.TARGET_CONTINUOUS => + // blend means weight * class_stat + (1 - weight) * global_stat } } From 229e5ed73fccf03a09db7da49c7a51e4735460e4 Mon Sep 17 00:00:00 2001 From: Enrique Rebollo Date: Sun, 20 Oct 2024 22:39:10 +0200 Subject: [PATCH 5/9] [SPARK-37178][ML] allow different feature names in model --- .../spark/ml/feature/TargetEncoder.scala | 125 ++++++++++-------- .../spark/ml/feature/TargetEncoderSuite.scala | 97 +++++++++----- python/pyspark/ml/feature.py | 2 +- 3 files changed, 132 insertions(+), 92 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/TargetEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/TargetEncoder.scala index 2be3529d00a35..dd76bfa286031 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/TargetEncoder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/TargetEncoder.scala @@ -47,7 +47,7 @@ private[ml] trait TargetEncoderBase extends Params with HasLabelCol override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "How to handle invalid data during transform(). " + "Options are 'keep' (invalid data presented as an extra categorical feature) " + - "or error (throw an error). Note that this Param is only used during transform; " + + "or 'error' (throw an error). Note that this Param is only used during transform; " + "during fitting, invalid data will result in an error.", ParamValidators.inArray(TargetEncoder.supportedHandleInvalids)) @@ -55,10 +55,11 @@ private[ml] trait TargetEncoderBase extends Params with HasLabelCol @Since("4.0.0") val targetType: Param[String] = new Param[String](this, "targetType", - "How to handle invalid data during transform(). " + - "Options are 'keep' (invalid data presented as an extra categorical feature) " + - "or error (throw an error). Note that this Param is only used during transform; " + - "during fitting, invalid data will result in an error.", + "Type of label considered during fit(). " + + "Options are 'binary' and 'continuous'. When 'binary', estimates are calculated as " + + "conditional probability of the target given each category. When 'continuous', " + + "estimates are calculated as the average of the target given each category" + + "Note that this Param is only used during fitting.", ParamValidators.inArray(TargetEncoder.supportedTargetTypes)) setDefault(targetType -> TargetEncoder.TARGET_BINARY) @@ -67,7 +68,9 @@ private[ml] trait TargetEncoderBase extends Params with HasLabelCol @Since("4.0.0") val smoothing: DoubleParam = new DoubleParam(this, "smoothing", - "lower bound of the output feature range", + "Smoothing factor for encodings. Smoothing blends in-class estimates with overall estimates " + + "according to the relative size of the particular class on the whole dataset, reducing the " + + "risk of overfitting due to unreliable estimates", ParamValidators.gtEq(0.0)) setDefault(smoothing -> 0.0) @@ -82,7 +85,8 @@ private[ml] trait TargetEncoderBase extends Params with HasLabelCol else if (isSet(outputCols)) $(outputCols) else inputFeatures.map{field: String => s"${field}_indexed"} - private[feature] def validateSchema(schema: StructType, + private[feature] def validateSchema( + schema: StructType, fitting: Boolean): StructType = { require(inputFeatures.length > 0, @@ -181,37 +185,20 @@ class TargetEncoder @Since("4.0.0") (@Since("4.0.0") override val uid: String) override def fit(dataset: Dataset[_]): TargetEncoderModel = { validateSchema(dataset.schema, fitting = true) - val feature_types = inputFeatures.map{ - feature => dataset.schema(feature).dataType - } - val label_type = dataset.schema($(labelCol)).dataType - val stats = dataset - .select((inputFeatures :+ $(labelCol)).map(col).toIndexedSeq: _*) + .select((inputFeatures :+ $(labelCol)).map(col(_).cast(DoubleType)).toIndexedSeq: _*) .rdd.treeAggregate( Array.fill(inputFeatures.length) { Map.empty[Option[Double], (Double, Double)] })( (agg, row: Row) => if (!row.isNullAt(inputFeatures.length)) { - val label = label_type match { - case ByteType => row.getByte(inputFeatures.length).toDouble - case ShortType => row.getShort(inputFeatures.length).toDouble - case IntegerType => row.getInt(inputFeatures.length).toDouble - case LongType => row.getLong(inputFeatures.length).toDouble - case DoubleType => row.getDouble(inputFeatures.length) - } + val label = row.getDouble(inputFeatures.length) inputFeatures.indices.map { feature => { val category: Option[Double] = { if (row.isNullAt(feature)) None // null category else { - val value: Double = feature_types(feature) match { - case ByteType => row.getByte(feature).toDouble - case ShortType => row.getShort(feature).toDouble - case IntegerType => row.getInt(feature).toDouble - case LongType => row.getLong(feature).toDouble - case DoubleType => row.getDouble(feature) - } + val value = row.getDouble(feature) if (value < 0.0 || value != value.toInt) throw new SparkException( s"Values from column ${inputFeatures(feature)} must be indices, " + s"but got $value.") @@ -262,12 +249,12 @@ class TargetEncoder @Since("4.0.0") (@Since("4.0.0") override val uid: String) } }.toArray) - // encodings: Map[feature, Map[Some(category), encoding]] - val encodings: Map[String, Map[Option[Double], Double]] = - inputFeatures.zip(stats).map { - case (feature, stat) => + // encodings: Array[Map[Some(category), encoding]] + val encodings: Array[Map[Option[Double], Double]] = + stats.map { + stat => val (global_count, global_stat) = stat.get(TargetEncoder.UNSEEN_CATEGORY).get - feature -> stat.map { + stat.map { case (cat, (class_count, class_stat)) => cat -> { val weight = class_count / (class_count + $(smoothing)) // smoothing weight $(targetType) match { @@ -280,7 +267,7 @@ class TargetEncoder @Since("4.0.0") (@Since("4.0.0") override val uid: String) } } } - }.toMap + } val model = new TargetEncoderModel(uid, encodings).setParent(this) copyValues(model) @@ -316,17 +303,40 @@ object TargetEncoder extends DefaultParamsReadable[TargetEncoder] { @Since("4.0.0") class TargetEncoderModel private[ml] ( @Since("4.0.0") override val uid: String, - @Since("4.0.0") val encodings: Map[String, Map[Option[Double], Double]]) + @Since("4.0.0") val encodings: Array[Map[Option[Double], Double]]) extends Model[TargetEncoderModel] with TargetEncoderBase with MLWritable { + /** @group setParam */ + @Since("4.0.0") + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + @Since("4.0.0") + def setOutputCol(value: String): this.type = set(outputCol, value) + + /** @group setParam */ + @Since("4.0.0") + def setInputCols(values: Array[String]): this.type = set(inputCols, values) + + /** @group setParam */ + @Since("4.0.0") + def setOutputCols(values: Array[String]): this.type = set(outputCols, values) + + /** @group setParam */ + @Since("4.0.0") + def setHandleInvalid(value: String): this.type = set(handleInvalid, value) + @Since("4.0.0") override def transformSchema(schema: StructType): StructType = { - inputFeatures.zip(outputFeatures) - .foldLeft(validateSchema(schema, fitting = false)) { - case (newSchema, fieldName) => - val field = schema(fieldName._1) - newSchema.add(StructField(fieldName._2, field.dataType, field.nullable)) - } + if (outputFeatures.length == encodings.length) { + outputFeatures.filter(_ != null) + .foldLeft(validateSchema(schema, fitting = false)) { + case (newSchema, outputField) => + newSchema.add(StructField(outputField, DoubleType, nullable = false)) + } + } else throw new SparkException("The number of features does not match the number of " + + s"encodings in the model (${encodings.length}). " + + s"Found ${outputFeatures.length} features)") } @Since("4.0.0") @@ -369,21 +379,20 @@ class TargetEncoderModel private[ml] ( } dataset.withColumns( - inputFeatures.zip(outputFeatures).map { - feature => - feature._2 -> (encodings.get(feature._1) match { - case Some(dict) => - apply_encodings(dict)(col(feature._1)) - .as(feature._2, NominalAttribute.defaultAttr - .withName(feature._2) - .withNumValues(dict.size) - .withValues(dict.values.toSet.toArray.map(_.toString)).toMetadata()) - case None => - throw new SparkException(s"No encodings found for ${feature._1}.") - col(feature._1) - }) - }.toMap) - } + inputFeatures.zip(outputFeatures).zip(encodings) + .filter{ + case ((featureIn, featureOut), _) => (featureIn != null) && (featureOut != null) + }.map { + case ((featureIn, featureOut), mapping) => + featureOut -> + apply_encodings(mapping)(col(featureIn)) + .as(featureOut, NominalAttribute.defaultAttr + .withName(featureOut) + .withNumValues(mapping.values.toSet.size) + .withValues(mapping.values.toSet.toArray.map(_.toString)).toMetadata()) + }.toMap) + + } @Since("4.0.0") @@ -398,7 +407,7 @@ class TargetEncoderModel private[ml] ( @Since("4.0.0") override def toString: String = { s"TargetEncoderModel: uid=$uid, " + - s" handleInvalid=${$(handleInvalid)}, targetType=${$(targetType)}, " + + s"handleInvalid=${$(handleInvalid)}, targetType=${$(targetType)}, " + s"numInputCols=${inputFeatures.length}, numOutputCols=${outputFeatures.length}, " + s"smoothing=${$(smoothing)}" } @@ -411,7 +420,7 @@ object TargetEncoderModel extends MLReadable[TargetEncoderModel] { private[TargetEncoderModel] class TargetEncoderModelWriter(instance: TargetEncoderModel) extends MLWriter { - private case class Data(encodings: Map[String, Map[Option[Double], Double]]) + private case class Data(encodings: Array[Map[Option[Double], Double]]) override protected def saveImpl(path: String): Unit = { DefaultParamsWriter.saveMetadata(instance, path, sparkSession) @@ -431,7 +440,7 @@ object TargetEncoderModel extends MLReadable[TargetEncoderModel] { val data = sparkSession.read.parquet(dataPath) .select("encodings") .head() - val encodings = data.getAs[Map[String, Map[Option[Double], Double]]](0) + val encodings = data.getAs[Array[Map[Option[Double], Double]]](0) val model = new TargetEncoderModel(metadata.uid, encodings) metadata.getAndSetParams(model) model diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/TargetEncoderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/TargetEncoderSuite.scala index 4d3f4f3f7213b..aa4155ae4d658 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/TargetEncoderSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/TargetEncoderSuite.scala @@ -82,14 +82,15 @@ class TargetEncoderSuite extends MLTest with DefaultReadWriteTest { val model = encoder.fit(df) - val expected_encodings = Map( - "input1" -> + val expected_encodings = Array( Map(Some(0.0) -> 1.0/3, Some(1.0) -> 2.0/3, Some(2.0) -> 1.0/3, Some(-1.0) -> 4.0/9), - "input2" -> Map(Some(3.0) -> 0.0, Some(4.0) -> 1.0, Some(-1.0) -> 4.0/9), - "input3" -> HashMap(Some(5.0) -> 1.0/3, Some(6.0) -> 2.0/3, Some(7.0) -> 0.0, - Some(8.0) -> 1.0, Some(9.0) -> 0.0, Some(-1.0) -> 4.0/9)) + Map(Some(3.0) -> 0.0, Some(4.0) -> 1.0, Some(-1.0) -> 4.0/9), + HashMap(Some(5.0) -> 1.0/3, Some(6.0) -> 2.0/3, Some(7.0) -> 0.0, + Some(8.0) -> 1.0, Some(9.0) -> 0.0, Some(-1.0) -> 4.0/9)) - assert(model.encodings.equals(expected_encodings)) + model.encodings.zip(expected_encodings).foreach{ + case (actual, expected) => actual.equals(expected) + } testTransformer[(Double, Double, Double, Double, Double, Double)]( df.select("input1", "input2", "input3", @@ -121,13 +122,15 @@ class TargetEncoderSuite extends MLTest with DefaultReadWriteTest { val model = encoder.fit(df) - val expected_encodings = Map( - "input1" -> Map(Some(0.0) -> 40.0, Some(1.0) -> 50.0, Some(2.0) -> 60.0, Some(-1.0) -> 50.0), - "input2" -> Map(Some(3.0) -> 50.0, Some(4.0) -> 50.0, Some(-1.0) -> 50.0), - "input3" -> HashMap(Some(5.0) -> 20.0, Some(6.0) -> 50.0, Some(7.0) -> 70.0, + val expected_encodings = Array( + Map(Some(0.0) -> 40.0, Some(1.0) -> 50.0, Some(2.0) -> 60.0, Some(-1.0) -> 50.0), + Map(Some(3.0) -> 50.0, Some(4.0) -> 50.0, Some(-1.0) -> 50.0), + HashMap(Some(5.0) -> 20.0, Some(6.0) -> 50.0, Some(7.0) -> 70.0, Some(8.0) -> 80.0, Some(9.0) -> 90.0, Some(-1.0) -> 50.0)) - assert(model.encodings.equals(expected_encodings)) + model.encodings.zip(expected_encodings).foreach{ + case (actual, expected) => actual.equals(expected) + } testTransformer[(Double, Double, Double, Double, Double, Double)]( df.select("input1", "input2", "input3", @@ -160,13 +163,15 @@ class TargetEncoderSuite extends MLTest with DefaultReadWriteTest { val model = encoder.fit(df) - val expected_encodings = Map( - "input1" -> Map(Some(0.0) -> 42.5, Some(1.0) -> 50.0, Some(2.0) -> 57.5, Some(-1.0) -> 50.0), - "input2" -> Map(Some(3.0) -> 50.0, Some(4.0) -> 50.0, Some(-1.0) -> 50.0), - "input3" -> HashMap(Some(5.0) -> 27.5, Some(6.0) -> 50.0, Some(7.0) -> 60.0, + val expected_encodings = Array( + Map(Some(0.0) -> 42.5, Some(1.0) -> 50.0, Some(2.0) -> 57.5, Some(-1.0) -> 50.0), + Map(Some(3.0) -> 50.0, Some(4.0) -> 50.0, Some(-1.0) -> 50.0), + HashMap(Some(5.0) -> 27.5, Some(6.0) -> 50.0, Some(7.0) -> 60.0, Some(8.0) -> 65.0, Some(9.0) -> 70.0, Some(-1.0) -> 50.0)) - assert(model.encodings.equals(expected_encodings)) + model.encodings.zip(expected_encodings).foreach{ + case (actual, expected) => actual.equals(expected) + } testTransformer[(Double, Double, Double, Double, Double, Double)]( df.select("input1", "input2", "input3", @@ -314,20 +319,22 @@ class TargetEncoderSuite extends MLTest with DefaultReadWriteTest { val model = encoder.fit(df_null) - val expected_encodings = Map( - "input1" -> Map(Some(0.0) -> 40.0, Some(1.0) -> 50.0, Some(2.0) -> 60.0, Some(-1.0) -> 50.0), - "input2" -> Map(Some(3.0) -> 50.0, Some(4.0) -> 50.0, Some(-1.0) -> 50.0), - "input3" -> HashMap(Some(5.0) -> 20.0, Some(6.0) -> 50.0, Some(7.0) -> 70.0, + val expected_encodings = Array( + Map(Some(0.0) -> 40.0, Some(1.0) -> 50.0, Some(2.0) -> 60.0, Some(-1.0) -> 50.0), + Map(Some(3.0) -> 50.0, Some(4.0) -> 50.0, Some(-1.0) -> 50.0), + HashMap(Some(5.0) -> 20.0, Some(6.0) -> 50.0, Some(7.0) -> 70.0, Some(8.0) -> 80.0, None -> 90.0, Some(-1.0) -> 50.0)) - assert(model.encodings.equals(expected_encodings)) + model.encodings.zip(expected_encodings).foreach{ + case (actual, expected) => actual.equals(expected) + } val output = model.transform(df_null) assert_true( output("output1") === output("continuousExpected1") && - output("output1") === output("continuousExpected1") && - output("output1") === output("continuousExpected1")) + output("output2") === output("continuousExpected2") && + output("output3") === output("continuousExpected3")) } @@ -355,8 +362,8 @@ class TargetEncoderSuite extends MLTest with DefaultReadWriteTest { assert_true( output("output1") === output("continuousExpected1") && - output("output1") === output("continuousExpected1") && - output("output1") === output("continuousExpected1")) + output("output2") === output("continuousExpected2") && + output("output3") === output("continuousExpected3")) } @@ -401,15 +408,15 @@ class TargetEncoderSuite extends MLTest with DefaultReadWriteTest { val model = encoder.fit(df_nolabel) - val expected_encodings = Map( - "input1" -> Map(Some(0.0) -> 40.0, Some(1.0) -> 50.0, Some(2.0) -> 60.0, Some(-1.0) -> 50.0), - "input2" -> Map(Some(3.0) -> 50.0, Some(4.0) -> 50.0, Some(-1.0) -> 50.0), - "input3" -> HashMap(Some(5.0) -> 20.0, Some(6.0) -> 50.0, Some(7.0) -> 70.0, + val expected_encodings = Array( + Map(Some(0.0) -> 40.0, Some(1.0) -> 50.0, Some(2.0) -> 60.0, Some(-1.0) -> 50.0), + Map(Some(3.0) -> 50.0, Some(4.0) -> 50.0, Some(-1.0) -> 50.0), + HashMap(Some(5.0) -> 20.0, Some(6.0) -> 50.0, Some(7.0) -> 70.0, Some(8.0) -> 80.0, Some(9.0) -> 90.0, Some(-1.0) -> 50.0)) - print(model.encodings) - - assert(model.encodings.equals(expected_encodings)) + model.encodings.zip(expected_encodings).foreach{ + case (actual, expected) => actual.equals(expected) + } } @@ -429,7 +436,7 @@ class TargetEncoderSuite extends MLTest with DefaultReadWriteTest { val ex = intercept[SparkException] { val model = encoder.fit(df_non_binary) - print(model.encodings) + print(model.encodings.mkString) } assert(ex.isInstanceOf[SparkException]) @@ -438,6 +445,30 @@ class TargetEncoderSuite extends MLTest with DefaultReadWriteTest { } + test("TargetEncoder - features renamed") { + + val df = spark + .createDataFrame(sc.parallelize(data), schema) + + val encoder = new TargetEncoder() + .setLabelCol("continuousLabel") + .setTargetType(TargetEncoder.TARGET_CONTINUOUS) + .setInputCols(Array("input1", "input2", "input3")) + .setOutputCols(Array("output1", "output2", "output3")) + + val model = encoder.fit(df) + .setInputCols(Array("input1", "renamed_input2", null)) + .setOutputCols(Array(null, "renamed_output2", "renamed_output3")) + + val df_renamed = df + .drop("input1", "input3") + .withColumnRenamed("input2", "renamed_input2") + + val output = model.transform(df_renamed) + assert(output.filter(col("renamed_output2") === col("continuousExpected2")).count() ==0) + + } + test("TargetEncoder - R/W single-column") { val encoder = new TargetEncoder() diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 98fc6dc690880..c409a330cd8a8 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -5493,7 +5493,7 @@ def setSmoothing(self, value: float) -> "TargetEncoderModel": @property @since("4.0.0") - def encodings(self) -> dict[str, dict[float, float]]: + def encodings(self) -> List[Dict[float, float]]: """ Fitted mappings for each feature to being encoded. The dictionary contains a dictionary for each input column. From bb95c8d8df3b754399e44720894168fb0b9c132a Mon Sep 17 00:00:00 2001 From: Enrique Rebollo Date: Wed, 23 Oct 2024 19:45:45 +0200 Subject: [PATCH 6/9] [SPARK-37178][ML] passing raw stats to model, building encodings in transform() --- .../spark/ml/feature/TargetEncoder.scala | 106 +++--- .../ml/feature/JavaTargetEncoderSuite.java | 74 ++-- .../spark/ml/feature/TargetEncoderSuite.scala | 342 ++++++++++-------- python/pyspark/ml/feature.py | 18 +- python/pyspark/ml/tests/test_feature.py | 91 +++++ 5 files changed, 397 insertions(+), 234 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/TargetEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/TargetEncoder.scala index dd76bfa286031..f7f848b12f17a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/TargetEncoder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/TargetEncoder.scala @@ -69,8 +69,8 @@ private[ml] trait TargetEncoderBase extends Params with HasLabelCol @Since("4.0.0") val smoothing: DoubleParam = new DoubleParam(this, "smoothing", "Smoothing factor for encodings. Smoothing blends in-class estimates with overall estimates " + - "according to the relative size of the particular class on the whole dataset, reducing the " + - "risk of overfitting due to unreliable estimates", + "according to the relative size of the particular class on the whole dataset, reducing the " + + "risk of overfitting due to unreliable estimates", ParamValidators.gtEq(0.0)) setDefault(smoothing -> 0.0) @@ -78,16 +78,16 @@ private[ml] trait TargetEncoderBase extends Params with HasLabelCol final def getSmoothing: Double = $(smoothing) private[feature] lazy val inputFeatures = if (isSet(inputCol)) Array($(inputCol)) - else if (isSet(inputCols)) $(inputCols) - else Array.empty[String] + else if (isSet(inputCols)) $(inputCols) + else Array.empty[String] private[feature] lazy val outputFeatures = if (isSet(outputCol)) Array($(outputCol)) - else if (isSet(outputCols)) $(outputCols) - else inputFeatures.map{field: String => s"${field}_indexed"} + else if (isSet(outputCols)) $(outputCols) + else inputFeatures.map{field: String => s"${field}_indexed"} private[feature] def validateSchema( - schema: StructType, - fitting: Boolean): StructType = { + schema: StructType, + fitting: Boolean): StructType = { require(inputFeatures.length > 0, s"At least one input column must be specified.") @@ -185,6 +185,7 @@ class TargetEncoder @Since("4.0.0") (@Since("4.0.0") override val uid: String) override def fit(dataset: Dataset[_]): TargetEncoderModel = { validateSchema(dataset.schema, fitting = true) + // stats: Array[Map[Some(category), (counter,stat)]] val stats = dataset .select((inputFeatures :+ $(labelCol)).map(col(_).cast(DoubleType)).toIndexedSeq: _*) .rdd.treeAggregate( @@ -249,27 +250,9 @@ class TargetEncoder @Since("4.0.0") (@Since("4.0.0") override val uid: String) } }.toArray) - // encodings: Array[Map[Some(category), encoding]] - val encodings: Array[Map[Option[Double], Double]] = - stats.map { - stat => - val (global_count, global_stat) = stat.get(TargetEncoder.UNSEEN_CATEGORY).get - stat.map { - case (cat, (class_count, class_stat)) => cat -> { - val weight = class_count / (class_count + $(smoothing)) // smoothing weight - $(targetType) match { - case TargetEncoder.TARGET_BINARY => - // calculate conditional probabilities and blend - weight * (class_stat/ class_count) + (1 - weight) * (global_stat / global_count) - case TargetEncoder.TARGET_CONTINUOUS => - // blend means - weight * class_stat + (1 - weight) * global_stat - } - } - } - } - val model = new TargetEncoderModel(uid, encodings).setParent(this) + + val model = new TargetEncoderModel(uid, stats).setParent(this) copyValues(model) } @@ -297,13 +280,13 @@ object TargetEncoder extends DefaultParamsReadable[TargetEncoder] { } /** - * @param encodings Original number of categories for each feature being encoded. - * The array contains one value for each input column, in order. + * @param stats Array of statistics for each input feature. + * Array( Map( Some(category), (counter, stat) ) ) */ @Since("4.0.0") class TargetEncoderModel private[ml] ( - @Since("4.0.0") override val uid: String, - @Since("4.0.0") val encodings: Array[Map[Option[Double], Double]]) + @Since("4.0.0") override val uid: String, + @Since("4.0.0") val stats: Array[Map[Option[Double], (Double, Double)]]) extends Model[TargetEncoderModel] with TargetEncoderBase with MLWritable { /** @group setParam */ @@ -326,22 +309,46 @@ class TargetEncoderModel private[ml] ( @Since("4.0.0") def setHandleInvalid(value: String): this.type = set(handleInvalid, value) + /** @group setParam */ + @Since("4.0.0") + def setSmoothing(value: Double): this.type = set(smoothing, value) + @Since("4.0.0") override def transformSchema(schema: StructType): StructType = { - if (outputFeatures.length == encodings.length) { + if (outputFeatures.length == stats.length) { outputFeatures.filter(_ != null) .foldLeft(validateSchema(schema, fitting = false)) { case (newSchema, outputField) => newSchema.add(StructField(outputField, DoubleType, nullable = false)) } } else throw new SparkException("The number of features does not match the number of " + - s"encodings in the model (${encodings.length}). " + - s"Found ${outputFeatures.length} features)") + s"encodings in the model (${stats.length}). " + + s"Found ${outputFeatures.length} features)") } @Since("4.0.0") override def transform(dataset: Dataset[_]): DataFrame = { - validateSchema(dataset.schema, fitting = false) + transformSchema(dataset.schema) + + // encodings: Array[Map[Some(category), encoding]] + val encodings: Array[Map[Option[Double], Double]] = + stats.map { + stat => + val (global_count, global_stat) = stat.get(TargetEncoder.UNSEEN_CATEGORY).get + stat.map { + case (cat, (class_count, class_stat)) => cat -> { + val weight = class_count / (class_count + $(smoothing)) // smoothing weight + $(targetType) match { + case TargetEncoder.TARGET_BINARY => + // calculate conditional probabilities and blend + weight * (class_stat/ class_count) + (1 - weight) * (global_stat / global_count) + case TargetEncoder.TARGET_CONTINUOUS => + // blend means + weight * class_stat + (1 - weight) * global_stat + } + } + } + } // builds a column-to-column function from a map of encodings val apply_encodings: Map[Option[Double], Double] => (Column => Column) = @@ -380,24 +387,21 @@ class TargetEncoderModel private[ml] ( dataset.withColumns( inputFeatures.zip(outputFeatures).zip(encodings) - .filter{ - case ((featureIn, featureOut), _) => (featureIn != null) && (featureOut != null) - }.map { + .map { case ((featureIn, featureOut), mapping) => featureOut -> - apply_encodings(mapping)(col(featureIn)) - .as(featureOut, NominalAttribute.defaultAttr - .withName(featureOut) - .withNumValues(mapping.values.toSet.size) - .withValues(mapping.values.toSet.toArray.map(_.toString)).toMetadata()) + apply_encodings(mapping)(col(featureIn)) + .as(featureOut, NominalAttribute.defaultAttr + .withName(featureOut) + .withNumValues(mapping.values.toSet.size) + .withValues(mapping.values.toSet.toArray.map(_.toString)).toMetadata()) }.toMap) - } - + } @Since("4.0.0") override def copy(extra: ParamMap): TargetEncoderModel = { - val copied = new TargetEncoderModel(uid, encodings) + val copied = new TargetEncoderModel(uid, stats) copyValues(copied, extra).setParent(parent) } @@ -420,11 +424,11 @@ object TargetEncoderModel extends MLReadable[TargetEncoderModel] { private[TargetEncoderModel] class TargetEncoderModelWriter(instance: TargetEncoderModel) extends MLWriter { - private case class Data(encodings: Array[Map[Option[Double], Double]]) + private case class Data(stats: Array[Map[Option[Double], (Double, Double)]]) override protected def saveImpl(path: String): Unit = { DefaultParamsWriter.saveMetadata(instance, path, sparkSession) - val data = Data(instance.encodings) + val data = Data(instance.stats) val dataPath = new Path(path, "data").toString sparkSession.createDataFrame(Seq(data)).write.parquet(dataPath) } @@ -440,8 +444,8 @@ object TargetEncoderModel extends MLReadable[TargetEncoderModel] { val data = sparkSession.read.parquet(dataPath) .select("encodings") .head() - val encodings = data.getAs[Array[Map[Option[Double], Double]]](0) - val model = new TargetEncoderModel(metadata.uid, encodings) + val stats = data.getAs[Array[Map[Option[Double], (Double, Double)]]](0) + val model = new TargetEncoderModel(metadata.uid, stats) metadata.getAndSetParams(model) model } diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaTargetEncoderSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaTargetEncoderSuite.java index 44e38543c515e..ae78780d1f067 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaTargetEncoderSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaTargetEncoderSuite.java @@ -36,16 +36,27 @@ public class JavaTargetEncoderSuite extends SharedSparkSession { @Test public void testTargetEncoderBinary() { + // checkstyle.off: LineLength List data = Arrays.asList( - RowFactory.create((short)0, 3, 5.0, 0.0, 1.0/3, 0.0, 1.0/3), - RowFactory.create((short)1, 4, 5.0, 1.0, 2.0/3, 1.0, 1.0/3), - RowFactory.create((short)2, 3, 5.0, 0.0, 1.0/3, 0.0, 1.0/3), - RowFactory.create((short)0, 4, 6.0, 1.0, 1.0/3, 1.0, 2.0/3), - RowFactory.create((short)1, 3, 6.0, 0.0, 2.0/3, 0.0, 2.0/3), - RowFactory.create((short)2, 4, 6.0, 1.0, 1.0/3, 1.0, 2.0/3), - RowFactory.create((short)0, 3, 7.0, 0.0, 1.0/3, 0.0, 0.0), - RowFactory.create((short)1, 4, 8.0, 1.0, 2.0/3, 1.0, 1.0), - RowFactory.create((short)2, 3, null, 0.0, 1.0/3, 0.0, 0.0)); + RowFactory.create((short)0, 3, 5.0, 0.0, 1.0/3, 0.0, 1.0/3, (3.0/4)*(1.0/3)+(1-3.0/4)*(4.0/9), + (1-5.0/6)*(4.0/9), (3.0/4)*(1.0/3)+(1-3.0/4)*(4.0/9)), + RowFactory.create((short)1, 4, 5.0, 1.0, 2.0/3, 1.0, 1.0/3, (3.0/4)*(2.0/3)+(1-3.0/4)*(4.0/9), + (4.0/5)*1+(1-4.0/5)*(4.0/9), (3.0/4)*(1.0/3)+(1-3.0/4)*(4.0/9)), + RowFactory.create((short)2, 3, 5.0, 0.0, 1.0/3, 0.0, 1.0/3, (3.0/4)*(1.0/3)+(1-3.0/4)*(4.0/9), + (1-5.0/6)*(4.0/9), (3.0/4)*(1.0/3)+(1-3.0/4)*(4.0/9)), + RowFactory.create((short)0, 4, 6.0, 1.0, 1.0/3, 1.0, 2.0/3, (3.0/4)*(1.0/3)+(1-3.0/4)*(4.0/9), + (4.0/5)*1+(1-4.0/5)*(4.0/9), (3.0/4)*(2.0/3)+(1-3.0/4)*(4.0/9)), + RowFactory.create((short)1, 3, 6.0, 0.0, 2.0/3, 0.0, 2.0/3, (3.0/4)*(2.0/3)+(1-3.0/4)*(4.0/9), + (1-5.0/6)*(4.0/9), (3.0/4)*(2.0/3)+(1-3.0/4)*(4.0/9)), + RowFactory.create((short)2, 4, 6.0, 1.0, 1.0/3, 1.0, 2.0/3, (3.0/4)*(1.0/3)+(1-3.0/4)*(4.0/9), + (4.0/5)*1+(1-4.0/5)*(4.0/9), (3.0/4)*(2.0/3)+(1-3.0/4)*(4.0/9)), + RowFactory.create((short)0, 3, 7.0, 0.0, 1.0/3, 0.0, 0.0, (3.0/4)*(1.0/3)+(1-3.0/4)*(4.0/9), + (1-5.0/6)*(4.0/9), (1-1.0/2)*(4.0/9)), + RowFactory.create((short)1, 4, 8.0, 1.0, 2.0/3, 1.0, 1.0, (3.0/4)*(2.0/3)+(1-3.0/4)*(4.0/9), + (4.0/5)*1+(1-4.0/5)*(4.0/9), (1.0/2)+(1-1.0/2)*(4.0/9)), + RowFactory.create((short)2, 3, null, 0.0, 1.0/3, 0.0, 0.0, (3.0/4)*(1.0/3)+(1-3.0/4)*(4.0/9), + (1-5.0/6)*(4.0/9), (1-1.0/2)*(4.0/9))); + // checkstyle.off: LineLength StructType schema = createStructType(new StructField[]{ createStructField("input1", ShortType, true), createStructField("input2", IntegerType, true), @@ -53,8 +64,12 @@ public void testTargetEncoderBinary() { createStructField("label", DoubleType, false), createStructField("expected1", DoubleType, false), createStructField("expected2", DoubleType, false), - createStructField("expected3", DoubleType, false) + createStructField("expected3", DoubleType, false), + createStructField("smoothing1", DoubleType, false), + createStructField("smoothing2", DoubleType, false), + createStructField("smoothing3", DoubleType, false) }); + Dataset dataset = spark.createDataFrame(data, schema); TargetEncoder encoder = new TargetEncoder() @@ -62,27 +77,33 @@ public void testTargetEncoderBinary() { .setOutputCols(new String[]{"output1", "output2", "output3"}) .setTargetType("binary"); TargetEncoderModel model = encoder.fit(dataset); - Dataset output = model.transform(dataset); + Dataset output = model.transform(dataset); Assertions.assertEquals( output.select("output1", "output2", "output3").collectAsList(), output.select("expected1", "expected2", "expected3").collectAsList()); + Dataset output_smoothing = model.setSmoothing(1.0).transform(dataset); + Assertions.assertEquals( + output_smoothing.select("output1", "output2", "output3").collectAsList(), + output_smoothing.select("smoothing1", "smoothing2", "smoothing3").collectAsList()); + } @Test public void testTargetEncoderContinuous() { List data = Arrays.asList( - RowFactory.create((short)0, 3, 5.0, 10.0, 40.0, 50.0, 20.0), - RowFactory.create((short)1, 4, 5.0, 20.0, 50.0, 50.0, 20.0), - RowFactory.create((short)2, 3, 5.0, 30.0, 60.0, 50.0, 20.0), - RowFactory.create((short)0, 4, 6.0, 40.0, 40.0, 50.0, 50.0), - RowFactory.create((short)1, 3, 6.0, 50.0, 50.0, 50.0, 50.0), - RowFactory.create((short)2, 4, 6.0, 60.0, 60.0, 50.0, 50.0), - RowFactory.create((short)0, 3, 7.0, 70.0, 40.0, 50.0, 70.0), - RowFactory.create((short)1, 4, 8.0, 80.0, 50.0, 50.0, 80.0), - RowFactory.create((short)2, 3, null, 90.0, 60.0, 50.0, 90.0)); + RowFactory.create((short)0, 3, 5.0, 10.0, 40.0, 50.0, 20.0, 42.5, 50.0, 27.5), + RowFactory.create((short)1, 4, 5.0, 20.0, 50.0, 50.0, 20.0, 50.0, 50.0, 27.5), + RowFactory.create((short)2, 3, 5.0, 30.0, 60.0, 50.0, 20.0, 57.5, 50.0, 27.5), + RowFactory.create((short)0, 4, 6.0, 40.0, 40.0, 50.0, 50.0, 42.5, 50.0, 50.0), + RowFactory.create((short)1, 3, 6.0, 50.0, 50.0, 50.0, 50.0, 50.0, 50.0, 50.0), + RowFactory.create((short)2, 4, 6.0, 60.0, 60.0, 50.0, 50.0, 57.5, 50.0, 50.0), + RowFactory.create((short)0, 3, 7.0, 70.0, 40.0, 50.0, 70.0, 42.5, 50.0, 60.0), + RowFactory.create((short)1, 4, 8.0, 80.0, 50.0, 50.0, 80.0, 50.0, 50.0, 65.0), + RowFactory.create((short)2, 3, null, 90.0, 60.0, 50.0, 90.0, 57.5, 50.0, 70.0)); + StructType schema = createStructType(new StructField[]{ createStructField("input1", ShortType, true), createStructField("input2", IntegerType, true), @@ -90,8 +111,12 @@ public void testTargetEncoderContinuous() { createStructField("label", DoubleType, false), createStructField("expected1", DoubleType, false), createStructField("expected2", DoubleType, false), - createStructField("expected3", DoubleType, false) + createStructField("expected3", DoubleType, false), + createStructField("smoothing1", DoubleType, false), + createStructField("smoothing2", DoubleType, false), + createStructField("smoothing3", DoubleType, false) }); + Dataset dataset = spark.createDataFrame(data, schema); TargetEncoder encoder = new TargetEncoder() @@ -99,12 +124,17 @@ public void testTargetEncoderContinuous() { .setOutputCols(new String[]{"output1", "output2", "output3"}) .setTargetType("continuous"); TargetEncoderModel model = encoder.fit(dataset); - Dataset output = model.transform(dataset); + Dataset output = model.transform(dataset); Assertions.assertEquals( output.select("output1", "output2", "output3").collectAsList(), output.select("expected1", "expected2", "expected3").collectAsList()); + Dataset output_smoothing = model.setSmoothing(1.0).transform(dataset); + Assertions.assertEquals( + output_smoothing.select("output1", "output2", "output3").collectAsList(), + output_smoothing.select("smoothing1", "smoothing2", "smoothing3").collectAsList()); + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/TargetEncoderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/TargetEncoderSuite.scala index aa4155ae4d658..c7560eaedc44a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/TargetEncoderSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/TargetEncoderSuite.scala @@ -30,40 +30,48 @@ class TargetEncoderSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ - @transient var data: Seq[Row] = _ + @transient var data_binary: Seq[Row] = _ + @transient var data_continuous: Seq[Row] = _ @transient var schema: StructType = _ override def beforeAll(): Unit = { super.beforeAll() // scalastyle:off - data = Seq( - Row(0.toShort, 3, 5.0, 0.toByte, 1.0/3, 0.0, 1.0/3, 10.0, 40.0, 50.0, 20.0, 42.5, 50.0, 27.5), - Row(1.toShort, 4, 5.0, 1.toByte, 2.0/3, 1.0, 1.0/3, 20.0, 50.0, 50.0, 20.0, 50.0, 50.0, 27.5), - Row(2.toShort, 3, 5.0, 0.toByte, 1.0/3, 0.0, 1.0/3, 30.0, 60.0, 50.0, 20.0, 57.5, 50.0, 27.5), - Row(0.toShort, 4, 6.0, 1.toByte, 1.0/3, 1.0, 2.0/3, 40.0, 40.0, 50.0, 50.0, 42.5, 50.0, 50.0), - Row(1.toShort, 3, 6.0, 0.toByte, 2.0/3, 0.0, 2.0/3, 50.0, 50.0, 50.0, 50.0, 50.0, 50.0, 50.0), - Row(2.toShort, 4, 6.0, 1.toByte, 1.0/3, 1.0, 2.0/3, 60.0, 60.0, 50.0, 50.0, 57.5, 50.0, 50.0), - Row(0.toShort, 3, 7.0, 0.toByte, 1.0/3, 0.0, 0.0, 70.0, 40.0, 50.0, 70.0, 42.5, 50.0, 60.0), - Row(1.toShort, 4, 8.0, 1.toByte, 2.0/3, 1.0, 1.0, 80.0, 50.0, 50.0, 80.0, 50.0, 50.0, 65.0), - Row(2.toShort, 3, 9.0, 0.toByte, 1.0/3, 0.0, 0.0, 90.0, 60.0, 50.0, 90.0, 57.5, 50.0, 70.0)) + data_binary = Seq( + Row(0.toShort, 3, 5.0, 0.0, 1.0/3, 0.0, 1.0/3, (3.0/4)*(1.0/3)+(1-3.0/4)*(4.0/9), (1-5.0/6)*(4.0/9), (3.0/4)*(1.0/3)+(1-3.0/4)*(4.0/9)), + Row(1.toShort, 4, 5.0, 1.0, 2.0/3, 1.0, 1.0/3, (3.0/4)*(2.0/3)+(1-3.0/4)*(4.0/9), (4.0/5)*1+(1-4.0/5)*(4.0/9), (3.0/4)*(1.0/3)+(1-3.0/4)*(4.0/9)), + Row(2.toShort, 3, 5.0, 0.0, 1.0/3, 0.0, 1.0/3, (3.0/4)*(1.0/3)+(1-3.0/4)*(4.0/9), (1-5.0/6)*(4.0/9), (3.0/4)*(1.0/3)+(1-3.0/4)*(4.0/9)), + Row(0.toShort, 4, 6.0, 1.0, 1.0/3, 1.0, 2.0/3, (3.0/4)*(1.0/3)+(1-3.0/4)*(4.0/9), (4.0/5)*1+(1-4.0/5)*(4.0/9), (3.0/4)*(2.0/3)+(1-3.0/4)*(4.0/9)), + Row(1.toShort, 3, 6.0, 0.0, 2.0/3, 0.0, 2.0/3, (3.0/4)*(2.0/3)+(1-3.0/4)*(4.0/9), (1-5.0/6)*(4.0/9), (3.0/4)*(2.0/3)+(1-3.0/4)*(4.0/9)), + Row(2.toShort, 4, 6.0, 1.0, 1.0/3, 1.0, 2.0/3, (3.0/4)*(1.0/3)+(1-3.0/4)*(4.0/9), (4.0/5)*1+(1-4.0/5)*(4.0/9), (3.0/4)*(2.0/3)+(1-3.0/4)*(4.0/9)), + Row(0.toShort, 3, 7.0, 0.0, 1.0/3, 0.0, 0.0, (3.0/4)*(1.0/3)+(1-3.0/4)*(4.0/9), (1-5.0/6)*(4.0/9), (1-1.0/2)*(4.0/9)), + Row(1.toShort, 4, 8.0, 1.0, 2.0/3, 1.0, 1.0, (3.0/4)*(2.0/3)+(1-3.0/4)*(4.0/9), (4.0/5)*1+(1-4.0/5)*(4.0/9), (1.0/2) +(1-1.0/2)*(4.0/9)), + Row(2.toShort, 3, 9.0, 0.0, 1.0/3, 0.0, 0.0, (3.0/4)*(1.0/3)+(1-3.0/4)*(4.0/9), (1-5.0/6)*(4.0/9), (1-1.0/2)*(4.0/9))) + + data_continuous = Seq( + Row(0.toShort, 3, 5.0, 10.0, 40.0, 50.0, 20.0, 42.5, 50.0, 27.5), + Row(1.toShort, 4, 5.0, 20.0, 50.0, 50.0, 20.0, 50.0, 50.0, 27.5), + Row(2.toShort, 3, 5.0, 30.0, 60.0, 50.0, 20.0, 57.5, 50.0, 27.5), + Row(0.toShort, 4, 6.0, 40.0, 40.0, 50.0, 50.0, 42.5, 50.0, 50.0), + Row(1.toShort, 3, 6.0, 50.0, 50.0, 50.0, 50.0, 50.0, 50.0, 50.0), + Row(2.toShort, 4, 6.0, 60.0, 60.0, 50.0, 50.0, 57.5, 50.0, 50.0), + Row(0.toShort, 3, 7.0, 70.0, 40.0, 50.0, 70.0, 42.5, 50.0, 60.0), + Row(1.toShort, 4, 8.0, 80.0, 50.0, 50.0, 80.0, 50.0, 50.0, 65.0), + Row(2.toShort, 3, 9.0, 90.0, 60.0, 50.0, 90.0, 57.5, 50.0, 70.0)) // scalastyle:on schema = StructType(Array( StructField("input1", ShortType, nullable = true), StructField("input2", IntegerType, nullable = true), StructField("input3", DoubleType, nullable = true), - StructField("binaryLabel", ByteType), - StructField("binaryExpected1", DoubleType), - StructField("binaryExpected2", DoubleType), - StructField("binaryExpected3", DoubleType), - StructField("continuousLabel", DoubleType), - StructField("continuousExpected1", DoubleType), - StructField("continuousExpected2", DoubleType), - StructField("continuousExpected3", DoubleType), - StructField("smoothingExpected1", DoubleType), - StructField("smoothingExpected2", DoubleType), - StructField("smoothingExpected3", DoubleType))) + StructField("label", DoubleType), + StructField("expected1", DoubleType), + StructField("expected2", DoubleType), + StructField("expected3", DoubleType), + StructField("smoothing1", DoubleType), + StructField("smoothing2", DoubleType), + StructField("smoothing3", DoubleType))) } test("params") { @@ -72,33 +80,34 @@ class TargetEncoderSuite extends MLTest with DefaultReadWriteTest { test("TargetEncoder - binary target") { - val df = spark.createDataFrame(sc.parallelize(data), schema) + val df = spark.createDataFrame(sc.parallelize(data_binary), schema) val encoder = new TargetEncoder() - .setLabelCol("binaryLabel") + .setLabelCol("label") .setTargetType(TargetEncoder.TARGET_BINARY) .setInputCols(Array("input1", "input2", "input3")) .setOutputCols(Array("output1", "output2", "output3")) val model = encoder.fit(df) - val expected_encodings = Array( - Map(Some(0.0) -> 1.0/3, Some(1.0) -> 2.0/3, Some(2.0) -> 1.0/3, Some(-1.0) -> 4.0/9), - Map(Some(3.0) -> 0.0, Some(4.0) -> 1.0, Some(-1.0) -> 4.0/9), - HashMap(Some(5.0) -> 1.0/3, Some(6.0) -> 2.0/3, Some(7.0) -> 0.0, - Some(8.0) -> 1.0, Some(9.0) -> 0.0, Some(-1.0) -> 4.0/9)) + val expected_stats = Array( + Map(Some(0.0) -> (3.0, 1.0), Some(1.0) -> (3.0, 2.0), Some(2.0) -> (3.0, 1.0), + Some(-1.0) -> (9.0, 4.0)), + Map(Some(3.0) -> (5.0, 0.0), Some(4.0) -> (4.0, 4.0), Some(-1.0) -> (9.0, 4.0)), + HashMap(Some(5.0) -> (3.0, 1.0), Some(6.0) -> (3.0, 2.0), Some(7.0) -> (1.0, 0.0), + Some(8.0) -> (1.0, 1.0), Some(9.0) -> (1.0, 0.0), Some(-1.0) -> (9.0, 4.0))) - model.encodings.zip(expected_encodings).foreach{ - case (actual, expected) => actual.equals(expected) + model.stats.zip(expected_stats).foreach{ + case (actual, expected) => assert(actual.equals(expected)) } testTransformer[(Double, Double, Double, Double, Double, Double)]( df.select("input1", "input2", "input3", - "binaryExpected1", "binaryExpected2", "binaryExpected3"), + "expected1", "expected2", "expected3"), model, - "output1", "binaryExpected1", - "output2", "binaryExpected2", - "output3", "binaryExpected3") { + "output1", "expected1", + "output2", "expected2", + "output3", "expected3") { case Row(output1: Double, expected1: Double, output2: Double, expected2: Double, output3: Double, expected3: Double) => @@ -107,38 +116,57 @@ class TargetEncoderSuite extends MLTest with DefaultReadWriteTest { assert(output3 === expected3) } + val model_smooth = model.setSmoothing(1.0) + + testTransformer[(Double, Double, Double, Double, Double, Double)]( + df.select("input1", "input2", "input3", + "smoothing1", "smoothing2", "smoothing3"), + model_smooth, + "output1", "smoothing1", + "output2", "smoothing2", + "output3", "smoothing3") { + case Row(output1: Double, expected1: Double, + output2: Double, expected2: Double, + output3: Double, expected3: Double) => + assert(output1 === expected1) + assert(output2 === expected2) + assert(output3 === expected3) + } + + } test("TargetEncoder - continuous target") { val df = spark - .createDataFrame(sc.parallelize(data), schema) + .createDataFrame(sc.parallelize(data_continuous), schema) val encoder = new TargetEncoder() - .setLabelCol("continuousLabel") + .setLabelCol("label") .setTargetType(TargetEncoder.TARGET_CONTINUOUS) .setInputCols(Array("input1", "input2", "input3")) .setOutputCols(Array("output1", "output2", "output3")) val model = encoder.fit(df) - val expected_encodings = Array( - Map(Some(0.0) -> 40.0, Some(1.0) -> 50.0, Some(2.0) -> 60.0, Some(-1.0) -> 50.0), - Map(Some(3.0) -> 50.0, Some(4.0) -> 50.0, Some(-1.0) -> 50.0), - HashMap(Some(5.0) -> 20.0, Some(6.0) -> 50.0, Some(7.0) -> 70.0, - Some(8.0) -> 80.0, Some(9.0) -> 90.0, Some(-1.0) -> 50.0)) + val expected_stats = Array( + Map(Some(0.0) -> (3.0, 40.0), Some(1.0) -> (3.0, 50.0), Some(2.0) -> (3.0, 60.0), + Some(-1.0) -> (9.0, 50.0)), + Map(Some(3.0) -> (5.0, 50.0), Some(4.0) -> (4.0, 50.0), Some(-1.0) -> (9.0, 50.0)), + HashMap(Some(5.0) -> (3.0, 20.0), Some(6.0) -> (3.0, 50.0), Some(7.0) -> (1.0, 70.0), + Some(8.0) -> (1.0, 80.0), Some(9.0) -> (1.0, 90.0), Some(-1.0) -> (9.0, 50.0))) - model.encodings.zip(expected_encodings).foreach{ - case (actual, expected) => actual.equals(expected) + model.stats.zip(expected_stats).foreach{ + case (actual, expected) => assert(actual.equals(expected)) } testTransformer[(Double, Double, Double, Double, Double, Double)]( df.select("input1", "input2", "input3", - "continuousExpected1", "continuousExpected2", "continuousExpected3"), + "expected1", "expected2", "expected3"), model, - "output1", "continuousExpected1", - "output2", "continuousExpected2", - "output3", "continuousExpected3") { + "output1", "expected1", + "output2", "expected2", + "output3", "expected3") { case Row(output1: Double, expected1: Double, output2: Double, expected2: Double, output3: Double, expected3: Double) => @@ -147,39 +175,15 @@ class TargetEncoderSuite extends MLTest with DefaultReadWriteTest { assert(output3 === expected3) } - } - - test("TargetEncoder - smoothing") { - - val df = spark - .createDataFrame(sc.parallelize(data), schema) - - val encoder = new TargetEncoder() - .setLabelCol("continuousLabel") - .setTargetType(TargetEncoder.TARGET_CONTINUOUS) - .setInputCols(Array("input1", "input2", "input3")) - .setOutputCols(Array("output1", "output2", "output3")) - .setSmoothing(1) - - val model = encoder.fit(df) - - val expected_encodings = Array( - Map(Some(0.0) -> 42.5, Some(1.0) -> 50.0, Some(2.0) -> 57.5, Some(-1.0) -> 50.0), - Map(Some(3.0) -> 50.0, Some(4.0) -> 50.0, Some(-1.0) -> 50.0), - HashMap(Some(5.0) -> 27.5, Some(6.0) -> 50.0, Some(7.0) -> 60.0, - Some(8.0) -> 65.0, Some(9.0) -> 70.0, Some(-1.0) -> 50.0)) - - model.encodings.zip(expected_encodings).foreach{ - case (actual, expected) => actual.equals(expected) - } + val model_smooth = model.setSmoothing(1.0) testTransformer[(Double, Double, Double, Double, Double, Double)]( df.select("input1", "input2", "input3", - "smoothingExpected1", "smoothingExpected2", "smoothingExpected3"), - model, - "output1", "smoothingExpected1", - "output2", "smoothingExpected2", - "output3", "smoothingExpected3") { + "smoothing1", "smoothing2", "smoothing3"), + model_smooth, + "output1", "smoothing1", + "output2", "smoothing2", + "output3", "smoothing3") { case Row(output1: Double, expected1: Double, output2: Double, expected2: Double, output3: Double, expected3: Double) => @@ -193,10 +197,10 @@ class TargetEncoderSuite extends MLTest with DefaultReadWriteTest { test("TargetEncoder - unseen value - keep") { val df = spark - .createDataFrame(sc.parallelize(data), schema) + .createDataFrame(sc.parallelize(data_continuous), schema) val encoder = new TargetEncoder() - .setLabelCol("continuousLabel") + .setLabelCol("label") .setTargetType(TargetEncoder.TARGET_CONTINUOUS) .setHandleInvalid(TargetEncoder.KEEP_INVALID) .setInputCols(Array("input1", "input2", "input3")) @@ -204,19 +208,18 @@ class TargetEncoderSuite extends MLTest with DefaultReadWriteTest { val model = encoder.fit(df) - val data_unseen = Row(0.toShort, 3, 10.0, - 0.toByte, 0.0, 0.0, 0.0, 0.0, 40.0, 50.0, 50.0, 0.0, 0.0, 0.0) + val data_unseen = Row(0.toShort, 3, 10.0, 0.0, 40.0, 50.0, 50.0, 0.0, 0.0, 0.0) val df_unseen = spark - .createDataFrame(sc.parallelize(data :+ data_unseen), schema) + .createDataFrame(sc.parallelize(data_continuous :+ data_unseen), schema) testTransformer[(Double, Double, Double, Double, Double, Double)]( df_unseen.select("input1", "input2", "input3", - "continuousExpected1", "continuousExpected2", "continuousExpected3"), + "expected1", "expected2", "expected3"), model, - "output1", "continuousExpected1", - "output2", "continuousExpected2", - "output3", "continuousExpected3") { + "output1", "expected1", + "output2", "expected2", + "output3", "expected3") { case Row(output1: Double, expected1: Double, output2: Double, expected2: Double, output3: Double, expected3: Double) => @@ -229,10 +232,10 @@ class TargetEncoderSuite extends MLTest with DefaultReadWriteTest { test("TargetEncoder - unseen value - error") { val df = spark - .createDataFrame(sc.parallelize(data), schema) + .createDataFrame(sc.parallelize(data_continuous), schema) val encoder = new TargetEncoder() - .setLabelCol("continuousLabel") + .setLabelCol("label") .setTargetType(TargetEncoder.TARGET_CONTINUOUS) .setHandleInvalid(TargetEncoder.ERROR_INVALID) .setInputCols(Array("input1", "input2", "input3")) @@ -240,11 +243,10 @@ class TargetEncoderSuite extends MLTest with DefaultReadWriteTest { val model = encoder.fit(df) - val data_unseen = Row(0.toShort, 3, 10.0, - 0.toByte, 0.0, 0.0, 0.0, 0.0, 4.0/9, 4.0/9, 4.0/9, 0.0, 0.0, 0.0) + val data_unseen = Row(0.toShort, 3, 10.0, 0.0, 4.0/9, 4.0/9, 4.0/9, 0.0, 0.0, 0.0) val df_unseen = spark - .createDataFrame(sc.parallelize(data :+ data_unseen), schema) + .createDataFrame(sc.parallelize(data_continuous :+ data_unseen), schema) val ex = intercept[SparkRuntimeException] { val out = model.transform(df_unseen) @@ -259,17 +261,17 @@ class TargetEncoderSuite extends MLTest with DefaultReadWriteTest { test("TargetEncoder - missing feature") { val df = spark - .createDataFrame(sc.parallelize(data), schema) + .createDataFrame(sc.parallelize(data_binary), schema) val encoder = new TargetEncoder() - .setLabelCol("binaryLabel") + .setLabelCol("label") .setInputCols(Array("input1", "input2", "input3")) .setTargetType(TargetEncoder.TARGET_BINARY) .setOutputCols(Array("output1", "output2", "output3")) val ex = intercept[SparkException] { val model = encoder.fit(df.drop("input3")) - print(model.encodings) + print(model.stats) } assert(ex.isInstanceOf[SparkException]) @@ -281,22 +283,21 @@ class TargetEncoderSuite extends MLTest with DefaultReadWriteTest { val wrong_schema = new StructType( schema.map{ field: StructField => if (field.name != "input3") field - else new StructField(field.name, StringType, field.nullable, field.metadata) + else StructField(field.name, StringType, field.nullable, field.metadata) }.toArray) val df = spark - .createDataFrame(sc.parallelize(data), wrong_schema) - .drop("continuousLabel") + .createDataFrame(sc.parallelize(data_binary), wrong_schema) val encoder = new TargetEncoder() - .setLabelCol("binaryLabel") + .setLabelCol("label") .setInputCols(Array("input1", "input2", "input3")) .setTargetType(TargetEncoder.TARGET_BINARY) .setOutputCols(Array("output1", "output2", "output3")) val ex = intercept[SparkException] { val model = encoder.fit(df) - print(model.encodings) + print(model.stats) } assert(ex.isInstanceOf[SparkException]) @@ -305,85 +306,83 @@ class TargetEncoderSuite extends MLTest with DefaultReadWriteTest { test("TargetEncoder - seen null category") { - val data_null = Row(2.toShort, 3, null, - 0.toByte, 1.0/3, 0.0, 0.0, 90.0, 60.0, 50.0, 90.0, 57.5, 50.0, 70.0) + val data_null = Row(2.toShort, 3, null, 90.0, 60.0, 50.0, 90.0, 57.5, 50.0, 70.0) val df_null = spark - .createDataFrame(sc.parallelize(data.dropRight(1) :+ data_null), schema) + .createDataFrame(sc.parallelize(data_continuous.dropRight(1) :+ data_null), schema) val encoder = new TargetEncoder() - .setLabelCol("continuousLabel") + .setLabelCol("label") .setTargetType(TargetEncoder.TARGET_CONTINUOUS) .setInputCols(Array("input1", "input2", "input3")) .setOutputCols(Array("output1", "output2", "output3")) val model = encoder.fit(df_null) - val expected_encodings = Array( - Map(Some(0.0) -> 40.0, Some(1.0) -> 50.0, Some(2.0) -> 60.0, Some(-1.0) -> 50.0), - Map(Some(3.0) -> 50.0, Some(4.0) -> 50.0, Some(-1.0) -> 50.0), - HashMap(Some(5.0) -> 20.0, Some(6.0) -> 50.0, Some(7.0) -> 70.0, - Some(8.0) -> 80.0, None -> 90.0, Some(-1.0) -> 50.0)) + val expected_stats = Array( + Map(Some(0.0) -> (3.0, 40.0), Some(1.0) -> (3.0, 50.0), Some(2.0) -> (3.0, 60.0), + Some(-1.0) -> (9.0, 50.0)), + Map(Some(3.0) -> (5.0, 50.0), Some(4.0) -> (4.0, 50.0), Some(-1.0) -> (9.0, 50.0)), + HashMap(Some(5.0) -> (3.0, 20.0), Some(6.0) -> (3.0, 50.0), Some(7.0) -> (1.0, 70.0), + Some(8.0) -> (1.0, 80.0), None -> (1.0, 90.0), Some(-1.0) -> (9.0, 50.0))) - model.encodings.zip(expected_encodings).foreach{ - case (actual, expected) => actual.equals(expected) + model.stats.zip(expected_stats).foreach{ + case (actual, expected) => assert(actual.equals(expected)) } val output = model.transform(df_null) assert_true( - output("output1") === output("continuousExpected1") && - output("output2") === output("continuousExpected2") && - output("output3") === output("continuousExpected3")) + output("output1") === output("expected1") && + output("output2") === output("expected2") && + output("output3") === output("expected3")) } test("TargetEncoder - unseen null category") { val df = spark - .createDataFrame(sc.parallelize(data), schema) + .createDataFrame(sc.parallelize(data_continuous), schema) val encoder = new TargetEncoder() - .setLabelCol("continuousLabel") + .setLabelCol("label") .setTargetType(TargetEncoder.TARGET_CONTINUOUS) .setHandleInvalid(TargetEncoder.KEEP_INVALID) .setInputCols(Array("input1", "input2", "input3")) .setOutputCols(Array("output1", "output2", "output3")) - val data_null = Row(null, null, null, - 0.toByte, 1.0/3, 0.0, 0.0, 90.0, 50.0, 50.0, 50.0, 57.5, 50.0, 70.0) + val data_null = Row(null, null, null, 90.0, 50.0, 50.0, 50.0, 57.5, 50.0, 70.0) val df_null = spark - .createDataFrame(sc.parallelize(data :+ data_null), schema) + .createDataFrame(sc.parallelize(data_continuous :+ data_null), schema) val model = encoder.fit(df) val output = model.transform(df_null) assert_true( - output("output1") === output("continuousExpected1") && - output("output2") === output("continuousExpected2") && - output("output3") === output("continuousExpected3")) + output("output1") === output("expected1") && + output("output2") === output("expected2") && + output("output3") === output("expected3")) } test("TargetEncoder - non-indexed categories") { val encoder = new TargetEncoder() - .setLabelCol("binaryLabel") + .setLabelCol("label") .setTargetType(TargetEncoder.TARGET_BINARY) .setInputCols(Array("input1", "input2", "input3")) .setOutputCols(Array("output1", "output2", "output3")) - val data_noindex = Row( - 0.toShort, 3, 5.1, 0.toByte, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0) + val data_noindex = Row(0.toShort, 3, 5.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0) val df_noindex = spark - .createDataFrame(sc.parallelize(data :+ data_noindex), schema) + .createDataFrame(sc.parallelize(data_binary :+ data_noindex), schema) val ex = intercept[SparkException] { val model = encoder.fit(df_noindex) - print(model.encodings) + print(model.stats) } assert(ex.isInstanceOf[SparkException]) @@ -394,28 +393,28 @@ class TargetEncoderSuite extends MLTest with DefaultReadWriteTest { test("TargetEncoder - null label") { - val data_nolabel = Row(2.toShort, 3, 5.0, - null, 1.0/3, 0.0, 0.0, null, 60.0, 50.0, 90.0, 57.5, 50.0, 70.0) + val data_nolabel = Row(2.toShort, 3, 5.0, null, 60.0, 50.0, 90.0, 57.5, 50.0, 70.0) val df_nolabel = spark - .createDataFrame(sc.parallelize(data :+ data_nolabel), schema) + .createDataFrame(sc.parallelize(data_continuous :+ data_nolabel), schema) val encoder = new TargetEncoder() - .setLabelCol("continuousLabel") + .setLabelCol("label") .setTargetType(TargetEncoder.TARGET_CONTINUOUS) .setInputCols(Array("input1", "input2", "input3")) .setOutputCols(Array("output1", "output2", "output3")) val model = encoder.fit(df_nolabel) - val expected_encodings = Array( - Map(Some(0.0) -> 40.0, Some(1.0) -> 50.0, Some(2.0) -> 60.0, Some(-1.0) -> 50.0), - Map(Some(3.0) -> 50.0, Some(4.0) -> 50.0, Some(-1.0) -> 50.0), - HashMap(Some(5.0) -> 20.0, Some(6.0) -> 50.0, Some(7.0) -> 70.0, - Some(8.0) -> 80.0, Some(9.0) -> 90.0, Some(-1.0) -> 50.0)) + val expected_stats = Array( + Map(Some(0.0) -> (3.0, 40.0), Some(1.0) -> (3.0, 50.0), Some(2.0) -> (3.0, 60.0), + Some(-1.0) -> (9.0, 50.0)), + Map(Some(3.0) -> (5.0, 50.0), Some(4.0) -> (4.0, 50.0), Some(-1.0) -> (9.0, 50.0)), + HashMap(Some(5.0) -> (3.0, 20.0), Some(6.0) -> (3.0, 50.0), Some(7.0) -> (1.0, 70.0), + Some(8.0) -> (1.0, 80.0), Some(9.0) -> (1.0, 90.0), Some(-1.0) -> (9.0, 50.0))) - model.encodings.zip(expected_encodings).foreach{ - case (actual, expected) => actual.equals(expected) + model.stats.zip(expected_stats).foreach{ + case (actual, expected) => assert(actual.equals(expected)) } } @@ -423,49 +422,88 @@ class TargetEncoderSuite extends MLTest with DefaultReadWriteTest { test("TargetEncoder - non-binary labels") { val encoder = new TargetEncoder() - .setLabelCol("binaryLabel") + .setLabelCol("label") .setTargetType(TargetEncoder.TARGET_BINARY) .setInputCols(Array("input1", "input2", "input3")) .setOutputCols(Array("output1", "output2", "output3")) - val data_non_binary = Row( - 0.toShort, 3, 5.0, 2.toByte, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0) + val data_non_binary = Row(0.toShort, 3, 5.0, 2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0) val df_non_binary = spark - .createDataFrame(sc.parallelize(data :+ data_non_binary), schema) + .createDataFrame(sc.parallelize(data_binary :+ data_non_binary), schema) val ex = intercept[SparkException] { val model = encoder.fit(df_non_binary) - print(model.encodings.mkString) + print(model.stats) } assert(ex.isInstanceOf[SparkException]) assert(ex.getMessage.contains( - "Values from column binaryLabel must be binary (0,1) but got 2.0")) + "Values from column label must be binary (0,1) but got 2.0")) } test("TargetEncoder - features renamed") { val df = spark - .createDataFrame(sc.parallelize(data), schema) + .createDataFrame(sc.parallelize(data_continuous), schema) val encoder = new TargetEncoder() - .setLabelCol("continuousLabel") + .setLabelCol("label") .setTargetType(TargetEncoder.TARGET_CONTINUOUS) .setInputCols(Array("input1", "input2", "input3")) .setOutputCols(Array("output1", "output2", "output3")) val model = encoder.fit(df) - .setInputCols(Array("input1", "renamed_input2", null)) - .setOutputCols(Array(null, "renamed_output2", "renamed_output3")) + .setInputCols(Array("renamed_input1", "renamed_input2", "renamed_input3")) + .setOutputCols(Array("renamed_output1", "renamed_output2", "renamed_output3")) val df_renamed = df - .drop("input1", "input3") - .withColumnRenamed("input2", "renamed_input2") + .withColumnsRenamed((1 to 3).map{ + f => s"input${f}" -> s"renamed_input${f}"}.toMap) + + testTransformer[(Double, Double, Double, Double, Double, Double)]( + df_renamed + .select("renamed_input1", "renamed_input2", "renamed_input3", + "expected1", "expected2", "expected3"), + model, + "renamed_output1", "expected1", + "renamed_output2", "expected2", + "renamed_output3", "expected3") { + case Row(output1: Double, expected1: Double, + output2: Double, expected2: Double, + output3: Double, expected3: Double) => + assert(output1 === expected1) + assert(output2 === expected2) + assert(output3 === expected3) + } - val output = model.transform(df_renamed) - assert(output.filter(col("renamed_output2") === col("continuousExpected2")).count() ==0) + } + + test("TargetEncoder - wrong number of features") { + + val df = spark + .createDataFrame(sc.parallelize(data_binary), schema) + + val encoder = new TargetEncoder() + .setLabelCol("label") + .setTargetType(TargetEncoder.TARGET_BINARY) + .setInputCols(Array("input1", "input2", "input3")) + .setOutputCols(Array("output1", "output2", "output3")) + + val model = encoder.fit(df) + + val ex = intercept[SparkException] { + val output = model + .setInputCols(Array("input1", "input2")) + .setOutputCols(Array("output1", "output2")) + .transform(df) + output.show() + } + + assert(ex.isInstanceOf[SparkException]) + assert(ex.getMessage.contains( + "does not match the number of encodings in the model (3). Found 2 features")) } diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index c409a330cd8a8..29f49d51ee36d 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -5380,6 +5380,13 @@ def setParams( kwargs = self._input_kwargs return self._set(**kwargs) + @since("4.0.0") + def setLabelCol(self, value: str) -> "TargetEncoder": + """ + Sets the value of :py:attr:`labelCol`. + """ + return self._set(labelCol=value) + @since("4.0.0") def setInputCols(self, value: List[str]) -> "TargetEncoder": """ @@ -5477,13 +5484,6 @@ def setHandleInvalid(self, value: str) -> "TargetEncoderModel": """ return self._set(handleInvalid=value) - @since("4.0.0") - def setTargetType(self, value: str) -> "TargetEncoderModel": - """ - Sets the value of :py:attr:`targetType`. - """ - return self._set(targetType=value) - @since("4.0.0") def setSmoothing(self, value: float) -> "TargetEncoderModel": """ @@ -5493,12 +5493,12 @@ def setSmoothing(self, value: float) -> "TargetEncoderModel": @property @since("4.0.0") - def encodings(self) -> List[Dict[float, float]]: + def stats(self) -> List[Dict[float, float]]: """ Fitted mappings for each feature to being encoded. The dictionary contains a dictionary for each input column. """ - return self._call_java("encodings") + return self._call_java("stats") @inherit_doc diff --git a/python/pyspark/ml/tests/test_feature.py b/python/pyspark/ml/tests/test_feature.py index 666ed1c4269e1..92919adecd069 100644 --- a/python/pyspark/ml/tests/test_feature.py +++ b/python/pyspark/ml/tests/test_feature.py @@ -383,6 +383,83 @@ def test_target_encoder_binary(self): Row(input1=2, input2=3, input3=9.0, output1=1.0 / 3, output2=0.0, output3=0.0), ] self.assertEqual(actual, expected) + te = model.setSmoothing(1.0).transform(df) + actual = te.drop("label").collect() + expected = [ + Row( + input1=0, + input2=3, + input3=5.0, + output1=(3 / 4) * (1 / 3) + (1 - 3 / 4) * (4 / 9), + output2=(1 - 5 / 6) * (4 / 9), + output3=(3 / 4) * (1 / 3) + (1 - 3 / 4) * (4 / 9), + ), + Row( + input1=1, + input2=4, + input3=5.0, + output1=(3 / 4) * (2 / 3) + (1 - 3 / 4) * (4 / 9), + output2=(4 / 5) * 1 + (1 - 4 / 5) * (4 / 9), + output3=(3 / 4) * (1 / 3) + (1 - 3 / 4) * (4 / 9), + ), + Row( + input1=2, + input2=3, + input3=5.0, + output1=(3 / 4) * (1 / 3) + (1 - 3 / 4) * (4 / 9), + output2=(1 - 5 / 6) * (4 / 9), + output3=(3 / 4) * (1 / 3) + (1 - 3 / 4) * (4 / 9), + ), + Row( + input1=0, + input2=4, + input3=6.0, + output1=(3 / 4) * (1 / 3) + (1 - 3 / 4) * (4 / 9), + output2=(4 / 5) * 1 + (1 - 4 / 5) * (4 / 9), + output3=(3 / 4) * (2 / 3) + (1 - 3 / 4) * (4 / 9), + ), + Row( + input1=1, + input2=3, + input3=6.0, + output1=(3 / 4) * (2 / 3) + (1 - 3 / 4) * (4 / 9), + output2=(1 - 5 / 6) * (4 / 9), + output3=(3 / 4) * (2 / 3) + (1 - 3 / 4) * (4 / 9), + ), + Row( + input1=2, + input2=4, + input3=6.0, + output1=(3 / 4) * (1 / 3) + (1 - 3 / 4) * (4 / 9), + output2=(4 / 5) * 1 + (1 - 4 / 5) * (4 / 9), + output3=(3 / 4) * (2 / 3) + (1 - 3 / 4) * (4 / 9), + ), + Row( + input1=0, + input2=3, + input3=7.0, + output1=(3 / 4) * (1 / 3) + (1 - 3 / 4) * (4 / 9), + output2=(1 - 5 / 6) * (4 / 9), + output3=(1 - 1 / 2) * (4 / 9), + ), + Row( + input1=1, + input2=4, + input3=8.0, + output1=(3 / 4) * (2 / 3) + (1 - 3 / 4) * (4 / 9), + output2=(4 / 5) * 1 + (1 - 4 / 5) * (4 / 9), + output3=(1 / 2) + (1 - 1 / 2) * (4 / 9), + ), + Row( + input1=2, + input2=3, + input3=9.0, + output1=(3 / 4) * (1 / 3) + (1 - 3 / 4) * (4 / 9), + output2=(1 - 5 / 6) * (4 / 9), + output3=(1 - 1 / 2) * (4 / 9), + ), + ] + self.assertEqual(actual, expected) def test_target_encoder_continuous(self): df = self.spark.createDataFrame( @@ -420,6 +497,20 @@ def test_target_encoder_continuous(self): Row(input1=2, input2=3, input3=9.0, output1=60.0, output2=50.0, output3=90.0), ] self.assertEqual(actual, expected) + te = model.setSmoothing(1.0).transform(df) + actual = te.drop("label").collect() + expected = [ + Row(input1=0, input2=3, input3=5.0, output1=42.5, output2=50.0, output3=27.5), + Row(input1=1, input2=4, input3=5.0, output1=50.0, output2=50.0, output3=27.5), + Row(input1=2, input2=3, input3=5.0, output1=57.5, output2=50.0, output3=27.5), + Row(input1=0, input2=4, input3=6.0, output1=42.5, output2=50.0, output3=50.0), + Row(input1=1, input2=3, input3=6.0, output1=50.0, output2=50.0, output3=50.0), + Row(input1=2, input2=4, input3=6.0, output1=57.5, output2=50.0, output3=50.0), + Row(input1=0, input2=3, input3=7.0, output1=42.5, output2=50.0, output3=60.0), + Row(input1=1, input2=4, input3=8.0, output1=50.0, output2=50.0, output3=65.0), + Row(input1=2, input2=3, input3=9.0, output1=57.5, output2=50.0, output3=70.0), + ] + self.assertEqual(actual, expected) def test_vector_size_hint(self): df = self.spark.createDataFrame( From 32adb85313cbbd65e8547337dbbdd8deeaee5f21 Mon Sep 17 00:00:00 2001 From: Enrique Rebollo Date: Mon, 28 Oct 2024 20:44:30 +0100 Subject: [PATCH 7/9] [SPARK-37178][ML] disregard NaN-labeled observations --- .../spark/ml/feature/TargetEncoder.scala | 98 +++++++++++-------- .../spark/ml/feature/TargetEncoderSuite.scala | 57 +++++------ 2 files changed, 79 insertions(+), 76 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/TargetEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/TargetEncoder.scala index f7f848b12f17a..575c535f6ada2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/TargetEncoder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/TargetEncoder.scala @@ -77,13 +77,21 @@ private[ml] trait TargetEncoderBase extends Params with HasLabelCol final def getSmoothing: Double = $(smoothing) - private[feature] lazy val inputFeatures = if (isSet(inputCol)) Array($(inputCol)) - else if (isSet(inputCols)) $(inputCols) - else Array.empty[String] - - private[feature] lazy val outputFeatures = if (isSet(outputCol)) Array($(outputCol)) - else if (isSet(outputCols)) $(outputCols) - else inputFeatures.map{field: String => s"${field}_indexed"} + private[feature] lazy val inputFeatures = if (isSet(inputCol)) { + Array($(inputCol)) + } else if (isSet(inputCols)) { + $(inputCols) + } else { + Array.empty[String] + } + + private[feature] lazy val outputFeatures = if (isSet(outputCol)) { + Array($(outputCol)) + } else if (isSet(outputCols)) { + $(outputCols) + } else { + inputFeatures.map{field: String => s"${field}_indexed"} + } private[feature] def validateSchema( schema: StructType, @@ -192,48 +200,52 @@ class TargetEncoder @Since("4.0.0") (@Since("4.0.0") override val uid: String) Array.fill(inputFeatures.length) { Map.empty[Option[Double], (Double, Double)] })( + (agg, row: Row) => if (!row.isNullAt(inputFeatures.length)) { val label = row.getDouble(inputFeatures.length) - inputFeatures.indices.map { - feature => { - val category: Option[Double] = { - if (row.isNullAt(feature)) None // null category - else { - val value = row.getDouble(feature) - if (value < 0.0 || value != value.toInt) throw new SparkException( - s"Values from column ${inputFeatures(feature)} must be indices, " + - s"but got $value.") - else Some(value) // non-null category + if (!label.equals(Double.NaN)) { + inputFeatures.indices.map { + feature => { + val category: Option[Double] = { + if (row.isNullAt(feature)) None // null category + else { + val value = row.getDouble(feature) + if (value < 0.0 || value != value.toInt) throw new SparkException( + s"Values from column ${inputFeatures(feature)} must be indices, " + + s"but got $value.") + else Some(value) // non-null category + } } - } - val (class_count, class_stat) = agg(feature).getOrElse(category, (0.0, 0.0)) - val (global_count, global_stat) = - agg(feature).getOrElse(TargetEncoder.UNSEEN_CATEGORY, (0.0, 0.0)) - $(targetType) match { - case TargetEncoder.TARGET_BINARY => // counting - if (label == 1.0) { - // positive => increment both counters for current & unseen categories - agg(feature) + - (category -> (1 + class_count, 1 + class_stat)) + - (TargetEncoder.UNSEEN_CATEGORY -> (1 + global_count, 1 + global_stat)) - } else if (label == 0.0) { - // negative => increment only global counter for current & unseen categories + val (class_count, class_stat) = agg(feature).getOrElse(category, (0.0, 0.0)) + val (global_count, global_stat) = + agg(feature).getOrElse(TargetEncoder.UNSEEN_CATEGORY, (0.0, 0.0)) + $(targetType) match { + case TargetEncoder.TARGET_BINARY => // counting + if (label == 1.0) { + // positive => increment both counters for current & unseen categories + agg(feature) + + (category -> (1 + class_count, 1 + class_stat)) + + (TargetEncoder.UNSEEN_CATEGORY -> (1 + global_count, 1 + global_stat)) + } else if (label == 0.0) { + // negative => increment only global counter for current & unseen categories + agg(feature) + + (category -> (1 + class_count, class_stat)) + + (TargetEncoder.UNSEEN_CATEGORY -> (1 + global_count, global_stat)) + } else throw new SparkException( + s"Values from column ${getLabelCol} must be binary (0,1) but got $label.") + case TargetEncoder.TARGET_CONTINUOUS => // incremental mean + // increment counter and iterate on mean for current & unseen categories agg(feature) + - (category -> (1 + class_count, class_stat)) + - (TargetEncoder.UNSEEN_CATEGORY -> (1 + global_count, global_stat)) - } else throw new SparkException( - s"Values from column ${getLabelCol} must be binary (0,1) but got $label.") - case TargetEncoder.TARGET_CONTINUOUS => // incremental mean - // increment counter and iterate on mean for current & unseen categories - agg(feature) + - (category -> (1 + class_count, - class_stat + ((label - class_stat) / (1 + class_count)))) + - (TargetEncoder.UNSEEN_CATEGORY -> (1 + global_count, - global_stat + ((label - global_stat) / (1 + global_count)))) + (category -> (1 + class_count, + class_stat + ((label - class_stat) / (1 + class_count)))) + + (TargetEncoder.UNSEEN_CATEGORY -> (1 + global_count, + global_stat + ((label - global_stat) / (1 + global_count)))) + } } - } - }.toArray + }.toArray + } else agg // ignore NaN-labeled observations } else agg, // ignore null-labeled observations + (agg1, agg2) => inputFeatures.indices.map { feature => { val categories = agg1(feature).keySet ++ agg2(feature).keySet diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/TargetEncoderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/TargetEncoderSuite.scala index c7560eaedc44a..f53cda625f155 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/TargetEncoderSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/TargetEncoderSuite.scala @@ -33,6 +33,8 @@ class TargetEncoderSuite extends MLTest with DefaultReadWriteTest { @transient var data_binary: Seq[Row] = _ @transient var data_continuous: Seq[Row] = _ @transient var schema: StructType = _ + @transient var expected_stats_binary: Array[Map[Option[Double], (Double, Double)]] = _ + @transient var expected_stats_continuous: Array[Map[Option[Double], (Double, Double)]] = _ override def beforeAll(): Unit = { super.beforeAll() @@ -59,7 +61,6 @@ class TargetEncoderSuite extends MLTest with DefaultReadWriteTest { Row(0.toShort, 3, 7.0, 70.0, 40.0, 50.0, 70.0, 42.5, 50.0, 60.0), Row(1.toShort, 4, 8.0, 80.0, 50.0, 50.0, 80.0, 50.0, 50.0, 65.0), Row(2.toShort, 3, 9.0, 90.0, 60.0, 50.0, 90.0, 57.5, 50.0, 70.0)) - // scalastyle:on schema = StructType(Array( StructField("input1", ShortType, nullable = true), @@ -72,6 +73,17 @@ class TargetEncoderSuite extends MLTest with DefaultReadWriteTest { StructField("smoothing1", DoubleType), StructField("smoothing2", DoubleType), StructField("smoothing3", DoubleType))) + + expected_stats_binary = Array( + Map(Some(0.0) -> (3.0, 1.0), Some(1.0) -> (3.0, 2.0), Some(2.0) -> (3.0, 1.0), Some(-1.0) -> (9.0, 4.0)), + Map(Some(3.0) -> (5.0, 0.0), Some(4.0) -> (4.0, 4.0), Some(-1.0) -> (9.0, 4.0)), + HashMap(Some(5.0) -> (3.0, 1.0), Some(6.0) -> (3.0, 2.0), Some(7.0) -> (1.0, 0.0), Some(8.0) -> (1.0, 1.0), Some(9.0) -> (1.0, 0.0), Some(-1.0) -> (9.0, 4.0))) + + expected_stats_continuous = Array( + Map(Some(0.0) -> (3.0, 40.0), Some(1.0) -> (3.0, 50.0), Some(2.0) -> (3.0, 60.0), Some(-1.0) -> (9.0, 50.0)), + Map(Some(3.0) -> (5.0, 50.0), Some(4.0) -> (4.0, 50.0), Some(-1.0) -> (9.0, 50.0)), + HashMap(Some(5.0) -> (3.0, 20.0), Some(6.0) -> (3.0, 50.0), Some(7.0) -> (1.0, 70.0), Some(8.0) -> (1.0, 80.0), Some(9.0) -> (1.0, 90.0), Some(-1.0) -> (9.0, 50.0))) + // scalastyle:on } test("params") { @@ -90,14 +102,7 @@ class TargetEncoderSuite extends MLTest with DefaultReadWriteTest { val model = encoder.fit(df) - val expected_stats = Array( - Map(Some(0.0) -> (3.0, 1.0), Some(1.0) -> (3.0, 2.0), Some(2.0) -> (3.0, 1.0), - Some(-1.0) -> (9.0, 4.0)), - Map(Some(3.0) -> (5.0, 0.0), Some(4.0) -> (4.0, 4.0), Some(-1.0) -> (9.0, 4.0)), - HashMap(Some(5.0) -> (3.0, 1.0), Some(6.0) -> (3.0, 2.0), Some(7.0) -> (1.0, 0.0), - Some(8.0) -> (1.0, 1.0), Some(9.0) -> (1.0, 0.0), Some(-1.0) -> (9.0, 4.0))) - - model.stats.zip(expected_stats).foreach{ + model.stats.zip(expected_stats_binary).foreach{ case (actual, expected) => assert(actual.equals(expected)) } @@ -149,14 +154,7 @@ class TargetEncoderSuite extends MLTest with DefaultReadWriteTest { val model = encoder.fit(df) - val expected_stats = Array( - Map(Some(0.0) -> (3.0, 40.0), Some(1.0) -> (3.0, 50.0), Some(2.0) -> (3.0, 60.0), - Some(-1.0) -> (9.0, 50.0)), - Map(Some(3.0) -> (5.0, 50.0), Some(4.0) -> (4.0, 50.0), Some(-1.0) -> (9.0, 50.0)), - HashMap(Some(5.0) -> (3.0, 20.0), Some(6.0) -> (3.0, 50.0), Some(7.0) -> (1.0, 70.0), - Some(8.0) -> (1.0, 80.0), Some(9.0) -> (1.0, 90.0), Some(-1.0) -> (9.0, 50.0))) - - model.stats.zip(expected_stats).foreach{ + model.stats.zip(expected_stats_continuous).foreach{ case (actual, expected) => assert(actual.equals(expected)) } @@ -320,11 +318,9 @@ class TargetEncoderSuite extends MLTest with DefaultReadWriteTest { val model = encoder.fit(df_null) val expected_stats = Array( - Map(Some(0.0) -> (3.0, 40.0), Some(1.0) -> (3.0, 50.0), Some(2.0) -> (3.0, 60.0), - Some(-1.0) -> (9.0, 50.0)), - Map(Some(3.0) -> (5.0, 50.0), Some(4.0) -> (4.0, 50.0), Some(-1.0) -> (9.0, 50.0)), - HashMap(Some(5.0) -> (3.0, 20.0), Some(6.0) -> (3.0, 50.0), Some(7.0) -> (1.0, 70.0), - Some(8.0) -> (1.0, 80.0), None -> (1.0, 90.0), Some(-1.0) -> (9.0, 50.0))) + expected_stats_continuous(0), + expected_stats_continuous(1), + expected_stats_continuous(2) - Some(9.0) + (None -> (1.0, 90.0))) model.stats.zip(expected_stats).foreach{ case (actual, expected) => assert(actual.equals(expected)) @@ -391,12 +387,14 @@ class TargetEncoderSuite extends MLTest with DefaultReadWriteTest { } - test("TargetEncoder - null label") { + test("TargetEncoder - invalid label") { - val data_nolabel = Row(2.toShort, 3, 5.0, null, 60.0, 50.0, 90.0, 57.5, 50.0, 70.0) + val data_null = Row(2.toShort, 3, 5.0, null, 160.0, 150.0, 190.0, 57.5, 50.0, 70.0) + val data_nan = Row(1.toShort, 2, 6.0, Double.NaN, 160.0, 150.0, 190.0, 57.5, 50.0, 70.0) val df_nolabel = spark - .createDataFrame(sc.parallelize(data_continuous :+ data_nolabel), schema) + .createDataFrame(sc.parallelize( + data_continuous :+ data_null :+ data_nan), schema) val encoder = new TargetEncoder() .setLabelCol("label") @@ -406,14 +404,7 @@ class TargetEncoderSuite extends MLTest with DefaultReadWriteTest { val model = encoder.fit(df_nolabel) - val expected_stats = Array( - Map(Some(0.0) -> (3.0, 40.0), Some(1.0) -> (3.0, 50.0), Some(2.0) -> (3.0, 60.0), - Some(-1.0) -> (9.0, 50.0)), - Map(Some(3.0) -> (5.0, 50.0), Some(4.0) -> (4.0, 50.0), Some(-1.0) -> (9.0, 50.0)), - HashMap(Some(5.0) -> (3.0, 20.0), Some(6.0) -> (3.0, 50.0), Some(7.0) -> (1.0, 70.0), - Some(8.0) -> (1.0, 80.0), Some(9.0) -> (1.0, 90.0), Some(-1.0) -> (9.0, 50.0))) - - model.stats.zip(expected_stats).foreach{ + model.stats.zip(expected_stats_continuous).foreach{ case (actual, expected) => assert(actual.equals(expected)) } From 7ca04f34626988cc823003f7ce81330d97245379 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Tue, 5 Nov 2024 14:09:40 -0800 Subject: [PATCH 8/9] nits --- .../examples/ml/JavaTargetEncoderExample.java | 8 +- .../spark/ml/feature/TargetEncoder.scala | 38 +++--- .../ml/feature/JavaTargetEncoderSuite.java | 122 +++++++++--------- 3 files changed, 84 insertions(+), 84 deletions(-) diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaTargetEncoderExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaTargetEncoderExample.java index da391bd469192..460f0d5a51e69 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaTargetEncoderExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaTargetEncoderExample.java @@ -74,10 +74,10 @@ public static void main(String[] args) { // continuous target TargetEncoder cont_encoder = new TargetEncoder() - .setInputCols(new String[] {"categoryIndex1", "categoryIndex2"}) - .setOutputCols(new String[] {"categoryIndex1Target", "categoryIndex2Target"}) - .setLabelCol("continuousLabel") - .setTargetType("continuous"); + .setInputCols(new String[] {"categoryIndex1", "categoryIndex2"}) + .setOutputCols(new String[] {"categoryIndex1Target", "categoryIndex2Target"}) + .setLabelCol("continuousLabel") + .setTargetType("continuous"); TargetEncoderModel cont_model = cont_encoder.fit(df); Dataset cont_encoded = cont_model.transform(df); diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/TargetEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/TargetEncoder.scala index 575c535f6ada2..85f4a0d566f6e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/TargetEncoder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/TargetEncoder.scala @@ -77,25 +77,25 @@ private[ml] trait TargetEncoderBase extends Params with HasLabelCol final def getSmoothing: Double = $(smoothing) - private[feature] lazy val inputFeatures = if (isSet(inputCol)) { - Array($(inputCol)) - } else if (isSet(inputCols)) { - $(inputCols) - } else { - Array.empty[String] - } - - private[feature] lazy val outputFeatures = if (isSet(outputCol)) { - Array($(outputCol)) - } else if (isSet(outputCols)) { - $(outputCols) - } else { - inputFeatures.map{field: String => s"${field}_indexed"} - } - - private[feature] def validateSchema( - schema: StructType, - fitting: Boolean): StructType = { + private[feature] lazy val inputFeatures = + if (isSet(inputCol)) { + Array($(inputCol)) + } else if (isSet(inputCols)) { + $(inputCols) + } else { + Array.empty[String] + } + + private[feature] lazy val outputFeatures = + if (isSet(outputCol)) { + Array($(outputCol)) + } else if (isSet(outputCols)) { + $(outputCols) + } else { + inputFeatures.map{field: String => s"${field}_indexed"} + } + + private[feature] def validateSchema(schema: StructType, fitting: Boolean): StructType = { require(inputFeatures.length > 0, s"At least one input column must be specified.") diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaTargetEncoderSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaTargetEncoderSuite.java index ae78780d1f067..8044d3a1cb4df 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaTargetEncoderSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaTargetEncoderSuite.java @@ -38,55 +38,55 @@ public void testTargetEncoderBinary() { // checkstyle.off: LineLength List data = Arrays.asList( - RowFactory.create((short)0, 3, 5.0, 0.0, 1.0/3, 0.0, 1.0/3, (3.0/4)*(1.0/3)+(1-3.0/4)*(4.0/9), - (1-5.0/6)*(4.0/9), (3.0/4)*(1.0/3)+(1-3.0/4)*(4.0/9)), - RowFactory.create((short)1, 4, 5.0, 1.0, 2.0/3, 1.0, 1.0/3, (3.0/4)*(2.0/3)+(1-3.0/4)*(4.0/9), - (4.0/5)*1+(1-4.0/5)*(4.0/9), (3.0/4)*(1.0/3)+(1-3.0/4)*(4.0/9)), - RowFactory.create((short)2, 3, 5.0, 0.0, 1.0/3, 0.0, 1.0/3, (3.0/4)*(1.0/3)+(1-3.0/4)*(4.0/9), - (1-5.0/6)*(4.0/9), (3.0/4)*(1.0/3)+(1-3.0/4)*(4.0/9)), - RowFactory.create((short)0, 4, 6.0, 1.0, 1.0/3, 1.0, 2.0/3, (3.0/4)*(1.0/3)+(1-3.0/4)*(4.0/9), - (4.0/5)*1+(1-4.0/5)*(4.0/9), (3.0/4)*(2.0/3)+(1-3.0/4)*(4.0/9)), - RowFactory.create((short)1, 3, 6.0, 0.0, 2.0/3, 0.0, 2.0/3, (3.0/4)*(2.0/3)+(1-3.0/4)*(4.0/9), - (1-5.0/6)*(4.0/9), (3.0/4)*(2.0/3)+(1-3.0/4)*(4.0/9)), - RowFactory.create((short)2, 4, 6.0, 1.0, 1.0/3, 1.0, 2.0/3, (3.0/4)*(1.0/3)+(1-3.0/4)*(4.0/9), - (4.0/5)*1+(1-4.0/5)*(4.0/9), (3.0/4)*(2.0/3)+(1-3.0/4)*(4.0/9)), - RowFactory.create((short)0, 3, 7.0, 0.0, 1.0/3, 0.0, 0.0, (3.0/4)*(1.0/3)+(1-3.0/4)*(4.0/9), - (1-5.0/6)*(4.0/9), (1-1.0/2)*(4.0/9)), - RowFactory.create((short)1, 4, 8.0, 1.0, 2.0/3, 1.0, 1.0, (3.0/4)*(2.0/3)+(1-3.0/4)*(4.0/9), - (4.0/5)*1+(1-4.0/5)*(4.0/9), (1.0/2)+(1-1.0/2)*(4.0/9)), - RowFactory.create((short)2, 3, null, 0.0, 1.0/3, 0.0, 0.0, (3.0/4)*(1.0/3)+(1-3.0/4)*(4.0/9), - (1-5.0/6)*(4.0/9), (1-1.0/2)*(4.0/9))); + RowFactory.create((short) 0, 3, 5.0, 0.0, 1.0 / 3, 0.0, 1.0 / 3, (3.0 / 4) * (1.0 / 3) + (1 - 3.0 / 4) * (4.0 / 9), + (1 - 5.0 / 6) * (4.0 / 9), (3.0 / 4) * (1.0 / 3) + (1 - 3.0 / 4) * (4.0 / 9)), + RowFactory.create((short) 1, 4, 5.0, 1.0, 2.0 / 3, 1.0, 1.0 / 3, (3.0 / 4) * (2.0 / 3) + (1 - 3.0 / 4) * (4.0 / 9), + (4.0 / 5) * 1 + (1 - 4.0 / 5) * (4.0 / 9), (3.0 / 4) * (1.0 / 3) + (1 - 3.0 / 4) * (4.0 / 9)), + RowFactory.create((short) 2, 3, 5.0, 0.0, 1.0 / 3, 0.0, 1.0 / 3, (3.0 / 4) * (1.0 / 3) + (1 - 3.0 / 4) * (4.0 / 9), + (1 - 5.0 / 6) * (4.0 / 9), (3.0 / 4) * (1.0 / 3) + (1 - 3.0 / 4) * (4.0 / 9)), + RowFactory.create((short) 0, 4, 6.0, 1.0, 1.0 / 3, 1.0, 2.0 / 3, (3.0 / 4) * (1.0 / 3) + (1 - 3.0 / 4) * (4.0 / 9), + (4.0 / 5) * 1 + (1 - 4.0 / 5) * (4.0 / 9), (3.0 / 4) * (2.0 / 3) + (1 - 3.0 / 4) * (4.0 / 9)), + RowFactory.create((short) 1, 3, 6.0, 0.0, 2.0 / 3, 0.0, 2.0 / 3, (3.0 / 4) * (2.0 / 3) + (1 - 3.0 / 4) * (4.0 / 9), + (1 - 5.0 / 6) * (4.0 / 9), (3.0 / 4) * (2.0 / 3) + (1 - 3.0 / 4) * (4.0 / 9)), + RowFactory.create((short) 2, 4, 6.0, 1.0, 1.0 / 3, 1.0, 2.0 / 3, (3.0 / 4) * (1.0 / 3) + (1 - 3.0 / 4) * (4.0 / 9), + (4.0 / 5) * 1 + (1 - 4.0 / 5) * (4.0 / 9), (3.0 / 4) * (2.0 / 3) + (1 - 3.0 / 4) * (4.0 / 9)), + RowFactory.create((short) 0, 3, 7.0, 0.0, 1.0 / 3, 0.0, 0.0, (3.0 / 4) * (1.0 / 3) + (1 - 3.0 / 4) * (4.0 / 9), + (1 - 5.0 / 6) * (4.0 / 9), (1 - 1.0 / 2) * (4.0 / 9)), + RowFactory.create((short) 1, 4, 8.0, 1.0, 2.0 / 3, 1.0, 1.0, (3.0 / 4) * (2.0 / 3) + (1 - 3.0 / 4) * (4.0 / 9), + (4.0 / 5) * 1 + (1 - 4.0 / 5) * (4.0 / 9), (1.0 / 2) + (1 - 1.0 / 2) * (4.0 / 9)), + RowFactory.create((short) 2, 3, null, 0.0, 1.0 / 3, 0.0, 0.0, (3.0 / 4) * (1.0 / 3) + (1 - 3.0 / 4) * (4.0 / 9), + (1 - 5.0 / 6) * (4.0 / 9), (1 - 1.0 / 2) * (4.0 / 9))); // checkstyle.off: LineLength StructType schema = createStructType(new StructField[]{ - createStructField("input1", ShortType, true), - createStructField("input2", IntegerType, true), - createStructField("input3", DoubleType, true), - createStructField("label", DoubleType, false), - createStructField("expected1", DoubleType, false), - createStructField("expected2", DoubleType, false), - createStructField("expected3", DoubleType, false), - createStructField("smoothing1", DoubleType, false), - createStructField("smoothing2", DoubleType, false), - createStructField("smoothing3", DoubleType, false) + createStructField("input1", ShortType, true), + createStructField("input2", IntegerType, true), + createStructField("input3", DoubleType, true), + createStructField("label", DoubleType, false), + createStructField("expected1", DoubleType, false), + createStructField("expected2", DoubleType, false), + createStructField("expected3", DoubleType, false), + createStructField("smoothing1", DoubleType, false), + createStructField("smoothing2", DoubleType, false), + createStructField("smoothing3", DoubleType, false) }); Dataset dataset = spark.createDataFrame(data, schema); TargetEncoder encoder = new TargetEncoder() - .setInputCols(new String[]{"input1", "input2", "input3"}) - .setOutputCols(new String[]{"output1", "output2", "output3"}) - .setTargetType("binary"); + .setInputCols(new String[]{"input1", "input2", "input3"}) + .setOutputCols(new String[]{"output1", "output2", "output3"}) + .setTargetType("binary"); TargetEncoderModel model = encoder.fit(dataset); Dataset output = model.transform(dataset); Assertions.assertEquals( - output.select("output1", "output2", "output3").collectAsList(), - output.select("expected1", "expected2", "expected3").collectAsList()); + output.select("output1", "output2", "output3").collectAsList(), + output.select("expected1", "expected2", "expected3").collectAsList()); Dataset output_smoothing = model.setSmoothing(1.0).transform(dataset); Assertions.assertEquals( - output_smoothing.select("output1", "output2", "output3").collectAsList(), - output_smoothing.select("smoothing1", "smoothing2", "smoothing3").collectAsList()); + output_smoothing.select("output1", "output2", "output3").collectAsList(), + output_smoothing.select("smoothing1", "smoothing2", "smoothing3").collectAsList()); } @@ -94,46 +94,46 @@ public void testTargetEncoderBinary() { public void testTargetEncoderContinuous() { List data = Arrays.asList( - RowFactory.create((short)0, 3, 5.0, 10.0, 40.0, 50.0, 20.0, 42.5, 50.0, 27.5), - RowFactory.create((short)1, 4, 5.0, 20.0, 50.0, 50.0, 20.0, 50.0, 50.0, 27.5), - RowFactory.create((short)2, 3, 5.0, 30.0, 60.0, 50.0, 20.0, 57.5, 50.0, 27.5), - RowFactory.create((short)0, 4, 6.0, 40.0, 40.0, 50.0, 50.0, 42.5, 50.0, 50.0), - RowFactory.create((short)1, 3, 6.0, 50.0, 50.0, 50.0, 50.0, 50.0, 50.0, 50.0), - RowFactory.create((short)2, 4, 6.0, 60.0, 60.0, 50.0, 50.0, 57.5, 50.0, 50.0), - RowFactory.create((short)0, 3, 7.0, 70.0, 40.0, 50.0, 70.0, 42.5, 50.0, 60.0), - RowFactory.create((short)1, 4, 8.0, 80.0, 50.0, 50.0, 80.0, 50.0, 50.0, 65.0), - RowFactory.create((short)2, 3, null, 90.0, 60.0, 50.0, 90.0, 57.5, 50.0, 70.0)); + RowFactory.create((short) 0, 3, 5.0, 10.0, 40.0, 50.0, 20.0, 42.5, 50.0, 27.5), + RowFactory.create((short) 1, 4, 5.0, 20.0, 50.0, 50.0, 20.0, 50.0, 50.0, 27.5), + RowFactory.create((short) 2, 3, 5.0, 30.0, 60.0, 50.0, 20.0, 57.5, 50.0, 27.5), + RowFactory.create((short) 0, 4, 6.0, 40.0, 40.0, 50.0, 50.0, 42.5, 50.0, 50.0), + RowFactory.create((short) 1, 3, 6.0, 50.0, 50.0, 50.0, 50.0, 50.0, 50.0, 50.0), + RowFactory.create((short) 2, 4, 6.0, 60.0, 60.0, 50.0, 50.0, 57.5, 50.0, 50.0), + RowFactory.create((short) 0, 3, 7.0, 70.0, 40.0, 50.0, 70.0, 42.5, 50.0, 60.0), + RowFactory.create((short) 1, 4, 8.0, 80.0, 50.0, 50.0, 80.0, 50.0, 50.0, 65.0), + RowFactory.create((short) 2, 3, null, 90.0, 60.0, 50.0, 90.0, 57.5, 50.0, 70.0)); StructType schema = createStructType(new StructField[]{ - createStructField("input1", ShortType, true), - createStructField("input2", IntegerType, true), - createStructField("input3", DoubleType, true), - createStructField("label", DoubleType, false), - createStructField("expected1", DoubleType, false), - createStructField("expected2", DoubleType, false), - createStructField("expected3", DoubleType, false), - createStructField("smoothing1", DoubleType, false), - createStructField("smoothing2", DoubleType, false), - createStructField("smoothing3", DoubleType, false) + createStructField("input1", ShortType, true), + createStructField("input2", IntegerType, true), + createStructField("input3", DoubleType, true), + createStructField("label", DoubleType, false), + createStructField("expected1", DoubleType, false), + createStructField("expected2", DoubleType, false), + createStructField("expected3", DoubleType, false), + createStructField("smoothing1", DoubleType, false), + createStructField("smoothing2", DoubleType, false), + createStructField("smoothing3", DoubleType, false) }); Dataset dataset = spark.createDataFrame(data, schema); TargetEncoder encoder = new TargetEncoder() - .setInputCols(new String[]{"input1", "input2", "input3"}) - .setOutputCols(new String[]{"output1", "output2", "output3"}) - .setTargetType("continuous"); + .setInputCols(new String[]{"input1", "input2", "input3"}) + .setOutputCols(new String[]{"output1", "output2", "output3"}) + .setTargetType("continuous"); TargetEncoderModel model = encoder.fit(dataset); Dataset output = model.transform(dataset); Assertions.assertEquals( - output.select("output1", "output2", "output3").collectAsList(), - output.select("expected1", "expected2", "expected3").collectAsList()); + output.select("output1", "output2", "output3").collectAsList(), + output.select("expected1", "expected2", "expected3").collectAsList()); Dataset output_smoothing = model.setSmoothing(1.0).transform(dataset); Assertions.assertEquals( - output_smoothing.select("output1", "output2", "output3").collectAsList(), - output_smoothing.select("smoothing1", "smoothing2", "smoothing3").collectAsList()); + output_smoothing.select("output1", "output2", "output3").collectAsList(), + output_smoothing.select("smoothing1", "smoothing2", "smoothing3").collectAsList()); } From 6236bd095d9c5e641f91f3b1712d6919030e4a46 Mon Sep 17 00:00:00 2001 From: Enrique Rebollo Date: Wed, 6 Nov 2024 20:19:26 +0100 Subject: [PATCH 9/9] [SPARK-37178][ML] changed category datatype to Double --- .../spark/ml/feature/TargetEncoder.scala | 42 +++++++-------- .../ml/feature/JavaTargetEncoderSuite.java | 51 ++++++++++++------- .../spark/ml/feature/TargetEncoderSuite.scala | 22 ++++---- python/pyspark/ml/feature.py | 6 +-- 4 files changed, 70 insertions(+), 51 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/TargetEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/TargetEncoder.scala index 85f4a0d566f6e..9afb88afec932 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/TargetEncoder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/TargetEncoder.scala @@ -193,12 +193,12 @@ class TargetEncoder @Since("4.0.0") (@Since("4.0.0") override val uid: String) override def fit(dataset: Dataset[_]): TargetEncoderModel = { validateSchema(dataset.schema, fitting = true) - // stats: Array[Map[Some(category), (counter,stat)]] + // stats: Array[Map[category, (counter,stat)]] val stats = dataset .select((inputFeatures :+ $(labelCol)).map(col(_).cast(DoubleType)).toIndexedSeq: _*) .rdd.treeAggregate( Array.fill(inputFeatures.length) { - Map.empty[Option[Double], (Double, Double)] + Map.empty[Double, (Double, Double)] })( (agg, row: Row) => if (!row.isNullAt(inputFeatures.length)) { @@ -206,14 +206,14 @@ class TargetEncoder @Since("4.0.0") (@Since("4.0.0") override val uid: String) if (!label.equals(Double.NaN)) { inputFeatures.indices.map { feature => { - val category: Option[Double] = { - if (row.isNullAt(feature)) None // null category + val category: Double = { + if (row.isNullAt(feature)) TargetEncoder.NULL_CATEGORY // null category else { val value = row.getDouble(feature) if (value < 0.0 || value != value.toInt) throw new SparkException( s"Values from column ${inputFeatures(feature)} must be indices, " + s"but got $value.") - else Some(value) // non-null category + else value // non-null category } } val (class_count, class_stat) = agg(feature).getOrElse(category, (0.0, 0.0)) @@ -285,7 +285,8 @@ object TargetEncoder extends DefaultParamsReadable[TargetEncoder] { private[feature] val TARGET_CONTINUOUS: String = "continuous" private[feature] val supportedTargetTypes: Array[String] = Array(TARGET_BINARY, TARGET_CONTINUOUS) - private[feature] val UNSEEN_CATEGORY: Option[Double] = Some(-1) + private[feature] val UNSEEN_CATEGORY: Double = Int.MaxValue + private[feature] val NULL_CATEGORY: Double = -1 @Since("4.0.0") override def load(path: String): TargetEncoder = super.load(path) @@ -293,12 +294,12 @@ object TargetEncoder extends DefaultParamsReadable[TargetEncoder] { /** * @param stats Array of statistics for each input feature. - * Array( Map( Some(category), (counter, stat) ) ) + * Array( Map( category, (counter, stat) ) ) */ @Since("4.0.0") class TargetEncoderModel private[ml] ( @Since("4.0.0") override val uid: String, - @Since("4.0.0") val stats: Array[Map[Option[Double], (Double, Double)]]) + @Since("4.0.0") val stats: Array[Map[Double, (Double, Double)]]) extends Model[TargetEncoderModel] with TargetEncoderBase with MLWritable { /** @group setParam */ @@ -342,8 +343,8 @@ class TargetEncoderModel private[ml] ( override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema) - // encodings: Array[Map[Some(category), encoding]] - val encodings: Array[Map[Option[Double], Double]] = + // encodings: Array[Map[category, encoding]] + val encodings: Array[Map[Double, Double]] = stats.map { stat => val (global_count, global_stat) = stat.get(TargetEncoder.UNSEEN_CATEGORY).get @@ -363,11 +364,11 @@ class TargetEncoderModel private[ml] ( } // builds a column-to-column function from a map of encodings - val apply_encodings: Map[Option[Double], Double] => (Column => Column) = - (mappings: Map[Option[Double], Double]) => { + val apply_encodings: Map[Double, Double] => (Column => Column) = + (mappings: Map[Double, Double]) => { (col: Column) => { val nullWhen = when(col.isNull, - mappings.get(None) match { + mappings.get(TargetEncoder.NULL_CATEGORY) match { case Some(code) => lit(code) case None => if ($(handleInvalid) == TargetEncoder.KEEP_INVALID) { lit(mappings.get(TargetEncoder.UNSEEN_CATEGORY).get) @@ -375,15 +376,16 @@ class TargetEncoderModel private[ml] ( s"Unseen null value in feature ${col.toString}. To handle unseen values, " + s"set Param handleInvalid to ${TargetEncoder.KEEP_INVALID}.")) }) - val ordered_mappings = (mappings - None).toList.sortWith { - (a, b) => (b._1 == TargetEncoder.UNSEEN_CATEGORY) || - ((a._1 != TargetEncoder.UNSEEN_CATEGORY) && (a._1.get < b._1.get)) + val ordered_mappings = (mappings - TargetEncoder.NULL_CATEGORY).toList.sortWith { + (a, b) => + (b._1 == TargetEncoder.UNSEEN_CATEGORY) || + ((a._1 != TargetEncoder.UNSEEN_CATEGORY) && (a._1 < b._1)) } ordered_mappings .foldLeft(nullWhen)( (new_col: Column, mapping) => { - val (Some(original), encoded) = mapping - if (original != TargetEncoder.UNSEEN_CATEGORY.get) { + val (original, encoded) = mapping + if (original != TargetEncoder.UNSEEN_CATEGORY) { new_col.when(col === original, lit(encoded)) } else { // unseen category new_col.otherwise( @@ -436,7 +438,7 @@ object TargetEncoderModel extends MLReadable[TargetEncoderModel] { private[TargetEncoderModel] class TargetEncoderModelWriter(instance: TargetEncoderModel) extends MLWriter { - private case class Data(stats: Array[Map[Option[Double], (Double, Double)]]) + private case class Data(stats: Array[Map[Double, (Double, Double)]]) override protected def saveImpl(path: String): Unit = { DefaultParamsWriter.saveMetadata(instance, path, sparkSession) @@ -456,7 +458,7 @@ object TargetEncoderModel extends MLReadable[TargetEncoderModel] { val data = sparkSession.read.parquet(dataPath) .select("encodings") .head() - val stats = data.getAs[Array[Map[Option[Double], (Double, Double)]]](0) + val stats = data.getAs[Array[Map[Double, (Double, Double)]]](0) val model = new TargetEncoderModel(metadata.uid, stats) metadata.getAndSetParams(model) model diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaTargetEncoderSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaTargetEncoderSuite.java index 8044d3a1cb4df..c488cc0dfca14 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaTargetEncoderSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaTargetEncoderSuite.java @@ -38,24 +38,41 @@ public void testTargetEncoderBinary() { // checkstyle.off: LineLength List data = Arrays.asList( - RowFactory.create((short) 0, 3, 5.0, 0.0, 1.0 / 3, 0.0, 1.0 / 3, (3.0 / 4) * (1.0 / 3) + (1 - 3.0 / 4) * (4.0 / 9), - (1 - 5.0 / 6) * (4.0 / 9), (3.0 / 4) * (1.0 / 3) + (1 - 3.0 / 4) * (4.0 / 9)), - RowFactory.create((short) 1, 4, 5.0, 1.0, 2.0 / 3, 1.0, 1.0 / 3, (3.0 / 4) * (2.0 / 3) + (1 - 3.0 / 4) * (4.0 / 9), - (4.0 / 5) * 1 + (1 - 4.0 / 5) * (4.0 / 9), (3.0 / 4) * (1.0 / 3) + (1 - 3.0 / 4) * (4.0 / 9)), - RowFactory.create((short) 2, 3, 5.0, 0.0, 1.0 / 3, 0.0, 1.0 / 3, (3.0 / 4) * (1.0 / 3) + (1 - 3.0 / 4) * (4.0 / 9), - (1 - 5.0 / 6) * (4.0 / 9), (3.0 / 4) * (1.0 / 3) + (1 - 3.0 / 4) * (4.0 / 9)), - RowFactory.create((short) 0, 4, 6.0, 1.0, 1.0 / 3, 1.0, 2.0 / 3, (3.0 / 4) * (1.0 / 3) + (1 - 3.0 / 4) * (4.0 / 9), - (4.0 / 5) * 1 + (1 - 4.0 / 5) * (4.0 / 9), (3.0 / 4) * (2.0 / 3) + (1 - 3.0 / 4) * (4.0 / 9)), - RowFactory.create((short) 1, 3, 6.0, 0.0, 2.0 / 3, 0.0, 2.0 / 3, (3.0 / 4) * (2.0 / 3) + (1 - 3.0 / 4) * (4.0 / 9), - (1 - 5.0 / 6) * (4.0 / 9), (3.0 / 4) * (2.0 / 3) + (1 - 3.0 / 4) * (4.0 / 9)), - RowFactory.create((short) 2, 4, 6.0, 1.0, 1.0 / 3, 1.0, 2.0 / 3, (3.0 / 4) * (1.0 / 3) + (1 - 3.0 / 4) * (4.0 / 9), - (4.0 / 5) * 1 + (1 - 4.0 / 5) * (4.0 / 9), (3.0 / 4) * (2.0 / 3) + (1 - 3.0 / 4) * (4.0 / 9)), - RowFactory.create((short) 0, 3, 7.0, 0.0, 1.0 / 3, 0.0, 0.0, (3.0 / 4) * (1.0 / 3) + (1 - 3.0 / 4) * (4.0 / 9), + RowFactory.create((short) 0, 3, 5.0, 0.0, 1.0 / 3, 0.0, 1.0 / 3, + (3.0 / 4) * (1.0 / 3) + (1 - 3.0 / 4) * (4.0 / 9), + (1 - 5.0 / 6) * (4.0 / 9), + (3.0 / 4) * (1.0 / 3) + (1 - 3.0 / 4) * (4.0 / 9)), + RowFactory.create((short) 1, 4, 5.0, 1.0, 2.0 / 3, 1.0, 1.0 / 3, + (3.0 / 4) * (2.0 / 3) + (1 - 3.0 / 4) * (4.0 / 9), + (4.0 / 5) * 1 + (1 - 4.0 / 5) * (4.0 / 9), + (3.0 / 4) * (1.0 / 3) + (1 - 3.0 / 4) * (4.0 / 9)), + RowFactory.create((short) 2, 3, 5.0, 0.0, 1.0 / 3, 0.0, 1.0 / 3, + (3.0 / 4) * (1.0 / 3) + (1 - 3.0 / 4) * (4.0 / 9), + (1 - 5.0 / 6) * (4.0 / 9), + (3.0 / 4) * (1.0 / 3) + (1 - 3.0 / 4) * (4.0 / 9)), + RowFactory.create((short) 0, 4, 6.0, 1.0, 1.0 / 3, 1.0, 2.0 / 3, + (3.0 / 4) * (1.0 / 3) + (1 - 3.0 / 4) * (4.0 / 9), + (4.0 / 5) * 1 + (1 - 4.0 / 5) * (4.0 / 9), + (3.0 / 4) * (2.0 / 3) + (1 - 3.0 / 4) * (4.0 / 9)), + RowFactory.create((short) 1, 3, 6.0, 0.0, 2.0 / 3, 0.0, 2.0 / 3, + (3.0 / 4) * (2.0 / 3) + (1 - 3.0 / 4) * (4.0 / 9), + (1 - 5.0 / 6) * (4.0 / 9), + (3.0 / 4) * (2.0 / 3) + (1 - 3.0 / 4) * (4.0 / 9)), + RowFactory.create((short) 2, 4, 6.0, 1.0, 1.0 / 3, 1.0, 2.0 / 3, + (3.0 / 4) * (1.0 / 3) + (1 - 3.0 / 4) * (4.0 / 9), + (4.0 / 5) * 1 + (1 - 4.0 / 5) * (4.0 / 9), + (3.0 / 4) * (2.0 / 3) + (1 - 3.0 / 4) * (4.0 / 9)), + RowFactory.create((short) 0, 3, 7.0, 0.0, 1.0 / 3, 0.0, 0.0, + (3.0 / 4) * (1.0 / 3) + (1 - 3.0 / 4) * (4.0 / 9), (1 - 5.0 / 6) * (4.0 / 9), (1 - 1.0 / 2) * (4.0 / 9)), - RowFactory.create((short) 1, 4, 8.0, 1.0, 2.0 / 3, 1.0, 1.0, (3.0 / 4) * (2.0 / 3) + (1 - 3.0 / 4) * (4.0 / 9), - (4.0 / 5) * 1 + (1 - 4.0 / 5) * (4.0 / 9), (1.0 / 2) + (1 - 1.0 / 2) * (4.0 / 9)), - RowFactory.create((short) 2, 3, null, 0.0, 1.0 / 3, 0.0, 0.0, (3.0 / 4) * (1.0 / 3) + (1 - 3.0 / 4) * (4.0 / 9), - (1 - 5.0 / 6) * (4.0 / 9), (1 - 1.0 / 2) * (4.0 / 9))); + RowFactory.create((short) 1, 4, 8.0, 1.0, 2.0 / 3, 1.0, 1.0, + (3.0 / 4) * (2.0 / 3) + (1 - 3.0 / 4) * (4.0 / 9), + (4.0 / 5) * 1 + (1 - 4.0 / 5) * (4.0 / 9), + (1.0 / 2) + (1 - 1.0 / 2) * (4.0 / 9)), + RowFactory.create((short) 2, 3, null, 0.0, 1.0 / 3, 0.0, 0.0, + (3.0 / 4) * (1.0 / 3) + (1 - 3.0 / 4) * (4.0 / 9), + (1 - 5.0 / 6) * (4.0 / 9), + (1 - 1.0 / 2) * (4.0 / 9))); // checkstyle.off: LineLength StructType schema = createStructType(new StructField[]{ createStructField("input1", ShortType, true), diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/TargetEncoderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/TargetEncoderSuite.scala index f53cda625f155..869be94ff1273 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/TargetEncoderSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/TargetEncoderSuite.scala @@ -33,8 +33,8 @@ class TargetEncoderSuite extends MLTest with DefaultReadWriteTest { @transient var data_binary: Seq[Row] = _ @transient var data_continuous: Seq[Row] = _ @transient var schema: StructType = _ - @transient var expected_stats_binary: Array[Map[Option[Double], (Double, Double)]] = _ - @transient var expected_stats_continuous: Array[Map[Option[Double], (Double, Double)]] = _ + @transient var expected_stats_binary: Array[Map[Double, (Double, Double)]] = _ + @transient var expected_stats_continuous: Array[Map[Double, (Double, Double)]] = _ override def beforeAll(): Unit = { super.beforeAll() @@ -75,14 +75,14 @@ class TargetEncoderSuite extends MLTest with DefaultReadWriteTest { StructField("smoothing3", DoubleType))) expected_stats_binary = Array( - Map(Some(0.0) -> (3.0, 1.0), Some(1.0) -> (3.0, 2.0), Some(2.0) -> (3.0, 1.0), Some(-1.0) -> (9.0, 4.0)), - Map(Some(3.0) -> (5.0, 0.0), Some(4.0) -> (4.0, 4.0), Some(-1.0) -> (9.0, 4.0)), - HashMap(Some(5.0) -> (3.0, 1.0), Some(6.0) -> (3.0, 2.0), Some(7.0) -> (1.0, 0.0), Some(8.0) -> (1.0, 1.0), Some(9.0) -> (1.0, 0.0), Some(-1.0) -> (9.0, 4.0))) + Map(0.0 -> (3.0, 1.0), 1.0 -> (3.0, 2.0), 2.0 -> (3.0, 1.0), TargetEncoder.UNSEEN_CATEGORY -> (9.0, 4.0)), + Map(3.0 -> (5.0, 0.0), 4.0 -> (4.0, 4.0), TargetEncoder.UNSEEN_CATEGORY -> (9.0, 4.0)), + HashMap(5.0 -> (3.0, 1.0), 6.0 -> (3.0, 2.0), 7.0 -> (1.0, 0.0), 8.0 -> (1.0, 1.0), 9.0 -> (1.0, 0.0), TargetEncoder.UNSEEN_CATEGORY -> (9.0, 4.0))) expected_stats_continuous = Array( - Map(Some(0.0) -> (3.0, 40.0), Some(1.0) -> (3.0, 50.0), Some(2.0) -> (3.0, 60.0), Some(-1.0) -> (9.0, 50.0)), - Map(Some(3.0) -> (5.0, 50.0), Some(4.0) -> (4.0, 50.0), Some(-1.0) -> (9.0, 50.0)), - HashMap(Some(5.0) -> (3.0, 20.0), Some(6.0) -> (3.0, 50.0), Some(7.0) -> (1.0, 70.0), Some(8.0) -> (1.0, 80.0), Some(9.0) -> (1.0, 90.0), Some(-1.0) -> (9.0, 50.0))) + Map(0.0 -> (3.0, 40.0), 1.0 -> (3.0, 50.0), 2.0 -> (3.0, 60.0), TargetEncoder.UNSEEN_CATEGORY -> (9.0, 50.0)), + Map(3.0 -> (5.0, 50.0), 4.0 -> (4.0, 50.0), TargetEncoder.UNSEEN_CATEGORY -> (9.0, 50.0)), + HashMap(5.0 -> (3.0, 20.0), 6.0 -> (3.0, 50.0), 7.0 -> (1.0, 70.0), 8.0 -> (1.0, 80.0), 9.0 -> (1.0, 90.0), TargetEncoder.UNSEEN_CATEGORY -> (9.0, 50.0))) // scalastyle:on } @@ -247,8 +247,8 @@ class TargetEncoderSuite extends MLTest with DefaultReadWriteTest { .createDataFrame(sc.parallelize(data_continuous :+ data_unseen), schema) val ex = intercept[SparkRuntimeException] { - val out = model.transform(df_unseen) - out.show(false) + val output = model.transform(df_unseen) + output.show() } assert(ex.isInstanceOf[SparkRuntimeException]) @@ -320,7 +320,7 @@ class TargetEncoderSuite extends MLTest with DefaultReadWriteTest { val expected_stats = Array( expected_stats_continuous(0), expected_stats_continuous(1), - expected_stats_continuous(2) - Some(9.0) + (None -> (1.0, 90.0))) + expected_stats_continuous(2) + (TargetEncoder.NULL_CATEGORY -> (1.0, 90.0)) - 9.0) model.stats.zip(expected_stats).foreach{ case (actual, expected) => assert(actual.equals(expected)) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 29f49d51ee36d..e053ea273140c 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -5493,10 +5493,10 @@ def setSmoothing(self, value: float) -> "TargetEncoderModel": @property @since("4.0.0") - def stats(self) -> List[Dict[float, float]]: + def stats(self) -> List[Dict[float, Tuple[float, float]]]: """ - Fitted mappings for each feature to being encoded. - The dictionary contains a dictionary for each input column. + Fitted statistics for each feature to being encoded. + The list contains a dictionary for each input column. """ return self._call_java("stats")