@@ -2701,7 +2701,8 @@ def setParams(self, inputCol=None, outputCol=None):
27012701
27022702
27032703@inherit_doc
2704- class VectorAssembler (JavaTransformer , HasInputCols , HasOutputCol , JavaMLReadable , JavaMLWritable ):
2704+ class VectorAssembler (JavaTransformer , HasInputCols , HasOutputCol , HasHandleInvalid , JavaMLReadable ,
2705+ JavaMLWritable ):
27052706 """
27062707 A feature transformer that merges multiple columns into a vector column.
27072708
@@ -2719,14 +2720,44 @@ class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol, JavaMLReadabl
27192720 >>> loadedAssembler = VectorAssembler.load(vectorAssemblerPath)
27202721 >>> loadedAssembler.transform(df).head().freqs == vecAssembler.transform(df).head().freqs
27212722 True
2723+ >>> dfWithNullsAndNaNs = spark.createDataFrame(
2724+ ... [(1.0, 2.0, None), (3.0, float("nan"), 4.0), (5.0, 6.0, 7.0)], ["a", "b", "c"])
2725+ >>> vecAssembler2 = VectorAssembler(inputCols=["a", "b", "c"], outputCol="features",
2726+ ... handleInvalid="keep")
2727+ >>> vecAssembler2.transform(dfWithNullsAndNaNs).show()
2728+ +---+---+----+-------------+
2729+ | a| b| c| features|
2730+ +---+---+----+-------------+
2731+ |1.0|2.0|null|[1.0,2.0,NaN]|
2732+ |3.0|NaN| 4.0|[3.0,NaN,4.0]|
2733+ |5.0|6.0| 7.0|[5.0,6.0,7.0]|
2734+ +---+---+----+-------------+
2735+ ...
2736+ >>> vecAssembler2.setParams(handleInvalid="skip").transform(dfWithNullsAndNaNs).show()
2737+ +---+---+---+-------------+
2738+ | a| b| c| features|
2739+ +---+---+---+-------------+
2740+ |5.0|6.0|7.0|[5.0,6.0,7.0]|
2741+ +---+---+---+-------------+
2742+ ...
27222743
27232744 .. versionadded:: 1.4.0
27242745 """
27252746
2747+ handleInvalid = Param (Params ._dummy (), "handleInvalid" , "How to handle invalid data (NULL " +
2748+ "values). Options are 'skip' (filter out rows with invalid data), " +
2749+ "'error' (throw an error), or 'keep' (return relevant number of NaN in " +
2750+ "the output). Column lengths are taken from the size of ML Attribute " +
2751+ "Group, which can be set using `VectorSizeHint` in a pipeline before " +
2752+ "`VectorAssembler`. Column lengths can also be inferred from first " +
2753+ "rows of the data since it is safe to do so but only in case of " +
2754+ "'error' or 'skip')." ,
2755+ typeConverter = TypeConverters .toString )
2756+
27262757 @keyword_only
2727- def __init__ (self , inputCols = None , outputCol = None ):
2758+ def __init__ (self , inputCols = None , outputCol = None , handleInvalid = "error" ):
27282759 """
2729- __init__(self, inputCols=None, outputCol=None)
2760+ __init__(self, inputCols=None, outputCol=None, handleInvalid="error" )
27302761 """
27312762 super (VectorAssembler , self ).__init__ ()
27322763 self ._java_obj = self ._new_java_obj ("org.apache.spark.ml.feature.VectorAssembler" , self .uid )
@@ -2735,9 +2766,9 @@ def __init__(self, inputCols=None, outputCol=None):
27352766
27362767 @keyword_only
27372768 @since ("1.4.0" )
2738- def setParams (self , inputCols = None , outputCol = None ):
2769+ def setParams (self , inputCols = None , outputCol = None , handleInvalid = "error" ):
27392770 """
2740- setParams(self, inputCols=None, outputCol=None)
2771+ setParams(self, inputCols=None, outputCol=None, handleInvalid="error" )
27412772 Sets params for this VectorAssembler.
27422773 """
27432774 kwargs = self ._input_kwargs
0 commit comments