diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 64b21caa616ec..a33c3e79453e1 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -2494,21 +2494,30 @@ class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol, JavaMLReadable, JavaM formula = Param(Params._dummy(), "formula", "R model formula", typeConverter=TypeConverters.toString) + forceIndexLabel = Param(Params._dummy(), "forceIndexLabel", + "Force to index label whether it is numeric or string", + typeConverter=TypeConverters.toBoolean) + @keyword_only - def __init__(self, formula=None, featuresCol="features", labelCol="label"): + def __init__(self, formula=None, featuresCol="features", labelCol="label", + forceIndexLabel=False): """ - __init__(self, formula=None, featuresCol="features", labelCol="label") + __init__(self, formula=None, featuresCol="features", labelCol="label", \ + forceIndexLabel=False) """ super(RFormula, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.RFormula", self.uid) + self._setDefault(forceIndexLabel=False) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @keyword_only @since("1.5.0") - def setParams(self, formula=None, featuresCol="features", labelCol="label"): + def setParams(self, formula=None, featuresCol="features", labelCol="label", + forceIndexLabel=False): """ - setParams(self, formula=None, featuresCol="features", labelCol="label") + setParams(self, formula=None, featuresCol="features", labelCol="label", \ + forceIndexLabel=False) Sets params for RFormula. """ kwargs = self.setParams._input_kwargs @@ -2528,6 +2537,20 @@ def getFormula(self): """ return self.getOrDefault(self.formula) + @since("2.1.0") + def setForceIndexLabel(self, value): + """ + Sets the value of :py:attr:`forceIndexLabel`. + """ + return self._set(forceIndexLabel=value) + + @since("2.1.0") + def getForceIndexLabel(self): + """ + Gets the value of :py:attr:`forceIndexLabel`. + """ + return self.getOrDefault(self.forceIndexLabel) + def _create_model(self, java_model): return RFormulaModel(java_model) diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index e233549850888..9d46cc3b4ae64 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -477,6 +477,22 @@ def test_count_vectorizer_with_binary(self): feature, expected = r self.assertEqual(feature, expected) + def test_rformula_force_index_label(self): + df = self.spark.createDataFrame([ + (1.0, 1.0, "a"), + (0.0, 2.0, "b"), + (1.0, 0.0, "a")], ["y", "x", "s"]) + # Does not index label by default since it's numeric type. + rf = RFormula(formula="y ~ x + s") + model = rf.fit(df) + transformedDF = model.transform(df) + self.assertEqual(transformedDF.head().label, 1.0) + # Force to index label. + rf2 = RFormula(formula="y ~ x + s").setForceIndexLabel(True) + model2 = rf2.fit(df) + transformedDF2 = model2.transform(df) + self.assertEqual(transformedDF2.head().label, 0.0) + class HasInducedError(Params):