From 43ae30f08aed921178da07a5e982297b272c7c8f Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Fri, 24 Nov 2017 06:00:16 -0800 Subject: [PATCH 01/16] Initial attempt at allowing Spark ML writers to be slightly more pluggable --- .../org/apache/spark/ml/util/ReadWrite.scala | 120 ++++++++++++++++-- 1 file changed, 112 insertions(+), 8 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index a616907800969..4c338421ce547 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -18,9 +18,11 @@ package org.apache.spark.ml.util import java.io.IOException -import java.util.Locale +import java.util.{Locale, ServiceLoader} +import scala.collection.JavaConverters._ import scala.collection.mutable +import scala.util.{Failure, Success, Try} import org.apache.hadoop.fs.Path import org.json4s._ @@ -28,8 +30,8 @@ import org.json4s.{DefaultFormats, JObject} import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ -import org.apache.spark.SparkContext -import org.apache.spark.annotation.{DeveloperApi, Since} +import org.apache.spark.{SparkContext, SparkException} +import org.apache.spark.annotation.{DeveloperApi, InterfaceStability, Since} import org.apache.spark.internal.Logging import org.apache.spark.ml._ import org.apache.spark.ml.classification.{OneVsRest, OneVsRestModel} @@ -85,12 +87,52 @@ private[util] sealed trait BaseReadWrite { protected final def sc: SparkContext = sparkSession.sparkContext } +/** + * ML formats for export should implement this trait so they can register an alias to their format. + * + * A new instance of this class will be instantiated each time a DDL call is made. + * + * @since 2.3.0 + */ +@InterfaceStability.Evolving +trait MLFormatRegister { + /** + * The string that represents the format that this data source provider uses. This is + * overridden by children to provide a nice alias for the data source. For example: + * + * {{{ + * override def shortName(): String = + * "pmml+org.apache.spark.ml.regression.LinearRegressionModel" + * }}} + * Indicates that this format is capable of saving Spark's own LinearRegressionModel in pmml. + * + * @since 2.3.0 + */ + def shortName(): String +} + +/** + * Implemented by objects that provide ML exportability. + * + * A new instance of this class will be instantiated each time a DDL call is made. + * + * @since 2.3.0 + */ +@InterfaceStability.Evolving +trait MLWriterFormat{ + /** + * Function write the provided pipeline stage out. + */ + def write(path: String, session: SparkSession, optionMap: mutable.Map[String, String], + stage: PipelineStage) +} + /** * Abstract class for utility classes that can save ML instances. */ +@deprecated("Use GeneralMLWriter instead. Will be removed in Spark 3.0.0", "2.3.0") @Since("1.6.0") abstract class MLWriter extends BaseReadWrite with Logging { - protected var shouldOverwrite: Boolean = false /** @@ -98,6 +140,7 @@ abstract class MLWriter extends BaseReadWrite with Logging { */ @Since("1.6.0") @throws[IOException]("If the input path already exists but overwrite is not enabled.") + @throws[ClassNotFoundException]("If the requested format class can be loaded.") def save(path: String): Unit = { new FileSystemOverwrite().handleOverwrite(path, shouldOverwrite, sc) saveImpl(path) @@ -110,6 +153,15 @@ abstract class MLWriter extends BaseReadWrite with Logging { @Since("1.6.0") protected def saveImpl(path: String): Unit + /** + * Overwrites if the output path already exists. + */ + @Since("1.6.0") + def overwrite(): this.type = { + shouldOverwrite = true + this + } + /** * Map to store extra options for this writer. */ @@ -126,15 +178,67 @@ abstract class MLWriter extends BaseReadWrite with Logging { this } + // override for Java compatibility + override def session(sparkSession: SparkSession): this.type = super.session(sparkSession) + + // override for Java compatibility + override def context(sqlContext: SQLContext): this.type = super.session(sqlContext.sparkSession) +} + +/** + * A ML Writer which delegates based on the requested format. + */ +class GeneralMLWriter(stage: PipelineStage) extends MLWriter with Logging { + private var source: String = "internal" + /** - * Overwrites if the output path already exists. + * Specifies the format of ML export (e.g. PMML, internal, or + * the fully qualified class name for export). */ - @Since("1.6.0") - def overwrite(): this.type = { - shouldOverwrite = true + @Since("2.3.0") + def format(source: String): this.type = { + this.source = source this } + /** + * Dispatches the save to the correct MLFormat. + */ + @Since("2.3.0") + @throws[IOException]("If the input path already exists but overwrite is not enabled.") + @throws[ClassNotFoundException]("If the requested format class can be loaded.") + @throws[SparkException]("If multiple sources for a given short name format are found.") + override protected def saveImpl(path: String) = { + val loader = Utils.getContextOrSparkClassLoader + val serviceLoader = ServiceLoader.load(classOf[MLFormatRegister], loader) + val stageName = stage.getClass.getName + val targetName = s"${source}+${stageName}" + val writerCls = serviceLoader.asScala.filter(_.shortName().equalsIgnoreCase(targetName)) match { + // requested name did not match any given registered alias + case Nil => + Try(loader.loadClass(source)) match { + case Success(writer) => + // Found the ML writer using the fully qualified path + writer + case Failure(error) => + throw new ClassNotFoundException( + s"Could not load requested format $source for $stageName", error) + } + case head :: Nil => + head.getClass + case _ => + // Multiple sources + throw new SparkException( + s"Multiple writers found for $source+$stageName, try using the class name of the writer") + } + if (classOf[MLWriterFormat].isAssignableFrom(writerCls)) { + val writer = writerCls.newInstance().asInstanceOf[MLWriterFormat] + writer.write(path, sparkSession, optionMap, stage) + } else { + throw new SparkException("ML source $source is not a valid MLWriterFormat") + } + } + // override for Java compatibility override def session(sparkSession: SparkSession): this.type = super.session(sparkSession) From 9fec08fbd2dd1c980d5862f0b4521213e1e9349c Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Sat, 25 Nov 2017 04:55:19 -0800 Subject: [PATCH 02/16] The LinearRegression suite passes --- .../ml/regression/LinearRegression.scala | 47 +++++++++++-------- .../org/apache/spark/ml/util/ReadWrite.scala | 27 ++++++++--- 2 files changed, 48 insertions(+), 26 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index da6bcf07e4742..5ddff04cde0b9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -27,7 +27,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.SparkException import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.internal.Logging -import org.apache.spark.ml.PredictorParams +import org.apache.spark.ml.{PipelineStage, PredictorParams} import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.linalg.BLAS._ @@ -42,7 +42,7 @@ import org.apache.spark.mllib.linalg.VectorImplicits._ import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Dataset, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.DoubleType import org.apache.spark.storage.StorageLevel @@ -482,7 +482,7 @@ class LinearRegressionModel private[ml] ( @Since("2.0.0") val coefficients: Vector, @Since("1.3.0") val intercept: Double) extends RegressionModel[Vector, LinearRegressionModel] - with LinearRegressionParams with MLWritable { + with LinearRegressionParams with GeneralMLWritable { private var trainingSummary: Option[LinearRegressionTrainingSummary] = None @@ -554,9 +554,32 @@ class LinearRegressionModel private[ml] ( * This also does not save the [[parent]] currently. */ @Since("1.6.0") - override def write: MLWriter = new LinearRegressionModel.LinearRegressionModelWriter(this) + override def write: GeneralMLWriter = new GeneralMLWriter(this) } +/** [[MLWriterFormat]] providing "internal" instance for [[LinearRegressionModel]] */ +class InternalLinearRegressionModelWriter() + extends MLWriterFormat with MLFormatRegister { + + override def shortName(): String = + "internal+org.apache.spark.ml.regression.LinearRegressionModel" + + private case class Data(intercept: Double, coefficients: Vector) + + override def write(path: String, sparkSession: SparkSession, + optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = { + val instance = stage.asInstanceOf[LinearRegressionModel] + val sc = sparkSession.sparkContext + // Save metadata and Params + DefaultParamsWriter.saveMetadata(instance, path, sc) + // Save model data: intercept, coefficients + val data = Data(instance.intercept, instance.coefficients) + val dataPath = new Path(path, "data").toString + sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } +} + + @Since("1.6.0") object LinearRegressionModel extends MLReadable[LinearRegressionModel] { @@ -566,22 +589,6 @@ object LinearRegressionModel extends MLReadable[LinearRegressionModel] { @Since("1.6.0") override def load(path: String): LinearRegressionModel = super.load(path) - /** [[MLWriter]] instance for [[LinearRegressionModel]] */ - private[LinearRegressionModel] class LinearRegressionModelWriter(instance: LinearRegressionModel) - extends MLWriter with Logging { - - private case class Data(intercept: Double, coefficients: Vector) - - override protected def saveImpl(path: String): Unit = { - // Save metadata and Params - DefaultParamsWriter.saveMetadata(instance, path, sc) - // Save model data: intercept, coefficients - val data = Data(instance.intercept, instance.coefficients) - val dataPath = new Path(path, "data").toString - sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) - } - } - private class LinearRegressionModelReader extends MLReader[LinearRegressionModel] { /** Checked against metadata when loading model */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index 4c338421ce547..f2f75681763be 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -106,6 +106,8 @@ trait MLFormatRegister { * }}} * Indicates that this format is capable of saving Spark's own LinearRegressionModel in pmml. * + * Format discovery is done using a ServiceLoader so make sure to list your format in + * META-INF/services. * @since 2.3.0 */ def shortName(): String @@ -119,7 +121,7 @@ trait MLFormatRegister { * @since 2.3.0 */ @InterfaceStability.Evolving -trait MLWriterFormat{ +trait MLWriterFormat { /** * Function write the provided pipeline stage out. */ @@ -140,7 +142,6 @@ abstract class MLWriter extends BaseReadWrite with Logging { */ @Since("1.6.0") @throws[IOException]("If the input path already exists but overwrite is not enabled.") - @throws[ClassNotFoundException]("If the requested format class can be loaded.") def save(path: String): Unit = { new FileSystemOverwrite().handleOverwrite(path, shouldOverwrite, sc) saveImpl(path) @@ -206,14 +207,15 @@ class GeneralMLWriter(stage: PipelineStage) extends MLWriter with Logging { */ @Since("2.3.0") @throws[IOException]("If the input path already exists but overwrite is not enabled.") - @throws[ClassNotFoundException]("If the requested format class can be loaded.") @throws[SparkException]("If multiple sources for a given short name format are found.") override protected def saveImpl(path: String) = { val loader = Utils.getContextOrSparkClassLoader val serviceLoader = ServiceLoader.load(classOf[MLFormatRegister], loader) val stageName = stage.getClass.getName val targetName = s"${source}+${stageName}" - val writerCls = serviceLoader.asScala.filter(_.shortName().equalsIgnoreCase(targetName)) match { + val formats = serviceLoader.asScala.toList + val shortNames = formats.map(_.shortName()) + val writerCls = formats.filter(_.shortName().equalsIgnoreCase(targetName)) match { // requested name did not match any given registered alias case Nil => Try(loader.loadClass(source)) match { @@ -221,8 +223,9 @@ class GeneralMLWriter(stage: PipelineStage) extends MLWriter with Logging { // Found the ML writer using the fully qualified path writer case Failure(error) => - throw new ClassNotFoundException( - s"Could not load requested format $source for $stageName", error) + throw new SparkException( + s"Could not load requested format $source for $stageName ($targetName) had $formats" + + s"supporting $shortNames", error) } case head :: Nil => head.getClass @@ -266,6 +269,18 @@ trait MLWritable { def save(path: String): Unit = write.save(path) } +/** + * Trait for classes that provide `GeneralMLWriter`. + */ +@Since("2.3.0") +trait GeneralMLWritable extends MLWritable { + /** + * Returns an `MLWriter` instance for this ML instance. + */ + @Since("2.3.0") + override def write: GeneralMLWriter +} + /** * :: DeveloperApi :: * From 0075bf4776ecffa7fcb24a6f74c0e96161d6221c Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Sat, 25 Nov 2017 05:00:18 -0800 Subject: [PATCH 03/16] Add missing META-INFO for MLFormatRegister --- .../META-INF/services/org.apache.spark.ml.util.MLFormatRegister | 1 + 1 file changed, 1 insertion(+) create mode 100644 mllib/src/main/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister diff --git a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister new file mode 100644 index 0000000000000..869338a1fa454 --- /dev/null +++ b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister @@ -0,0 +1 @@ +org.apache.spark.ml.regression.InternalLinearRegressionModelWriter \ No newline at end of file From c68880d6d982c56934f4b583263ed5cd4e8329d6 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Sat, 25 Nov 2017 08:19:35 -0800 Subject: [PATCH 04/16] Add a (untested) PMMLLinearRegressionModelWriter --- .../org.apache.spark.ml.util.MLFormatRegister | 3 ++- .../ml/regression/LinearRegression.scala | 23 ++++++++++++++++++- 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister index 869338a1fa454..5e5484fd8784d 100644 --- a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister +++ b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister @@ -1 +1,2 @@ -org.apache.spark.ml.regression.InternalLinearRegressionModelWriter \ No newline at end of file +org.apache.spark.ml.regression.InternalLinearRegressionModelWriter +org.apache.spark.ml.regression.PMMLLinearRegressionModelWriter \ No newline at end of file diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 5ddff04cde0b9..f65f60eadab03 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -39,6 +39,7 @@ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.mllib.evaluation.RegressionMetrics import org.apache.spark.mllib.linalg.VectorImplicits._ +import org.apache.spark.mllib.regression.{LinearRegressionModel => OldLinearRegressionModel} import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD @@ -558,7 +559,7 @@ class LinearRegressionModel private[ml] ( } /** [[MLWriterFormat]] providing "internal" instance for [[LinearRegressionModel]] */ -class InternalLinearRegressionModelWriter() +private class InternalLinearRegressionModelWriter() extends MLWriterFormat with MLFormatRegister { override def shortName(): String = @@ -579,6 +580,26 @@ class InternalLinearRegressionModelWriter() } } +/** [[MLWriterFormat]] providing "pmml" instance for [[LinearRegressionModel]] */ +private class PMMLLinearRegressionModelWriter() + extends MLWriterFormat with MLFormatRegister { + + override def shortName(): String = + "pmml+org.apache.spark.ml.regression.LinearRegressionModel" + + private case class Data(intercept: Double, coefficients: Vector) + + override def write(path: String, sparkSession: SparkSession, + optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = { + val sc = sparkSession.sparkContext + // Construct the MLLib model which knows how to write to PMML. + val instance = stage.asInstanceOf[LinearRegressionModel] + val oldModel = new OldLinearRegressionModel(instance.coefficients, instance.intercept) + // Save PMML + oldModel.toPMML(sc, path) + } +} + @Since("1.6.0") object LinearRegressionModel extends MLReadable[LinearRegressionModel] { From c2108df2b499bd45dff0e8add789f01d8c3c2c48 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Mon, 4 Dec 2017 02:00:56 -0800 Subject: [PATCH 05/16] Basic PMML export test --- .../ml/regression/LinearRegressionSuite.scala | 26 +++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index 0e0be58dbf022..be4db1220fd04 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -17,20 +17,24 @@ package org.apache.spark.ml.regression +import scala.collection.JavaConverters._ import scala.util.Random +import org.dmg.pmml.{OpType, PMML, RegressionModel => PMMLRegressionModel} + import org.apache.spark.SparkFunSuite import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.{DenseVector, Vector, Vectors} import org.apache.spark.ml.param.{ParamMap, ParamsSuite} -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils, PMMLReadWriteTest} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} import org.apache.spark.sql.{DataFrame, Row} class LinearRegressionSuite - extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest + with PMMLReadWriteTest { import testImplicits._ @@ -994,6 +998,24 @@ class LinearRegressionSuite LinearRegressionSuite.allParamSettings, checkModelData) } + test("pmml export") { + val lr = new LinearRegression() + val model = lr.fit(datasetWithWeight) + def checkModel(pmml: PMML): Unit = { + val dd = pmml.getDataDictionary + assert(dd.getNumberOfFields === 3) + val fields = dd.getDataFields.asScala + assert(fields(0).getName().toString === "field_0") + assert(fields(0).getOpType() == OpType.CONTINUOUS) + val pmmlRegressionModel = pmml.getModels().get(0).asInstanceOf[PMMLRegressionModel] + val pmmlPredictors = pmmlRegressionModel.getRegressionTables.get(0).getNumericPredictors + val pmmlWeights = pmmlPredictors.asScala.map(_.getCoefficient()).toList + assert(pmmlWeights(0) ~== model.coefficients(0) relTol 1E-3) + assert(pmmlWeights(1) ~== model.coefficients(1) relTol 1E-3) + } + testPMMLWrite(sc, model, checkModel) + } + test("should support all NumericType labels and weights, and not support other types") { for (solver <- Seq("auto", "l-bfgs", "normal")) { val lr = new LinearRegression().setMaxIter(1).setSolver(solver) From de8619098eeb01ff86b54753f27c29729935bb94 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Mon, 4 Dec 2017 03:27:03 -0800 Subject: [PATCH 06/16] Add PMML testing utils for Spark ML that were accidently left out --- .../spark/ml/util/PMMLReadWriteTest.scala | 55 +++++++++++++++++++ .../org/apache/spark/ml/util/PMMLUtils.scala | 43 +++++++++++++++ 2 files changed, 98 insertions(+) create mode 100644 mllib/src/test/scala/org/apache/spark/ml/util/PMMLReadWriteTest.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/util/PMMLUtils.scala diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/PMMLReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/PMMLReadWriteTest.scala new file mode 100644 index 0000000000000..d2c4832b12bac --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/util/PMMLReadWriteTest.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util + +import java.io.{File, IOException} + +import org.dmg.pmml.PMML +import org.scalatest.Suite + +import org.apache.spark.SparkContext +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.param._ +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.Dataset + +trait PMMLReadWriteTest extends TempDirectory { self: Suite => + /** + * Test PMML export. Requires exported model is small enough to be loaded locally. + * Checks that the model can be exported and the result is valid PMML, but does not check + * the specific contents of the model. + */ + def testPMMLWrite[T <: Params with GeneralMLWritable](sc: SparkContext, instance: T, + checkModelData: PMML => Unit): Unit = { + val uid = instance.uid + val subdirName = Identifiable.randomUID("pmml-") + + val subdir = new File(tempDir, subdirName) + val path = new File(subdir, uid).getPath + + instance.write.format("pmml").save(path) + intercept[IOException] { + instance.write.format("pmml").save(path) + } + instance.write.format("pmml").overwrite().save(path) + val pmmlStr = sc.textFile(path).collect.mkString("\n") + val pmmlModel = PMMLUtils.loadFromString(pmmlStr) + assert(pmmlModel.getHeader().getApplication().getName().startsWith("Apache Spark")) + checkModelData(pmmlModel) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/PMMLUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/PMMLUtils.scala new file mode 100644 index 0000000000000..dbdc69f95d841 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/util/PMMLUtils.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.ml.util + +import java.io.StringReader +import javax.xml.bind.Unmarshaller +import javax.xml.transform.Source + +import org.dmg.pmml._ +import org.jpmml.model.{ImportFilter, JAXBUtil} +import org.xml.sax.InputSource + +/** + * Testing utils for working with PMML. + * Predictive Model Markup Language (PMML) is an XML-based file format + * developed by the Data Mining Group (www.dmg.org). + */ +private[spark] object PMMLUtils { + /** + * :: Experimental :: + * Load a PMML model from a string. Note: for testing only, PMML model evaluation is supported + * through external spark-packages. + */ + def loadFromString(input: String): PMML = { + val is = new StringReader(input) + val transformed = ImportFilter.apply(new InputSource(is)) + JAXBUtil.unmarshalPMML(transformed) + } +} From 8b1c7525cadd084686774b584ab376958ded2eb0 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Tue, 5 Dec 2017 00:56:48 -0800 Subject: [PATCH 07/16] Minor wording/whitespace change --- .../org/apache/spark/ml/regression/LinearRegression.scala | 1 - mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala | 3 ++- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index f65f60eadab03..d4189a221ddad 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -600,7 +600,6 @@ private class PMMLLinearRegressionModelWriter() } } - @Since("1.6.0") object LinearRegressionModel extends MLReadable[LinearRegressionModel] { diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index f2f75681763be..fd142431988ae 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -88,7 +88,8 @@ private[util] sealed trait BaseReadWrite { } /** - * ML formats for export should implement this trait so they can register an alias to their format. + * ML export formats for should implement this trait so that users can specify a shortname rather + * than the fully qualified class name of the exporter. * * A new instance of this class will be instantiated each time a DDL call is made. * From 72b509ff1919c7c82ab2909ffb5c9a7596e52e6c Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Tue, 5 Dec 2017 01:29:02 -0800 Subject: [PATCH 08/16] Remove link causing doc issue --- .../org/apache/spark/ml/regression/LinearRegression.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index d4189a221ddad..c6504057aa9e1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -558,7 +558,7 @@ class LinearRegressionModel private[ml] ( override def write: GeneralMLWriter = new GeneralMLWriter(this) } -/** [[MLWriterFormat]] providing "internal" instance for [[LinearRegressionModel]] */ +/** A writer for LinearRegression that handles the "internal" (or default) format */ private class InternalLinearRegressionModelWriter() extends MLWriterFormat with MLFormatRegister { @@ -580,7 +580,7 @@ private class InternalLinearRegressionModelWriter() } } -/** [[MLWriterFormat]] providing "pmml" instance for [[LinearRegressionModel]] */ +/** A writer for LinearRegression that handles the "pmml" format */ private class PMMLLinearRegressionModelWriter() extends MLWriterFormat with MLFormatRegister { From b8362a463c92f902c12a6efcb7d82e5324e64e24 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Tue, 5 Dec 2017 01:49:49 -0800 Subject: [PATCH 09/16] Verify we throw on invalid export formats --- .../ml/regression/LinearRegressionSuite.scala | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index be4db1220fd04..079831e2307a5 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -22,7 +22,7 @@ import scala.util.Random import org.dmg.pmml.{OpType, PMML, RegressionModel => PMMLRegressionModel} -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.{DenseVector, Vector, Vectors} @@ -1016,6 +1016,20 @@ class LinearRegressionSuite testPMMLWrite(sc, model, checkModel) } + test("unsupported export format") { + val lr = new LinearRegression() + val model = lr.fit(datasetWithWeight) + intercept[SparkException] { + model.write.format("boop").save("boop") + } + intercept[SparkException] { + model.write.format("com.holdenkarau.boop").save("boop") + } + intercept[SparkException] { + model.write.format("org.apache.spark.SparkContext").save("boop2") + } + } + test("should support all NumericType labels and weights, and not support other types") { for (solver <- Seq("auto", "l-bfgs", "normal")) { val lr = new LinearRegression().setMaxIter(1).setSolver(solver) From b8844c75b7b0278f19cd340c3c036935c43feef4 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Wed, 20 Dec 2017 12:35:04 -0800 Subject: [PATCH 10/16] Updates on CR feedback --- .../ml/regression/LinearRegression.scala | 5 ++-- .../org/apache/spark/ml/util/ReadWrite.scala | 4 +-- .../ml/regression/LinearRegressionSuite.scala | 29 ++++++++++++++++--- 3 files changed, 29 insertions(+), 9 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 670c94d327628..31f95afce19b1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -711,7 +711,7 @@ class LinearRegressionModel private[ml] ( } /** - * Returns a [[org.apache.spark.ml.util.MLWriter]] instance for this ML instance. + * Returns a [[org.apache.spark.ml.util.GeneralMLWriter]] instance for this ML instance. * * For [[LinearRegressionModel]], this does NOT currently save the training [[summary]]. * An option to save [[summary]] may be added in the future. @@ -729,7 +729,7 @@ private class InternalLinearRegressionModelWriter() override def shortName(): String = "internal+org.apache.spark.ml.regression.LinearRegressionModel" - private case class Data(intercept: Double, coefficients: Vector) + private case class Data(intercept: Double, coefficients: Vector, scale: Double) override def write(path: String, sparkSession: SparkSession, optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = { @@ -738,7 +738,6 @@ private class InternalLinearRegressionModelWriter() // Save metadata and Params DefaultParamsWriter.saveMetadata(instance, path, sc) // Save model data: intercept, coefficients, scale - private case class Data(intercept: Double, coefficients: Vector, scale: Double) val data = Data(instance.intercept, instance.coefficients, instance.scale) val dataPath = new Path(path, "data").toString sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index fd142431988ae..29eccf3b1497b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -127,7 +127,7 @@ trait MLWriterFormat { * Function write the provided pipeline stage out. */ def write(path: String, session: SparkSession, optionMap: mutable.Map[String, String], - stage: PipelineStage) + stage: PipelineStage): Unit } /** @@ -239,7 +239,7 @@ class GeneralMLWriter(stage: PipelineStage) extends MLWriter with Logging { val writer = writerCls.newInstance().asInstanceOf[MLWriterFormat] writer.write(path, sparkSession, optionMap, stage) } else { - throw new SparkException("ML source $source is not a valid MLWriterFormat") + throw new SparkException(s"ML source $source is not a valid MLWriterFormat") } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index 7139292988f54..542ef15a10e4f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -18,19 +18,28 @@ package org.apache.spark.ml.regression import scala.collection.JavaConverters._ +import scala.collection.mutable import scala.util.Random import org.dmg.pmml.{OpType, PMML, RegressionModel => PMMLRegressionModel} import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.ml.PipelineStage import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.{DenseVector, Vector, Vectors} import org.apache.spark.ml.param.{ParamMap, ParamsSuite} -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils, PMMLReadWriteTest} +import org.apache.spark.ml.util._ import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Row, SparkSession} + +class DummyLinearRegressionWriter extends MLWriterFormat { + override def write(path: String, sparkSession: SparkSession, + optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = { + throw new Exception(s"Dummy writer doesn't write") + } +} class LinearRegressionSuite extends MLTest with DefaultReadWriteTest with PMMLReadWriteTest { @@ -1074,8 +1083,20 @@ class LinearRegressionSuite extends MLTest with DefaultReadWriteTest with PMMLRe intercept[SparkException] { model.write.format("com.holdenkarau.boop").save("boop") } - intercept[SparkException] { - model.write.format("org.apache.spark.SparkContext").save("boop2") + withClue("ML source org.apache.spark.SparkContext is not a valid MLWriterFormat") { + intercept[SparkException] { + model.write.format("org.apache.spark.SparkContext").save("boop2") + } + } + } + + test("dummy export format is called") { + val lr = new LinearRegression() + val model = lr.fit(datasetWithWeight) + withClue("Dummy writer doesn't write") { + intercept[Exception] { + model.write.format("org.apache.spark.ml.regression.DummyLinearRegressionWriter").save("") + } } } From 8fba2e58daa2b14957428e9436af4ba7cdba7a26 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Tue, 16 Jan 2018 18:39:42 -0800 Subject: [PATCH 11/16] Refactor a bit (especially in tests). --- .../org.apache.spark.ml.MLFormatRegister | 1 + .../ml/regression/LinearRegression.scala | 15 +- .../org/apache/spark/ml/util/ReadWrite.scala | 81 +++++++---- .../org.apache.spark.ml.util.MLFormatRegister | 3 + .../ml/regression/LinearRegressionSuite.scala | 35 ----- .../apache/spark/ml/util/ReadWriteSuite.scala | 134 ++++++++++++++++++ 6 files changed, 201 insertions(+), 68 deletions(-) create mode 100644 mllib/src/main/resources/META-INF/services/org.apache.spark.ml.MLFormatRegister create mode 100644 mllib/src/test/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister create mode 100644 mllib/src/test/scala/org/apache/spark/ml/util/ReadWriteSuite.scala diff --git a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.MLFormatRegister b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.MLFormatRegister new file mode 100644 index 0000000000000..869338a1fa454 --- /dev/null +++ b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.MLFormatRegister @@ -0,0 +1 @@ +org.apache.spark.ml.regression.InternalLinearRegressionModelWriter \ No newline at end of file diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 31f95afce19b1..8e161f745c78e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -723,11 +723,11 @@ class LinearRegressionModel private[ml] ( } /** A writer for LinearRegression that handles the "internal" (or default) format */ -private class InternalLinearRegressionModelWriter() +private class InternalLinearRegressionModelWriter extends MLWriterFormat with MLFormatRegister { - override def shortName(): String = - "internal+org.apache.spark.ml.regression.LinearRegressionModel" + override def format(): String = "internal" + override def stageName(): String = "org.apache.spark.ml.regression.LinearRegressionModel" private case class Data(intercept: Double, coefficients: Vector, scale: Double) @@ -745,11 +745,12 @@ private class InternalLinearRegressionModelWriter() } /** A writer for LinearRegression that handles the "pmml" format */ -private class PMMLLinearRegressionModelWriter() - extends MLWriterFormat with MLFormatRegister { +private class PMMLLinearRegressionModelWriter + extends MLWriterFormat with MLFormatRegister { + + override def format(): String = "pmml" - override def shortName(): String = - "pmml+org.apache.spark.ml.regression.LinearRegressionModel" + override def stageName(): String = "org.apache.spark.ml.regression.LinearRegressionModel" private case class Data(intercept: Double, coefficients: Vector) diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index 29eccf3b1497b..d779312cd0e21 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -87,53 +87,80 @@ private[util] sealed trait BaseReadWrite { protected final def sc: SparkContext = sparkSession.sparkContext } +/** + * Implemented by objects that provide ML exportability. + * + * A new instance of this class will be instantiated each time a save call is made. + * + * Must have a valid zero argument constructor which will be called to instantiate. + * + * @since 2.3.0 + */ +@InterfaceStability.Evolving +trait MLWriterFormat { + /** + * Function to write the provided pipeline stage out. + * + * @param path The path to write the result out to. + * @param session SparkSession associated with the write request. + * @param optionMap User provided options stored as strings. + * @param stage The pipeline stage to be saved. + */ + @Since("2.3.0") + def write(path: String, session: SparkSession, optionMap: mutable.Map[String, String], + stage: PipelineStage): Unit +} + /** * ML export formats for should implement this trait so that users can specify a shortname rather * than the fully qualified class name of the exporter. * - * A new instance of this class will be instantiated each time a DDL call is made. + * A new instance of this class will be instantiated each time a save call is made. * * @since 2.3.0 */ @InterfaceStability.Evolving -trait MLFormatRegister { +trait MLFormatRegister extends MLWriterFormat { /** - * The string that represents the format that this data source provider uses. This is - * overridden by children to provide a nice alias for the data source. For example: + * The string that represents the format that this format provider uses. This is, along with + * stageName, is overridden by children to provide a nice alias for the writer. For example: * * {{{ - * override def shortName(): String = - * "pmml+org.apache.spark.ml.regression.LinearRegressionModel" + * override def format(): String = + * "pmml" * }}} - * Indicates that this format is capable of saving Spark's own LinearRegressionModel in pmml. + * Indicates that this format is capable of saving a pmml model. * * Format discovery is done using a ServiceLoader so make sure to list your format in * META-INF/services. * @since 2.3.0 */ - def shortName(): String -} + @Since("2.3.0") + def format(): String -/** - * Implemented by objects that provide ML exportability. - * - * A new instance of this class will be instantiated each time a DDL call is made. - * - * @since 2.3.0 - */ -@InterfaceStability.Evolving -trait MLWriterFormat { /** - * Function write the provided pipeline stage out. + * The string that represents the stage type that this writer supports. This is, along with + * format, is overridden by children to provide a nice alias for the writer. For example: + * + * {{{ + * override def stageName(): String = + * "org.apache.spark.ml.regression.LinearRegressionModel" + * }}} + * Indicates that this format is capable of saving Spark's own PMML model. + * + * Format discovery is done using a ServiceLoader so make sure to list your format in + * META-INF/services. + * @since 2.3.0 */ - def write(path: String, session: SparkSession, optionMap: mutable.Map[String, String], - stage: PipelineStage): Unit + @Since("2.3.0") + def stageName(): String + + private[ml] def shortName(): String = s"${format()}+${stageName()}" } /** - * Abstract class for utility classes that can save ML instances. + * Abstract class for utility classes that can save ML instances in Spark's internal format. */ -@deprecated("Use GeneralMLWriter instead. Will be removed in Spark 3.0.0", "2.3.0") @Since("1.6.0") abstract class MLWriter extends BaseReadWrite with Logging { protected var shouldOverwrite: Boolean = false @@ -181,9 +208,11 @@ abstract class MLWriter extends BaseReadWrite with Logging { } // override for Java compatibility + @Since("1.6.0") override def session(sparkSession: SparkSession): this.type = super.session(sparkSession) // override for Java compatibility + @Since("1.6.0") override def context(sqlContext: SQLContext): this.type = super.session(sqlContext.sparkSession) } @@ -194,7 +223,7 @@ class GeneralMLWriter(stage: PipelineStage) extends MLWriter with Logging { private var source: String = "internal" /** - * Specifies the format of ML export (e.g. PMML, internal, or + * Specifies the format of ML export (e.g. "pmml", "internal", or * the fully qualified class name for export). */ @Since("2.3.0") @@ -209,11 +238,11 @@ class GeneralMLWriter(stage: PipelineStage) extends MLWriter with Logging { @Since("2.3.0") @throws[IOException]("If the input path already exists but overwrite is not enabled.") @throws[SparkException]("If multiple sources for a given short name format are found.") - override protected def saveImpl(path: String) = { + override protected def saveImpl(path: String): Unit = { val loader = Utils.getContextOrSparkClassLoader val serviceLoader = ServiceLoader.load(classOf[MLFormatRegister], loader) val stageName = stage.getClass.getName - val targetName = s"${source}+${stageName}" + val targetName = s"$source+$stageName" val formats = serviceLoader.asScala.toList val shortNames = formats.map(_.shortName()) val writerCls = formats.filter(_.shortName().equalsIgnoreCase(targetName)) match { diff --git a/mllib/src/test/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister b/mllib/src/test/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister new file mode 100644 index 0000000000000..100ef2545418f --- /dev/null +++ b/mllib/src/test/resources/META-INF/services/org.apache.spark.ml.util.MLFormatRegister @@ -0,0 +1,3 @@ +org.apache.spark.ml.util.DuplicateLinearRegressionWriter1 +org.apache.spark.ml.util.DuplicateLinearRegressionWriter2 +org.apache.spark.ml.util.FakeLinearRegressionWriterWithName diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index 790963e41a7f4..e6c156c7fd0ab 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -23,8 +23,6 @@ import scala.util.Random import org.dmg.pmml.{OpType, PMML, RegressionModel => PMMLRegressionModel} -import org.apache.spark.SparkException -import org.apache.spark.ml.PipelineStage import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.{DenseVector, Vector, Vectors} @@ -33,14 +31,7 @@ import org.apache.spark.ml.util._ import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.LinearDataGenerator import org.apache.spark.sql.{DataFrame, Row} -import org.apache.spark.sql.{DataFrame, Row, SparkSession} -class DummyLinearRegressionWriter extends MLWriterFormat { - override def write(path: String, sparkSession: SparkSession, - optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = { - throw new Exception(s"Dummy writer doesn't write") - } -} class LinearRegressionSuite extends MLTest with DefaultReadWriteTest with PMMLReadWriteTest { @@ -1077,32 +1068,6 @@ class LinearRegressionSuite extends MLTest with DefaultReadWriteTest with PMMLRe testPMMLWrite(sc, model, checkModel) } - test("unsupported export format") { - val lr = new LinearRegression() - val model = lr.fit(datasetWithWeight) - intercept[SparkException] { - model.write.format("boop").save("boop") - } - intercept[SparkException] { - model.write.format("com.holdenkarau.boop").save("boop") - } - withClue("ML source org.apache.spark.SparkContext is not a valid MLWriterFormat") { - intercept[SparkException] { - model.write.format("org.apache.spark.SparkContext").save("boop2") - } - } - } - - test("dummy export format is called") { - val lr = new LinearRegression() - val model = lr.fit(datasetWithWeight) - withClue("Dummy writer doesn't write") { - intercept[Exception] { - model.write.format("org.apache.spark.ml.regression.DummyLinearRegressionWriter").save("") - } - } - } - test("should support all NumericType labels and weights, and not support other types") { for (solver <- Seq("auto", "l-bfgs", "normal")) { val lr = new LinearRegression().setMaxIter(1).setSolver(solver) diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/ReadWriteSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/util/ReadWriteSuite.scala new file mode 100644 index 0000000000000..34f9ca3f7d061 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/util/ReadWriteSuite.scala @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.util + +import scala.collection.mutable + +import org.apache.spark.SparkException +import org.apache.spark.ml.PipelineStage +import org.apache.spark.ml.regression.LinearRegression +import org.apache.spark.mllib.util.LinearDataGenerator +import org.apache.spark.sql.{DataFrame, SparkSession} + +class FakeLinearRegressionWriter extends MLWriterFormat { + override def write(path: String, sparkSession: SparkSession, + optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = { + throw new Exception(s"Fake writer doesn't writestart") + } +} + +class FakeLinearRegressionWriterWithName extends MLFormatRegister { + override def format(): String = "fakeWithName" + override def stageName(): String = "org.apache.spark.ml.regression.LinearRegressionModel" + override def write(path: String, sparkSession: SparkSession, + optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = { + throw new Exception(s"Fake writer doesn't writestart") + } +} + + +class DuplicateLinearRegressionWriter1 extends MLFormatRegister { + override def format(): String = "dupe" + override def stageName(): String = "org.apache.spark.ml.regression.LinearRegressionModel" + override def write(path: String, sparkSession: SparkSession, + optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = { + throw new Exception(s"Duplicate writer shouldn't have been called") + } +} + +class DuplicateLinearRegressionWriter2 extends MLFormatRegister { + override def format(): String = "dupe" + override def stageName(): String = "org.apache.spark.ml.regression.LinearRegressionModel" + override def write(path: String, sparkSession: SparkSession, + optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = { + throw new Exception(s"Duplicate writer shouldn't have been called") + } +} + +class ReadWriteSuite extends MLTest { + + import testImplicits._ + + private val seed: Int = 42 + @transient var dataset: DataFrame = _ + + override def beforeAll(): Unit = { + super.beforeAll() + dataset = sc.parallelize(LinearDataGenerator.generateLinearInput( + intercept = 0.0, weights = Array(1.0, 2.0), xMean = Array(0.0, 1.0), + xVariance = Array(2.0, 1.0), nPoints = 10, seed, eps = 0.2)).map(_.asML).toDF() + } + + test("unsupported/non existent export formats") { + val lr = new LinearRegression() + val model = lr.fit(dataset) + // Does not exist with a long class name + val thrownDNE = intercept[SparkException] { + model.write.format("com.holdenkarau.boop").save("boop") + } + assert(thrownDNE.getMessage(). + contains("Could not load requested format")) + + // Does not exist with a short name + val thrownDNEShort = intercept[SparkException] { + model.write.format("boop").save("boop") + } + assert(thrownDNEShort.getMessage(). + contains("Could not load requested format")) + + // Check with a valid class that is not a writer format. + val thrownInvalid = intercept[SparkException] { + model.write.format("org.apache.spark.SparkContext").save("boop2") + } + assert(thrownInvalid.getMessage() + .contains("ML source org.apache.spark.SparkContext is not a valid MLWriterFormat")) + } + + test("invalid paths fail") { + val lr = new LinearRegression() + val model = lr.fit(dataset) + val thrown = intercept[Exception] { + model.write.format("pmml").save("") + } + assert(thrown.getMessage().contains("Can not create a Path from an empty string")) + } + + test("dummy export format is called") { + val lr = new LinearRegression() + val model = lr.fit(dataset) + val thrown = intercept[Exception] { + model.write.format("org.apache.spark.ml.util.FakeLinearRegressionWriter").save("name") + } + assert(thrown.getMessage().contains("Fake writer doesn't write")) + val thrownWithName = intercept[Exception] { + model.write.format("fakeWithName").save("name") + } + assert(thrownWithName.getMessage().contains("Fake writer doesn't write")) + } + + test("duplicate format raises error") { + val lr = new LinearRegression() + val model = lr.fit(dataset) + val thrown = intercept[Exception] { + model.write.format("dupe").save("dupepanda") + } + assert(thrown.getMessage().contains("Multiple writers found for")) + } + + // TODO (save wrong format model) +} From 6411054d12ea3d4b98bdc7ad409b392a8cc16a8f Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Tue, 16 Jan 2018 18:40:34 -0800 Subject: [PATCH 12/16] eh the wrong format error is up to each implementation (e.g. someone could write a multi-format export class I suppose) --- .../test/scala/org/apache/spark/ml/util/ReadWriteSuite.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/ReadWriteSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/util/ReadWriteSuite.scala index 34f9ca3f7d061..f4c1f0bdb32cd 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/ReadWriteSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/ReadWriteSuite.scala @@ -129,6 +129,4 @@ class ReadWriteSuite extends MLTest { } assert(thrown.getMessage().contains("Multiple writers found for")) } - - // TODO (save wrong format model) } From 40472390fe2932900c4d788bf2e19bcfbff12c92 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Tue, 16 Jan 2018 18:57:00 -0800 Subject: [PATCH 13/16] remove old format register meta-inf file --- .../META-INF/services/org.apache.spark.ml.MLFormatRegister | 1 - 1 file changed, 1 deletion(-) delete mode 100644 mllib/src/main/resources/META-INF/services/org.apache.spark.ml.MLFormatRegister diff --git a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.MLFormatRegister b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.MLFormatRegister deleted file mode 100644 index 869338a1fa454..0000000000000 --- a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.MLFormatRegister +++ /dev/null @@ -1 +0,0 @@ -org.apache.spark.ml.regression.InternalLinearRegressionModelWriter \ No newline at end of file From cd330f3f37796f61b8b62a24c5600cf2868b6470 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Tue, 16 Jan 2018 19:00:00 -0800 Subject: [PATCH 14/16] Annoations and remove unecessary whitespace change --- .../src/main/scala/org/apache/spark/ml/util/ReadWrite.scala | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index d779312cd0e21..b074062a51818 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -97,6 +97,7 @@ private[util] sealed trait BaseReadWrite { * @since 2.3.0 */ @InterfaceStability.Evolving +@Since("2.3.0") trait MLWriterFormat { /** * Function to write the provided pipeline stage out. @@ -120,6 +121,7 @@ trait MLWriterFormat { * @since 2.3.0 */ @InterfaceStability.Evolving +@Since("2.3.0") trait MLFormatRegister extends MLWriterFormat { /** * The string that represents the format that this format provider uses. This is, along with @@ -163,6 +165,7 @@ trait MLFormatRegister extends MLWriterFormat { */ @Since("1.6.0") abstract class MLWriter extends BaseReadWrite with Logging { + protected var shouldOverwrite: Boolean = false /** @@ -219,6 +222,8 @@ abstract class MLWriter extends BaseReadWrite with Logging { /** * A ML Writer which delegates based on the requested format. */ +@InterfaceStability.Evolving +@Since("2.3.0") class GeneralMLWriter(stage: PipelineStage) extends MLWriter with Logging { private var source: String = "internal" @@ -303,6 +308,7 @@ trait MLWritable { * Trait for classes that provide `GeneralMLWriter`. */ @Since("2.3.0") +@InterfaceStability.Evolving trait GeneralMLWritable extends MLWritable { /** * Returns an `MLWriter` instance for this ML instance. From 41312e7fcd6d706323fecb828776fced5e5a769c Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Tue, 16 Jan 2018 19:03:05 -0800 Subject: [PATCH 15/16] Weaken promise to Unstable --- .../main/scala/org/apache/spark/ml/util/ReadWrite.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index b074062a51818..a9c42715d7a72 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -96,7 +96,7 @@ private[util] sealed trait BaseReadWrite { * * @since 2.3.0 */ -@InterfaceStability.Evolving +@InterfaceStability.Unstable @Since("2.3.0") trait MLWriterFormat { /** @@ -120,7 +120,7 @@ trait MLWriterFormat { * * @since 2.3.0 */ -@InterfaceStability.Evolving +@InterfaceStability.Unstable @Since("2.3.0") trait MLFormatRegister extends MLWriterFormat { /** @@ -222,7 +222,7 @@ abstract class MLWriter extends BaseReadWrite with Logging { /** * A ML Writer which delegates based on the requested format. */ -@InterfaceStability.Evolving +@InterfaceStability.Unstable @Since("2.3.0") class GeneralMLWriter(stage: PipelineStage) extends MLWriter with Logging { private var source: String = "internal" @@ -308,7 +308,7 @@ trait MLWritable { * Trait for classes that provide `GeneralMLWriter`. */ @Since("2.3.0") -@InterfaceStability.Evolving +@InterfaceStability.Unstable trait GeneralMLWritable extends MLWritable { /** * Returns an `MLWriter` instance for this ML instance. From cb6fd70d0c61b6477f7514431ee2e1c097ec0aff Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Fri, 23 Mar 2018 11:26:28 -0700 Subject: [PATCH 16/16] Update since annotation to 2.4.0 and add comment re: zero arg constructor --- .../org/apache/spark/ml/util/ReadWrite.scala | 32 ++++++++++--------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index a9c42715d7a72..7edcd498678cc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -88,16 +88,16 @@ private[util] sealed trait BaseReadWrite { } /** - * Implemented by objects that provide ML exportability. + * Abstract class to be implemented by objects that provide ML exportability. * * A new instance of this class will be instantiated each time a save call is made. * * Must have a valid zero argument constructor which will be called to instantiate. * - * @since 2.3.0 + * @since 2.4.0 */ @InterfaceStability.Unstable -@Since("2.3.0") +@Since("2.4.0") trait MLWriterFormat { /** * Function to write the provided pipeline stage out. @@ -107,7 +107,7 @@ trait MLWriterFormat { * @param optionMap User provided options stored as strings. * @param stage The pipeline stage to be saved. */ - @Since("2.3.0") + @Since("2.4.0") def write(path: String, session: SparkSession, optionMap: mutable.Map[String, String], stage: PipelineStage): Unit } @@ -118,10 +118,10 @@ trait MLWriterFormat { * * A new instance of this class will be instantiated each time a save call is made. * - * @since 2.3.0 + * @since 2.4.0 */ @InterfaceStability.Unstable -@Since("2.3.0") +@Since("2.4.0") trait MLFormatRegister extends MLWriterFormat { /** * The string that represents the format that this format provider uses. This is, along with @@ -133,11 +133,13 @@ trait MLFormatRegister extends MLWriterFormat { * }}} * Indicates that this format is capable of saving a pmml model. * + * Must have a valid zero argument constructor which will be called to instantiate. + * * Format discovery is done using a ServiceLoader so make sure to list your format in * META-INF/services. - * @since 2.3.0 + * @since 2.4.0 */ - @Since("2.3.0") + @Since("2.4.0") def format(): String /** @@ -152,9 +154,9 @@ trait MLFormatRegister extends MLWriterFormat { * * Format discovery is done using a ServiceLoader so make sure to list your format in * META-INF/services. - * @since 2.3.0 + * @since 2.4.0 */ - @Since("2.3.0") + @Since("2.4.0") def stageName(): String private[ml] def shortName(): String = s"${format()}+${stageName()}" @@ -223,7 +225,7 @@ abstract class MLWriter extends BaseReadWrite with Logging { * A ML Writer which delegates based on the requested format. */ @InterfaceStability.Unstable -@Since("2.3.0") +@Since("2.4.0") class GeneralMLWriter(stage: PipelineStage) extends MLWriter with Logging { private var source: String = "internal" @@ -231,7 +233,7 @@ class GeneralMLWriter(stage: PipelineStage) extends MLWriter with Logging { * Specifies the format of ML export (e.g. "pmml", "internal", or * the fully qualified class name for export). */ - @Since("2.3.0") + @Since("2.4.0") def format(source: String): this.type = { this.source = source this @@ -240,7 +242,7 @@ class GeneralMLWriter(stage: PipelineStage) extends MLWriter with Logging { /** * Dispatches the save to the correct MLFormat. */ - @Since("2.3.0") + @Since("2.4.0") @throws[IOException]("If the input path already exists but overwrite is not enabled.") @throws[SparkException]("If multiple sources for a given short name format are found.") override protected def saveImpl(path: String): Unit = { @@ -307,13 +309,13 @@ trait MLWritable { /** * Trait for classes that provide `GeneralMLWriter`. */ -@Since("2.3.0") +@Since("2.4.0") @InterfaceStability.Unstable trait GeneralMLWritable extends MLWritable { /** * Returns an `MLWriter` instance for this ML instance. */ - @Since("2.3.0") + @Since("2.4.0") override def write: GeneralMLWriter }