From b8d022a577d20ac213fb263f973172c23d2f2ee7 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Mon, 11 Apr 2016 09:12:15 -0700 Subject: [PATCH 1/5] [SPARK-14472] Made JavaCallable a base class to JavaWrapper to define _java_obj and creating Java objects --- python/pyspark/ml/evaluation.py | 2 +- python/pyspark/ml/wrapper.py | 74 +++++++++++++++------------------ 2 files changed, 34 insertions(+), 42 deletions(-) diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py index c9b95b3bf45d9..8bda32146e9c8 100644 --- a/python/pyspark/ml/evaluation.py +++ b/python/pyspark/ml/evaluation.py @@ -81,7 +81,7 @@ def isLargerBetter(self): @inherit_doc -class JavaEvaluator(Evaluator, JavaWrapper): +class JavaEvaluator(JavaWrapper, Evaluator): """ Base class for :py:class:`Evaluator`s that wrap Java/Scala implementations. diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index a2cf2296fbe09..0314d6c2cf864 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -25,29 +25,32 @@ from pyspark.mllib.common import inherit_doc, _java2py, _py2java -@inherit_doc -class JavaWrapper(Params): +class JavaCallable(object): """ - Utility class to help create wrapper classes from Java/Scala - implementations of pipeline components. + Wrapper class for a Java companion object """ + def __init__(self, java_obj=None): + super(JavaCallable, self).__init__() + self._java_obj = java_obj - __metaclass__ = ABCMeta - - def __init__(self): + @classmethod + def createFromClassName(cls, java_class, *args): """ - Initialize the wrapped java object to None + Construct this object from given Java classname and arguments """ - super(JavaWrapper, self).__init__() - #: The wrapped Java companion object. Subclasses should initialize - #: it properly. The param values in the Java object should be - #: synced with the Python wrapper in fit/transform/evaluate/copy. - self._java_obj = None + java_obj = JavaCallable._new_java_obj(java_class, *args) + return cls(java_obj) + + def _call_java(self, name, *args): + m = getattr(self._java_obj, name) + sc = SparkContext._active_spark_context + java_args = [_py2java(sc, arg) for arg in args] + return _java2py(sc, m(*java_args)) @staticmethod def _new_java_obj(java_class, *args): """ - Construct a new Java object. + Returns a new Java object. """ sc = SparkContext._active_spark_context java_obj = _jvm() @@ -56,6 +59,18 @@ def _new_java_obj(java_class, *args): java_args = [_py2java(sc, arg) for arg in args] return java_obj(*java_args) + +@inherit_doc +class JavaWrapper(JavaCallable, Params): + """ + Utility class to help create wrapper classes from Java/Scala + implementations of pipeline components. + """ + #: The param values in the Java object should be + #: synced with the Python wrapper in fit/transform/evaluate/copy. + + __metaclass__ = ABCMeta + def _make_java_param_pair(self, param, value): """ Makes a Java parm pair. @@ -166,7 +181,7 @@ def __get_class(clazz): @inherit_doc -class JavaEstimator(Estimator, JavaWrapper): +class JavaEstimator(JavaWrapper, Estimator): """ Base class for :py:class:`Estimator`s that wrap Java/Scala implementations. @@ -199,7 +214,7 @@ def _fit(self, dataset): @inherit_doc -class JavaTransformer(Transformer, JavaWrapper): +class JavaTransformer(JavaWrapper, Transformer): """ Base class for :py:class:`Transformer`s that wrap Java/Scala implementations. Subclasses should ensure they have the transformer Java object @@ -213,30 +228,8 @@ def _transform(self, dataset): return DataFrame(self._java_obj.transform(dataset._jdf), dataset.sql_ctx) -class JavaCallable(object): - """ - Wrapper for a plain object in JVM to make Java calls, can be used - as a mixin to another class that defines a _java_obj wrapper - """ - def __init__(self, java_obj=None, sc=None): - super(JavaCallable, self).__init__() - self._sc = sc if sc is not None else SparkContext._active_spark_context - # if this class is a mixin and _java_obj is already defined then don't initialize - if java_obj is not None or not hasattr(self, "_java_obj"): - self._java_obj = java_obj - - def __del__(self): - if self._java_obj is not None: - self._sc._gateway.detach(self._java_obj) - - def _call_java(self, name, *args): - m = getattr(self._java_obj, name) - java_args = [_py2java(self._sc, arg) for arg in args] - return _java2py(self._sc, m(*java_args)) - - @inherit_doc -class JavaModel(Model, JavaCallable, JavaTransformer): +class JavaModel(JavaTransformer, Model): """ Base class for :py:class:`Model`s that wrap Java/Scala implementations. Subclasses should inherit this class before @@ -259,9 +252,8 @@ def __init__(self, java_model=None): these wrappers depend on pyspark.ml.util (both directly and via other ML classes). """ - super(JavaModel, self).__init__() + super(JavaModel, self).__init__(java_model) if java_model is not None: - self._java_obj = java_model self.uid = java_model.uid() def copy(self, extra=None): From c95ba2782c002c531b3d46cb4e770c5ace52215d Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Mon, 11 Apr 2016 11:33:51 -0700 Subject: [PATCH 2/5] Renamed classes to better reflect purpose, JavaWrapper to JavaWrapperParams and JavaCallable to JavaWrapper --- python/pyspark/ml/classification.py | 4 ++-- python/pyspark/ml/evaluation.py | 4 ++-- python/pyspark/ml/pipeline.py | 13 +++++++------ python/pyspark/ml/regression.py | 4 ++-- python/pyspark/ml/tests.py | 4 ++-- python/pyspark/ml/tuning.py | 29 ++++++++++++++++------------- python/pyspark/ml/wrapper.py | 14 +++++++------- 7 files changed, 38 insertions(+), 34 deletions(-) diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index d98919b3c6398..546b17762b34b 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -19,7 +19,7 @@ from pyspark import since from pyspark.ml.util import * -from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaCallable +from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaWrapper from pyspark.ml.param import TypeConverters from pyspark.ml.param.shared import * from pyspark.ml.regression import ( @@ -272,7 +272,7 @@ def evaluate(self, dataset): return BinaryLogisticRegressionSummary(java_blr_summary) -class LogisticRegressionSummary(JavaCallable): +class LogisticRegressionSummary(JavaWrapper): """ Abstraction for Logistic Regression Results for a given model. diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py index 8bda32146e9c8..37a20b3062be2 100644 --- a/python/pyspark/ml/evaluation.py +++ b/python/pyspark/ml/evaluation.py @@ -18,7 +18,7 @@ from abc import abstractmethod, ABCMeta from pyspark import since -from pyspark.ml.wrapper import JavaWrapper +from pyspark.ml.wrapper import JavaWrapperParams from pyspark.ml.param import Param, Params from pyspark.ml.param.shared import HasLabelCol, HasPredictionCol, HasRawPredictionCol from pyspark.ml.util import keyword_only @@ -81,7 +81,7 @@ def isLargerBetter(self): @inherit_doc -class JavaEvaluator(JavaWrapper, Evaluator): +class JavaEvaluator(JavaWrapperParams, Evaluator): """ Base class for :py:class:`Evaluator`s that wrap Java/Scala implementations. diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py index 2b5504bc2966a..0a0c277aae46d 100644 --- a/python/pyspark/ml/pipeline.py +++ b/python/pyspark/ml/pipeline.py @@ -25,7 +25,7 @@ from pyspark.ml import Estimator, Model, Transformer from pyspark.ml.param import Param, Params from pyspark.ml.util import keyword_only, JavaMLWriter, JavaMLReader, MLReadable, MLWritable -from pyspark.ml.wrapper import JavaWrapper +from pyspark.ml.wrapper import JavaWrapperParams from pyspark.mllib.common import inherit_doc @@ -177,7 +177,7 @@ def _from_java(cls, java_stage): # Create a new instance of this stage. py_stage = cls() # Load information from java_stage to the instance. - py_stages = [JavaWrapper._from_java(s) for s in java_stage.getStages()] + py_stages = [JavaWrapperParams._from_java(s) for s in java_stage.getStages()] py_stage.setStages(py_stages) py_stage._resetUid(java_stage.uid()) return py_stage @@ -195,7 +195,7 @@ def _to_java(self): for idx, stage in enumerate(self.getStages()): java_stages[idx] = stage._to_java() - _java_obj = JavaWrapper._new_java_obj("org.apache.spark.ml.Pipeline", self.uid) + _java_obj = JavaWrapperParams._new_java_obj("org.apache.spark.ml.Pipeline", self.uid) _java_obj.setStages(java_stages) return _java_obj @@ -275,7 +275,7 @@ def _from_java(cls, java_stage): Used for ML persistence. """ # Load information from java_stage to the instance. - py_stages = [JavaWrapper._from_java(s) for s in java_stage.stages()] + py_stages = [JavaWrapperParams._from_java(s) for s in java_stage.stages()] # Create a new instance of this stage. py_stage = cls(py_stages) py_stage._resetUid(java_stage.uid()) @@ -294,7 +294,8 @@ def _to_java(self): for idx, stage in enumerate(self.stages): java_stages[idx] = stage._to_java() - _java_obj =\ - JavaWrapper._new_java_obj("org.apache.spark.ml.PipelineModel", self.uid, java_stages) + _java_obj = JavaWrapperParams._new_java_obj("org.apache.spark.ml.PipelineModel", + self.uid, + java_stages) return _java_obj diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index f6c5d130dd856..fccb9f0666897 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -20,7 +20,7 @@ from pyspark import since from pyspark.ml.param.shared import * from pyspark.ml.util import * -from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaCallable +from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaWrapper from pyspark.mllib.common import inherit_doc from pyspark.sql import DataFrame @@ -187,7 +187,7 @@ def evaluate(self, dataset): return LinearRegressionSummary(java_lr_summary) -class LinearRegressionSummary(JavaCallable): +class LinearRegressionSummary(JavaWrapper): """ .. note:: Experimental diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 2dcd5eeb52c21..dc7a5c12bdee7 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -52,7 +52,7 @@ from pyspark.ml.tuning import * from pyspark.ml.util import keyword_only from pyspark.ml.util import MLWritable, MLWriter -from pyspark.ml.wrapper import JavaWrapper +from pyspark.ml.wrapper import JavaWrapperParams from pyspark.mllib.linalg import Vectors, DenseVector, SparseVector from pyspark.sql import DataFrame, SQLContext, Row from pyspark.sql.functions import rand @@ -644,7 +644,7 @@ def _compare_pipelines(self, m1, m2): """ self.assertEqual(m1.uid, m2.uid) self.assertEqual(type(m1), type(m2)) - if isinstance(m1, JavaWrapper): + if isinstance(m1, JavaWrapperParams): self.assertEqual(len(m1.params), len(m2.params)) for p in m1.params: self.assertEqual(m1.getOrDefault(p), m2.getOrDefault(p)) diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index da00f317b348c..b47abdd3d7a8c 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -24,7 +24,7 @@ from pyspark.ml.param import Params, Param, TypeConverters from pyspark.ml.param.shared import HasSeed from pyspark.ml.util import keyword_only, JavaMLWriter, JavaMLReader, MLReadable, MLWritable -from pyspark.ml.wrapper import JavaWrapper +from pyspark.ml.wrapper import JavaWrapperParams from pyspark.sql.functions import rand from pyspark.mllib.common import inherit_doc, _py2java @@ -148,8 +148,8 @@ def _from_java_impl(cls, java_stage): """ # Load information from java_stage to the instance. - estimator = JavaWrapper._from_java(java_stage.getEstimator()) - evaluator = JavaWrapper._from_java(java_stage.getEvaluator()) + estimator = JavaWrapperParams._from_java(java_stage.getEstimator()) + evaluator = JavaWrapperParams._from_java(java_stage.getEvaluator()) epms = [estimator._transfer_param_map_from_java(epm) for epm in java_stage.getEstimatorParamMaps()] return estimator, epms, evaluator @@ -329,7 +329,8 @@ def _to_java(self): estimator, epms, evaluator = super(CrossValidator, self)._to_java_impl() - _java_obj = JavaWrapper._new_java_obj("org.apache.spark.ml.tuning.CrossValidator", self.uid) + _java_obj = JavaWrapperParams._new_java_obj("org.apache.spark.ml.tuning.CrossValidator", + self.uid) _java_obj.setEstimatorParamMaps(epms) _java_obj.setEvaluator(evaluator) _java_obj.setEstimator(estimator) @@ -393,7 +394,7 @@ def _from_java(cls, java_stage): """ # Load information from java_stage to the instance. - bestModel = JavaWrapper._from_java(java_stage.bestModel()) + bestModel = JavaWrapperParams._from_java(java_stage.bestModel()) estimator, epms, evaluator = super(CrossValidatorModel, cls)._from_java_impl(java_stage) # Create a new instance of this stage. py_stage = cls(bestModel=bestModel)\ @@ -410,10 +411,11 @@ def _to_java(self): sc = SparkContext._active_spark_context - _java_obj = JavaWrapper._new_java_obj("org.apache.spark.ml.tuning.CrossValidatorModel", - self.uid, - self.bestModel._to_java(), - _py2java(sc, [])) + _java_obj = JavaWrapperParams._new_java_obj( + "org.apache.spark.ml.tuning.CrossValidatorModel", + self.uid, + self.bestModel._to_java(), + _py2java(sc, [])) estimator, epms, evaluator = super(CrossValidatorModel, self)._to_java_impl() _java_obj.set("evaluator", evaluator) @@ -574,8 +576,9 @@ def _to_java(self): estimator, epms, evaluator = super(TrainValidationSplit, self)._to_java_impl() - _java_obj = JavaWrapper._new_java_obj("org.apache.spark.ml.tuning.TrainValidationSplit", - self.uid) + _java_obj = JavaWrapperParams._new_java_obj( + "org.apache.spark.ml.tuning.TrainValidationSplit", + self.uid) _java_obj.setEstimatorParamMaps(epms) _java_obj.setEvaluator(evaluator) _java_obj.setEstimator(estimator) @@ -637,7 +640,7 @@ def _from_java(cls, java_stage): """ # Load information from java_stage to the instance. - bestModel = JavaWrapper._from_java(java_stage.bestModel()) + bestModel = JavaWrapperParams._from_java(java_stage.bestModel()) estimator, epms, evaluator = \ super(TrainValidationSplitModel, cls)._from_java_impl(java_stage) # Create a new instance of this stage. @@ -655,7 +658,7 @@ def _to_java(self): sc = SparkContext._active_spark_context - _java_obj = JavaWrapper._new_java_obj( + _java_obj = JavaWrapperParams._new_java_obj( "org.apache.spark.ml.tuning.TrainValidationSplitModel", self.uid, self.bestModel._to_java(), diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index 0314d6c2cf864..f1603698b3850 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -25,12 +25,12 @@ from pyspark.mllib.common import inherit_doc, _java2py, _py2java -class JavaCallable(object): +class JavaWrapper(object): """ Wrapper class for a Java companion object """ def __init__(self, java_obj=None): - super(JavaCallable, self).__init__() + super(JavaWrapper, self).__init__() self._java_obj = java_obj @classmethod @@ -38,7 +38,7 @@ def createFromClassName(cls, java_class, *args): """ Construct this object from given Java classname and arguments """ - java_obj = JavaCallable._new_java_obj(java_class, *args) + java_obj = JavaWrapper._new_java_obj(java_class, *args) return cls(java_obj) def _call_java(self, name, *args): @@ -61,7 +61,7 @@ def _new_java_obj(java_class, *args): @inherit_doc -class JavaWrapper(JavaCallable, Params): +class JavaWrapperParams(JavaWrapper, Params): """ Utility class to help create wrapper classes from Java/Scala implementations of pipeline components. @@ -166,7 +166,7 @@ def __get_class(clazz): stage_name = java_stage.getClass().getName().replace("org.apache.spark", "pyspark") # Generate a default new instance from the stage_name class. py_type = __get_class(stage_name) - if issubclass(py_type, JavaWrapper): + if issubclass(py_type, JavaWrapperParams): # Load information from java_stage to the instance. py_stage = py_type() py_stage._java_obj = java_stage @@ -181,7 +181,7 @@ def __get_class(clazz): @inherit_doc -class JavaEstimator(JavaWrapper, Estimator): +class JavaEstimator(JavaWrapperParams, Estimator): """ Base class for :py:class:`Estimator`s that wrap Java/Scala implementations. @@ -214,7 +214,7 @@ def _fit(self, dataset): @inherit_doc -class JavaTransformer(JavaWrapper, Transformer): +class JavaTransformer(JavaWrapperParams, Transformer): """ Base class for :py:class:`Transformer`s that wrap Java/Scala implementations. Subclasses should ensure they have the transformer Java object From c1d41c7de25ba798abd4da5b342df15eede442b5 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Mon, 11 Apr 2016 11:37:23 -0700 Subject: [PATCH 3/5] Made alternate constructor for JavaWrapper private so doesn't appear in docs --- python/pyspark/ml/wrapper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index f1603698b3850..8c2e2d7692e81 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -34,7 +34,7 @@ def __init__(self, java_obj=None): self._java_obj = java_obj @classmethod - def createFromClassName(cls, java_class, *args): + def _create_from_java_class(cls, java_class, *args): """ Construct this object from given Java classname and arguments """ From 8121a3dd08d478e7de3fab8d80974beb50592abf Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Tue, 12 Apr 2016 16:02:18 -0700 Subject: [PATCH 4/5] renamed JavaWrapperParams to just JavaParams and fixed 2 comments --- python/pyspark/ml/evaluation.py | 4 ++-- python/pyspark/ml/pipeline.py | 13 ++++++------- python/pyspark/ml/tests.py | 4 ++-- python/pyspark/ml/tuning.py | 29 +++++++++++++---------------- python/pyspark/ml/util.py | 4 ++-- python/pyspark/ml/wrapper.py | 8 ++++---- 6 files changed, 29 insertions(+), 33 deletions(-) diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py index 37a20b3062be2..4b0bade102802 100644 --- a/python/pyspark/ml/evaluation.py +++ b/python/pyspark/ml/evaluation.py @@ -18,7 +18,7 @@ from abc import abstractmethod, ABCMeta from pyspark import since -from pyspark.ml.wrapper import JavaWrapperParams +from pyspark.ml.wrapper import JavaParams from pyspark.ml.param import Param, Params from pyspark.ml.param.shared import HasLabelCol, HasPredictionCol, HasRawPredictionCol from pyspark.ml.util import keyword_only @@ -81,7 +81,7 @@ def isLargerBetter(self): @inherit_doc -class JavaEvaluator(JavaWrapperParams, Evaluator): +class JavaEvaluator(JavaParams, Evaluator): """ Base class for :py:class:`Evaluator`s that wrap Java/Scala implementations. diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py index 0a0c277aae46d..d95f7dea6094e 100644 --- a/python/pyspark/ml/pipeline.py +++ b/python/pyspark/ml/pipeline.py @@ -25,7 +25,7 @@ from pyspark.ml import Estimator, Model, Transformer from pyspark.ml.param import Param, Params from pyspark.ml.util import keyword_only, JavaMLWriter, JavaMLReader, MLReadable, MLWritable -from pyspark.ml.wrapper import JavaWrapperParams +from pyspark.ml.wrapper import JavaParams from pyspark.mllib.common import inherit_doc @@ -177,7 +177,7 @@ def _from_java(cls, java_stage): # Create a new instance of this stage. py_stage = cls() # Load information from java_stage to the instance. - py_stages = [JavaWrapperParams._from_java(s) for s in java_stage.getStages()] + py_stages = [JavaParams._from_java(s) for s in java_stage.getStages()] py_stage.setStages(py_stages) py_stage._resetUid(java_stage.uid()) return py_stage @@ -195,7 +195,7 @@ def _to_java(self): for idx, stage in enumerate(self.getStages()): java_stages[idx] = stage._to_java() - _java_obj = JavaWrapperParams._new_java_obj("org.apache.spark.ml.Pipeline", self.uid) + _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.Pipeline", self.uid) _java_obj.setStages(java_stages) return _java_obj @@ -275,7 +275,7 @@ def _from_java(cls, java_stage): Used for ML persistence. """ # Load information from java_stage to the instance. - py_stages = [JavaWrapperParams._from_java(s) for s in java_stage.stages()] + py_stages = [JavaParams._from_java(s) for s in java_stage.stages()] # Create a new instance of this stage. py_stage = cls(py_stages) py_stage._resetUid(java_stage.uid()) @@ -294,8 +294,7 @@ def _to_java(self): for idx, stage in enumerate(self.stages): java_stages[idx] = stage._to_java() - _java_obj = JavaWrapperParams._new_java_obj("org.apache.spark.ml.PipelineModel", - self.uid, - java_stages) + _java_obj = \ + JavaParams._new_java_obj("org.apache.spark.ml.PipelineModel", self.uid, java_stages) return _java_obj diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index dc7a5c12bdee7..bcbeacbe80491 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -52,7 +52,7 @@ from pyspark.ml.tuning import * from pyspark.ml.util import keyword_only from pyspark.ml.util import MLWritable, MLWriter -from pyspark.ml.wrapper import JavaWrapperParams +from pyspark.ml.wrapper import JavaParams from pyspark.mllib.linalg import Vectors, DenseVector, SparseVector from pyspark.sql import DataFrame, SQLContext, Row from pyspark.sql.functions import rand @@ -644,7 +644,7 @@ def _compare_pipelines(self, m1, m2): """ self.assertEqual(m1.uid, m2.uid) self.assertEqual(type(m1), type(m2)) - if isinstance(m1, JavaWrapperParams): + if isinstance(m1, JavaParams): self.assertEqual(len(m1.params), len(m2.params)) for p in m1.params: self.assertEqual(m1.getOrDefault(p), m2.getOrDefault(p)) diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index b47abdd3d7a8c..f0420f490911d 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -24,7 +24,7 @@ from pyspark.ml.param import Params, Param, TypeConverters from pyspark.ml.param.shared import HasSeed from pyspark.ml.util import keyword_only, JavaMLWriter, JavaMLReader, MLReadable, MLWritable -from pyspark.ml.wrapper import JavaWrapperParams +from pyspark.ml.wrapper import JavaParams from pyspark.sql.functions import rand from pyspark.mllib.common import inherit_doc, _py2java @@ -148,8 +148,8 @@ def _from_java_impl(cls, java_stage): """ # Load information from java_stage to the instance. - estimator = JavaWrapperParams._from_java(java_stage.getEstimator()) - evaluator = JavaWrapperParams._from_java(java_stage.getEvaluator()) + estimator = JavaParams._from_java(java_stage.getEstimator()) + evaluator = JavaParams._from_java(java_stage.getEvaluator()) epms = [estimator._transfer_param_map_from_java(epm) for epm in java_stage.getEstimatorParamMaps()] return estimator, epms, evaluator @@ -329,8 +329,7 @@ def _to_java(self): estimator, epms, evaluator = super(CrossValidator, self)._to_java_impl() - _java_obj = JavaWrapperParams._new_java_obj("org.apache.spark.ml.tuning.CrossValidator", - self.uid) + _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.tuning.CrossValidator", self.uid) _java_obj.setEstimatorParamMaps(epms) _java_obj.setEvaluator(evaluator) _java_obj.setEstimator(estimator) @@ -394,7 +393,7 @@ def _from_java(cls, java_stage): """ # Load information from java_stage to the instance. - bestModel = JavaWrapperParams._from_java(java_stage.bestModel()) + bestModel = JavaParams._from_java(java_stage.bestModel()) estimator, epms, evaluator = super(CrossValidatorModel, cls)._from_java_impl(java_stage) # Create a new instance of this stage. py_stage = cls(bestModel=bestModel)\ @@ -411,11 +410,10 @@ def _to_java(self): sc = SparkContext._active_spark_context - _java_obj = JavaWrapperParams._new_java_obj( - "org.apache.spark.ml.tuning.CrossValidatorModel", - self.uid, - self.bestModel._to_java(), - _py2java(sc, [])) + _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.tuning.CrossValidatorModel", + self.uid, + self.bestModel._to_java(), + _py2java(sc, [])) estimator, epms, evaluator = super(CrossValidatorModel, self)._to_java_impl() _java_obj.set("evaluator", evaluator) @@ -576,9 +574,8 @@ def _to_java(self): estimator, epms, evaluator = super(TrainValidationSplit, self)._to_java_impl() - _java_obj = JavaWrapperParams._new_java_obj( - "org.apache.spark.ml.tuning.TrainValidationSplit", - self.uid) + _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.tuning.TrainValidationSplit", + self.uid) _java_obj.setEstimatorParamMaps(epms) _java_obj.setEvaluator(evaluator) _java_obj.setEstimator(estimator) @@ -640,7 +637,7 @@ def _from_java(cls, java_stage): """ # Load information from java_stage to the instance. - bestModel = JavaWrapperParams._from_java(java_stage.bestModel()) + bestModel = JavaParams._from_java(java_stage.bestModel()) estimator, epms, evaluator = \ super(TrainValidationSplitModel, cls)._from_java_impl(java_stage) # Create a new instance of this stage. @@ -658,7 +655,7 @@ def _to_java(self): sc = SparkContext._active_spark_context - _java_obj = JavaWrapperParams._new_java_obj( + _java_obj = JavaParams._new_java_obj( "org.apache.spark.ml.tuning.TrainValidationSplitModel", self.uid, self.bestModel._to_java(), diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index d4411fdfb9dde..9dfcef0e40d67 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -99,7 +99,7 @@ def context(self, sqlContext): @inherit_doc class JavaMLWriter(MLWriter): """ - (Private) Specialization of :py:class:`MLWriter` for :py:class:`JavaWrapper` types + (Private) Specialization of :py:class:`MLWriter` for :py:class:`JavaParams` types """ def __init__(self, instance): @@ -178,7 +178,7 @@ def context(self, sqlContext): @inherit_doc class JavaMLReader(MLReader): """ - (Private) Specialization of :py:class:`MLReader` for :py:class:`JavaWrapper` types + (Private) Specialization of :py:class:`MLReader` for :py:class:`JavaParams` types """ def __init__(self, clazz): diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index 8c2e2d7692e81..72127e8ec1917 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -61,7 +61,7 @@ def _new_java_obj(java_class, *args): @inherit_doc -class JavaWrapperParams(JavaWrapper, Params): +class JavaParams(JavaWrapper, Params): """ Utility class to help create wrapper classes from Java/Scala implementations of pipeline components. @@ -166,7 +166,7 @@ def __get_class(clazz): stage_name = java_stage.getClass().getName().replace("org.apache.spark", "pyspark") # Generate a default new instance from the stage_name class. py_type = __get_class(stage_name) - if issubclass(py_type, JavaWrapperParams): + if issubclass(py_type, JavaParams): # Load information from java_stage to the instance. py_stage = py_type() py_stage._java_obj = java_stage @@ -181,7 +181,7 @@ def __get_class(clazz): @inherit_doc -class JavaEstimator(JavaWrapperParams, Estimator): +class JavaEstimator(JavaParams, Estimator): """ Base class for :py:class:`Estimator`s that wrap Java/Scala implementations. @@ -214,7 +214,7 @@ def _fit(self, dataset): @inherit_doc -class JavaTransformer(JavaWrapperParams, Transformer): +class JavaTransformer(JavaParams, Transformer): """ Base class for :py:class:`Transformer`s that wrap Java/Scala implementations. Subclasses should ensure they have the transformer Java object From 0fa51bb4c1ae98dd08c4ae46af4426aec2715de8 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Tue, 12 Apr 2016 16:32:20 -0700 Subject: [PATCH 5/5] fixed typo --- python/pyspark/ml/pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py index d95f7dea6094e..9d654e8b0f8d0 100644 --- a/python/pyspark/ml/pipeline.py +++ b/python/pyspark/ml/pipeline.py @@ -294,7 +294,7 @@ def _to_java(self): for idx, stage in enumerate(self.stages): java_stages[idx] = stage._to_java() - _java_obj = \ + _java_obj =\ JavaParams._new_java_obj("org.apache.spark.ml.PipelineModel", self.uid, java_stages) return _java_obj