-
Notifications
You must be signed in to change notification settings - Fork 28.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-37178][ML] Add Target Encoding to ml.feature #48347
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good start, just need to clarify the implementation and consider supporting a few more cases
docs/ml-features.md
Outdated
@@ -855,6 +855,46 @@ for more details on the API. | |||
|
|||
</div> | |||
|
|||
## TargetEncoder | |||
|
|||
Target Encoding maps a column of categorical indices into a numerical feature derived from the target. For string type input data, it is common to encode categorical features using [StringIndexer](ml-features.html#stringindexer) first. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's drop at least a link to information on what target encoding is here.
Also, the explanation you give in the PR about what this actually does to which types of input is valuable and should probably be here too, either here or below in discussion of what the parameters do in some detail.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think it's ok now. what do you think?
feature => { | ||
try { | ||
val field = schema(feature) | ||
if (field.dataType != DoubleType) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do the features have to be floats? I'd imagine they aren't if they're categorical representations you're encoding. I think it's OK to demand they're not strings and are already passed through StringIndexer in that case, but it feels like any numeric type works here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i mimic this behavior from other encoders (i.e. OneHotEncoder)
what would be your approach? accepting Integers? checking for nominal attribute in metadata?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that's OK if it's what other encoders do. But I see checks like https://github.com/apache/spark/blob/master/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala#L93 - maybe follow that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you're right
now accepting any subclass of NumericType for features & label
(maybe it doesn't make much sense the continuous case, it could be done anyway)
validateSchema(dataset.schema, fitting = true) | ||
|
||
val stats = dataset | ||
.select(ArraySeq.unsafeWrapArray( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the ArraySeq business necessary? you're just selecting columns with : _*
syntax so any seq would do
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it doesn't work in Scala 2.13
Passing an explicit array value to a Scala varargs method is deprecated (since 2.13.0) and will result in a defensive copy; Use the more efficient non-copying ArraySeq.unsafeWrapArray or an explicit toIndexedSeq call
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, I'd say .toIndexedSeq is simpler
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you're right. done
globalCounter._2 + ((label - globalCounter._2) / (1 + globalCounter._1)))) | ||
} | ||
} catch { | ||
case e: SparkRuntimeException => |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indent
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
got resolved in the overall refactor
case e: SparkRuntimeException => | ||
if (e.getErrorClass == "ROW_VALUE_IS_NULL") { | ||
throw new SparkException(s"Null value found in feature ${inputFeatures(feature)}." + | ||
s" See Imputer estimator for completing missing values.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems like you can still target-encode null; it's just another possible value, no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes
it will be encoded as an unseen category (global statistics)
we could raise an error (as we do while fitting)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But this throws an exception? how is it handled as unseen but also raises an exception? it shouldn't, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, it raises an exception while fitting and encodes as unseen category while transforming.
I´ll check scikit-learn behavior
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, I mean while fitting. I don't feel like this is necessary
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
following scikit approach, now treating null as another category
(category becomes Option[Double])
encodings: Map[String, Map[Option[Double], Double]]
val value = row.getDouble(feature) | ||
if (value < 0.0 || value != value.toInt) throw new SparkException( | ||
s"Values from column ${inputFeatures(feature)} must be indices, but got $value.") | ||
val counter = agg(feature).getOrElse(value, (0.0, 0.0)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use val (foo, bar) =
syntax so you don't have to use more cryptic ._1
references later
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
val globalCounter = agg(feature).getOrElse(TargetEncoder.UNSEEN_CATEGORY, (0.0, 0.0)) | ||
$(targetType) match { | ||
case TargetEncoder.TARGET_BINARY => | ||
if (label == 1.0) agg(feature) + |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These if-else clauses need to be indented and with braces around them for clarity
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
})( | ||
(agg, row: Row) => { | ||
val label = row.getDouble(inputFeatures.length) | ||
Range(0, inputFeatures.length).map { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(1 until inputFeatures.length)
feels a little more idiomatic, or even for ... yield
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
finally changed for inputFeatures.indices
val values = agg1(feature).keySet ++ agg2(feature).keySet | ||
values.map(value => | ||
value -> { | ||
val stat1 = agg1(feature).getOrElse(value, (0.0, 0.0)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same, let's give names to the elements of this tuple
A comment or two in these blocks about what this sum is doing would help too
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
rest | ||
.foldLeft(when(col === first._1, first._2))( | ||
(new_col: Column, encoding) => | ||
if (encoding._1 != TargetEncoder.UNSEEN_CATEGORY) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And same again around here - some comments and more descriptive var names are important, as I have trouble evaluating the logic
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
…ding into sparkml-target-encoding
Map.empty[Option[Double], (Double, Double)] | ||
})( | ||
(agg, row: Row) => { | ||
val label = label_type match { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
didn't work yet on handling null labels
i checked scikit and it fails at this (encoding all to NaN)
we could
- raise an exception
- do not consider the observation and keep going
what do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. I guess I think it's most sensible to ignore nulls then
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
docs/ml-features.md
Outdated
`TargetEncoder` supports the `targetType` parameter to choose the label type when fitting data, affecting how statistics are calculated. | ||
Available options include 'binary' and 'continuous' (mean-encoding). | ||
When set to 'binary', encodings will be fitted from target conditional probabilities (a.k.a bin-counting). | ||
When set to 'continuous', encodings will be fitted from according to target mean (a.k.a. mean-encoding). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you can describe this a little bit more somewhere, could be here or at the top - what does target encoding actually do? with a simplistic example of a few rows?
Just want to make it immediate clearly in 1 paragraph what this is doing for binary vs continuous targets
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
case (cat, (class_count, class_stat)) => cat -> { | ||
val weight = class_count / (class_count + $(smoothing)) | ||
$(targetType) match { | ||
case TargetEncoder.TARGET_BINARY => |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This all might be worth a few lines of comments explaining the math here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
mllib/src/main/scala/org/apache/spark/ml/feature/TargetEncoder.scala
Outdated
Show resolved
Hide resolved
Let me call in @zhengruifeng for a look at this too. I think it's pretty good |
@Since("4.0.0") | ||
val targetType: Param[String] = new Param[String](this, "targetType", | ||
"How to handle invalid data during transform(). " + | ||
"Options are 'keep' (invalid data presented as an extra categorical feature) " + |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
targetType
's description is same as handleInvalid
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ups! Fixed for targetType & smoothing
override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", | ||
"How to handle invalid data during transform(). " + | ||
"Options are 'keep' (invalid data presented as an extra categorical feature) " + | ||
"or error (throw an error). Note that this Param is only used during transform; " + |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"or error (throw an error). Note that this Param is only used during transform; " + | |
"or 'error' (throw an error). Note that this Param is only used during transform; " + |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
private[feature] def validateSchema(schema: StructType, | ||
fitting: Boolean): StructType = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
private[feature] def validateSchema(schema: StructType, | |
fitting: Boolean): StructType = { | |
private[feature] def validateSchema( | |
schema: StructType, | |
fitting: Boolean): StructType = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
case ShortType => row.getShort(inputFeatures.length).toDouble | ||
case IntegerType => row.getInt(inputFeatures.length).toDouble | ||
case LongType => row.getLong(inputFeatures.length).toDouble | ||
case DoubleType => row.getDouble(inputFeatures.length) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would suggest make the casting happen before the aggregation (dataset.select
in the above ) to simplify the process
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
}.toArray) | ||
|
||
// encodings: Map[feature, Map[Some(category), encoding]] | ||
val encodings: Map[String, Map[Option[Double], Double]] = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel this is computation is not very complex and maybe implemented with sql functions.
but I am also fine to start with a RDD implementation.
dataset.withColumns( | ||
inputFeatures.zip(outputFeatures).map { | ||
feature => | ||
feature._2 -> (encodings.get(feature._1) match { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
model coefficients encodings
stores the column names inputCols
used in fit
, does encodings.get(feature._1)
requires inputCols
in transform
should be exactly the same as inputCols
in fit
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you're right. Fixed.
now encodings: Array[Map[Some(category), encoding]]
also cc @WeichenXu123 for visibility |
I think we should pass raw estimates to the model and calculate encodings in transform()
|
…ding into sparkml-target-encoding
done! |
|
||
} | ||
|
||
test("TargetEncoder - null label") { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how does it handle NaN
? treat as a normal value or invalid value?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- NaN features => invalid (only accepting indices and null)
- NaN labels => was failing, it's fixed now (observation not considered)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- NaN features => invalid (only accepting indices and null)
- NaN labels => was failing, it's fixed now (observation not considered)
hi @rebo16v
TargetEncoder considers missing values, such as np.nan or None, as another category and encodes them like any other category. Categories that are not seen during fit are encoded with the target mean, i.e. target_mean_.
It seems scikit-learn's implementation also treat NaN as a valid missing value?
else if (isSet(outputCols)) $(outputCols) | ||
else inputFeatures.map{field: String => s"${field}_indexed"} | ||
|
||
private[feature] def validateSchema( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would you mind checking the scala style according to the section Code style guide
in https://spark.apache.org/contributing.html?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
I think it looks good. There are 'failing' tests but it looks like a timeout. I'll run again to see if they complete. Anyone know about issues with the builder at the moment? |
…ding into sparkml-target-encoding
@Since("4.0.0") | ||
class TargetEncoderModel private[ml] ( | ||
@Since("4.0.0") override val uid: String, | ||
@Since("4.0.0") val stats: Array[Map[Option[Double], (Double, Double)]]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: is it possible to avoid the usage of Option
in the model coefficient?
e.g.
val stats: Array[Map[Double, (Double, Double)]]), # for valid values;
val statForInvalid: (Double, Double), # for invalid values;
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
None is reserved for null category
It's possible to avoid it by reserving another double value (i.e. -2)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
examples/src/main/java/org/apache/spark/examples/ml/JavaTargetEncoderExample.java
Outdated
Show resolved
Hide resolved
mllib/src/main/scala/org/apache/spark/ml/feature/TargetEncoder.scala
Outdated
Show resolved
Hide resolved
mllib/src/main/scala/org/apache/spark/ml/feature/TargetEncoder.scala
Outdated
Show resolved
Hide resolved
mllib/src/main/scala/org/apache/spark/ml/feature/TargetEncoder.scala
Outdated
Show resolved
Hide resolved
mllib/src/test/java/org/apache/spark/ml/feature/JavaTargetEncoderSuite.java
Outdated
Show resolved
Hide resolved
mllib/src/test/java/org/apache/spark/ml/feature/JavaTargetEncoderSuite.java
Outdated
Show resolved
Hide resolved
mllib/src/test/java/org/apache/spark/ml/feature/JavaTargetEncoderSuite.java
Outdated
Show resolved
Hide resolved
mllib/src/test/java/org/apache/spark/ml/feature/JavaTargetEncoderSuite.java
Outdated
Show resolved
Hide resolved
mllib/src/test/java/org/apache/spark/ml/feature/JavaTargetEncoderSuite.java
Outdated
Show resolved
Hide resolved
mllib/src/test/java/org/apache/spark/ml/feature/JavaTargetEncoderSuite.java
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The failed test passes locally. Let's merge this and see if the test failure persists.
Merged to master. |
What changes were proposed in this pull request?
Adds support for target encoding of ml features.
Target Encoding maps a column of categorical indices into a numerical feature derived from the target.
Leveraging the relationship between categorical variables and the target variable, target encoding usually performs better than one-hot encoding (while avoiding the need to add extra columns)
Why are the changes needed?
Target Encoding is a well-known encoding technique for categorical features.
It's supported on most ml frameworks
https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.TargetEncoder.html
https://search.r-project.org/CRAN/refmans/dataPreparation/html/target_encode.html
Does this PR introduce any user-facing change?
Spark API now includes 2 new classes in package org.apache.spark.ml
How was this patch tested?
Scala => org.apache.spark.ml.feature.TargetEncoderSuite
Java => org.apache.spark.ml.feature.JavaTargetEncoderSuite
Python => python.pyspark.ml.tests.test_feature.FeatureTests (added 2 tests)
Was this patch authored or co-authored using generative AI tooling?
No
Some design notes ... |-
binary and continuous target types (no multi-label yet)
available in Scala, Java and Python APIs
fitting implemented on RDD API (treeAggregate)
transformation implemented on Dataframe API (no UDFs)
categorical features must be indices (integers) in Double-typed columns (as if StringIndexer were used before)
unseen categories in training are represented as class -1.0
Encodings structure
Parameters