Skip to content

Commit 1749aec

Browse files
committed
Try adding PMMLExportable to ML with KMeans
1 parent 2f6dd63 commit 1749aec

File tree

3 files changed

+104
-3
lines changed

3 files changed

+104
-3
lines changed

mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,12 @@
1717

1818
package org.apache.spark.ml.clustering
1919

20+
import javax.xml.transform.stream.StreamResult
21+
2022
import org.apache.spark.annotation.{Since, Experimental}
2123
import org.apache.spark.ml.param.{Param, Params, IntParam, ParamMap}
2224
import org.apache.spark.ml.param.shared._
25+
import org.apache.spark.ml.pmml.PMMLExportable
2326
import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
2427
import org.apache.spark.ml.{Estimator, Model}
2528
import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel}
@@ -94,7 +97,8 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe
9497
@Experimental
9598
class KMeansModel private[ml] (
9699
@Since("1.5.0") override val uid: String,
97-
private val parentModel: MLlibKMeansModel) extends Model[KMeansModel] with KMeansParams {
100+
private val parentModel: MLlibKMeansModel) extends Model[KMeansModel] with KMeansParams
101+
with PMMLExportable {
98102

99103
@Since("1.5.0")
100104
override def copy(extra: ParamMap): KMeansModel = {
@@ -129,6 +133,14 @@ class KMeansModel private[ml] (
129133
val data = dataset.select(col($(featuresCol))).map { case Row(point: Vector) => point }
130134
parentModel.computeCost(data)
131135
}
136+
137+
/**
138+
* Export the model to stream result in PMML format
139+
*/
140+
@Since("1.6.0")
141+
override def toPMML(streamResult: StreamResult): Unit = {
142+
parentModel.toPMML(streamResult)
143+
}
132144
}
133145

134146
/**
@@ -209,4 +221,3 @@ class KMeans @Since("1.5.0") (
209221
validateAndTransformSchema(schema)
210222
}
211223
}
212-
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.ml.pmml
19+
20+
import java.io.{File, OutputStream, StringWriter}
21+
import javax.xml.transform.stream.StreamResult
22+
23+
import org.jpmml.model.JAXBUtil
24+
25+
import org.apache.spark.SparkContext
26+
import org.apache.spark.annotation.{DeveloperApi, Experimental, Since}
27+
import org.apache.spark.mllib.pmml.export.PMMLModelExportFactory
28+
29+
/**
30+
* :: DeveloperApi ::
31+
* Export model to the PMML format
32+
* Predictive Model Markup Language (PMML) is an XML-based file format
33+
* developed by the Data Mining Group (www.dmg.org).
34+
* Based on [[org.apache.spark.mllib.pmml.Exportable]]
35+
*/
36+
@DeveloperApi
37+
@Since("1.6.0")
38+
trait PMMLExportable {
39+
40+
/**
41+
* Export the model to the stream result in PMML format.
42+
*/
43+
private[spark] def toPMML(streamResult: StreamResult): Unit
44+
45+
/**
46+
* :: Experimental ::
47+
* Export the model to a local file in PMML format
48+
*/
49+
@Experimental
50+
@Since("1.6.0")
51+
def toPMML(localPath: String): Unit = {
52+
toPMML(new StreamResult(new File(localPath)))
53+
}
54+
55+
/**
56+
* :: Experimental ::
57+
* Export the model to a directory on a distributed file system in PMML format.
58+
* Models should override if they may contain more data than
59+
* is reasonable to store locally.
60+
*/
61+
@Experimental
62+
@Since("1.6.0")
63+
def toPMML(sc: SparkContext, path: String): Unit = {
64+
val pmml = toPMML()
65+
sc.parallelize(Array(pmml), 1).saveAsTextFile(path)
66+
}
67+
68+
/**
69+
* :: Experimental ::
70+
* Export the model to the OutputStream in PMML format
71+
*/
72+
@Experimental
73+
@Since("1.6.0")
74+
def toPMML(outputStream: OutputStream): Unit = {
75+
toPMML(new StreamResult(outputStream))
76+
}
77+
78+
/**
79+
* :: Experimental ::
80+
* Export the model to a String in PMML format
81+
*/
82+
@Experimental
83+
@Since("1.6.0")
84+
def toPMML(): String = {
85+
val writer = new StringWriter
86+
toPMML(new StreamResult(writer))
87+
writer.toString
88+
}
89+
90+
}

mllib/src/main/scala/org/apache/spark/mllib/pmml/PMMLExportable.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ trait PMMLExportable {
3939
/**
4040
* Export the model to the stream result in PMML format
4141
*/
42-
private def toPMML(streamResult: StreamResult): Unit = {
42+
private[spark] def toPMML(streamResult: StreamResult): Unit = {
4343
val pmmlModelExport = PMMLModelExportFactory.createPMMLModelExport(this)
4444
JAXBUtil.marshalPMML(pmmlModelExport.getPmml, streamResult)
4545
}

0 commit comments

Comments
 (0)