2020
2121from pyspark import since , keyword_only
2222from pyspark .ml import Estimator , Model
23- from pyspark .ml .common import _py2java
23+ from pyspark .ml .common import _java2py , _py2java
24+ from pyspark .ml .evaluation import JavaEvaluator
2425from pyspark .ml .param import Params , Param , TypeConverters
2526from pyspark .ml .param .shared import HasParallelism , HasSeed
2627from pyspark .ml .util import *
27- from pyspark .ml .wrapper import JavaParams
28+ from pyspark .ml .wrapper import JavaEstimator , JavaParams
2829from pyspark .sql .functions import rand
2930
3031__all__ = ['ParamGridBuilder' , 'CrossValidator' , 'CrossValidatorModel' , 'TrainValidationSplit' ,
@@ -247,9 +248,14 @@ def getNumFolds(self):
247248
248249 def _fit (self , dataset ):
249250 est = self .getOrDefault (self .estimator )
251+ eva = self .getOrDefault (self .evaluator )
252+
253+ if isinstance (est , JavaEstimator ) and isinstance (eva , JavaEvaluator ):
254+ java_model = self ._to_java ().fit (dataset ._jdf )
255+ return CrossValidatorModel ._from_java (java_model )
256+
250257 epm = self .getOrDefault (self .estimatorParamMaps )
251258 numModels = len (epm )
252- eva = self .getOrDefault (self .evaluator )
253259 nFolds = self .getOrDefault (self .numFolds )
254260 seed = self .getOrDefault (self .seed )
255261 h = 1.0 / nFolds
@@ -266,15 +272,15 @@ def _fit(self, dataset):
266272 validation = df .filter (condition ).cache ()
267273 train = df .filter (~ condition ).cache ()
268274
269- def singleTrain ( paramMap ):
270- model = est . fit ( train , paramMap )
271- # TODO: duplicate evaluator to take extra params from input
272- metric = eva . evaluate ( model . transform ( validation , paramMap ))
273- return metric
275+ currentFoldMetrics = [ 0.0 ] * numModels
276+ def modelCallback ( model , paramMapIndex ):
277+ metric = eva . evaluate ( model . transform ( validation , epm [ paramMapIndex ]))
278+ currentFoldMetrics [ paramMapIndex ] = metric
279+ est . parallelFit ( train , epm , pool , modelCallback )
274280
275- currentFoldMetrics = pool .map (singleTrain , epm )
276281 for j in range (numModels ):
277282 metrics [j ] += (currentFoldMetrics [j ] / nFolds )
283+
278284 validation .unpersist ()
279285 train .unpersist ()
280286
@@ -409,10 +415,12 @@ def _from_java(cls, java_stage):
409415 Used for ML persistence.
410416 """
411417
418+ sc = SparkContext ._active_spark_context
412419 bestModel = JavaParams ._from_java (java_stage .bestModel ())
420+ avgMetrics = _java2py (sc , java_stage .avgMetrics ())
413421 estimator , epms , evaluator = super (CrossValidatorModel , cls )._from_java_impl (java_stage )
414422
415- py_stage = cls (bestModel = bestModel ).setEstimator (estimator )
423+ py_stage = cls (bestModel = bestModel , avgMetrics = avgMetrics ).setEstimator (estimator )
416424 py_stage = py_stage .setEstimatorParamMaps (epms ).setEvaluator (evaluator )
417425
418426 py_stage ._resetUid (java_stage .uid ())
@@ -426,11 +434,10 @@ def _to_java(self):
426434 """
427435
428436 sc = SparkContext ._active_spark_context
429- # TODO: persist average metrics as well
430437 _java_obj = JavaParams ._new_java_obj ("org.apache.spark.ml.tuning.CrossValidatorModel" ,
431438 self .uid ,
432439 self .bestModel ._to_java (),
433- _py2java (sc , [] ))
440+ _py2java (sc , self . avgMetrics ))
434441 estimator , epms , evaluator = super (CrossValidatorModel , self )._to_java_impl ()
435442
436443 _java_obj .set ("evaluator" , evaluator )
@@ -512,9 +519,14 @@ def getTrainRatio(self):
512519
513520 def _fit (self , dataset ):
514521 est = self .getOrDefault (self .estimator )
522+ eva = self .getOrDefault (self .evaluator )
523+
524+ if isinstance (est , JavaEstimator ) and isinstance (eva , JavaEvaluator ):
525+ java_model = self ._to_java ().fit (dataset ._jdf )
526+ return TrainValidationSplitModel ._from_java (java_model )
527+
515528 epm = self .getOrDefault (self .estimatorParamMaps )
516529 numModels = len (epm )
517- eva = self .getOrDefault (self .evaluator )
518530 tRatio = self .getOrDefault (self .trainRatio )
519531 seed = self .getOrDefault (self .seed )
520532 randCol = self .uid + "_rand"
@@ -523,13 +535,14 @@ def _fit(self, dataset):
523535 validation = df .filter (condition ).cache ()
524536 train = df .filter (~ condition ).cache ()
525537
526- def singleTrain (paramMap ):
527- model = est .fit (train , paramMap )
528- metric = eva .evaluate (model .transform (validation , paramMap ))
529- return metric
538+ def modelCallback (model , paramMapIndex ):
539+ metric = eva .evaluate (model .transform (validation , epm [paramMapIndex ]))
540+ metrics [paramMapIndex ] = metric
530541
531542 pool = ThreadPool (processes = min (self .getParallelism (), numModels ))
532- metrics = pool .map (singleTrain , epm )
543+ metrics = [0.0 ] * numModels
544+ est .parallelFit (train , epm , pool , modelCallback )
545+
533546 train .unpersist ()
534547 validation .unpersist ()
535548
@@ -663,12 +676,15 @@ def _from_java(cls, java_stage):
663676 Used for ML persistence.
664677 """
665678
679+ sc = SparkContext ._active_spark_context
666680 # Load information from java_stage to the instance.
667681 bestModel = JavaParams ._from_java (java_stage .bestModel ())
682+ validationMetrics = _java2py (sc , java_stage .validationMetrics ())
668683 estimator , epms , evaluator = super (TrainValidationSplitModel ,
669684 cls )._from_java_impl (java_stage )
670685 # Create a new instance of this stage.
671- py_stage = cls (bestModel = bestModel ).setEstimator (estimator )
686+ py_stage = cls (bestModel = bestModel , validationMetrics = validationMetrics )\
687+ .setEstimator (estimator )
672688 py_stage = py_stage .setEstimatorParamMaps (epms ).setEvaluator (evaluator )
673689
674690 py_stage ._resetUid (java_stage .uid ())
@@ -681,12 +697,11 @@ def _to_java(self):
681697 """
682698
683699 sc = SparkContext ._active_spark_context
684- # TODO: persst validation metrics as well
685700 _java_obj = JavaParams ._new_java_obj (
686701 "org.apache.spark.ml.tuning.TrainValidationSplitModel" ,
687702 self .uid ,
688703 self .bestModel ._to_java (),
689- _py2java (sc , [] ))
704+ _py2java (sc , self . validationMetrics ))
690705 estimator , epms , evaluator = super (TrainValidationSplitModel , self )._to_java_impl ()
691706
692707 _java_obj .set ("evaluator" , evaluator )
0 commit comments