Skip to content
Closed
82 changes: 82 additions & 0 deletions docs/ml-clustering.md
Original file line number Diff line number Diff line change
Expand Up @@ -148,3 +148,85 @@ Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.clustering.
{% include_example python/ml/bisecting_k_means_example.py %}
</div>
</div>

## Gaussian Mixture Model (GMM)

A [Gaussian Mixture Model](http://en.wikipedia.org/wiki/Mixture_model#Multivariate_Gaussian_mixture_model)
represents a composite distribution whereby points are drawn from one of *k* Gaussian sub-distributions,
each with its own probability. The `spark.ml` implementation uses the
[expectation-maximization](http://en.wikipedia.org/wiki/Expectation%E2%80%93maximization_algorithm)
algorithm to induce the maximum-likelihood model given a set of samples.

`GaussianMixture` is implemented as an `Estimator` and generates a `GaussianMixtureModel` as the base
model.

### Input Columns

<table class="table">
<thead>
<tr>
<th align="left">Param name</th>
<th align="left">Type(s)</th>
<th align="left">Default</th>
<th align="left">Description</th>
</tr>
</thead>
<tbody>
<tr>
<td>featuresCol</td>
<td>Vector</td>
<td>"features"</td>
<td>Feature vector</td>
</tr>
</tbody>
</table>

### Output Columns

<table class="table">
<thead>
<tr>
<th align="left">Param name</th>
<th align="left">Type(s)</th>
<th align="left">Default</th>
<th align="left">Description</th>
</tr>
</thead>
<tbody>
<tr>
<td>predictionCol</td>
<td>Int</td>
<td>"prediction"</td>
<td>Predicted cluster center</td>
</tr>
<tr>
<td>probabilityCol</td>
<td>Vector</td>
<td>"probability"</td>
<td>Probability of each cluster</td>
</tr>
</tbody>
</table>

### Example

<div class="codetabs">

<div data-lang="scala" markdown="1">
Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.clustering.GaussianMixture) for more details.

{% include_example scala/org/apache/spark/examples/ml/GaussianMixtureExample.scala %}
</div>

<div data-lang="java" markdown="1">
Refer to the [Java API docs](api/java/org/apache/spark/ml/clustering/GaussianMixture.html) for more details.

{% include_example java/org/apache/spark/examples/ml/JavaGaussianMixtureExample.java %}
</div>

<div data-lang="python" markdown="1">
Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.clustering.GaussianMixture) for more details.

{% include_example python/ml/gaussian_mixture_example.py %}
</div>
</div>
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.examples.ml;

// $example on$
import org.apache.spark.ml.clustering.GaussianMixture;
import org.apache.spark.ml.clustering.GaussianMixtureModel;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
// $example off$
import org.apache.spark.sql.SparkSession;


/**
* An example demonstrating Gaussian Mixture Model.
* Run with
* <pre>
* bin/run-example ml.JavaGaussianMixtureExample
* </pre>
*/
public class JavaGaussianMixtureExample {

public static void main(String[] args) {

// Creates a SparkSession
SparkSession spark = SparkSession
.builder()
.appName("JavaGaussianMixtureExample")
.getOrCreate();

// $example on$
// Loads data
Dataset<Row> dataset = spark.read().format("libsvm").load("data/mllib/sample_kmeans_data.txt");

// Trains a GaussianMixture model
GaussianMixture gmm = new GaussianMixture()
.setK(2);
GaussianMixtureModel model = gmm.fit(dataset);

// Output the parameters of the mixture model
for (int i = 0; i < model.getK(); i++) {
System.out.printf("weight=%f\nmu=%s\nsigma=\n%s\n",
model.weights()[i], model.gaussians()[i].mean(), model.gaussians()[i].cov());
}
// $example off$

spark.stop();
}
}
48 changes: 48 additions & 0 deletions examples/src/main/python/ml/gaussian_mixture_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from __future__ import print_function

# $example on$
from pyspark.ml.clustering import GaussianMixture
# $example off$
from pyspark.sql import SparkSession

"""
A simple example demonstrating Gaussian Mixture Model (GMM).
Run with:
bin/spark-submit examples/src/main/python/ml/gaussian_mixture_example.py
"""

if __name__ == "__main__":
spark = SparkSession\
.builder\
.appName("PythonGuassianMixtureExample")\
.getOrCreate()

# $example on$
# loads data
dataset = spark.read.format("libsvm").load("data/mllib/sample_kmeans_data.txt")

gmm = GaussianMixture().setK(2)
model = gmm.fit(dataset)

print("Gaussians: ")
model.gaussiansDF.show()
# $example off$

spark.stop()
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.examples.ml

// scalastyle:off println

// $example on$
import org.apache.spark.ml.clustering.GaussianMixture
import org.apache.spark.sql.SparkSession
// $example off$

/**
* An example demonstrating Gaussian Mixture Model (GMM).
* Run with
* {{{
* bin/run-example ml.GaussianMixtureExample
* }}}
*/
object GaussianMixtureExample {
def main(args: Array[String]): Unit = {
// Creates a SparkSession
val spark = SparkSession.builder.appName(s"${this.getClass.getSimpleName}").getOrCreate()

// $example on$
// Loads data
val dataset = spark.read.format("libsvm").load("data/mllib/sample_kmeans_data.txt")

// Trains Gaussian Mixture Model
val gmm = new GaussianMixture()
.setK(2)
val model = gmm.fit(dataset)

// output parameters of mixture model model
for (i <- 0 until model.getK) {
println("weight=%f\nmu=%s\nsigma=\n%s\n" format
(model.weights(i), model.gaussians(i).mean, model.gaussians(i).cov))
}
// $example off$

spark.stop()
}
}
// scalastyle:on println