Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/ml-clustering.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@ Refer to the [Java API docs](api/java/org/apache/spark/ml/clustering/KMeans.html
{% include_example java/org/apache/spark/examples/ml/JavaKMeansExample.java %}
</div>

<div data-lang="python" markdown="1">
Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.clustering.KMeans) for more details.

{% include_example python/ml/kmeans_example.py %}
</div>
</div>


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,77 +17,45 @@

package org.apache.spark.examples.ml;

import java.util.regex.Pattern;

import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.catalyst.expressions.GenericRow;
// $example on$
import org.apache.spark.ml.clustering.KMeansModel;
import org.apache.spark.ml.clustering.KMeans;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.VectorUDT;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
// $example off$
import org.apache.spark.sql.SparkSession;


/**
* An example demonstrating a k-means clustering.
* An example demonstrating k-means clustering.
* Run with
* <pre>
* bin/run-example ml.JavaKMeansExample <file> <k>
* bin/run-example ml.JavaKMeansExample
* </pre>
*/
public class JavaKMeansExample {

private static class ParsePoint implements Function<String, Row> {
private static final Pattern separator = Pattern.compile(" ");

@Override
public Row call(String line) {
String[] tok = separator.split(line);
double[] point = new double[tok.length];
for (int i = 0; i < tok.length; ++i) {
point[i] = Double.parseDouble(tok[i]);
}
Vector[] points = {Vectors.dense(point)};
return new GenericRow(points);
}
}

public static void main(String[] args) {
if (args.length != 2) {
System.err.println("Usage: ml.JavaKMeansExample <file> <k>");
System.exit(1);
}
String inputFile = args[0];
int k = Integer.parseInt(args[1]);

// Parses the arguments
// Create a SparkSession.
SparkSession spark = SparkSession
.builder()
.appName("JavaKMeansExample")
.getOrCreate();

// $example on$
// Loads data
JavaRDD<Row> points = spark.read().text(inputFile).javaRDD().map(new ParsePoint());
StructField[] fields = {new StructField("features", new VectorUDT(), false, Metadata.empty())};
StructType schema = new StructType(fields);
Dataset<Row> dataset = spark.createDataFrame(points, schema);
// Loads data.
Dataset<Row> dataset = spark.read().format("libsvm").load("data/mllib/sample_kmeans_data.txt");

// Trains a k-means model
KMeans kmeans = new KMeans()
.setK(k);
// Trains a k-means model.
KMeans kmeans = new KMeans().setK(2).setSeed(1L);
KMeansModel model = kmeans.fit(dataset);

// Shows the result
// Evaluate clustering by computing Within Set Sum of Squared Errors.
double WSSSE = model.computeCost(dataset);
System.out.println("Within Set Sum of Squared Errors = " + WSSSE);

// Shows the result.
Vector[] centers = model.clusterCenters();
System.out.println("Cluster Centers: ");
for (Vector center: centers) {
Expand Down
46 changes: 18 additions & 28 deletions examples/src/main/python/ml/kmeans_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,55 +17,45 @@

from __future__ import print_function

import sys
# $example on$
from pyspark.ml.clustering import KMeans
# $example off$

import numpy as np
from pyspark.ml.clustering import KMeans, KMeansModel
from pyspark.mllib.linalg import VectorUDT, _convert_to_vector
from pyspark.sql import SparkSession
from pyspark.sql.types import Row, StructField, StructType

"""
A simple example demonstrating a k-means clustering.
An example demonstrating k-means clustering.
Run with:
bin/spark-submit examples/src/main/python/ml/kmeans_example.py <input> <k>
bin/spark-submit examples/src/main/python/ml/kmeans_example.py

This example requires NumPy (http://www.numpy.org/).
"""

Copy link
Contributor

@holdenk holdenk May 9, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: So I believe this example still requires NumPy even though it isn't explicitly imported (see inside of def toArray called inside of clusterCenters which says it returns a NumPy array).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. I will revert this removal.


def parseVector(row):
array = np.array([float(x) for x in row.value.split(' ')])
return _convert_to_vector(array)


if __name__ == "__main__":

FEATURES_COL = "features"

if len(sys.argv) != 3:
print("Usage: kmeans_example.py <file> <k>", file=sys.stderr)
exit(-1)
path = sys.argv[1]
k = sys.argv[2]

spark = SparkSession\
.builder\
.appName("PythonKMeansExample")\
.getOrCreate()

lines = spark.read.text(path).rdd
data = lines.map(parseVector)
row_rdd = data.map(lambda x: Row(x))
schema = StructType([StructField(FEATURES_COL, VectorUDT(), False)])
df = spark.createDataFrame(row_rdd, schema)
# $example on$
# Loads data.
dataset = spark.read.format("libsvm").load("data/mllib/sample_kmeans_data.txt")

kmeans = KMeans().setK(2).setSeed(1).setFeaturesCol(FEATURES_COL)
model = kmeans.fit(df)
centers = model.clusterCenters()
# Trains a k-means model.
kmeans = KMeans().setK(2).setSeed(1)
model = kmeans.fit(dataset)

# Evaluate clustering by computing Within Set Sum of Squared Errors.
wssse = model.computeCost(dataset)
print("Within Set Sum of Squared Errors = " + str(wssse))

# Shows the result.
centers = model.clusterCenters()
print("Cluster Centers: ")
for center in centers:
print(center)
# $example off$

spark.stop()
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,11 @@ package org.apache.spark.examples.ml

// $example on$
import org.apache.spark.ml.clustering.KMeans
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.sql.{DataFrame, SparkSession}
// $example off$
import org.apache.spark.sql.SparkSession

/**
* An example demonstrating a k-means clustering.
* An example demonstrating k-means clustering.
* Run with
* {{{
* bin/run-example ml.KMeansExample
Expand All @@ -35,32 +34,26 @@ import org.apache.spark.sql.{DataFrame, SparkSession}
object KMeansExample {

def main(args: Array[String]): Unit = {
// Creates a Spark context and a SQL context
// Creates a SparkSession.
val spark = SparkSession
.builder
.appName(s"${this.getClass.getSimpleName}")
.getOrCreate()

// $example on$
// Crates a DataFrame
val dataset: DataFrame = spark.createDataFrame(Seq(
(1, Vectors.dense(0.0, 0.0, 0.0)),
(2, Vectors.dense(0.1, 0.1, 0.1)),
(3, Vectors.dense(0.2, 0.2, 0.2)),
(4, Vectors.dense(9.0, 9.0, 9.0)),
(5, Vectors.dense(9.1, 9.1, 9.1)),
(6, Vectors.dense(9.2, 9.2, 9.2))
)).toDF("id", "features")
// Loads data.
val dataset = spark.read.format("libsvm").load("data/mllib/sample_kmeans_data.txt")

// Trains a k-means model
val kmeans = new KMeans()
.setK(2)
.setFeaturesCol("features")
.setPredictionCol("prediction")
// Trains a k-means model.
val kmeans = new KMeans().setK(2).setSeed(1L)
val model = kmeans.fit(dataset)

// Shows the result
println("Final Centers: ")
// Evaluate clustering by computing Within Set Sum of Squared Errors.
val WSSSE = model.computeCost(dataset)
println(s"Within Set Sum of Squared Errors = $WSSSE")

// Shows the result.
println("Cluster Centers: ")
model.clusterCenters.foreach(println)
// $example off$

Expand Down