Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 32 additions & 8 deletions mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.ml.param.{Param, ParamMap, Params}
import org.apache.spark.ml.util._
import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.ArrayImplicits._

Expand Down Expand Up @@ -204,7 +204,7 @@ object Pipeline extends MLReadable[Pipeline] {
override def save(path: String): Unit =
instrumented(_.withSaveInstanceEvent(this, path)(super.save(path)))
override protected def saveImpl(path: String): Unit =
SharedReadWrite.saveImpl(instance, instance.getStages, sc, path)
SharedReadWrite.saveImpl(instance, instance.getStages, sparkSession, path)
}

private class PipelineReader extends MLReader[Pipeline] {
Expand All @@ -213,7 +213,8 @@ object Pipeline extends MLReadable[Pipeline] {
private val className = classOf[Pipeline].getName

override def load(path: String): Pipeline = instrumented(_.withLoadInstanceEvent(this, path) {
val (uid: String, stages: Array[PipelineStage]) = SharedReadWrite.load(className, sc, path)
val (uid: String, stages: Array[PipelineStage]) =
SharedReadWrite.load(className, sparkSession, path)
new Pipeline(uid).setStages(stages)
})
}
Expand Down Expand Up @@ -241,14 +242,26 @@ object Pipeline extends MLReadable[Pipeline] {
* - save metadata to path/metadata
* - save stages to stages/IDX_UID
*/
@deprecated("use saveImpl with SparkSession", "4.0.0")
def saveImpl(
instance: Params,
stages: Array[PipelineStage],
sc: SparkContext,
path: String): Unit =
saveImpl(
instance,
stages,
SparkSession.builder().sparkContext(sc).getOrCreate(),
path)

def saveImpl(
instance: Params,
stages: Array[PipelineStage],
spark: SparkSession,
path: String): Unit = instrumented { instr =>
val stageUids = stages.map(_.uid)
val jsonParams = List("stageUids" -> parse(compact(render(stageUids.toImmutableArraySeq))))
DefaultParamsWriter.saveMetadata(instance, path, sc, paramMap = Some(jsonParams))
DefaultParamsWriter.saveMetadata(instance, path, spark, None, Some(jsonParams))

// Save stages
val stagesDir = new Path(path, "stages").toString
Expand All @@ -263,18 +276,28 @@ object Pipeline extends MLReadable[Pipeline] {
* Load metadata and stages for a [[Pipeline]] or [[PipelineModel]]
* @return (UID, list of stages)
*/
@deprecated("use load with SparkSession", "4.0.0")
def load(
expectedClassName: String,
sc: SparkContext,
path: String): (String, Array[PipelineStage]) =
load(
expectedClassName,
SparkSession.builder().sparkContext(sc).getOrCreate(),
path)

def load(
expectedClassName: String,
spark: SparkSession,
path: String): (String, Array[PipelineStage]) = instrumented { instr =>
val metadata = DefaultParamsReader.loadMetadata(path, sc, expectedClassName)
val metadata = DefaultParamsReader.loadMetadata(path, spark, expectedClassName)

implicit val format = DefaultFormats
val stagesDir = new Path(path, "stages").toString
val stageUids: Array[String] = (metadata.params \ "stageUids").extract[Seq[String]].toArray
val stages: Array[PipelineStage] = stageUids.zipWithIndex.map { case (stageUid, idx) =>
val stagePath = SharedReadWrite.getStagePath(stageUid, idx, stageUids.length, stagesDir)
val reader = DefaultParamsReader.loadParamsInstanceReader[PipelineStage](stagePath, sc)
val reader = DefaultParamsReader.loadParamsInstanceReader[PipelineStage](stagePath, spark)
instr.withLoadInstanceEvent(reader, stagePath)(reader.load(stagePath))
}
(metadata.uid, stages)
Expand Down Expand Up @@ -344,7 +367,7 @@ object PipelineModel extends MLReadable[PipelineModel] {
override def save(path: String): Unit =
instrumented(_.withSaveInstanceEvent(this, path)(super.save(path)))
override protected def saveImpl(path: String): Unit = SharedReadWrite.saveImpl(instance,
instance.stages.asInstanceOf[Array[PipelineStage]], sc, path)
instance.stages.asInstanceOf[Array[PipelineStage]], sparkSession, path)
}

private class PipelineModelReader extends MLReader[PipelineModel] {
Expand All @@ -354,7 +377,8 @@ object PipelineModel extends MLReadable[PipelineModel] {

override def load(path: String): PipelineModel = instrumented(_.withLoadInstanceEvent(
this, path) {
val (uid: String, stages: Array[PipelineStage]) = SharedReadWrite.load(className, sc, path)
val (uid: String, stages: Array[PipelineStage]) =
SharedReadWrite.load(className, sparkSession, path)
val transformers = stages map {
case stage: Transformer => stage
case other => throw new RuntimeException(s"PipelineModel.read loaded a stage but found it" +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ import org.json4s._
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._

import org.apache.spark.SparkContext
import org.apache.spark.annotation.Since
import org.apache.spark.internal.{LogKeys, MDC}
import org.apache.spark.ml._
Expand All @@ -38,7 +37,7 @@ import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params}
import org.apache.spark.ml.param.shared.{HasParallelism, HasWeightCol}
import org.apache.spark.ml.util._
import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.sql.{Column, DataFrame, Dataset, Row}
import org.apache.spark.sql.{Column, DataFrame, Dataset, Row, SparkSession}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.storage.StorageLevel
Expand Down Expand Up @@ -94,7 +93,7 @@ private[ml] object OneVsRestParams extends ClassifierTypeTrait {
def saveImpl(
path: String,
instance: OneVsRestParams,
sc: SparkContext,
spark: SparkSession,
extraMetadata: Option[JObject] = None): Unit = {

val params = instance.extractParamMap().toSeq
Expand All @@ -103,20 +102,20 @@ private[ml] object OneVsRestParams extends ClassifierTypeTrait {
.map { case ParamPair(p, v) => p.name -> parse(p.jsonEncode(v)) }
.toList)

DefaultParamsWriter.saveMetadata(instance, path, sc, extraMetadata, Some(jsonParams))
DefaultParamsWriter.saveMetadata(instance, path, spark, extraMetadata, Some(jsonParams))

val classifierPath = new Path(path, "classifier").toString
instance.getClassifier.asInstanceOf[MLWritable].save(classifierPath)
}

def loadImpl(
path: String,
sc: SparkContext,
spark: SparkSession,
expectedClassName: String): (DefaultParamsReader.Metadata, ClassifierType) = {

val metadata = DefaultParamsReader.loadMetadata(path, sc, expectedClassName)
val metadata = DefaultParamsReader.loadMetadata(path, spark, expectedClassName)
val classifierPath = new Path(path, "classifier").toString
val estimator = DefaultParamsReader.loadParamsInstance[ClassifierType](classifierPath, sc)
val estimator = DefaultParamsReader.loadParamsInstance[ClassifierType](classifierPath, spark)
(metadata, estimator)
}
}
Expand Down Expand Up @@ -282,7 +281,7 @@ object OneVsRestModel extends MLReadable[OneVsRestModel] {
override protected def saveImpl(path: String): Unit = {
val extraJson = ("labelMetadata" -> instance.labelMetadata.json) ~
("numClasses" -> instance.models.length)
OneVsRestParams.saveImpl(path, instance, sc, Some(extraJson))
OneVsRestParams.saveImpl(path, instance, sparkSession, Some(extraJson))
instance.models.map(_.asInstanceOf[MLWritable]).zipWithIndex.foreach { case (model, idx) =>
val modelPath = new Path(path, s"model_$idx").toString
model.save(modelPath)
Expand All @@ -297,12 +296,12 @@ object OneVsRestModel extends MLReadable[OneVsRestModel] {

override def load(path: String): OneVsRestModel = {
implicit val format = DefaultFormats
val (metadata, classifier) = OneVsRestParams.loadImpl(path, sc, className)
val (metadata, classifier) = OneVsRestParams.loadImpl(path, sparkSession, className)
val labelMetadata = Metadata.fromJson((metadata.metadata \ "labelMetadata").extract[String])
val numClasses = (metadata.metadata \ "numClasses").extract[Int]
val models = Range(0, numClasses).toArray.map { idx =>
val modelPath = new Path(path, s"model_$idx").toString
DefaultParamsReader.loadParamsInstance[ClassificationModel[_, _]](modelPath, sc)
DefaultParamsReader.loadParamsInstance[ClassificationModel[_, _]](modelPath, sparkSession)
}
val ovrModel = new OneVsRestModel(metadata.uid, labelMetadata, models)
metadata.getAndSetParams(ovrModel)
Expand Down Expand Up @@ -490,7 +489,7 @@ object OneVsRest extends MLReadable[OneVsRest] {
OneVsRestParams.validateParams(instance)

override protected def saveImpl(path: String): Unit = {
OneVsRestParams.saveImpl(path, instance, sc)
OneVsRestParams.saveImpl(path, instance, sparkSession)
}
}

Expand All @@ -500,7 +499,7 @@ object OneVsRest extends MLReadable[OneVsRest] {
private val className = classOf[OneVsRest].getName

override def load(path: String): OneVsRest = {
val (metadata, classifier) = OneVsRestParams.loadImpl(path, sc, className)
val (metadata, classifier) = OneVsRestParams.loadImpl(path, sparkSession, className)
val ovr = new OneVsRest(metadata.uid)
metadata.getAndSetParams(ovr)
ovr.setClassifier(classifier)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ object ImputerModel extends MLReadable[ImputerModel] {
override def load(path: String): ImputerModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sparkSession, className)
val dataPath = new Path(path, "data").toString
val surrogateDF = sqlContext.read.parquet(dataPath)
val surrogateDF = sparkSession.read.parquet(dataPath)
val model = new ImputerModel(metadata.uid, surrogateDF)
metadata.getAndSetParams(model)
model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ private[ml] object EnsembleModelReadWrite {
DefaultParamsWriter.saveMetadata(instance, path, sparkSession, Some(extraMetadata))
val treesMetadataWeights = instance.trees.zipWithIndex.map { case (tree, treeID) =>
(treeID,
DefaultParamsWriter.getMetadataToSave(tree.asInstanceOf[Params], sparkSession.sparkContext),
DefaultParamsWriter.getMetadataToSave(tree.asInstanceOf[Params], sparkSession),
instance.treeWeights(treeID))
}
val treesMetadataPath = new Path(path, "treesMetadata").toString
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ object CrossValidator extends MLReadable[CrossValidator] {
ValidatorParams.validateParams(instance)

override protected def saveImpl(path: String): Unit =
ValidatorParams.saveImpl(path, instance, sc)
ValidatorParams.saveImpl(path, instance, sparkSession)
}

private class CrossValidatorReader extends MLReader[CrossValidator] {
Expand All @@ -260,7 +260,7 @@ object CrossValidator extends MLReadable[CrossValidator] {
implicit val format = DefaultFormats

val (metadata, estimator, evaluator, estimatorParamMaps) =
ValidatorParams.loadImpl(path, sc, className)
ValidatorParams.loadImpl(path, sparkSession, className)
val cv = new CrossValidator(metadata.uid)
.setEstimator(estimator)
.setEvaluator(evaluator)
Expand Down Expand Up @@ -403,7 +403,7 @@ object CrossValidatorModel extends MLReadable[CrossValidatorModel] {
import org.json4s.JsonDSL._
val extraMetadata = ("avgMetrics" -> instance.avgMetrics.toImmutableArraySeq) ~
("persistSubModels" -> persistSubModels)
ValidatorParams.saveImpl(path, instance, sc, Some(extraMetadata))
ValidatorParams.saveImpl(path, instance, sparkSession, Some(extraMetadata))
val bestModelPath = new Path(path, "bestModel").toString
instance.bestModel.asInstanceOf[MLWritable].save(bestModelPath)
if (persistSubModels) {
Expand Down Expand Up @@ -431,10 +431,10 @@ object CrossValidatorModel extends MLReadable[CrossValidatorModel] {
implicit val format = DefaultFormats

val (metadata, estimator, evaluator, estimatorParamMaps) =
ValidatorParams.loadImpl(path, sc, className)
ValidatorParams.loadImpl(path, sparkSession, className)
val numFolds = (metadata.params \ "numFolds").extract[Int]
val bestModelPath = new Path(path, "bestModel").toString
val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc)
val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sparkSession)
val avgMetrics = (metadata.metadata \ "avgMetrics").extract[Seq[Double]].toArray
val persistSubModels = (metadata.metadata \ "persistSubModels")
.extractOrElse[Boolean](false)
Expand All @@ -448,7 +448,7 @@ object CrossValidatorModel extends MLReadable[CrossValidatorModel] {
for (paramIndex <- estimatorParamMaps.indices) {
val modelPath = new Path(splitPath, paramIndex.toString).toString
_subModels(splitIndex)(paramIndex) =
DefaultParamsReader.loadParamsInstance(modelPath, sc)
DefaultParamsReader.loadParamsInstance(modelPath, sparkSession)
}
}
Some(_subModels)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ object TrainValidationSplit extends MLReadable[TrainValidationSplit] {
ValidatorParams.validateParams(instance)

override protected def saveImpl(path: String): Unit =
ValidatorParams.saveImpl(path, instance, sc)
ValidatorParams.saveImpl(path, instance, sparkSession)
}

private class TrainValidationSplitReader extends MLReader[TrainValidationSplit] {
Expand All @@ -228,7 +228,7 @@ object TrainValidationSplit extends MLReadable[TrainValidationSplit] {
implicit val format = DefaultFormats

val (metadata, estimator, evaluator, estimatorParamMaps) =
ValidatorParams.loadImpl(path, sc, className)
ValidatorParams.loadImpl(path, sparkSession, className)
val tvs = new TrainValidationSplit(metadata.uid)
.setEstimator(estimator)
.setEvaluator(evaluator)
Expand Down Expand Up @@ -368,7 +368,7 @@ object TrainValidationSplitModel extends MLReadable[TrainValidationSplitModel] {
import org.json4s.JsonDSL._
val extraMetadata = ("validationMetrics" -> instance.validationMetrics.toImmutableArraySeq) ~
("persistSubModels" -> persistSubModels)
ValidatorParams.saveImpl(path, instance, sc, Some(extraMetadata))
ValidatorParams.saveImpl(path, instance, sparkSession, Some(extraMetadata))
val bestModelPath = new Path(path, "bestModel").toString
instance.bestModel.asInstanceOf[MLWritable].save(bestModelPath)
if (persistSubModels) {
Expand All @@ -393,9 +393,9 @@ object TrainValidationSplitModel extends MLReadable[TrainValidationSplitModel] {
implicit val format = DefaultFormats

val (metadata, estimator, evaluator, estimatorParamMaps) =
ValidatorParams.loadImpl(path, sc, className)
ValidatorParams.loadImpl(path, sparkSession, className)
val bestModelPath = new Path(path, "bestModel").toString
val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc)
val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sparkSession)
val validationMetrics = (metadata.metadata \ "validationMetrics").extract[Seq[Double]].toArray
val persistSubModels = (metadata.metadata \ "persistSubModels")
.extractOrElse[Boolean](false)
Expand All @@ -406,7 +406,7 @@ object TrainValidationSplitModel extends MLReadable[TrainValidationSplitModel] {
for (paramIndex <- estimatorParamMaps.indices) {
val modelPath = new Path(subModelsPath, paramIndex.toString).toString
_subModels(paramIndex) =
DefaultParamsReader.loadParamsInstance(modelPath, sc)
DefaultParamsReader.loadParamsInstance(modelPath, sparkSession)
}
Some(_subModels)
} else None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@ import org.apache.hadoop.fs.Path
import org.json4s._
import org.json4s.jackson.JsonMethods._

import org.apache.spark.SparkContext
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.evaluation.Evaluator
import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params}
import org.apache.spark.ml.param.shared.HasSeed
import org.apache.spark.ml.util._
import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.ArrayImplicits._

Expand Down Expand Up @@ -123,7 +123,7 @@ private[ml] object ValidatorParams {
def saveImpl(
path: String,
instance: ValidatorParams,
sc: SparkContext,
spark: SparkSession,
extraMetadata: Option[JObject] = None): Unit = {
import org.json4s.JsonDSL._

Expand Down Expand Up @@ -160,7 +160,7 @@ private[ml] object ValidatorParams {
}.toList ++ List("estimatorParamMaps" -> parse(estimatorParamMapsJson))
)

DefaultParamsWriter.saveMetadata(instance, path, sc, extraMetadata, Some(jsonParams))
DefaultParamsWriter.saveMetadata(instance, path, spark, extraMetadata, Some(jsonParams))

val evaluatorPath = new Path(path, "evaluator").toString
instance.getEvaluator.asInstanceOf[MLWritable].save(evaluatorPath)
Expand All @@ -175,16 +175,16 @@ private[ml] object ValidatorParams {
*/
def loadImpl[M <: Model[M]](
path: String,
sc: SparkContext,
spark: SparkSession,
expectedClassName: String): (Metadata, Estimator[M], Evaluator, Array[ParamMap]) = {

val metadata = DefaultParamsReader.loadMetadata(path, sc, expectedClassName)
val metadata = DefaultParamsReader.loadMetadata(path, spark, expectedClassName)

implicit val format = DefaultFormats
val evaluatorPath = new Path(path, "evaluator").toString
val evaluator = DefaultParamsReader.loadParamsInstance[Evaluator](evaluatorPath, sc)
val evaluator = DefaultParamsReader.loadParamsInstance[Evaluator](evaluatorPath, spark)
val estimatorPath = new Path(path, "estimator").toString
val estimator = DefaultParamsReader.loadParamsInstance[Estimator[M]](estimatorPath, sc)
val estimator = DefaultParamsReader.loadParamsInstance[Estimator[M]](estimatorPath, spark)

val uidToParams = Map(evaluator.uid -> evaluator) ++ MetaAlgorithmReadWrite.getUidMap(estimator)

Expand All @@ -202,7 +202,7 @@ private[ml] object ValidatorParams {
} else {
val relativePath = param.jsonDecode(pInfo("value")).toString
val value = DefaultParamsReader
.loadParamsInstance[MLWritable](new Path(path, relativePath).toString, sc)
.loadParamsInstance[MLWritable](new Path(path, relativePath).toString, spark)
param -> value
}
}
Expand Down
Loading