Skip to content

Conversation

@VinceShieh
Copy link

What changes were proposed in this pull request?

This PR is an enhancement to ML StringIndexer.
Before this PR, String Indexer only supports "skip"/"error" options to deal with unseen records.
But those unseen records might still be useful and user would like to keep the unseen labels in
certain use cases, This PR enables StringIndexer to support keeping unseen labels as
indices [numLabels].

'''Before
StringIndexer().setHandleInvalid("skip")
StringIndexer().setHandleInvalid("error")
'''After
support the third option "keep"
StringIndexer().setHandleInvalid("keep")

How was this patch tested?

Test added in StringIndexerSuite

Signed-off-by: VinceShieh vincent.xie@intel.com
(Please fill in changes proposed in this fix)

@SparkQA
Copy link

SparkQA commented Feb 10, 2017

Test build #72687 has finished for PR 16883 at commit 30f3ba3.

  • This patch fails Scala style tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

This PR is an enhancement to ML StringIndexer. Before this PR, String Indexer only supports "skip"/"error" options
to deal with unseen records. But sometimes those unseen records might still be useful in certain use cases, so user
would like to keep the unseen labels. This PR enables StringIndexer to support keeping unseen labels as indices
[numLabels].

'''Before
StringIndexer().setHandleInvalid("skip")
StringIndexer().setHandleInvalid("error")
'''After
support the third option "keep"
StringIndexer().setHandleInvalid("keep")

Signed-off-by: VinceShieh <vincent.xie@intel.com>
@SparkQA
Copy link

SparkQA commented Feb 10, 2017

Test build #72688 has finished for PR 16883 at commit b970728.

  • This patch fails to build.
  • This patch merges cleanly.
  • This patch adds no public classes.

Signed-off-by: VinceShieh <vincent.xie@intel.com>
@SparkQA
Copy link

SparkQA commented Feb 10, 2017

Test build #72690 has finished for PR 16883 at commit 5d4b07f.

  • This patch fails MiMa tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

Signed-off-by: VinceShieh <vincent.xie@intel.com>
@SparkQA
Copy link

SparkQA commented Feb 10, 2017

Test build #72695 has finished for PR 16883 at commit 0eb7f07.

  • This patch fails MiMa tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Feb 10, 2017

Test build #72696 has finished for PR 16883 at commit 2d6da1c.

  • This patch fails Scala style tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

Signed-off-by: VinceShieh <vincent.xie@intel.com>
VinceShieh added 2 commits February 10, 2017 17:33
Signed-off-by: VinceShieh <vincent.xie@intel.com>
@VinceShieh VinceShieh changed the title [SPARK-17498][ML] enchance StringIndexer to handle unseen labels [SPARK-17498][ML] StringIndexer enhancement for handling unseen labels Feb 10, 2017
Signed-off-by: VinceShieh <vincent.xie@intel.com>
@SparkQA
Copy link

SparkQA commented Feb 10, 2017

Test build #72697 has finished for PR 16883 at commit 9a41745.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Feb 10, 2017

Test build #72698 has finished for PR 16883 at commit 1736057.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Feb 10, 2017

Test build #72700 has finished for PR 16883 at commit 27c1b10.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@VinceShieh
Copy link
Author

@srowen @jkbradley do u have time to take a look?

@jkbradley
Copy link
Member

I'll take a look now, thanks!

Copy link
Member

@jkbradley jkbradley left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done with review pass. Thanks!


package org.apache.spark.ml.feature

import scala.language.existentials
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this needed?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

local build&test are fine, but will get compilation error on line 193 on Jenkins

private[feature] trait StringIndexerBase extends Params with HasInputCol with HasOutputCol
with HasHandleInvalid {
private[feature] trait StringIndexerBase extends Params with HasInputCol with HasOutputCol {
val SKIP_UNSEEN_LABEL: String = "skip"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Constants like this should go in a static object class. I'd also like to keep them private for now.


/**
* Param for how to handle unseen labels. Options are 'skip' (filter out rows with
* unseen labels), 'error' (throw an error), or 'keep' (map unseen labels with
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rephrase: "map unseen labels with indices [numLabels]" --> "put unseen labels in a special additional bucket, at index numLabels"

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(same elsewhere in docs)


/** @group getParam */
@Since("2.1.0")
def getHandleInvalid: String = $(handleInvalid)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should go in the trait.

def getHandleInvalid: String = $(handleInvalid)

/** @group setParam */
@Since("2.1.0")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update Since annotations to 2.2.0

transformSchema(dataset.schema, logging = true)

val metadata = NominalAttribute.defaultAttr
.withName($(outputCol)).withValues(labels).toMetadata()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

withValues should include a special field for invalid, if we are keeping invalid labels. How about calling that field "_invalidLabels" (and defining this constant in the StringIndexer object)?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, I cannot fully get the point.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think he means that "labels" above should also include the invalid bucket. In previous ML frameworks I've worked on we've just called this "unknown".

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, that's what I meant: In withValues(labels), labels can be set as:

val labels = getHandleInvalid match {
  case StringIndexer.KEEP_UNSEEN_LABEL => labels :+ "__unknown"
  case _ => labels
}

I'm adding underscores to the attribute name to make it a little less likely to hit conflicts.

} else if (keepInvalid) {
labels.length
} else {
throw new SparkException(s"Unseen label: $label.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update with recommendation to set handleInvalid to "keep" to handle unseen labels.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you improve the error message?

throw new SparkException(s"Unseen label: $label.  To handle unseen labels, set Param handleInvalid to ${StringIndexer.KEEP_UNSEEN_LABEL}.")

// Verify that we skip the c record
val transformed = indexerSkipInvalid.transform(df2)
val attr = Attribute.fromStructField(transformed.schema("labelIndex"))
var transformed = indexer.transform(df2)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Keep as val, and just define new vals below. Vals make the code clearer.


`StringIndexer` encodes a string column of labels to a column of label indices.
The indices are in `[0, numLabels)`, ordered by label frequencies, so the most frequent label gets index `0`.
The indices are in `[0, numLabels]`, ordered by label frequencies, so the most frequent label gets index `0`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is not correct, except when keeping invalid ones.

Notice that the row containing "d" does not appear.
Notice that the rows containing "d" or "e" do not appear.

If you had called `setHandleInvalid("keep")`, the following dataset
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"If you had called" --> "If you call"

@jkbradley
Copy link
Member

Btw, we're near the time when the 2.2 branch will be cut, and I'd like to get this into 2.2. Let me know if you're busy, and I'd be happy to help finalize the PR. Thanks!

@VinceShieh
Copy link
Author

gotcha, will update soon.

VinceShieh added 2 commits March 1, 2017 10:09
Signed-off-by: VinceShieh <vincent.xie@intel.com>
Signed-off-by: VinceShieh <vincent.xie@intel.com>
@SparkQA
Copy link

SparkQA commented Mar 1, 2017

Test build #73637 has finished for PR 16883 at commit 9bcaffc.

  • This patch fails to build.
  • This patch merges cleanly.
  • This patch adds no public classes.

val (filteredDataset, keepInvalid) = getHandleInvalid match {
  case ..
}

Signed-off-by: VinceShieh <vincent.xie@intel.com>
@SparkQA
Copy link

SparkQA commented Mar 1, 2017

Test build #73639 has finished for PR 16883 at commit 4dc10e6.

  • This patch fails to build.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Mar 1, 2017

Test build #73643 has finished for PR 16883 at commit fa24e43.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.


- throw an exception (which is the default)
- skip the row containing the unseen label entirely
- map the unseen labels with indices [numLabels]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doc suggestion: "map the unseen labels to their own index"

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or just match the phrasing in the doc param

4 | e | 3.0
~~~~

Notice that the rows containing "d" or "e" are mapped with indices "3.0"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doc suggestion: rows containing "d" or "e" are mapped with indices "3.0" => rows containing "d" and "e" are mapped to index "3.0"


/** @group setParam */
@Since("2.2.0")
def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
Copy link
Contributor

@imatiach-msft imatiach-msft Mar 3, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you keep the order of the params same as before? also, minor style comment -- keep the setDefault(handleInvalid) below the set method.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 for maintaining order.
setDefault will go in the trait (except in cases where it belongs in just one of the Estimator or Model)


private[feature] val SKIP_UNSEEN_LABEL: String = "skip"
private[feature] val ERROR_UNSEEN_LABEL: String = "error"
private[feature] val KEEP_UNSEEN_LABEL: String = "keep"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is very nice, good use of constants, I really like to see this type of code :)

Copy link
Contributor

@imatiach-msft imatiach-msft Mar 3, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would make me even happier if these were public and could be used by the test code, but I think it's up to the committers (jkbradley)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At some point, let's do that, but not yet. I like keeping things private at first in case we find mistakes after release and need to change things.

val metadata = NominalAttribute.defaultAttr
.withName($(outputCol)).withValues(labels).toMetadata()
// If we are skipping invalid records, filter them out.
val (filteredDataset, keepInvalid) = getHandleInvalid match {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor style comment: instead of keepInvalid, do you think that indexInvalid might be a better name (?)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually, I think returning a tuple here just makes things more confusing. Maybe you can move the check outside of the match.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm OK with returning a tuple; that's a common pattern. Do you mean that it makes the code inside the match statement confusing?

@imatiach-msft
Copy link
Contributor

@VinceShieh I added some minor comments. This is a nice feature!

Copy link
Member

@jkbradley jkbradley left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates! I made a few more comments.

@Since("2.1.0")
val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle " +
"unseen labels. Options are 'skip' (filter out rows with unseen labels), " +
"error (throw an error), or 'keep' (put unseen labels in a special additional bucket," +
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need space after comma: "bucket, "


private[feature] val SKIP_UNSEEN_LABEL: String = "skip"
private[feature] val ERROR_UNSEEN_LABEL: String = "error"
private[feature] val KEEP_UNSEEN_LABEL: String = "keep"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At some point, let's do that, but not yet. I like keeping things private at first in case we find mistakes after release and need to change things.

private[feature] val KEEP_UNSEEN_LABEL: String = "keep"
private[feature] val supportedHandleInvalids: Array[String] =
Array(SKIP_UNSEEN_LABEL, ERROR_UNSEEN_LABEL, KEEP_UNSEEN_LABEL)
@Since("1.6.0")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: add newline here

/** @group setParam */
@Since("2.2.0")
def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
setDefault(handleInvalid, StringIndexer.ERROR_UNSEEN_LABEL)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to set default here since it's set in the trait

transformSchema(dataset.schema, logging = true)

val metadata = NominalAttribute.defaultAttr
.withName($(outputCol)).withValues(labels).toMetadata()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, that's what I meant: In withValues(labels), labels can be set as:

val labels = getHandleInvalid match {
  case StringIndexer.KEEP_UNSEEN_LABEL => labels :+ "__unknown"
  case _ => labels
}

I'm adding underscores to the attribute name to make it a little less likely to hit conflicts.

} else if (keepInvalid) {
labels.length
} else {
throw new SparkException(s"Unseen label: $label.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you improve the error message?

throw new SparkException(s"Unseen label: $label.  To handle unseen labels, set Param handleInvalid to ${StringIndexer.KEEP_UNSEEN_LABEL}.")


- throw an exception (which is the default)
- skip the row containing the unseen label entirely
- map the unseen labels with indices [numLabels]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or just match the phrasing in the doc param


/** @group setParam */
@Since("2.2.0")
def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 for maintaining order.
setDefault will go in the trait (except in cases where it belongs in just one of the Estimator or Model)

val metadata = NominalAttribute.defaultAttr
.withName($(outputCol)).withValues(labels).toMetadata()
// If we are skipping invalid records, filter them out.
val (filteredDataset, keepInvalid) = getHandleInvalid match {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm OK with returning a tuple; that's a common pattern. Do you mean that it makes the code inside the match statement confusing?

Signed-off-by: VinceShieh <vincent.xie@intel.com>
@VinceShieh
Copy link
Author

updated. Thank you both @imatiach-msft @jkbradley

@SparkQA
Copy link

SparkQA commented Mar 6, 2017

Test build #74007 has finished for PR 16883 at commit d1acfdb.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

Copy link
Member

@jkbradley jkbradley left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes look good, thanks! Only the Since annotation issue remains.

* Default: "error"
* @group param
*/
@Since("2.1.0")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I missed this before, but these Since annotations should stay set to 1.6.0 since handleInvalid and the get/set methods were added in 1.6.0

setDefault(handleInvalid, StringIndexer.ERROR_UNSEEN_LABEL)

/** @group getParam */
@Since("2.2.0")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto


/** @group setParam */
@Since("1.6.0")
@Since("2.2.0")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

def setOutputCol(value: String): this.type = set(outputCol, value)

/** @group setParam */
@Since("2.2.0")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Signed-off-by: VinceShieh <vincent.xie@intel.com>
@SparkQA
Copy link

SparkQA commented Mar 7, 2017

Test build #74055 has finished for PR 16883 at commit c70e003.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@crackcell
Copy link

Nice work! I'm just planning to improve StringIndexer exactly the same way as yours. Now I can have a rest. :-)

@jkbradley
Copy link
Member

LGTM
Merging with master
Thanks a lot!

@asfgit asfgit closed this in 4a9034b Mar 7, 2017
@jkbradley
Copy link
Member

Btw, are you interested in updating the Python API too? https://issues.apache.org/jira/browse/SPARK-19852

@VinceShieh
Copy link
Author

Sure, I can work on that :) @jkbradley

@jkbradley
Copy link
Member

Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants