Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-37178][ML] Add Target Encoding to ml.feature #48347

Closed
wants to merge 18 commits into from

Conversation

rebo16v
Copy link

@rebo16v rebo16v commented Oct 4, 2024

What changes were proposed in this pull request?

Adds support for target encoding of ml features.
Target Encoding maps a column of categorical indices into a numerical feature derived from the target.
Leveraging the relationship between categorical variables and the target variable, target encoding usually performs better than one-hot encoding (while avoiding the need to add extra columns)

Why are the changes needed?

Target Encoding is a well-known encoding technique for categorical features.
It's supported on most ml frameworks
https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.TargetEncoder.html
https://search.r-project.org/CRAN/refmans/dataPreparation/html/target_encode.html

Does this PR introduce any user-facing change?

Spark API now includes 2 new classes in package org.apache.spark.ml

  • TargetEncoder (estimator)
  • TargetEncoderModel (transformer)

How was this patch tested?

Scala => org.apache.spark.ml.feature.TargetEncoderSuite
Java => org.apache.spark.ml.feature.JavaTargetEncoderSuite
Python => python.pyspark.ml.tests.test_feature.FeatureTests (added 2 tests)

Was this patch authored or co-authored using generative AI tooling?

No

Some design notes ... |-

  • binary and continuous target types (no multi-label yet)

  • available in Scala, Java and Python APIs

  • fitting implemented on RDD API (treeAggregate)

  • transformation implemented on Dataframe API (no UDFs)

  • categorical features must be indices (integers) in Double-typed columns (as if StringIndexer were used before)

  • unseen categories in training are represented as class -1.0

  • Encodings structure

    • Map[String, Map[Double, Double]]) => Map[ feature_name, Map[ original_category, encoded category ] ]
  • Parameters

    • inputCol(s) / outputCol(s) / labelCol => as usual
    • targetType
      • binary => encodings calculated as in-category conditional probability (counting)
      • continuous => encodings calculated as in-category target mean (incrementally)
    • handleInvalid
      • error => raises an error if trying to encode an unseen category
      • keep => encodes an unseen category with the overall statistics
    • smoothing => controls how in-category stats and overall stats are weighted to calculate final encodings (to avoid overfitting)

Copy link
Member

@srowen srowen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good start, just need to clarify the implementation and consider supporting a few more cases

@@ -855,6 +855,46 @@ for more details on the API.

</div>

## TargetEncoder

Target Encoding maps a column of categorical indices into a numerical feature derived from the target. For string type input data, it is common to encode categorical features using [StringIndexer](ml-features.html#stringindexer) first.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's drop at least a link to information on what target encoding is here.
Also, the explanation you give in the PR about what this actually does to which types of input is valuable and should probably be here too, either here or below in discussion of what the parameters do in some detail.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think it's ok now. what do you think?

feature => {
try {
val field = schema(feature)
if (field.dataType != DoubleType) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do the features have to be floats? I'd imagine they aren't if they're categorical representations you're encoding. I think it's OK to demand they're not strings and are already passed through StringIndexer in that case, but it feels like any numeric type works here

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i mimic this behavior from other encoders (i.e. OneHotEncoder)
what would be your approach? accepting Integers? checking for nominal attribute in metadata?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that's OK if it's what other encoders do. But I see checks like https://github.com/apache/spark/blob/master/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala#L93 - maybe follow that?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you're right
now accepting any subclass of NumericType for features & label
(maybe it doesn't make much sense the continuous case, it could be done anyway)

validateSchema(dataset.schema, fitting = true)

val stats = dataset
.select(ArraySeq.unsafeWrapArray(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the ArraySeq business necessary? you're just selecting columns with : _* syntax so any seq would do

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it doesn't work in Scala 2.13
Passing an explicit array value to a Scala varargs method is deprecated (since 2.13.0) and will result in a defensive copy; Use the more efficient non-copying ArraySeq.unsafeWrapArray or an explicit toIndexedSeq call

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I'd say .toIndexedSeq is simpler

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you're right. done

globalCounter._2 + ((label - globalCounter._2) / (1 + globalCounter._1))))
}
} catch {
case e: SparkRuntimeException =>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indent

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got resolved in the overall refactor

case e: SparkRuntimeException =>
if (e.getErrorClass == "ROW_VALUE_IS_NULL") {
throw new SparkException(s"Null value found in feature ${inputFeatures(feature)}." +
s" See Imputer estimator for completing missing values.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like you can still target-encode null; it's just another possible value, no?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes
it will be encoded as an unseen category (global statistics)
we could raise an error (as we do while fitting)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But this throws an exception? how is it handled as unseen but also raises an exception? it shouldn't, right?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, it raises an exception while fitting and encodes as unseen category while transforming.
I´ll check scikit-learn behavior

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, I mean while fitting. I don't feel like this is necessary

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

following scikit approach, now treating null as another category
(category becomes Option[Double])
encodings: Map[String, Map[Option[Double], Double]]

val value = row.getDouble(feature)
if (value < 0.0 || value != value.toInt) throw new SparkException(
s"Values from column ${inputFeatures(feature)} must be indices, but got $value.")
val counter = agg(feature).getOrElse(value, (0.0, 0.0))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use val (foo, bar) = syntax so you don't have to use more cryptic ._1 references later

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

val globalCounter = agg(feature).getOrElse(TargetEncoder.UNSEEN_CATEGORY, (0.0, 0.0))
$(targetType) match {
case TargetEncoder.TARGET_BINARY =>
if (label == 1.0) agg(feature) +
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These if-else clauses need to be indented and with braces around them for clarity

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

})(
(agg, row: Row) => {
val label = row.getDouble(inputFeatures.length)
Range(0, inputFeatures.length).map {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(1 until inputFeatures.length) feels a little more idiomatic, or even for ... yield

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

finally changed for inputFeatures.indices

val values = agg1(feature).keySet ++ agg2(feature).keySet
values.map(value =>
value -> {
val stat1 = agg1(feature).getOrElse(value, (0.0, 0.0))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same, let's give names to the elements of this tuple
A comment or two in these blocks about what this sum is doing would help too

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

rest
.foldLeft(when(col === first._1, first._2))(
(new_col: Column, encoding) =>
if (encoding._1 != TargetEncoder.UNSEEN_CATEGORY) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And same again around here - some comments and more descriptive var names are important, as I have trouble evaluating the logic

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Map.empty[Option[Double], (Double, Double)]
})(
(agg, row: Row) => {
val label = label_type match {
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

didn't work yet on handling null labels
i checked scikit and it fails at this (encoding all to NaN)
we could

  1. raise an exception
  2. do not consider the observation and keep going
    what do you think?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. I guess I think it's most sensible to ignore nulls then

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

`TargetEncoder` supports the `targetType` parameter to choose the label type when fitting data, affecting how statistics are calculated.
Available options include 'binary' and 'continuous' (mean-encoding).
When set to 'binary', encodings will be fitted from target conditional probabilities (a.k.a bin-counting).
When set to 'continuous', encodings will be fitted from according to target mean (a.k.a. mean-encoding).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can describe this a little bit more somewhere, could be here or at the top - what does target encoding actually do? with a simplistic example of a few rows?

Just want to make it immediate clearly in 1 paragraph what this is doing for binary vs continuous targets

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

case (cat, (class_count, class_stat)) => cat -> {
val weight = class_count / (class_count + $(smoothing))
$(targetType) match {
case TargetEncoder.TARGET_BINARY =>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This all might be worth a few lines of comments explaining the math here

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@srowen
Copy link
Member

srowen commented Oct 19, 2024

Let me call in @zhengruifeng for a look at this too. I think it's pretty good

@Since("4.0.0")
val targetType: Param[String] = new Param[String](this, "targetType",
"How to handle invalid data during transform(). " +
"Options are 'keep' (invalid data presented as an extra categorical feature) " +
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

targetType's description is same as handleInvalid?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ups! Fixed for targetType & smoothing

override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid",
"How to handle invalid data during transform(). " +
"Options are 'keep' (invalid data presented as an extra categorical feature) " +
"or error (throw an error). Note that this Param is only used during transform; " +
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"or error (throw an error). Note that this Param is only used during transform; " +
"or 'error' (throw an error). Note that this Param is only used during transform; " +

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Comment on lines 85 to 86
private[feature] def validateSchema(schema: StructType,
fitting: Boolean): StructType = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
private[feature] def validateSchema(schema: StructType,
fitting: Boolean): StructType = {
private[feature] def validateSchema(
schema: StructType,
fitting: Boolean): StructType = {

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

case ShortType => row.getShort(inputFeatures.length).toDouble
case IntegerType => row.getInt(inputFeatures.length).toDouble
case LongType => row.getLong(inputFeatures.length).toDouble
case DoubleType => row.getDouble(inputFeatures.length)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest make the casting happen before the aggregation (dataset.select in the above ) to simplify the process

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

}.toArray)

// encodings: Map[feature, Map[Some(category), encoding]]
val encodings: Map[String, Map[Option[Double], Double]] =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel this is computation is not very complex and maybe implemented with sql functions.
but I am also fine to start with a RDD implementation.

dataset.withColumns(
inputFeatures.zip(outputFeatures).map {
feature =>
feature._2 -> (encodings.get(feature._1) match {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

model coefficients encodings stores the column names inputCols used in fit, does encodings.get(feature._1) requires inputCols in transform should be exactly the same as inputCols in fit?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you're right. Fixed.
now encodings: Array[Map[Some(category), encoding]]

@zhengruifeng
Copy link
Contributor

also cc @WeichenXu123 for visibility

@rebo16v
Copy link
Author

rebo16v commented Oct 21, 2024

I think we should pass raw estimates to the model and calculate encodings in transform()
So we can apply different smoothing factors without having to re-fit
Makes sense? Will work on this ...

val encodings: Array[Map[Option[Double], Double]] =

@rebo16v
Copy link
Author

rebo16v commented Oct 23, 2024

I think we should pass raw estimates to the model and calculate encodings in transform() So we can apply different smoothing factors without having to re-fit Makes sense? Will work on this ...

val encodings: Array[Map[Option[Double], Double]] =

done!

@rebo16v
Copy link
Author

rebo16v commented Oct 27, 2024

@srowen @zhengruifeng


}

test("TargetEncoder - null label") {
Copy link
Contributor

@zhengruifeng zhengruifeng Oct 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how does it handle NaN? treat as a normal value or invalid value?

Copy link
Author

@rebo16v rebo16v Oct 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • NaN features => invalid (only accepting indices and null)
  • NaN labels => was failing, it's fixed now (observation not considered)

Copy link
Contributor

@zhengruifeng zhengruifeng Nov 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • NaN features => invalid (only accepting indices and null)
  • NaN labels => was failing, it's fixed now (observation not considered)

hi @rebo16v

https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.TargetEncoder.html#targetencoder

TargetEncoder considers missing values, such as np.nan or None, as another category and encodes them like any other category. Categories that are not seen during fit are encoded with the target mean, i.e. target_mean_.

It seems scikit-learn's implementation also treat NaN as a valid missing value?

else if (isSet(outputCols)) $(outputCols)
else inputFeatures.map{field: String => s"${field}_indexed"}

private[feature] def validateSchema(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would you mind checking the scala style according to the section Code style guide in https://spark.apache.org/contributing.html?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

@rebo16v
Copy link
Author

rebo16v commented Oct 28, 2024

@zhengruifeng

@srowen
Copy link
Member

srowen commented Nov 2, 2024

I think it looks good. There are 'failing' tests but it looks like a timeout. I'll run again to see if they complete. Anyone know about issues with the builder at the moment?

@Since("4.0.0")
class TargetEncoderModel private[ml] (
@Since("4.0.0") override val uid: String,
@Since("4.0.0") val stats: Array[Map[Option[Double], (Double, Double)]])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: is it possible to avoid the usage of Option in the model coefficient?

e.g.

val stats: Array[Map[Double, (Double, Double)]]), # for valid values;
val statForInvalid: (Double, Double), # for invalid values;

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

None is reserved for null category
It's possible to avoid it by reserving another double value (i.e. -2)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Member

@HyukjinKwon HyukjinKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The failed test passes locally. Let's merge this and see if the test failure persists.

@rebo16v
Copy link
Author

rebo16v commented Nov 6, 2024

@HyukjinKwon @zhengruifeng

@HyukjinKwon
Copy link
Member

Merged to master.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants