diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py index 40d830062581..dbcf2cf691aa 100644 --- a/python/pyspark/ml/param/__init__.py +++ b/python/pyspark/ml/param/__init__.py @@ -357,7 +357,7 @@ def getOrDefault(self, param): return self._defaultParamMap[param] @since("1.4.0") - def extractParamMap(self, extra=None): + def extractParamMap(self, extra=None, default=False): """ Extracts the embedded default param values and user-supplied values, and then merges them with extra values from input into @@ -366,17 +366,19 @@ def extractParamMap(self, extra=None): user-supplied values < extra. :param extra: extra param values + :param default: if just copy the default param map :return: merged param map """ - if extra is None: + if extra is None and not default: extra = dict() paramMap = self._defaultParamMap.copy() - paramMap.update(self._paramMap) - paramMap.update(extra) + if not default: + paramMap.update(self._paramMap) + paramMap.update(extra) return paramMap @since("1.4.0") - def copy(self, extra=None): + def copy(self, extra=None, default=False): """ Creates a copy of this instance with the same uid and some extra params. The default implementation creates a @@ -386,13 +388,14 @@ def copy(self, extra=None): is not sufficient. :param extra: Extra parameters to copy to the new instance + :param default: if just copy the default param map :return: Copy of this instance """ if extra is None: extra = dict() that = copy.copy(self) that._paramMap = {} - return self._copyValues(that, extra) + return self._copyValues(that, extra, default) def _shouldOwn(self, param): """ @@ -463,18 +466,19 @@ def _setDefault(self, **kwargs): self._defaultParamMap[p] = value return self - def _copyValues(self, to, extra=None): + def _copyValues(self, to, extra=None, default=False): """ Copies param values from this instance to another instance for params shared by them. :param to: the target instance :param extra: extra params to be copied + :param default: if just copy the default param map :return: the target instance with param values copied """ if extra is None: extra = dict() - paramMap = self.extractParamMap(extra) + paramMap = self.extractParamMap(extra, default) for p in self.params: if p in paramMap and to.hasParam(p.name): to._set(**{p.name: paramMap[p]})