-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-17498][ML] StringIndexer enhancement for handling unseen labels #16883
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Test build #72687 has finished for PR 16883 at commit
|
This PR is an enhancement to ML StringIndexer. Before this PR, String Indexer only supports "skip"/"error" options
to deal with unseen records. But sometimes those unseen records might still be useful in certain use cases, so user
would like to keep the unseen labels. This PR enables StringIndexer to support keeping unseen labels as indices
[numLabels].
'''Before
StringIndexer().setHandleInvalid("skip")
StringIndexer().setHandleInvalid("error")
'''After
support the third option "keep"
StringIndexer().setHandleInvalid("keep")
Signed-off-by: VinceShieh <vincent.xie@intel.com>
30f3ba3 to
b970728
Compare
|
Test build #72688 has finished for PR 16883 at commit
|
Signed-off-by: VinceShieh <vincent.xie@intel.com>
|
Test build #72690 has finished for PR 16883 at commit
|
Signed-off-by: VinceShieh <vincent.xie@intel.com>
|
Test build #72695 has finished for PR 16883 at commit
|
|
Test build #72696 has finished for PR 16883 at commit
|
Signed-off-by: VinceShieh <vincent.xie@intel.com>
2d6da1c to
9a41745
Compare
Signed-off-by: VinceShieh <vincent.xie@intel.com>
This reverts commit 9a41745.
Signed-off-by: VinceShieh <vincent.xie@intel.com>
|
Test build #72697 has finished for PR 16883 at commit
|
|
Test build #72698 has finished for PR 16883 at commit
|
|
Test build #72700 has finished for PR 16883 at commit
|
|
@srowen @jkbradley do u have time to take a look? |
|
I'll take a look now, thanks! |
jkbradley
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done with review pass. Thanks!
|
|
||
| package org.apache.spark.ml.feature | ||
|
|
||
| import scala.language.existentials |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
local build&test are fine, but will get compilation error on line 193 on Jenkins
| private[feature] trait StringIndexerBase extends Params with HasInputCol with HasOutputCol | ||
| with HasHandleInvalid { | ||
| private[feature] trait StringIndexerBase extends Params with HasInputCol with HasOutputCol { | ||
| val SKIP_UNSEEN_LABEL: String = "skip" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Constants like this should go in a static object class. I'd also like to keep them private for now.
|
|
||
| /** | ||
| * Param for how to handle unseen labels. Options are 'skip' (filter out rows with | ||
| * unseen labels), 'error' (throw an error), or 'keep' (map unseen labels with |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rephrase: "map unseen labels with indices [numLabels]" --> "put unseen labels in a special additional bucket, at index numLabels"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(same elsewhere in docs)
|
|
||
| /** @group getParam */ | ||
| @Since("2.1.0") | ||
| def getHandleInvalid: String = $(handleInvalid) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should go in the trait.
| def getHandleInvalid: String = $(handleInvalid) | ||
|
|
||
| /** @group setParam */ | ||
| @Since("2.1.0") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Update Since annotations to 2.2.0
| transformSchema(dataset.schema, logging = true) | ||
|
|
||
| val metadata = NominalAttribute.defaultAttr | ||
| .withName($(outputCol)).withValues(labels).toMetadata() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
withValues should include a special field for invalid, if we are keeping invalid labels. How about calling that field "_invalidLabels" (and defining this constant in the StringIndexer object)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sorry, I cannot fully get the point.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think he means that "labels" above should also include the invalid bucket. In previous ML frameworks I've worked on we've just called this "unknown".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, that's what I meant: In withValues(labels), labels can be set as:
val labels = getHandleInvalid match {
case StringIndexer.KEEP_UNSEEN_LABEL => labels :+ "__unknown"
case _ => labels
}
I'm adding underscores to the attribute name to make it a little less likely to hit conflicts.
| } else if (keepInvalid) { | ||
| labels.length | ||
| } else { | ||
| throw new SparkException(s"Unseen label: $label.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Update with recommendation to set handleInvalid to "keep" to handle unseen labels.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you improve the error message?
throw new SparkException(s"Unseen label: $label. To handle unseen labels, set Param handleInvalid to ${StringIndexer.KEEP_UNSEEN_LABEL}.")
| // Verify that we skip the c record | ||
| val transformed = indexerSkipInvalid.transform(df2) | ||
| val attr = Attribute.fromStructField(transformed.schema("labelIndex")) | ||
| var transformed = indexer.transform(df2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Keep as val, and just define new vals below. Vals make the code clearer.
docs/ml-features.md
Outdated
|
|
||
| `StringIndexer` encodes a string column of labels to a column of label indices. | ||
| The indices are in `[0, numLabels)`, ordered by label frequencies, so the most frequent label gets index `0`. | ||
| The indices are in `[0, numLabels]`, ordered by label frequencies, so the most frequent label gets index `0`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change is not correct, except when keeping invalid ones.
docs/ml-features.md
Outdated
| Notice that the row containing "d" does not appear. | ||
| Notice that the rows containing "d" or "e" do not appear. | ||
|
|
||
| If you had called `setHandleInvalid("keep")`, the following dataset |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"If you had called" --> "If you call"
|
Btw, we're near the time when the 2.2 branch will be cut, and I'd like to get this into 2.2. Let me know if you're busy, and I'd be happy to help finalize the PR. Thanks! |
|
gotcha, will update soon. |
Signed-off-by: VinceShieh <vincent.xie@intel.com>
|
Test build #73637 has finished for PR 16883 at commit
|
val (filteredDataset, keepInvalid) = getHandleInvalid match {
case ..
}
Signed-off-by: VinceShieh <vincent.xie@intel.com>
|
Test build #73639 has finished for PR 16883 at commit
|
|
Test build #73643 has finished for PR 16883 at commit
|
docs/ml-features.md
Outdated
|
|
||
| - throw an exception (which is the default) | ||
| - skip the row containing the unseen label entirely | ||
| - map the unseen labels with indices [numLabels] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
doc suggestion: "map the unseen labels to their own index"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or just match the phrasing in the doc param
docs/ml-features.md
Outdated
| 4 | e | 3.0 | ||
| ~~~~ | ||
|
|
||
| Notice that the rows containing "d" or "e" are mapped with indices "3.0" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
doc suggestion: rows containing "d" or "e" are mapped with indices "3.0" => rows containing "d" and "e" are mapped to index "3.0"
|
|
||
| /** @group setParam */ | ||
| @Since("2.2.0") | ||
| def setHandleInvalid(value: String): this.type = set(handleInvalid, value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you keep the order of the params same as before? also, minor style comment -- keep the setDefault(handleInvalid) below the set method.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 for maintaining order.
setDefault will go in the trait (except in cases where it belongs in just one of the Estimator or Model)
|
|
||
| private[feature] val SKIP_UNSEEN_LABEL: String = "skip" | ||
| private[feature] val ERROR_UNSEEN_LABEL: String = "error" | ||
| private[feature] val KEEP_UNSEEN_LABEL: String = "keep" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is very nice, good use of constants, I really like to see this type of code :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would make me even happier if these were public and could be used by the test code, but I think it's up to the committers (jkbradley)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At some point, let's do that, but not yet. I like keeping things private at first in case we find mistakes after release and need to change things.
| val metadata = NominalAttribute.defaultAttr | ||
| .withName($(outputCol)).withValues(labels).toMetadata() | ||
| // If we are skipping invalid records, filter them out. | ||
| val (filteredDataset, keepInvalid) = getHandleInvalid match { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
minor style comment: instead of keepInvalid, do you think that indexInvalid might be a better name (?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actually, I think returning a tuple here just makes things more confusing. Maybe you can move the check outside of the match.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm OK with returning a tuple; that's a common pattern. Do you mean that it makes the code inside the match statement confusing?
|
@VinceShieh I added some minor comments. This is a nice feature! |
jkbradley
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the updates! I made a few more comments.
| @Since("2.1.0") | ||
| val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle " + | ||
| "unseen labels. Options are 'skip' (filter out rows with unseen labels), " + | ||
| "error (throw an error), or 'keep' (put unseen labels in a special additional bucket," + |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
need space after comma: "bucket, "
|
|
||
| private[feature] val SKIP_UNSEEN_LABEL: String = "skip" | ||
| private[feature] val ERROR_UNSEEN_LABEL: String = "error" | ||
| private[feature] val KEEP_UNSEEN_LABEL: String = "keep" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At some point, let's do that, but not yet. I like keeping things private at first in case we find mistakes after release and need to change things.
| private[feature] val KEEP_UNSEEN_LABEL: String = "keep" | ||
| private[feature] val supportedHandleInvalids: Array[String] = | ||
| Array(SKIP_UNSEEN_LABEL, ERROR_UNSEEN_LABEL, KEEP_UNSEEN_LABEL) | ||
| @Since("1.6.0") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: add newline here
| /** @group setParam */ | ||
| @Since("2.2.0") | ||
| def setHandleInvalid(value: String): this.type = set(handleInvalid, value) | ||
| setDefault(handleInvalid, StringIndexer.ERROR_UNSEEN_LABEL) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need to set default here since it's set in the trait
| transformSchema(dataset.schema, logging = true) | ||
|
|
||
| val metadata = NominalAttribute.defaultAttr | ||
| .withName($(outputCol)).withValues(labels).toMetadata() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, that's what I meant: In withValues(labels), labels can be set as:
val labels = getHandleInvalid match {
case StringIndexer.KEEP_UNSEEN_LABEL => labels :+ "__unknown"
case _ => labels
}
I'm adding underscores to the attribute name to make it a little less likely to hit conflicts.
| } else if (keepInvalid) { | ||
| labels.length | ||
| } else { | ||
| throw new SparkException(s"Unseen label: $label.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you improve the error message?
throw new SparkException(s"Unseen label: $label. To handle unseen labels, set Param handleInvalid to ${StringIndexer.KEEP_UNSEEN_LABEL}.")
docs/ml-features.md
Outdated
|
|
||
| - throw an exception (which is the default) | ||
| - skip the row containing the unseen label entirely | ||
| - map the unseen labels with indices [numLabels] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or just match the phrasing in the doc param
|
|
||
| /** @group setParam */ | ||
| @Since("2.2.0") | ||
| def setHandleInvalid(value: String): this.type = set(handleInvalid, value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 for maintaining order.
setDefault will go in the trait (except in cases where it belongs in just one of the Estimator or Model)
| val metadata = NominalAttribute.defaultAttr | ||
| .withName($(outputCol)).withValues(labels).toMetadata() | ||
| // If we are skipping invalid records, filter them out. | ||
| val (filteredDataset, keepInvalid) = getHandleInvalid match { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm OK with returning a tuple; that's a common pattern. Do you mean that it makes the code inside the match statement confusing?
|
updated. Thank you both @imatiach-msft @jkbradley |
|
Test build #74007 has finished for PR 16883 at commit
|
jkbradley
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The changes look good, thanks! Only the Since annotation issue remains.
| * Default: "error" | ||
| * @group param | ||
| */ | ||
| @Since("2.1.0") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I missed this before, but these Since annotations should stay set to 1.6.0 since handleInvalid and the get/set methods were added in 1.6.0
| setDefault(handleInvalid, StringIndexer.ERROR_UNSEEN_LABEL) | ||
|
|
||
| /** @group getParam */ | ||
| @Since("2.2.0") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
|
|
||
| /** @group setParam */ | ||
| @Since("1.6.0") | ||
| @Since("2.2.0") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
| def setOutputCol(value: String): this.type = set(outputCol, value) | ||
|
|
||
| /** @group setParam */ | ||
| @Since("2.2.0") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
Signed-off-by: VinceShieh <vincent.xie@intel.com>
|
Test build #74055 has finished for PR 16883 at commit
|
|
Nice work! I'm just planning to improve |
|
LGTM |
|
Btw, are you interested in updating the Python API too? https://issues.apache.org/jira/browse/SPARK-19852 |
|
Sure, I can work on that :) @jkbradley |
|
Thanks! |
What changes were proposed in this pull request?
This PR is an enhancement to ML StringIndexer.
Before this PR, String Indexer only supports "skip"/"error" options to deal with unseen records.
But those unseen records might still be useful and user would like to keep the unseen labels in
certain use cases, This PR enables StringIndexer to support keeping unseen labels as
indices [numLabels].
'''Before
StringIndexer().setHandleInvalid("skip")
StringIndexer().setHandleInvalid("error")
'''After
support the third option "keep"
StringIndexer().setHandleInvalid("keep")
How was this patch tested?
Test added in StringIndexerSuite
Signed-off-by: VinceShieh vincent.xie@intel.com
(Please fill in changes proposed in this fix)