Skip to content

Commit 53bd9d2

Browse files
committed
change setting of Param to ignore None values
1 parent 37b93f5 commit 53bd9d2

File tree

2 files changed

+14
-4
lines changed

2 files changed

+14
-4
lines changed

python/pyspark/ml/param/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -425,13 +425,13 @@ def _set(self, **kwargs):
425425
Sets user-supplied params.
426426
"""
427427
for param, value in kwargs.items():
428-
p = getattr(self, param)
429428
if value is not None:
429+
p = getattr(self, param)
430430
try:
431431
value = p.typeConverter(value)
432432
except TypeError as e:
433433
raise TypeError('Invalid param value given for param "%s". %s' % (p.name, e))
434-
self._paramMap[p] = value
434+
self._paramMap[p] = value
435435
return self
436436

437437
def _clear(self, param):

python/pyspark/ml/tests.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -247,14 +247,14 @@ class TestParams(HasMaxIter, HasInputCol, HasSeed):
247247
A subclass of Params mixed with HasMaxIter, HasInputCol and HasSeed.
248248
"""
249249
@keyword_only
250-
def __init__(self, seed=None):
250+
def __init__(self, maxIter=None, inputCol=None, seed=None):
251251
super(TestParams, self).__init__()
252252
self._setDefault(maxIter=10)
253253
kwargs = self.__init__._input_kwargs
254254
self.setParams(**kwargs)
255255

256256
@keyword_only
257-
def setParams(self, seed=None):
257+
def setParams(self, maxIter=None, inputCol=None, seed=None):
258258
"""
259259
setParams(self, seed=None)
260260
Sets params for this test.
@@ -389,6 +389,16 @@ def test_word2vec_param(self):
389389
# Check windowSize is set properly
390390
self.assertEqual(model.getWindowSize(), 6)
391391

392+
def test_param_value_None(self):
393+
tp = TestParams()
394+
self.assertFalse(tp.isSet(tp.inputCol), "inputCol is not set initially")
395+
tp.setParams(inputCol=None)
396+
self.assertFalse(tp.isSet(tp.inputCol), "Value of None should not change param")
397+
tp.setParams(inputCol="input")
398+
self.assertTrue(tp.isSet(tp.inputCol), "inputCol should now be set")
399+
tp.setParams(inputCol=None)
400+
self.assertTrue(tp.isSet(tp.inputCol), "inputCol should still be set")
401+
392402

393403
class FeatureTests(SparkSessionTestCase):
394404

0 commit comments

Comments
 (0)