Skip to content

Commit 46840fb

Browse files
committed
update wrappers
1 parent b6db1ed commit 46840fb

File tree

1 file changed

+5
-12
lines changed

1 file changed

+5
-12
lines changed

python/pyspark/ml/wrapper.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -69,20 +69,13 @@ def _transfer_params_to_java(self, java_obj):
6969
java_param = java_obj.getParam(param.name)
7070
java_obj.set(java_param.w(value))
7171

72-
def _empty_java_param_map(self):
72+
@staticmethod
73+
def _empty_java_param_map():
7374
"""
7475
Returns an empty Java ParamMap reference.
7576
"""
7677
return _jvm().org.apache.spark.ml.param.ParamMap()
7778

78-
def _create_java_param_map(self, params, java_obj):
79-
paramMap = self._empty_java_param_map()
80-
for param, value in params.items():
81-
if param.parent is self:
82-
java_param = java_obj.getParam(param.name)
83-
paramMap.put(java_param.w(value))
84-
return paramMap
85-
8679

8780
@inherit_doc
8881
class JavaEstimator(Estimator, JavaWrapper):
@@ -109,7 +102,7 @@ def _fit_java(self, dataset):
109102
"""
110103
java_obj = self._java_obj()
111104
self._transfer_params_to_java(java_obj)
112-
return java_obj.fit(dataset._jdf, self._empty_java_param_map())
105+
return java_obj.fit(dataset._jdf)
113106

114107
def _fit(self, dataset):
115108
java_model = self._fit_java(dataset)
@@ -161,7 +154,7 @@ def copy(self, extra={}):
161154
:return: Copy of this instance
162155
"""
163156
that = Params.copy(self, extra)
164-
that._java_model = that._java_model.copy()
157+
that._java_model = self._java_model.copy(self._empty_java_param_map())
165158
return that
166159

167160

@@ -182,4 +175,4 @@ def _evaluate(self, dataset):
182175
"""
183176
java_obj = self._java_obj()
184177
self._transfer_params_to_java(java_obj)
185-
return java_obj.evaluate(dataset._jdf, self._empty_java_param_map())
178+
return java_obj.evaluate(dataset._jdf)

0 commit comments

Comments
 (0)