-
Notifications
You must be signed in to change notification settings - Fork 29k
[Spark-21221][ML] CrossValidator and TrainValidationSplit Persist Nested Estimators such as OneVsRest #18428
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Spark-21221][ML] CrossValidator and TrainValidationSplit Persist Nested Estimators such as OneVsRest #18428
Conversation
…rsist nested estimators such as OneVsRest.
|
@jkbradley @thunterdb Could you please review this? |
|
Test build #78664 has finished for PR 18428 at commit
|
jkbradley
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR! There's one catch we may be able to address later, but overall, I think my comments are all small.
| .setClassifier(new LogisticRegression) | ||
| val evaluator = new BinaryClassificationEvaluator() | ||
| .setMetricName("areaUnderPR") // not default metric | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: remove extra newline
| val ova = new OneVsRest() | ||
| .setClassifier(new LogisticRegression) | ||
| val evaluator = new BinaryClassificationEvaluator() | ||
| .setMetricName("areaUnderPR") // not default metric |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this needed for this unit test?
|
|
||
| test("read/write: CrossValidator with nested estimator") { | ||
| val ova = new OneVsRest() | ||
| .setClassifier(new LogisticRegression) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: fix indentation
| val classifier1 = new LogisticRegression().setRegParam(2.0) | ||
| val classifier2 = new LogisticRegression().setRegParam(3.0) | ||
| val paramMaps = new ParamGridBuilder() | ||
| .addGrid(ova.classifier, Array(classifier1, classifier2)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add comment that it is important to test Param values which inherit from Params.
| cv2.getEstimator match { | ||
| case ova2: OneVsRest => | ||
| assert(ova.uid === ova2.uid) | ||
| assert(ova.getClassifier.asInstanceOf[LogisticRegression].getMaxIter |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Check type of classifier before casting
| Map("parent" -> p.parent, "name" -> p.name, "value" -> p.jsonEncode(v)) | ||
| v match { | ||
| case writeableObj: MLWritable => | ||
| numParamsNotJson += 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: move this down 1 line to index from 0
| v match { | ||
| case writeableObj: MLWritable => | ||
| numParamsNotJson += 1 | ||
| val paramPath = new Path(path, "param" + p.name + numParamsNotJson).toString |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about changing the prefix "param" -> "epm_"?
| param -> value | ||
| } else { | ||
| val path = param.jsonDecode(pInfo("value")).toString | ||
| val value = DefaultParamsReader.loadParamsInstance[MLWritable](path, sc) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is OK with me for now since it will address all cases I've seen. In the future, it'd be great to make this more general by allowing it to read any MLReadable type (not just DefaultParamsReadable). I'll comment in the save() section above about this too.
| paramMap.toSeq.map { case ParamPair(p, v) => | ||
| Map("parent" -> p.parent, "name" -> p.name, "value" -> p.jsonEncode(v)) | ||
| v match { | ||
| case writeableObj: MLWritable => |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Per my comment below in the load() section, this should be restricted to DefaultParamsWritable for now. Could you please do so, but also add a check which throws an error if v is MLWritable but not DefaultParamsWritable?
…sistence. Implemented python persistence for meta-algorithms. OneVsRest overrides necessary persistence functions. Code still has prints and comments that need to be cleaned up.
…d TrainValidationSplit now persist estimators in both Scala and Python.
|
Test build #78935 has finished for PR 18428 at commit
|
|
Test build #78936 has finished for PR 18428 at commit
|
jkbradley
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the update! The fixes for my original comments look good.
I did a pass over the new parts as well. My main question is whether we can eliminate more of the duplicated code.
I may be out of touch for a week, so please ping others as well. E.g. @yinxusen who worked on this long ago or @thunterdb or @sueann
| classifier match { | ||
| case lr: LogisticRegression => | ||
| assert(ova.getClassifier.asInstanceOf[LogisticRegression].getMaxIter | ||
| === lr.asInstanceOf[LogisticRegression].getMaxIter) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lr is already of type LogisticRegression (no need to cast)
| classifier match { | ||
| case lr: LogisticRegression => | ||
| assert(ova.getClassifier.asInstanceOf[LogisticRegression].getMaxIter | ||
| === lr.asInstanceOf[LogisticRegression].getMaxIter) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lr is already of type LogisticRegression (no need to cast)
| * Assert sequences of estimatorParamMaps are identical. | ||
| * Params must be simple types comparable with `===`. | ||
| */ | ||
| def compareParamMaps(pMaps: Array[ParamMap], pMaps2: Array[ParamMap]): Unit = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this is the same as in CrossValidatorSuite, then can you please move them to a shared file (maybe ValidatorParamsSuite)?
| .setEstimator(ova) | ||
| .setEvaluator(evaluator) | ||
| .setNumFolds(20) | ||
| .setEstimatorParamMaps(paramMaps) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please compare the original + the loaded estimatorParamMaps
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same for the TrainValidationSplitSuite
python/pyspark/ml/classification.py
Outdated
|
|
||
| def _make_java_param_pair(self, param, value): | ||
| """ | ||
| Makes a Java parm pair. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
correct typo: parm -> param (and in original please)
python/pyspark/ml/tests.py
Outdated
| loadedModel = CrossValidatorModel.load(cvModelPath) | ||
| self.assertEqual(loadedModel.bestModel.uid, cvModel.bestModel.uid) | ||
|
|
||
| def test_save_load_nested_stimator(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fix typo "stimator"
python/pyspark/ml/tuning.py
Outdated
| """ | ||
| return self.getOrDefault(self.evaluator) | ||
|
|
||
| def getEvaluator(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
duplicate of above method?
python/pyspark/ml/tuning.py
Outdated
| return JavaMLWriter(self) | ||
|
|
||
| @since("2.3.0") | ||
| def save(self, path): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no need to copy this here; it can use the one in MLWritable
python/pyspark/ml/tuning.py
Outdated
| return JavaMLReader(cls) | ||
|
|
||
| @classmethod | ||
| @since("2.3.0") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't add since annotations to private methods
python/pyspark/ml/wrapper.py
Outdated
| param = self._resolveParam(param) | ||
| java_param = self._java_obj.getParam(param.name) | ||
| java_value = _py2java(sc, value) | ||
| if isinstance(value, Estimator) or isinstance(value, Model): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
check for instances of JavaParams instead?
…eVsRest. Does not work because the make java param pair function in wrapper.py does not recognize the uid set in self._java_obj in the OneVsRest constructor.
|
Test build #79297 has finished for PR 18428 at commit
|
|
Test build #79299 has finished for PR 18428 at commit
|
|
I couldn't think of a great way to reduce code duplication between JavaWrapper and OneVsRest. One thing I realized: This make break backwards compatibility. Let's fix that. We unfortunately don't have a good way to test backwards compatibility, so I'd recommend testing manually (saving a model before your patch and loading it back after your patch). |
| val param = est.getParam(pInfo("name")) | ||
| val value = param.jsonDecode(pInfo("value")) | ||
| param -> value | ||
| if (pInfo("isJson").toBoolean.booleanValue()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think fixing backwards compatibility will just mean testing for whether the field "isJson" is present here
|
Also, can you please add "OneVsRest" to the PR and JIRA titles since this touches that class? |
|
Test build #79568 has finished for PR 18428 at commit
|
| val value = param.jsonDecode(pInfo("value")) | ||
| param -> value | ||
| if (!pInfo.contains("isJson") || | ||
| (pInfo.contains("isJson") && pInfo("isJson").toBoolean.booleanValue())) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style nit: indent line 202 +1 space
Also, could you please add a comment saying that SPARK-21221 introduced the "isJson" field?
|
Rats, one more thing: We need to use relative paths, not absolute ones, when we put paths in the persisted file. Could you please add a unit test which checks this, perhaps by saving a model, moving it, and then loading it? |
|
Test build #79574 has finished for PR 18428 at commit
|
…r as a relative path instead of absolute path.
|
Test build #79596 has finished for PR 18428 at commit
|
jkbradley
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the update! Just the style nit remains
| Files.move(subDirWithUid.toPath, newSubdirWithUid.toPath, StandardCopyOption.ATOMIC_MOVE) | ||
|
|
||
| val loader = instance.getClass.getMethod("read") | ||
| .invoke(null).asInstanceOf[MLReader[T]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fix indentation
|
Test build #79622 has finished for PR 18428 at commit
|
|
Test build #79623 has finished for PR 18428 at commit
|
|
LGTM |
What changes were proposed in this pull request?
Added functionality for CrossValidator and TrainValidationSplit to persist nested estimators such as OneVsRest. Also added CrossValidator and TrainValidation split persistence to pyspark.
How was this patch tested?
Performed both cross validation and train validation split with a one vs. rest estimator and tested read/write functionality of the estimator parameter maps required by these meta-algorithms.