Skip to content

Commit 2b588ef

Browse files
committed
[SPARK-23871][ML][PYTHON]add python api for VectorAssembler handleInvalid
1 parent 6ab134c commit 2b588ef

File tree

1 file changed

+36
-5
lines changed

1 file changed

+36
-5
lines changed

python/pyspark/ml/feature.py

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2701,7 +2701,8 @@ def setParams(self, inputCol=None, outputCol=None):
27012701

27022702

27032703
@inherit_doc
2704-
class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol, JavaMLReadable, JavaMLWritable):
2704+
class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol, HasHandleInvalid, JavaMLReadable,
2705+
JavaMLWritable):
27052706
"""
27062707
A feature transformer that merges multiple columns into a vector column.
27072708
@@ -2719,14 +2720,44 @@ class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol, JavaMLReadabl
27192720
>>> loadedAssembler = VectorAssembler.load(vectorAssemblerPath)
27202721
>>> loadedAssembler.transform(df).head().freqs == vecAssembler.transform(df).head().freqs
27212722
True
2723+
>>> dfWithNullsAndNaNs = spark.createDataFrame(
2724+
... [(1.0, 2.0, None), (3.0, float("nan"), 4.0), (5.0, 6.0, 7.0)], ["a", "b", "c"])
2725+
>>> vecAssembler2 = VectorAssembler(inputCols=["a", "b", "c"], outputCol="features",
2726+
... handleInvalid="keep")
2727+
>>> vecAssembler2.transform(dfWithNullsAndNaNs).show()
2728+
+---+---+----+-------------+
2729+
| a| b| c| features|
2730+
+---+---+----+-------------+
2731+
|1.0|2.0|null|[1.0,2.0,NaN]|
2732+
|3.0|NaN| 4.0|[3.0,NaN,4.0]|
2733+
|5.0|6.0| 7.0|[5.0,6.0,7.0]|
2734+
+---+---+----+-------------+
2735+
...
2736+
>>> vecAssembler2.setParams(handleInvalid="skip").transform(dfWithNullsAndNaNs).show()
2737+
+---+---+---+-------------+
2738+
| a| b| c| features|
2739+
+---+---+---+-------------+
2740+
|5.0|6.0|7.0|[5.0,6.0,7.0]|
2741+
+---+---+---+-------------+
2742+
...
27222743
27232744
.. versionadded:: 1.4.0
27242745
"""
27252746

2747+
handleInvalid = Param(Params._dummy(), "handleInvalid", "How to handle invalid data (NULL " +
2748+
"values). Options are 'skip' (filter out rows with invalid data), " +
2749+
"'error' (throw an error), or 'keep' (return relevant number of NaN in " +
2750+
"the output). Column lengths are taken from the size of ML Attribute " +
2751+
"Group, which can be set using `VectorSizeHint` in a pipeline before " +
2752+
"`VectorAssembler`. Column lengths can also be inferred from first " +
2753+
"rows of the data since it is safe to do so but only in case of " +
2754+
"'error' or 'skip').",
2755+
typeConverter=TypeConverters.toString)
2756+
27262757
@keyword_only
2727-
def __init__(self, inputCols=None, outputCol=None):
2758+
def __init__(self, inputCols=None, outputCol=None, handleInvalid="error"):
27282759
"""
2729-
__init__(self, inputCols=None, outputCol=None)
2760+
__init__(self, inputCols=None, outputCol=None, handleInvalid="error")
27302761
"""
27312762
super(VectorAssembler, self).__init__()
27322763
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.VectorAssembler", self.uid)
@@ -2735,9 +2766,9 @@ def __init__(self, inputCols=None, outputCol=None):
27352766

27362767
@keyword_only
27372768
@since("1.4.0")
2738-
def setParams(self, inputCols=None, outputCol=None):
2769+
def setParams(self, inputCols=None, outputCol=None, handleInvalid="error"):
27392770
"""
2740-
setParams(self, inputCols=None, outputCol=None)
2771+
setParams(self, inputCols=None, outputCol=None, handleInvalid="error")
27412772
Sets params for this VectorAssembler.
27422773
"""
27432774
kwargs = self._input_kwargs

0 commit comments

Comments
 (0)