Skip to content
Closed
85 changes: 71 additions & 14 deletions python/pyspark/ml/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,13 @@ class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, Jav
"""
Maps a column of continuous features to a column of feature buckets.

>>> df = spark.createDataFrame([(0.1,), (0.4,), (1.2,), (1.5,)], ["values"])
>>> values = [(0.1,), (0.4,), (1.2,), (1.5,), (float("nan"),), (float("nan"),)]
>>> df = spark.createDataFrame(values, ["values"])
>>> bucketizer = Bucketizer(splits=[-float("inf"), 0.5, 1.4, float("inf")],
... inputCol="values", outputCol="buckets")
>>> bucketed = bucketizer.transform(df).collect()
>>> bucketed = bucketizer.setHandleInvalid("keep").transform(df).collect()
>>> len(bucketed)
6
>>> bucketed[0].buckets
0.0
>>> bucketed[1].buckets
Expand All @@ -144,6 +147,9 @@ class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, Jav
>>> loadedBucketizer = Bucketizer.load(bucketizerPath)
>>> loadedBucketizer.getSplits() == bucketizer.getSplits()
True
>>> bucketed = bucketizer.setHandleInvalid("skip").transform(df).collect()
>>> len(bucketed)
4

.. versionadded:: 1.4.0
"""
Expand All @@ -158,21 +164,28 @@ class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, Jav
"splits specified will be treated as errors.",
typeConverter=TypeConverters.toListFloat)

handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid entries. " +
"Options are skip (filter out rows with invalid values), " +
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we put the options in single quotes, e.g. "Options are 'skip' ..."

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@techaddict I don't think you addressed this comment?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be fair we don't have it quoted in the scala param description, so if we want to make this change we should probably also change it in the scala side just for consistencies sake.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it's pretty minor. Maybe we can do it later in a follow up

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool, since we've already cut RC1 and it would be nice to have these params in sooner rather than later and @techaddict seems to be a bit busy I've created a follow up JIRA ( SPARK-18628 ) for this so that we can maybe move ahead with this as is.

"error (throw an error), or keep (keep invalid values in a special " +
"additional bucket).",
typeConverter=TypeConverters.toString)

@keyword_only
def __init__(self, splits=None, inputCol=None, outputCol=None):
def __init__(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error"):
"""
__init__(self, splits=None, inputCol=None, outputCol=None)
__init__(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error")
"""
super(Bucketizer, self).__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Bucketizer", self.uid)
self._setDefault(handleInvalid="error")
kwargs = self.__init__._input_kwargs
self.setParams(**kwargs)

@keyword_only
@since("1.4.0")
def setParams(self, splits=None, inputCol=None, outputCol=None):
def setParams(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing handleInvalid in doc string below.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

"""
setParams(self, splits=None, inputCol=None, outputCol=None)
setParams(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error")
Sets params for this Bucketizer.
"""
kwargs = self.setParams._input_kwargs
Expand All @@ -192,6 +205,20 @@ def getSplits(self):
"""
return self.getOrDefault(self.splits)

@since("2.1.0")
def setHandleInvalid(self, value):
"""
Sets the value of :py:attr:`handleInvalid`.
"""
return self._set(handleInvalid=value)

@since("2.1.0")
def getHandleInvalid(self):
"""
Gets the value of :py:attr:`handleInvalid` or its default value.
"""
return self.getOrDefault(self.handleInvalid)


@inherit_doc
class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
Expand Down Expand Up @@ -1161,12 +1188,17 @@ class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadab
:py:attr:`relativeError` parameter.
The lower and upper bin bounds will be `-Infinity` and `+Infinity`, covering all real values.

>>> df = spark.createDataFrame([(0.1,), (0.4,), (1.2,), (1.5,)], ["values"])
>>> values = [(0.1,), (0.4,), (1.2,), (1.5,), (float("nan"),), (float("nan"),)]
>>> df = spark.createDataFrame(values, ["values"])
>>> qds = QuantileDiscretizer(numBuckets=2,
... inputCol="values", outputCol="buckets", relativeError=0.01)
... inputCol="values", outputCol="buckets", relativeError=0.01, handleInvalid="error")
>>> qds.getRelativeError()
0.01
>>> bucketizer = qds.fit(df)
>>> qds.setHandleInvalid("keep").fit(df).transform(df).count()
6
>>> qds.setHandleInvalid("skip").fit(df).transform(df).count()
4
>>> splits = bucketizer.getSplits()
>>> splits[0]
-inf
Expand Down Expand Up @@ -1194,23 +1226,33 @@ class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadab
"Must be in the range [0, 1].",
typeConverter=TypeConverters.toFloat)

handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid entries. " +
"Options are skip (filter out rows with invalid values), " +
"error (throw an error), or keep (keep invalid values in a special " +
"additional bucket).",
typeConverter=TypeConverters.toString)

@keyword_only
def __init__(self, numBuckets=2, inputCol=None, outputCol=None, relativeError=0.001):
def __init__(self, numBuckets=2, inputCol=None, outputCol=None, relativeError=0.001,
handleInvalid="error"):
"""
__init__(self, numBuckets=2, inputCol=None, outputCol=None, relativeError=0.001)
__init__(self, numBuckets=2, inputCol=None, outputCol=None, relativeError=0.001, \
handleInvalid="error")
"""
super(QuantileDiscretizer, self).__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.QuantileDiscretizer",
self.uid)
self._setDefault(numBuckets=2, relativeError=0.001)
self._setDefault(numBuckets=2, relativeError=0.001, handleInvalid="error")
kwargs = self.__init__._input_kwargs
self.setParams(**kwargs)

@keyword_only
@since("2.0.0")
def setParams(self, numBuckets=2, inputCol=None, outputCol=None, relativeError=0.001):
def setParams(self, numBuckets=2, inputCol=None, outputCol=None, relativeError=0.001,
handleInvalid="error"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing handleInvalid in doc string below.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

"""
setParams(self, numBuckets=2, inputCol=None, outputCol=None, relativeError=0.001)
setParams(self, numBuckets=2, inputCol=None, outputCol=None, relativeError=0.001, \
handleInvalid="error")
Set the params for the QuantileDiscretizer
"""
kwargs = self.setParams._input_kwargs
Expand Down Expand Up @@ -1244,13 +1286,28 @@ def getRelativeError(self):
"""
return self.getOrDefault(self.relativeError)

@since("2.1.0")
def setHandleInvalid(self, value):
"""
Sets the value of :py:attr:`handleInvalid`.
"""
return self._set(handleInvalid=value)

@since("2.1.0")
def getHandleInvalid(self):
"""
Gets the value of :py:attr:`handleInvalid` or its default value.
"""
return self.getOrDefault(self.handleInvalid)

def _create_model(self, java_model):
"""
Private method to convert the java_model to a Python model.
"""
return Bucketizer(splits=list(java_model.getSplits()),
inputCol=self.getInputCol(),
outputCol=self.getOutputCol())
outputCol=self.getOutputCol(),
handleInvalid=self.getHandleInvalid())


@inherit_doc
Expand Down