Skip to content

Commit 980c8ec

Browse files
committed
init pr
1 parent 33d43bf commit 980c8ec

File tree

2 files changed

+52
-21
lines changed

2 files changed

+52
-21
lines changed

python/pyspark/ml/base.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,22 @@ def fit(self, dataset, params=None):
7171
raise ValueError("Params must be either a param map or a list/tuple of param maps, "
7272
"but got %s." % type(params))
7373

74+
@since("2.3.0")
75+
def parallelFit(self, dataset, paramMaps, threadPool, modelCallback):
76+
"""
77+
Parallelly fits models to the input dataset with a list of param maps.
78+
79+
:param dataset: input dataset, which is an instance of :py:class:`pyspark.sql.DataFrame`
80+
:param paramMaps: a list of param maps
81+
:param threadPool: a thread pool used to run parallel fitting
82+
:param modelCallback: fitted model with corresponding param map index will be passed to
83+
the callback function.
84+
"""
85+
def singleTrain(paramMapIndex):
86+
model = self.fit(dataset, paramMaps[paramMapIndex])
87+
modelCallback(model, paramMapIndex)
88+
threadPool.map(singleTrain, range(len(paramMaps)))
89+
7490

7591
@inherit_doc
7692
class Transformer(Params):

python/pyspark/ml/tuning.py

Lines changed: 36 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,12 @@
2020

2121
from pyspark import since, keyword_only
2222
from pyspark.ml import Estimator, Model
23-
from pyspark.ml.common import _py2java
23+
from pyspark.ml.common import _java2py, _py2java
24+
from pyspark.ml.evaluation import JavaEvaluator
2425
from pyspark.ml.param import Params, Param, TypeConverters
2526
from pyspark.ml.param.shared import HasParallelism, HasSeed
2627
from pyspark.ml.util import *
27-
from pyspark.ml.wrapper import JavaParams
28+
from pyspark.ml.wrapper import JavaEstimator, JavaParams
2829
from pyspark.sql.functions import rand
2930

3031
__all__ = ['ParamGridBuilder', 'CrossValidator', 'CrossValidatorModel', 'TrainValidationSplit',
@@ -247,9 +248,14 @@ def getNumFolds(self):
247248

248249
def _fit(self, dataset):
249250
est = self.getOrDefault(self.estimator)
251+
eva = self.getOrDefault(self.evaluator)
252+
253+
if isinstance(est, JavaEstimator) and isinstance(eva, JavaEvaluator):
254+
java_model = self._to_java().fit(dataset._jdf)
255+
return CrossValidatorModel._from_java(java_model)
256+
250257
epm = self.getOrDefault(self.estimatorParamMaps)
251258
numModels = len(epm)
252-
eva = self.getOrDefault(self.evaluator)
253259
nFolds = self.getOrDefault(self.numFolds)
254260
seed = self.getOrDefault(self.seed)
255261
h = 1.0 / nFolds
@@ -266,15 +272,15 @@ def _fit(self, dataset):
266272
validation = df.filter(condition).cache()
267273
train = df.filter(~condition).cache()
268274

269-
def singleTrain(paramMap):
270-
model = est.fit(train, paramMap)
271-
# TODO: duplicate evaluator to take extra params from input
272-
metric = eva.evaluate(model.transform(validation, paramMap))
273-
return metric
275+
currentFoldMetrics = [0.0] * numModels
276+
def modelCallback(model, paramMapIndex):
277+
metric = eva.evaluate(model.transform(validation, epm[paramMapIndex]))
278+
currentFoldMetrics[paramMapIndex] = metric
279+
est.parallelFit(train, epm, pool, modelCallback)
274280

275-
currentFoldMetrics = pool.map(singleTrain, epm)
276281
for j in range(numModels):
277282
metrics[j] += (currentFoldMetrics[j] / nFolds)
283+
278284
validation.unpersist()
279285
train.unpersist()
280286

@@ -409,10 +415,12 @@ def _from_java(cls, java_stage):
409415
Used for ML persistence.
410416
"""
411417

418+
sc = SparkContext._active_spark_context
412419
bestModel = JavaParams._from_java(java_stage.bestModel())
420+
avgMetrics = _java2py(sc, java_stage.avgMetrics())
413421
estimator, epms, evaluator = super(CrossValidatorModel, cls)._from_java_impl(java_stage)
414422

415-
py_stage = cls(bestModel=bestModel).setEstimator(estimator)
423+
py_stage = cls(bestModel = bestModel, avgMetrics = avgMetrics).setEstimator(estimator)
416424
py_stage = py_stage.setEstimatorParamMaps(epms).setEvaluator(evaluator)
417425

418426
py_stage._resetUid(java_stage.uid())
@@ -426,11 +434,10 @@ def _to_java(self):
426434
"""
427435

428436
sc = SparkContext._active_spark_context
429-
# TODO: persist average metrics as well
430437
_java_obj = JavaParams._new_java_obj("org.apache.spark.ml.tuning.CrossValidatorModel",
431438
self.uid,
432439
self.bestModel._to_java(),
433-
_py2java(sc, []))
440+
_py2java(sc, self.avgMetrics))
434441
estimator, epms, evaluator = super(CrossValidatorModel, self)._to_java_impl()
435442

436443
_java_obj.set("evaluator", evaluator)
@@ -512,9 +519,14 @@ def getTrainRatio(self):
512519

513520
def _fit(self, dataset):
514521
est = self.getOrDefault(self.estimator)
522+
eva = self.getOrDefault(self.evaluator)
523+
524+
if isinstance(est, JavaEstimator) and isinstance(eva, JavaEvaluator):
525+
java_model = self._to_java().fit(dataset._jdf)
526+
return TrainValidationSplitModel._from_java(java_model)
527+
515528
epm = self.getOrDefault(self.estimatorParamMaps)
516529
numModels = len(epm)
517-
eva = self.getOrDefault(self.evaluator)
518530
tRatio = self.getOrDefault(self.trainRatio)
519531
seed = self.getOrDefault(self.seed)
520532
randCol = self.uid + "_rand"
@@ -523,13 +535,14 @@ def _fit(self, dataset):
523535
validation = df.filter(condition).cache()
524536
train = df.filter(~condition).cache()
525537

526-
def singleTrain(paramMap):
527-
model = est.fit(train, paramMap)
528-
metric = eva.evaluate(model.transform(validation, paramMap))
529-
return metric
538+
def modelCallback(model, paramMapIndex):
539+
metric = eva.evaluate(model.transform(validation, epm[paramMapIndex]))
540+
metrics[paramMapIndex] = metric
530541

531542
pool = ThreadPool(processes=min(self.getParallelism(), numModels))
532-
metrics = pool.map(singleTrain, epm)
543+
metrics = [0.0] * numModels
544+
est.parallelFit(train, epm, pool, modelCallback)
545+
533546
train.unpersist()
534547
validation.unpersist()
535548

@@ -663,12 +676,15 @@ def _from_java(cls, java_stage):
663676
Used for ML persistence.
664677
"""
665678

679+
sc = SparkContext._active_spark_context
666680
# Load information from java_stage to the instance.
667681
bestModel = JavaParams._from_java(java_stage.bestModel())
682+
validationMetrics = _java2py(sc, java_stage.validationMetrics())
668683
estimator, epms, evaluator = super(TrainValidationSplitModel,
669684
cls)._from_java_impl(java_stage)
670685
# Create a new instance of this stage.
671-
py_stage = cls(bestModel=bestModel).setEstimator(estimator)
686+
py_stage = cls(bestModel = bestModel, validationMetrics = validationMetrics)\
687+
.setEstimator(estimator)
672688
py_stage = py_stage.setEstimatorParamMaps(epms).setEvaluator(evaluator)
673689

674690
py_stage._resetUid(java_stage.uid())
@@ -681,12 +697,11 @@ def _to_java(self):
681697
"""
682698

683699
sc = SparkContext._active_spark_context
684-
# TODO: persst validation metrics as well
685700
_java_obj = JavaParams._new_java_obj(
686701
"org.apache.spark.ml.tuning.TrainValidationSplitModel",
687702
self.uid,
688703
self.bestModel._to_java(),
689-
_py2java(sc, []))
704+
_py2java(sc, self.validationMetrics))
690705
estimator, epms, evaluator = super(TrainValidationSplitModel, self)._to_java_impl()
691706

692707
_java_obj.set("evaluator", evaluator)

0 commit comments

Comments
 (0)