-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-18366][PYSPARK][ML] Add handleInvalid to Pyspark for QuantileDiscretizer and Bucketizer #15817
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-18366][PYSPARK][ML] Add handleInvalid to Pyspark for QuantileDiscretizer and Bucketizer #15817
Changes from all commits
0e41b36
3b5133c
20bfd9b
1922472
b4720aa
67a666f
234d165
0327d8a
af0d3f2
a8dc962
7ff8ad3
d589515
08f8945
36ecddb
6687d3c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -125,10 +125,13 @@ class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, Jav | |
| """ | ||
| Maps a column of continuous features to a column of feature buckets. | ||
|
|
||
| >>> df = spark.createDataFrame([(0.1,), (0.4,), (1.2,), (1.5,)], ["values"]) | ||
| >>> values = [(0.1,), (0.4,), (1.2,), (1.5,), (float("nan"),), (float("nan"),)] | ||
| >>> df = spark.createDataFrame(values, ["values"]) | ||
| >>> bucketizer = Bucketizer(splits=[-float("inf"), 0.5, 1.4, float("inf")], | ||
| ... inputCol="values", outputCol="buckets") | ||
| >>> bucketed = bucketizer.transform(df).collect() | ||
| >>> bucketed = bucketizer.setHandleInvalid("keep").transform(df).collect() | ||
| >>> len(bucketed) | ||
| 6 | ||
| >>> bucketed[0].buckets | ||
| 0.0 | ||
| >>> bucketed[1].buckets | ||
|
|
@@ -144,6 +147,9 @@ class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, Jav | |
| >>> loadedBucketizer = Bucketizer.load(bucketizerPath) | ||
| >>> loadedBucketizer.getSplits() == bucketizer.getSplits() | ||
| True | ||
| >>> bucketed = bucketizer.setHandleInvalid("skip").transform(df).collect() | ||
| >>> len(bucketed) | ||
| 4 | ||
|
|
||
| .. versionadded:: 1.4.0 | ||
| """ | ||
|
|
@@ -158,21 +164,28 @@ class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, Jav | |
| "splits specified will be treated as errors.", | ||
| typeConverter=TypeConverters.toListFloat) | ||
|
|
||
| handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid entries. " + | ||
| "Options are skip (filter out rows with invalid values), " + | ||
| "error (throw an error), or keep (keep invalid values in a special " + | ||
| "additional bucket).", | ||
| typeConverter=TypeConverters.toString) | ||
|
|
||
| @keyword_only | ||
| def __init__(self, splits=None, inputCol=None, outputCol=None): | ||
| def __init__(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error"): | ||
| """ | ||
| __init__(self, splits=None, inputCol=None, outputCol=None) | ||
| __init__(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error") | ||
| """ | ||
| super(Bucketizer, self).__init__() | ||
| self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Bucketizer", self.uid) | ||
| self._setDefault(handleInvalid="error") | ||
| kwargs = self.__init__._input_kwargs | ||
| self.setParams(**kwargs) | ||
|
|
||
| @keyword_only | ||
| @since("1.4.0") | ||
| def setParams(self, splits=None, inputCol=None, outputCol=None): | ||
| def setParams(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error"): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missing
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed |
||
| """ | ||
| setParams(self, splits=None, inputCol=None, outputCol=None) | ||
| setParams(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error") | ||
| Sets params for this Bucketizer. | ||
| """ | ||
| kwargs = self.setParams._input_kwargs | ||
|
|
@@ -192,6 +205,20 @@ def getSplits(self): | |
| """ | ||
| return self.getOrDefault(self.splits) | ||
|
|
||
| @since("2.1.0") | ||
| def setHandleInvalid(self, value): | ||
| """ | ||
| Sets the value of :py:attr:`handleInvalid`. | ||
| """ | ||
| return self._set(handleInvalid=value) | ||
|
|
||
| @since("2.1.0") | ||
| def getHandleInvalid(self): | ||
| """ | ||
| Gets the value of :py:attr:`handleInvalid` or its default value. | ||
| """ | ||
| return self.getOrDefault(self.handleInvalid) | ||
|
|
||
|
|
||
| @inherit_doc | ||
| class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): | ||
|
|
@@ -1161,12 +1188,17 @@ class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadab | |
| :py:attr:`relativeError` parameter. | ||
| The lower and upper bin bounds will be `-Infinity` and `+Infinity`, covering all real values. | ||
|
|
||
| >>> df = spark.createDataFrame([(0.1,), (0.4,), (1.2,), (1.5,)], ["values"]) | ||
| >>> values = [(0.1,), (0.4,), (1.2,), (1.5,), (float("nan"),), (float("nan"),)] | ||
| >>> df = spark.createDataFrame(values, ["values"]) | ||
| >>> qds = QuantileDiscretizer(numBuckets=2, | ||
| ... inputCol="values", outputCol="buckets", relativeError=0.01) | ||
| ... inputCol="values", outputCol="buckets", relativeError=0.01, handleInvalid="error") | ||
| >>> qds.getRelativeError() | ||
| 0.01 | ||
| >>> bucketizer = qds.fit(df) | ||
| >>> qds.setHandleInvalid("keep").fit(df).transform(df).count() | ||
| 6 | ||
| >>> qds.setHandleInvalid("skip").fit(df).transform(df).count() | ||
| 4 | ||
| >>> splits = bucketizer.getSplits() | ||
| >>> splits[0] | ||
| -inf | ||
|
|
@@ -1194,23 +1226,33 @@ class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadab | |
| "Must be in the range [0, 1].", | ||
| typeConverter=TypeConverters.toFloat) | ||
|
|
||
| handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid entries. " + | ||
| "Options are skip (filter out rows with invalid values), " + | ||
| "error (throw an error), or keep (keep invalid values in a special " + | ||
| "additional bucket).", | ||
| typeConverter=TypeConverters.toString) | ||
|
|
||
| @keyword_only | ||
| def __init__(self, numBuckets=2, inputCol=None, outputCol=None, relativeError=0.001): | ||
| def __init__(self, numBuckets=2, inputCol=None, outputCol=None, relativeError=0.001, | ||
| handleInvalid="error"): | ||
| """ | ||
| __init__(self, numBuckets=2, inputCol=None, outputCol=None, relativeError=0.001) | ||
| __init__(self, numBuckets=2, inputCol=None, outputCol=None, relativeError=0.001, \ | ||
| handleInvalid="error") | ||
| """ | ||
| super(QuantileDiscretizer, self).__init__() | ||
| self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.QuantileDiscretizer", | ||
| self.uid) | ||
| self._setDefault(numBuckets=2, relativeError=0.001) | ||
| self._setDefault(numBuckets=2, relativeError=0.001, handleInvalid="error") | ||
| kwargs = self.__init__._input_kwargs | ||
| self.setParams(**kwargs) | ||
|
|
||
| @keyword_only | ||
| @since("2.0.0") | ||
| def setParams(self, numBuckets=2, inputCol=None, outputCol=None, relativeError=0.001): | ||
| def setParams(self, numBuckets=2, inputCol=None, outputCol=None, relativeError=0.001, | ||
| handleInvalid="error"): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. missing
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed |
||
| """ | ||
| setParams(self, numBuckets=2, inputCol=None, outputCol=None, relativeError=0.001) | ||
| setParams(self, numBuckets=2, inputCol=None, outputCol=None, relativeError=0.001, \ | ||
| handleInvalid="error") | ||
| Set the params for the QuantileDiscretizer | ||
| """ | ||
| kwargs = self.setParams._input_kwargs | ||
|
|
@@ -1244,13 +1286,28 @@ def getRelativeError(self): | |
| """ | ||
| return self.getOrDefault(self.relativeError) | ||
|
|
||
| @since("2.1.0") | ||
| def setHandleInvalid(self, value): | ||
| """ | ||
| Sets the value of :py:attr:`handleInvalid`. | ||
| """ | ||
| return self._set(handleInvalid=value) | ||
|
|
||
| @since("2.1.0") | ||
| def getHandleInvalid(self): | ||
| """ | ||
| Gets the value of :py:attr:`handleInvalid` or its default value. | ||
| """ | ||
| return self.getOrDefault(self.handleInvalid) | ||
|
|
||
| def _create_model(self, java_model): | ||
| """ | ||
| Private method to convert the java_model to a Python model. | ||
| """ | ||
| return Bucketizer(splits=list(java_model.getSplits()), | ||
| inputCol=self.getInputCol(), | ||
| outputCol=self.getOutputCol()) | ||
| outputCol=self.getOutputCol(), | ||
| handleInvalid=self.getHandleInvalid()) | ||
|
|
||
|
|
||
| @inherit_doc | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we put the options in single quotes, e.g. "Options are 'skip' ..."
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@techaddict I don't think you addressed this comment?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To be fair we don't have it quoted in the scala param description, so if we want to make this change we should probably also change it in the scala side just for consistencies sake.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah it's pretty minor. Maybe we can do it later in a follow up
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool, since we've already cut RC1 and it would be nice to have these params in sooner rather than later and @techaddict seems to be a bit busy I've created a follow up JIRA ( SPARK-18628 ) for this so that we can maybe move ahead with this as is.