Skip to content

Commit 9d45ec4

Browse files
sethahmengxr
authored andcommitted
[SPARK-13047][PYSPARK][ML] Pyspark Params.hasParam should not throw an error
Pyspark Params class has a method `hasParam(paramName)` which returns `True` if the class has a parameter by that name, but throws an `AttributeError` otherwise. There is not currently a way of getting a Boolean to indicate if a class has a parameter. With Spark 2.0 we could modify the existing behavior of `hasParam` or add an additional method with this functionality. In Python: ```python from pyspark.ml.classification import NaiveBayes nb = NaiveBayes() print nb.hasParam("smoothing") print nb.hasParam("notAParam") ``` produces: > True > AttributeError: 'NaiveBayes' object has no attribute 'notAParam' However, in Scala: ```scala import org.apache.spark.ml.classification.NaiveBayes val nb = new NaiveBayes() nb.hasParam("smoothing") nb.hasParam("notAParam") ``` produces: > true > false cc holdenk Author: sethah <seth.hendrickson16@gmail.com> Closes #10962 from sethah/SPARK-13047. (cherry picked from commit b354673) Signed-off-by: Xiangrui Meng <meng@databricks.com>
1 parent 91a5ca5 commit 9d45ec4

File tree

2 files changed

+12
-4
lines changed

2 files changed

+12
-4
lines changed

python/pyspark/ml/param/__init__.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,8 +156,11 @@ def hasParam(self, paramName):
156156
Tests whether this instance contains a param with a given
157157
(string) name.
158158
"""
159-
param = self._resolveParam(paramName)
160-
return param in self.params
159+
if isinstance(paramName, str):
160+
p = getattr(self, paramName, None)
161+
return isinstance(p, Param)
162+
else:
163+
raise TypeError("hasParam(): paramName must be a string")
161164

162165
@since("1.4.0")
163166
def getOrDefault(self, param):

python/pyspark/ml/tests.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,11 @@ def test_param(self):
170170
self.assertEqual(maxIter.doc, "max number of iterations (>= 0).")
171171
self.assertTrue(maxIter.parent == testParams.uid)
172172

173+
def test_hasparam(self):
174+
testParams = TestParams()
175+
self.assertTrue(all([testParams.hasParam(p.name) for p in testParams.params]))
176+
self.assertFalse(testParams.hasParam("notAParameter"))
177+
173178
def test_params(self):
174179
testParams = TestParams()
175180
maxIter = testParams.maxIter
@@ -179,7 +184,7 @@ def test_params(self):
179184
params = testParams.params
180185
self.assertEqual(params, [inputCol, maxIter, seed])
181186

182-
self.assertTrue(testParams.hasParam(maxIter))
187+
self.assertTrue(testParams.hasParam(maxIter.name))
183188
self.assertTrue(testParams.hasDefault(maxIter))
184189
self.assertFalse(testParams.isSet(maxIter))
185190
self.assertTrue(testParams.isDefined(maxIter))
@@ -188,7 +193,7 @@ def test_params(self):
188193
self.assertTrue(testParams.isSet(maxIter))
189194
self.assertEqual(testParams.getMaxIter(), 100)
190195

191-
self.assertTrue(testParams.hasParam(inputCol))
196+
self.assertTrue(testParams.hasParam(inputCol.name))
192197
self.assertFalse(testParams.hasDefault(inputCol))
193198
self.assertFalse(testParams.isSet(inputCol))
194199
self.assertFalse(testParams.isDefined(inputCol))

0 commit comments

Comments
 (0)