Skip to content

Commit a0bd273

Browse files
Louiszrhuaxingao
authored andcommitted
[SPARK-32092][ML][PYSPARK][FOLLOWUP] Fixed CrossValidatorModel.copy() to copy models instead of list
### What changes were proposed in this pull request? Fixed `CrossValidatorModel.copy()` so that it correctly calls `.copy()` on the models instead of lists of models. ### Why are the changes needed? `copy()` was first changed in #29445 . The issue was found in CI of #29524 and fixed. This PR introduces the exact same change so that `CrossValidatorModel.copy()` and its related tests are aligned in branch `master` and branch `branch-3.0`. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Updated `test_copy` to make sure `copy()` is called on models instead of lists of models. Closes #29553 from Louiszr/fix-cv-copy. Authored-by: Louiszr <zxhst14@gmail.com> Signed-off-by: Huaxin Gao <huaxing@us.ibm.com>
1 parent 58f87b3 commit a0bd273

File tree

2 files changed

+8
-5
lines changed

2 files changed

+8
-5
lines changed

python/pyspark/ml/tests/test_tuning.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -127,9 +127,9 @@ def test_copy(self):
127127
'foo',
128128
"Changing the original avgMetrics should not affect the copied model"
129129
)
130-
cvModel.subModels[0] = 'foo'
130+
cvModel.subModels[0][0].getInducedError = lambda: 'foo'
131131
self.assertNotEqual(
132-
cvModelCopied.subModels[0],
132+
cvModelCopied.subModels[0][0].getInducedError(),
133133
'foo',
134134
"Changing the original subModels should not affect the copied model"
135135
)
@@ -852,9 +852,9 @@ def test_copy(self):
852852
'foo',
853853
"Changing the original validationMetrics should not affect the copied model"
854854
)
855-
tvsModel.subModels[0] = 'foo'
855+
tvsModel.subModels[0].getInducedError = lambda: 'foo'
856856
self.assertNotEqual(
857-
tvsModelCopied.subModels[0],
857+
tvsModelCopied.subModels[0].getInducedError(),
858858
'foo',
859859
"Changing the original subModels should not affect the copied model"
860860
)

python/pyspark/ml/tuning.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -535,7 +535,10 @@ def copy(self, extra=None):
535535
extra = dict()
536536
bestModel = self.bestModel.copy(extra)
537537
avgMetrics = list(self.avgMetrics)
538-
subModels = [model.copy() for model in self.subModels]
538+
subModels = [
539+
[sub_model.copy() for sub_model in fold_sub_models]
540+
for fold_sub_models in self.subModels
541+
]
539542
return self._copyValues(CrossValidatorModel(bestModel, avgMetrics, subModels), extra=extra)
540543

541544
@since("2.3.0")

0 commit comments

Comments
 (0)