Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/mllib-clustering.md
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ to the algorithm. We then output the parameters of the mixture model.

{% highlight scala %}
import org.apache.spark.mllib.clustering.GaussianMixture
import org.apache.spark.mllib.clustering.GaussianMixtureModel
import org.apache.spark.mllib.linalg.Vectors

// Load and parse the data
Expand All @@ -182,6 +183,10 @@ val parsedData = data.map(s => Vectors.dense(s.trim.split(' ').map(_.toDouble)))
// Cluster the data into two classes using GaussianMixture
val gmm = new GaussianMixture().setK(2).run(parsedData)

// Save and load model
gmm.save(sc, "myGMMModel")
val sameModel = GaussianMixtureModel.load(sc, "myGMMModel")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please also update the Java example.


// output parameters of max-likelihood model
for (i <- 0 until gmm.k) {
println("weight=%f\nmu=%s\nsigma=\n%s\n" format
Expand Down Expand Up @@ -231,6 +236,9 @@ public class GaussianMixtureExample {
// Cluster the data into two classes using GaussianMixture
GaussianMixtureModel gmm = new GaussianMixture().setK(2).run(parsedData.rdd());

// Save and load GaussianMixtureModel
gmm.save(sc, "myGMMModel")
GaussianMixtureModel sameModel = GaussianMixtureModel.load(sc, "myGMMModel")
// Output the parameters of the mixture model
for(int j=0; j<gmm.k(); j++) {
System.out.println("weight=%f\nmu=%s\nsigma=\n%s\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,17 @@ package org.apache.spark.mllib.clustering

import breeze.linalg.{DenseVector => BreezeVector}

import org.json4s.DefaultFormats
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._

import org.apache.spark.SparkContext
import org.apache.spark.annotation.Experimental
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.linalg.{Vector, Matrices, Matrix}
import org.apache.spark.mllib.stat.distribution.MultivariateGaussian
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.mllib.util.{MLUtils, Loader, Saveable}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{SQLContext, Row}

/**
* :: Experimental ::
Expand All @@ -41,10 +47,16 @@ import org.apache.spark.rdd.RDD
@Experimental
class GaussianMixtureModel(
val weights: Array[Double],
val gaussians: Array[MultivariateGaussian]) extends Serializable {
val gaussians: Array[MultivariateGaussian]) extends Serializable with Saveable{

require(weights.length == gaussians.length, "Length of weight and Gaussian arrays must match")


override protected def formatVersion = "1.0"

override def save(sc: SparkContext, path: String): Unit = {
GaussianMixtureModel.SaveLoadV1_0.save(sc, path, weights, gaussians)
}

/** Number of gaussians in mixture */
def k: Int = weights.length

Expand Down Expand Up @@ -83,5 +95,79 @@ class GaussianMixtureModel(
p(i) /= pSum
}
p
}
}
}

@Experimental
object GaussianMixtureModel extends Loader[GaussianMixtureModel] {

private object SaveLoadV1_0 {

case class Data(weight: Double, mu: Vector, sigma: Matrix)

val formatVersionV1_0 = "1.0"

val classNameV1_0 = "org.apache.spark.mllib.clustering.GaussianMixtureModel"

def save(
sc: SparkContext,
path: String,
weights: Array[Double],
gaussians: Array[MultivariateGaussian]): Unit = {

val sqlContext = new SQLContext(sc)
import sqlContext.implicits._

// Create JSON metadata.
val metadata = compact(render
(("class" -> classNameV1_0) ~ ("version" -> formatVersionV1_0) ~ ("k" -> weights.length)))
sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))

// Create Parquet data.
val dataArray = Array.tabulate(weights.length) { i =>
Data(weights(i), gaussians(i).mu, gaussians(i).sigma)
}
sc.parallelize(dataArray, 1).toDF().saveAsParquetFile(Loader.dataPath(path))
}

def load(sc: SparkContext, path: String): GaussianMixtureModel = {
val dataPath = Loader.dataPath(path)
val sqlContext = new SQLContext(sc)
val dataFrame = sqlContext.parquetFile(dataPath)
val dataArray = dataFrame.select("weight", "mu", "sigma").collect()

// Check schema explicitly since erasure makes it hard to use match-case for checking.
Loader.checkSchema[Data](dataFrame.schema)

val (weights, gaussians) = dataArray.map {
case Row(weight: Double, mu: Vector, sigma: Matrix) =>
(weight, new MultivariateGaussian(mu, sigma))
}.unzip

return new GaussianMixtureModel(weights.toArray, gaussians.toArray)
}
}

override def load(sc: SparkContext, path: String) : GaussianMixtureModel = {
val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path)
implicit val formats = DefaultFormats
val k = (metadata \ "k").extract[Int]
val classNameV1_0 = SaveLoadV1_0.classNameV1_0
(loadedClassName, version) match {
case (classNameV1_0, "1.0") => {
val model = SaveLoadV1_0.load(sc, path)
require(model.weights.length == k,
s"GaussianMixtureModel requires weights of length $k " +
s"got weights of length ${model.weights.length}")
require(model.gaussians.length == k,
s"GaussianMixtureModel requires gaussians of length $k" +
s"got gaussians of length ${model.gaussians.length}")
model
}
case _ => throw new Exception(
s"GaussianMixtureModel.load did not recognize model with (className, format version):" +
s"($loadedClassName, $version). Supported:\n" +
s" ($classNameV1_0, 1.0)")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import org.apache.spark.mllib.linalg.{Vectors, Matrices}
import org.apache.spark.mllib.stat.distribution.MultivariateGaussian
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.util.Utils

class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext {
test("single cluster") {
Expand All @@ -48,13 +49,7 @@ class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext {
}

test("two clusters") {
val data = sc.parallelize(Array(
Vectors.dense(-5.1971), Vectors.dense(-2.5359), Vectors.dense(-3.8220),
Vectors.dense(-5.2211), Vectors.dense(-5.0602), Vectors.dense( 4.7118),
Vectors.dense( 6.8989), Vectors.dense( 3.4592), Vectors.dense( 4.6322),
Vectors.dense( 5.7048), Vectors.dense( 4.6567), Vectors.dense( 5.5026),
Vectors.dense( 4.5605), Vectors.dense( 5.2043), Vectors.dense( 6.2734)
))
val data = sc.parallelize(GaussianTestData.data)

// we set an initial gaussian to induce expected results
val initialGmm = new GaussianMixtureModel(
Expand Down Expand Up @@ -105,14 +100,7 @@ class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext {
}

test("two clusters with sparse data") {
val data = sc.parallelize(Array(
Vectors.dense(-5.1971), Vectors.dense(-2.5359), Vectors.dense(-3.8220),
Vectors.dense(-5.2211), Vectors.dense(-5.0602), Vectors.dense( 4.7118),
Vectors.dense( 6.8989), Vectors.dense( 3.4592), Vectors.dense( 4.6322),
Vectors.dense( 5.7048), Vectors.dense( 4.6567), Vectors.dense( 5.5026),
Vectors.dense( 4.5605), Vectors.dense( 5.2043), Vectors.dense( 6.2734)
))

val data = sc.parallelize(GaussianTestData.data)
val sparseData = data.map(point => Vectors.sparse(1, Array(0), point.toArray))
// we set an initial gaussian to induce expected results
val initialGmm = new GaussianMixtureModel(
Expand All @@ -138,4 +126,38 @@ class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext {
assert(sparseGMM.gaussians(0).sigma ~== Esigma(0) absTol 1E-3)
assert(sparseGMM.gaussians(1).sigma ~== Esigma(1) absTol 1E-3)
}

test("model save / load") {
val data = sc.parallelize(GaussianTestData.data)

val gmm = new GaussianMixture().setK(2).setSeed(0).run(data)
val tempDir = Utils.createTempDir()
val path = tempDir.toURI.toString

try {
gmm.save(sc, path)

// TODO: GaussianMixtureModel should implement equals/hashcode directly.
val sameModel = GaussianMixtureModel.load(sc, path)
assert(sameModel.k === gmm.k)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please leave a TODO here for GMM's equals/hashCode.

(0 until sameModel.k).foreach { i =>
assert(sameModel.gaussians(i).mu === gmm.gaussians(i).mu)
assert(sameModel.gaussians(i).sigma === gmm.gaussians(i).sigma)
}
} finally {
Utils.deleteRecursively(tempDir)
}
}

object GaussianTestData {

val data = Array(
Vectors.dense(-5.1971), Vectors.dense(-2.5359), Vectors.dense(-3.8220),
Vectors.dense(-5.2211), Vectors.dense(-5.0602), Vectors.dense( 4.7118),
Vectors.dense( 6.8989), Vectors.dense( 3.4592), Vectors.dense( 4.6322),
Vectors.dense( 5.7048), Vectors.dense( 4.6567), Vectors.dense( 5.5026),
Vectors.dense( 4.5605), Vectors.dense( 5.2043), Vectors.dense( 6.2734)
)

}
}