Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 3 additions & 35 deletions mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ import org.apache.spark.ml.util.MLWriter
import org.apache.spark.ml.util._
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.Utils

/**
* :: DeveloperApi ::
Expand Down Expand Up @@ -228,20 +227,9 @@ object Pipeline extends MLReadable[Pipeline] {
stages: Array[PipelineStage],
sc: SparkContext,
path: String): Unit = {
// Copied and edited from DefaultParamsWriter.saveMetadata
// TODO: modify DefaultParamsWriter.saveMetadata to avoid duplication
val uid = instance.uid
val cls = instance.getClass.getName
val stageUids = stages.map(_.uid)
val jsonParams = List("stageUids" -> parse(compact(render(stageUids.toSeq))))
val metadata = ("class" -> cls) ~
("timestamp" -> System.currentTimeMillis()) ~
("sparkVersion" -> sc.version) ~
("uid" -> uid) ~
("paramMap" -> jsonParams)
val metadataPath = new Path(path, "metadata").toString
val metadataJson = compact(render(metadata))
sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath)
DefaultParamsWriter.saveMetadata(instance, path, sc, paramMap = Some(jsonParams))

// Save stages
val stagesDir = new Path(path, "stages").toString
Expand All @@ -262,30 +250,10 @@ object Pipeline extends MLReadable[Pipeline] {

implicit val format = DefaultFormats
val stagesDir = new Path(path, "stages").toString
val stageUids: Array[String] = metadata.params match {
case JObject(pairs) =>
if (pairs.length != 1) {
// Should not happen unless file is corrupted or we have a bug.
throw new RuntimeException(
s"Pipeline read expected 1 Param (stageUids), but found ${pairs.length}.")
}
pairs.head match {
case ("stageUids", jsonValue) =>
jsonValue.extract[Seq[String]].toArray
case (paramName, jsonValue) =>
// Should not happen unless file is corrupted or we have a bug.
throw new RuntimeException(s"Pipeline read encountered unexpected Param $paramName" +
s" in metadata: ${metadata.metadataStr}")
}
case _ =>
throw new IllegalArgumentException(
s"Cannot recognize JSON metadata: ${metadata.metadataStr}.")
}
val stageUids: Array[String] = (metadata.params \ "stageUids").extract[Seq[String]].toArray
val stages: Array[PipelineStage] = stageUids.zipWithIndex.map { case (stageUid, idx) =>
val stagePath = SharedReadWrite.getStagePath(stageUid, idx, stageUids.length, stagesDir)
val stageMetadata = DefaultParamsReader.loadMetadata(stagePath, sc)
val cls = Utils.classForName(stageMetadata.className)
cls.getMethod("read").invoke(null).asInstanceOf[MLReader[PipelineStage]].load(stagePath)
DefaultParamsReader.loadParamsInstance[PipelineStage](stagePath, sc)
}
(metadata.uid, stages)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.ml.evaluation
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils}
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.sql.{DataFrame, Row}
Expand All @@ -33,7 +33,7 @@ import org.apache.spark.sql.types.DoubleType
@Since("1.2.0")
@Experimental
class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override val uid: String)
extends Evaluator with HasRawPredictionCol with HasLabelCol {
extends Evaluator with HasRawPredictionCol with HasLabelCol with DefaultParamsWritable {

@Since("1.2.0")
def this() = this(Identifiable.randomUID("binEval"))
Expand Down Expand Up @@ -105,3 +105,10 @@ class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override va
@Since("1.4.1")
override def copy(extra: ParamMap): BinaryClassificationEvaluator = defaultCopy(extra)
}

@Since("1.6.0")
object BinaryClassificationEvaluator extends DefaultParamsReadable[BinaryClassificationEvaluator] {

@Since("1.6.0")
override def load(path: String): BinaryClassificationEvaluator = super.load(path)
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.ml.evaluation
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.param.{ParamMap, ParamValidators, Param}
import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol}
import org.apache.spark.ml.util.{SchemaUtils, Identifiable}
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, SchemaUtils, Identifiable}
import org.apache.spark.mllib.evaluation.MulticlassMetrics
import org.apache.spark.sql.{Row, DataFrame}
import org.apache.spark.sql.types.DoubleType
Expand All @@ -32,7 +32,7 @@ import org.apache.spark.sql.types.DoubleType
@Since("1.5.0")
@Experimental
class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") override val uid: String)
extends Evaluator with HasPredictionCol with HasLabelCol {
extends Evaluator with HasPredictionCol with HasLabelCol with DefaultParamsWritable {

@Since("1.5.0")
def this() = this(Identifiable.randomUID("mcEval"))
Expand Down Expand Up @@ -101,3 +101,11 @@ class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") overrid
@Since("1.5.0")
override def copy(extra: ParamMap): MulticlassClassificationEvaluator = defaultCopy(extra)
}

@Since("1.6.0")
object MulticlassClassificationEvaluator
extends DefaultParamsReadable[MulticlassClassificationEvaluator] {

@Since("1.6.0")
override def load(path: String): MulticlassClassificationEvaluator = super.load(path)
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.ml.evaluation
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators}
import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol}
import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils}
import org.apache.spark.mllib.evaluation.RegressionMetrics
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.functions._
Expand All @@ -33,7 +33,7 @@ import org.apache.spark.sql.types.{DoubleType, FloatType}
@Since("1.4.0")
@Experimental
final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val uid: String)
extends Evaluator with HasPredictionCol with HasLabelCol {
extends Evaluator with HasPredictionCol with HasLabelCol with DefaultParamsWritable {

@Since("1.4.0")
def this() = this(Identifiable.randomUID("regEval"))
Expand Down Expand Up @@ -104,3 +104,10 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val ui
@Since("1.5.0")
override def copy(extra: ParamMap): RegressionEvaluator = defaultCopy(extra)
}

@Since("1.6.0")
object RegressionEvaluator extends DefaultParamsReadable[RegressionEvaluator] {

@Since("1.6.0")
override def load(path: String): RegressionEvaluator = super.load(path)
}
14 changes: 3 additions & 11 deletions mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,8 @@ import scala.util.hashing.byteswap64

import com.github.fommil.netlib.BLAS.{getInstance => blas}
import org.apache.hadoop.fs.{FileSystem, Path}
import org.json4s.{DefaultFormats, JValue}
import org.json4s.DefaultFormats
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._

import org.apache.spark.{Logging, Partitioner}
import org.apache.spark.annotation.{Since, DeveloperApi, Experimental}
Expand Down Expand Up @@ -240,7 +239,7 @@ object ALSModel extends MLReadable[ALSModel] {
private[recommendation] class ALSModelWriter(instance: ALSModel) extends MLWriter {

override protected def saveImpl(path: String): Unit = {
val extraMetadata = render("rank" -> instance.rank)
val extraMetadata = "rank" -> instance.rank
DefaultParamsWriter.saveMetadata(instance, path, sc, Some(extraMetadata))
val userPath = new Path(path, "userFactors").toString
instance.userFactors.write.format("parquet").save(userPath)
Expand All @@ -257,14 +256,7 @@ object ALSModel extends MLReadable[ALSModel] {
override def load(path: String): ALSModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
implicit val format = DefaultFormats
val rank: Int = metadata.extraMetadata match {
case Some(m: JValue) =>
(m \ "rank").extract[Int]
case None =>
throw new RuntimeException(s"ALSModel loader could not read rank from JSON metadata:" +
s" ${metadata.metadataStr}")
}

val rank = (metadata.metadata \ "rank").extract[Int]
val userPath = new Path(path, "userFactors").toString
val userFactors = sqlContext.read.format("parquet").load(userPath)
val itemPath = new Path(path, "itemFactors").toString
Expand Down
Loading