diff --git a/python/pyspark/ml/base.py b/python/pyspark/ml/base.py index 339e5d6af52a..cb869201ad30 100644 --- a/python/pyspark/ml/base.py +++ b/python/pyspark/ml/base.py @@ -59,6 +59,12 @@ def fit(self, dataset, params=None): return [self.fit(dataset, paramMap) for paramMap in params] elif isinstance(params, dict): if params: + if isinstance(params.keys()[0],str): + param_new = dict() + for param, value in params.items(): + p = getattr(self, param) + param_new.update({p: value}) + params=param_new return self.copy(params)._fit(dataset) else: return self._fit(dataset)