Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions python/pyspark/ml/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from pyspark import since
from pyspark.ml.util import *
from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaCallable
from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaWrapper
from pyspark.ml.param import TypeConverters
from pyspark.ml.param.shared import *
from pyspark.ml.regression import (
Expand Down Expand Up @@ -272,7 +272,7 @@ def evaluate(self, dataset):
return BinaryLogisticRegressionSummary(java_blr_summary)


class LogisticRegressionSummary(JavaCallable):
class LogisticRegressionSummary(JavaWrapper):
"""
Abstraction for Logistic Regression Results for a given model.

Expand Down
4 changes: 2 additions & 2 deletions python/pyspark/ml/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from abc import abstractmethod, ABCMeta

from pyspark import since
from pyspark.ml.wrapper import JavaWrapper
from pyspark.ml.wrapper import JavaParams
from pyspark.ml.param import Param, Params
from pyspark.ml.param.shared import HasLabelCol, HasPredictionCol, HasRawPredictionCol
from pyspark.ml.util import keyword_only
Expand Down Expand Up @@ -81,7 +81,7 @@ def isLargerBetter(self):


@inherit_doc
class JavaEvaluator(Evaluator, JavaWrapper):
class JavaEvaluator(JavaParams, Evaluator):
"""
Base class for :py:class:`Evaluator`s that wrap Java/Scala
implementations.
Expand Down
10 changes: 5 additions & 5 deletions python/pyspark/ml/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from pyspark.ml import Estimator, Model, Transformer
from pyspark.ml.param import Param, Params
from pyspark.ml.util import keyword_only, JavaMLWriter, JavaMLReader, MLReadable, MLWritable
from pyspark.ml.wrapper import JavaWrapper
from pyspark.ml.wrapper import JavaParams
from pyspark.mllib.common import inherit_doc


Expand Down Expand Up @@ -177,7 +177,7 @@ def _from_java(cls, java_stage):
# Create a new instance of this stage.
py_stage = cls()
# Load information from java_stage to the instance.
py_stages = [JavaWrapper._from_java(s) for s in java_stage.getStages()]
py_stages = [JavaParams._from_java(s) for s in java_stage.getStages()]
py_stage.setStages(py_stages)
py_stage._resetUid(java_stage.uid())
return py_stage
Expand All @@ -195,7 +195,7 @@ def _to_java(self):
for idx, stage in enumerate(self.getStages()):
java_stages[idx] = stage._to_java()

_java_obj = JavaWrapper._new_java_obj("org.apache.spark.ml.Pipeline", self.uid)
_java_obj = JavaParams._new_java_obj("org.apache.spark.ml.Pipeline", self.uid)
_java_obj.setStages(java_stages)

return _java_obj
Expand Down Expand Up @@ -275,7 +275,7 @@ def _from_java(cls, java_stage):
Used for ML persistence.
"""
# Load information from java_stage to the instance.
py_stages = [JavaWrapper._from_java(s) for s in java_stage.stages()]
py_stages = [JavaParams._from_java(s) for s in java_stage.stages()]
# Create a new instance of this stage.
py_stage = cls(py_stages)
py_stage._resetUid(java_stage.uid())
Expand All @@ -295,6 +295,6 @@ def _to_java(self):
java_stages[idx] = stage._to_java()

_java_obj =\
JavaWrapper._new_java_obj("org.apache.spark.ml.PipelineModel", self.uid, java_stages)
JavaParams._new_java_obj("org.apache.spark.ml.PipelineModel", self.uid, java_stages)

return _java_obj
4 changes: 2 additions & 2 deletions python/pyspark/ml/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from pyspark import since
from pyspark.ml.param.shared import *
from pyspark.ml.util import *
from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaCallable
from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaWrapper
from pyspark.mllib.common import inherit_doc
from pyspark.sql import DataFrame

Expand Down Expand Up @@ -187,7 +187,7 @@ def evaluate(self, dataset):
return LinearRegressionSummary(java_lr_summary)


class LinearRegressionSummary(JavaCallable):
class LinearRegressionSummary(JavaWrapper):
"""
.. note:: Experimental

Expand Down
4 changes: 2 additions & 2 deletions python/pyspark/ml/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
from pyspark.ml.tuning import *
from pyspark.ml.util import keyword_only
from pyspark.ml.util import MLWritable, MLWriter
from pyspark.ml.wrapper import JavaWrapper
from pyspark.ml.wrapper import JavaParams
from pyspark.mllib.linalg import Vectors, DenseVector, SparseVector
from pyspark.sql import DataFrame, SQLContext, Row
from pyspark.sql.functions import rand
Expand Down Expand Up @@ -644,7 +644,7 @@ def _compare_pipelines(self, m1, m2):
"""
self.assertEqual(m1.uid, m2.uid)
self.assertEqual(type(m1), type(m2))
if isinstance(m1, JavaWrapper):
if isinstance(m1, JavaParams):
self.assertEqual(len(m1.params), len(m2.params))
for p in m1.params:
self.assertEqual(m1.getOrDefault(p), m2.getOrDefault(p))
Expand Down
26 changes: 13 additions & 13 deletions python/pyspark/ml/tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from pyspark.ml.param import Params, Param, TypeConverters
from pyspark.ml.param.shared import HasSeed
from pyspark.ml.util import keyword_only, JavaMLWriter, JavaMLReader, MLReadable, MLWritable
from pyspark.ml.wrapper import JavaWrapper
from pyspark.ml.wrapper import JavaParams
from pyspark.sql.functions import rand
from pyspark.mllib.common import inherit_doc, _py2java

Expand Down Expand Up @@ -148,8 +148,8 @@ def _from_java_impl(cls, java_stage):
"""

# Load information from java_stage to the instance.
estimator = JavaWrapper._from_java(java_stage.getEstimator())
evaluator = JavaWrapper._from_java(java_stage.getEvaluator())
estimator = JavaParams._from_java(java_stage.getEstimator())
evaluator = JavaParams._from_java(java_stage.getEvaluator())
epms = [estimator._transfer_param_map_from_java(epm)
for epm in java_stage.getEstimatorParamMaps()]
return estimator, epms, evaluator
Expand Down Expand Up @@ -329,7 +329,7 @@ def _to_java(self):

estimator, epms, evaluator = super(CrossValidator, self)._to_java_impl()

_java_obj = JavaWrapper._new_java_obj("org.apache.spark.ml.tuning.CrossValidator", self.uid)
_java_obj = JavaParams._new_java_obj("org.apache.spark.ml.tuning.CrossValidator", self.uid)
_java_obj.setEstimatorParamMaps(epms)
_java_obj.setEvaluator(evaluator)
_java_obj.setEstimator(estimator)
Expand Down Expand Up @@ -393,7 +393,7 @@ def _from_java(cls, java_stage):
"""

# Load information from java_stage to the instance.
bestModel = JavaWrapper._from_java(java_stage.bestModel())
bestModel = JavaParams._from_java(java_stage.bestModel())
estimator, epms, evaluator = super(CrossValidatorModel, cls)._from_java_impl(java_stage)
# Create a new instance of this stage.
py_stage = cls(bestModel=bestModel)\
Expand All @@ -410,10 +410,10 @@ def _to_java(self):

sc = SparkContext._active_spark_context

_java_obj = JavaWrapper._new_java_obj("org.apache.spark.ml.tuning.CrossValidatorModel",
self.uid,
self.bestModel._to_java(),
_py2java(sc, []))
_java_obj = JavaParams._new_java_obj("org.apache.spark.ml.tuning.CrossValidatorModel",
self.uid,
self.bestModel._to_java(),
_py2java(sc, []))
estimator, epms, evaluator = super(CrossValidatorModel, self)._to_java_impl()

_java_obj.set("evaluator", evaluator)
Expand Down Expand Up @@ -574,8 +574,8 @@ def _to_java(self):

estimator, epms, evaluator = super(TrainValidationSplit, self)._to_java_impl()

_java_obj = JavaWrapper._new_java_obj("org.apache.spark.ml.tuning.TrainValidationSplit",
self.uid)
_java_obj = JavaParams._new_java_obj("org.apache.spark.ml.tuning.TrainValidationSplit",
self.uid)
_java_obj.setEstimatorParamMaps(epms)
_java_obj.setEvaluator(evaluator)
_java_obj.setEstimator(estimator)
Expand Down Expand Up @@ -637,7 +637,7 @@ def _from_java(cls, java_stage):
"""

# Load information from java_stage to the instance.
bestModel = JavaWrapper._from_java(java_stage.bestModel())
bestModel = JavaParams._from_java(java_stage.bestModel())
estimator, epms, evaluator = \
super(TrainValidationSplitModel, cls)._from_java_impl(java_stage)
# Create a new instance of this stage.
Expand All @@ -655,7 +655,7 @@ def _to_java(self):

sc = SparkContext._active_spark_context

_java_obj = JavaWrapper._new_java_obj(
_java_obj = JavaParams._new_java_obj(
"org.apache.spark.ml.tuning.TrainValidationSplitModel",
self.uid,
self.bestModel._to_java(),
Expand Down
4 changes: 2 additions & 2 deletions python/pyspark/ml/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def context(self, sqlContext):
@inherit_doc
class JavaMLWriter(MLWriter):
"""
(Private) Specialization of :py:class:`MLWriter` for :py:class:`JavaWrapper` types
(Private) Specialization of :py:class:`MLWriter` for :py:class:`JavaParams` types
"""

def __init__(self, instance):
Expand Down Expand Up @@ -178,7 +178,7 @@ def context(self, sqlContext):
@inherit_doc
class JavaMLReader(MLReader):
"""
(Private) Specialization of :py:class:`MLReader` for :py:class:`JavaWrapper` types
(Private) Specialization of :py:class:`MLReader` for :py:class:`JavaParams` types
"""

def __init__(self, clazz):
Expand Down
76 changes: 34 additions & 42 deletions python/pyspark/ml/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,29 +25,32 @@
from pyspark.mllib.common import inherit_doc, _java2py, _py2java


@inherit_doc
class JavaWrapper(Params):
class JavaWrapper(object):
"""
Utility class to help create wrapper classes from Java/Scala
implementations of pipeline components.
Wrapper class for a Java companion object
"""
def __init__(self, java_obj=None):
super(JavaWrapper, self).__init__()
self._java_obj = java_obj

__metaclass__ = ABCMeta

def __init__(self):
@classmethod
def _create_from_java_class(cls, java_class, *args):
"""
Initialize the wrapped java object to None
Construct this object from given Java classname and arguments
"""
super(JavaWrapper, self).__init__()
#: The wrapped Java companion object. Subclasses should initialize
#: it properly. The param values in the Java object should be
#: synced with the Python wrapper in fit/transform/evaluate/copy.
self._java_obj = None
java_obj = JavaWrapper._new_java_obj(java_class, *args)
return cls(java_obj)

def _call_java(self, name, *args):
m = getattr(self._java_obj, name)
sc = SparkContext._active_spark_context
java_args = [_py2java(sc, arg) for arg in args]
return _java2py(sc, m(*java_args))

@staticmethod
def _new_java_obj(java_class, *args):
"""
Construct a new Java object.
Returns a new Java object.
"""
sc = SparkContext._active_spark_context
java_obj = _jvm()
Expand All @@ -56,6 +59,18 @@ def _new_java_obj(java_class, *args):
java_args = [_py2java(sc, arg) for arg in args]
return java_obj(*java_args)


@inherit_doc
class JavaParams(JavaWrapper, Params):
"""
Utility class to help create wrapper classes from Java/Scala
implementations of pipeline components.
"""
#: The param values in the Java object should be
#: synced with the Python wrapper in fit/transform/evaluate/copy.

__metaclass__ = ABCMeta

def _make_java_param_pair(self, param, value):
"""
Makes a Java parm pair.
Expand Down Expand Up @@ -151,7 +166,7 @@ def __get_class(clazz):
stage_name = java_stage.getClass().getName().replace("org.apache.spark", "pyspark")
# Generate a default new instance from the stage_name class.
py_type = __get_class(stage_name)
if issubclass(py_type, JavaWrapper):
if issubclass(py_type, JavaParams):
# Load information from java_stage to the instance.
py_stage = py_type()
py_stage._java_obj = java_stage
Expand All @@ -166,7 +181,7 @@ def __get_class(clazz):


@inherit_doc
class JavaEstimator(Estimator, JavaWrapper):
class JavaEstimator(JavaParams, Estimator):
"""
Base class for :py:class:`Estimator`s that wrap Java/Scala
implementations.
Expand Down Expand Up @@ -199,7 +214,7 @@ def _fit(self, dataset):


@inherit_doc
class JavaTransformer(Transformer, JavaWrapper):
class JavaTransformer(JavaParams, Transformer):
"""
Base class for :py:class:`Transformer`s that wrap Java/Scala
implementations. Subclasses should ensure they have the transformer Java object
Expand All @@ -213,30 +228,8 @@ def _transform(self, dataset):
return DataFrame(self._java_obj.transform(dataset._jdf), dataset.sql_ctx)


class JavaCallable(object):
"""
Wrapper for a plain object in JVM to make Java calls, can be used
as a mixin to another class that defines a _java_obj wrapper
"""
def __init__(self, java_obj=None, sc=None):
super(JavaCallable, self).__init__()
self._sc = sc if sc is not None else SparkContext._active_spark_context
# if this class is a mixin and _java_obj is already defined then don't initialize
if java_obj is not None or not hasattr(self, "_java_obj"):
self._java_obj = java_obj

def __del__(self):
if self._java_obj is not None:
self._sc._gateway.detach(self._java_obj)

def _call_java(self, name, *args):
m = getattr(self._java_obj, name)
java_args = [_py2java(self._sc, arg) for arg in args]
return _java2py(self._sc, m(*java_args))


@inherit_doc
class JavaModel(Model, JavaCallable, JavaTransformer):
class JavaModel(JavaTransformer, Model):
"""
Base class for :py:class:`Model`s that wrap Java/Scala
implementations. Subclasses should inherit this class before
Expand All @@ -259,9 +252,8 @@ def __init__(self, java_model=None):
these wrappers depend on pyspark.ml.util (both directly and via
other ML classes).
"""
super(JavaModel, self).__init__()
super(JavaModel, self).__init__(java_model)
if java_model is not None:
self._java_obj = java_model
self.uid = java_model.uid()

def copy(self, extra=None):
Expand Down