Skip to content

Commit 8beae59

Browse files
zhengruifengNick Pentreath
authored andcommitted
[SPARK-15149][EXAMPLE][DOC] update kmeans example
## What changes were proposed in this pull request? Python example for ml.kmeans already exists, but not included in user guide. 1,small changes like: `example_on` `example_off` 2,add it to user guide 3,update examples to directly read datafile ## How was this patch tested? manual tests `./bin/spark-submit examples/src/main/python/ml/kmeans_example.py Author: Zheng RuiFeng <ruifengz@foxmail.com> Closes #12925 from zhengruifeng/km_pe.
1 parent cef73b5 commit 8beae59

File tree

4 files changed

+50
-94
lines changed

4 files changed

+50
-94
lines changed

docs/ml-clustering.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,11 @@ Refer to the [Java API docs](api/java/org/apache/spark/ml/clustering/KMeans.html
7979
{% include_example java/org/apache/spark/examples/ml/JavaKMeansExample.java %}
8080
</div>
8181

82+
<div data-lang="python" markdown="1">
83+
Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.clustering.KMeans) for more details.
84+
85+
{% include_example python/ml/kmeans_example.py %}
86+
</div>
8287
</div>
8388

8489

examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java

Lines changed: 14 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -17,77 +17,45 @@
1717

1818
package org.apache.spark.examples.ml;
1919

20-
import java.util.regex.Pattern;
21-
22-
import org.apache.spark.api.java.JavaRDD;
23-
import org.apache.spark.api.java.function.Function;
24-
import org.apache.spark.sql.Dataset;
25-
import org.apache.spark.sql.SparkSession;
26-
import org.apache.spark.sql.catalyst.expressions.GenericRow;
2720
// $example on$
2821
import org.apache.spark.ml.clustering.KMeansModel;
2922
import org.apache.spark.ml.clustering.KMeans;
3023
import org.apache.spark.mllib.linalg.Vector;
31-
import org.apache.spark.mllib.linalg.VectorUDT;
32-
import org.apache.spark.mllib.linalg.Vectors;
24+
import org.apache.spark.sql.Dataset;
3325
import org.apache.spark.sql.Row;
34-
import org.apache.spark.sql.types.Metadata;
35-
import org.apache.spark.sql.types.StructField;
36-
import org.apache.spark.sql.types.StructType;
3726
// $example off$
27+
import org.apache.spark.sql.SparkSession;
3828

3929

4030
/**
41-
* An example demonstrating a k-means clustering.
31+
* An example demonstrating k-means clustering.
4232
* Run with
4333
* <pre>
44-
* bin/run-example ml.JavaKMeansExample <file> <k>
34+
* bin/run-example ml.JavaKMeansExample
4535
* </pre>
4636
*/
4737
public class JavaKMeansExample {
4838

49-
private static class ParsePoint implements Function<String, Row> {
50-
private static final Pattern separator = Pattern.compile(" ");
51-
52-
@Override
53-
public Row call(String line) {
54-
String[] tok = separator.split(line);
55-
double[] point = new double[tok.length];
56-
for (int i = 0; i < tok.length; ++i) {
57-
point[i] = Double.parseDouble(tok[i]);
58-
}
59-
Vector[] points = {Vectors.dense(point)};
60-
return new GenericRow(points);
61-
}
62-
}
63-
6439
public static void main(String[] args) {
65-
if (args.length != 2) {
66-
System.err.println("Usage: ml.JavaKMeansExample <file> <k>");
67-
System.exit(1);
68-
}
69-
String inputFile = args[0];
70-
int k = Integer.parseInt(args[1]);
71-
72-
// Parses the arguments
40+
// Create a SparkSession.
7341
SparkSession spark = SparkSession
7442
.builder()
7543
.appName("JavaKMeansExample")
7644
.getOrCreate();
7745

7846
// $example on$
79-
// Loads data
80-
JavaRDD<Row> points = spark.read().text(inputFile).javaRDD().map(new ParsePoint());
81-
StructField[] fields = {new StructField("features", new VectorUDT(), false, Metadata.empty())};
82-
StructType schema = new StructType(fields);
83-
Dataset<Row> dataset = spark.createDataFrame(points, schema);
47+
// Loads data.
48+
Dataset<Row> dataset = spark.read().format("libsvm").load("data/mllib/sample_kmeans_data.txt");
8449

85-
// Trains a k-means model
86-
KMeans kmeans = new KMeans()
87-
.setK(k);
50+
// Trains a k-means model.
51+
KMeans kmeans = new KMeans().setK(2).setSeed(1L);
8852
KMeansModel model = kmeans.fit(dataset);
8953

90-
// Shows the result
54+
// Evaluate clustering by computing Within Set Sum of Squared Errors.
55+
double WSSSE = model.computeCost(dataset);
56+
System.out.println("Within Set Sum of Squared Errors = " + WSSSE);
57+
58+
// Shows the result.
9159
Vector[] centers = model.clusterCenters();
9260
System.out.println("Cluster Centers: ");
9361
for (Vector center: centers) {

examples/src/main/python/ml/kmeans_example.py

Lines changed: 18 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -17,55 +17,45 @@
1717

1818
from __future__ import print_function
1919

20-
import sys
20+
# $example on$
21+
from pyspark.ml.clustering import KMeans
22+
# $example off$
2123

22-
import numpy as np
23-
from pyspark.ml.clustering import KMeans, KMeansModel
24-
from pyspark.mllib.linalg import VectorUDT, _convert_to_vector
2524
from pyspark.sql import SparkSession
26-
from pyspark.sql.types import Row, StructField, StructType
2725

2826
"""
29-
A simple example demonstrating a k-means clustering.
27+
An example demonstrating k-means clustering.
3028
Run with:
31-
bin/spark-submit examples/src/main/python/ml/kmeans_example.py <input> <k>
29+
bin/spark-submit examples/src/main/python/ml/kmeans_example.py
3230
3331
This example requires NumPy (http://www.numpy.org/).
3432
"""
3533

3634

37-
def parseVector(row):
38-
array = np.array([float(x) for x in row.value.split(' ')])
39-
return _convert_to_vector(array)
40-
41-
4235
if __name__ == "__main__":
4336

44-
FEATURES_COL = "features"
45-
46-
if len(sys.argv) != 3:
47-
print("Usage: kmeans_example.py <file> <k>", file=sys.stderr)
48-
exit(-1)
49-
path = sys.argv[1]
50-
k = sys.argv[2]
51-
5237
spark = SparkSession\
5338
.builder\
5439
.appName("PythonKMeansExample")\
5540
.getOrCreate()
5641

57-
lines = spark.read.text(path).rdd
58-
data = lines.map(parseVector)
59-
row_rdd = data.map(lambda x: Row(x))
60-
schema = StructType([StructField(FEATURES_COL, VectorUDT(), False)])
61-
df = spark.createDataFrame(row_rdd, schema)
42+
# $example on$
43+
# Loads data.
44+
dataset = spark.read.format("libsvm").load("data/mllib/sample_kmeans_data.txt")
6245

63-
kmeans = KMeans().setK(2).setSeed(1).setFeaturesCol(FEATURES_COL)
64-
model = kmeans.fit(df)
65-
centers = model.clusterCenters()
46+
# Trains a k-means model.
47+
kmeans = KMeans().setK(2).setSeed(1)
48+
model = kmeans.fit(dataset)
49+
50+
# Evaluate clustering by computing Within Set Sum of Squared Errors.
51+
wssse = model.computeCost(dataset)
52+
print("Within Set Sum of Squared Errors = " + str(wssse))
6653

54+
# Shows the result.
55+
centers = model.clusterCenters()
6756
print("Cluster Centers: ")
6857
for center in centers:
6958
print(center)
59+
# $example off$
7060

7161
spark.stop()

examples/src/main/scala/org/apache/spark/examples/ml/KMeansExample.scala

Lines changed: 13 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,11 @@ package org.apache.spark.examples.ml
2121

2222
// $example on$
2323
import org.apache.spark.ml.clustering.KMeans
24-
import org.apache.spark.mllib.linalg.Vectors
25-
import org.apache.spark.sql.{DataFrame, SparkSession}
2624
// $example off$
25+
import org.apache.spark.sql.SparkSession
2726

2827
/**
29-
* An example demonstrating a k-means clustering.
28+
* An example demonstrating k-means clustering.
3029
* Run with
3130
* {{{
3231
* bin/run-example ml.KMeansExample
@@ -35,32 +34,26 @@ import org.apache.spark.sql.{DataFrame, SparkSession}
3534
object KMeansExample {
3635

3736
def main(args: Array[String]): Unit = {
38-
// Creates a Spark context and a SQL context
37+
// Creates a SparkSession.
3938
val spark = SparkSession
4039
.builder
4140
.appName(s"${this.getClass.getSimpleName}")
4241
.getOrCreate()
4342

4443
// $example on$
45-
// Crates a DataFrame
46-
val dataset: DataFrame = spark.createDataFrame(Seq(
47-
(1, Vectors.dense(0.0, 0.0, 0.0)),
48-
(2, Vectors.dense(0.1, 0.1, 0.1)),
49-
(3, Vectors.dense(0.2, 0.2, 0.2)),
50-
(4, Vectors.dense(9.0, 9.0, 9.0)),
51-
(5, Vectors.dense(9.1, 9.1, 9.1)),
52-
(6, Vectors.dense(9.2, 9.2, 9.2))
53-
)).toDF("id", "features")
44+
// Loads data.
45+
val dataset = spark.read.format("libsvm").load("data/mllib/sample_kmeans_data.txt")
5446

55-
// Trains a k-means model
56-
val kmeans = new KMeans()
57-
.setK(2)
58-
.setFeaturesCol("features")
59-
.setPredictionCol("prediction")
47+
// Trains a k-means model.
48+
val kmeans = new KMeans().setK(2).setSeed(1L)
6049
val model = kmeans.fit(dataset)
6150

62-
// Shows the result
63-
println("Final Centers: ")
51+
// Evaluate clustering by computing Within Set Sum of Squared Errors.
52+
val WSSSE = model.computeCost(dataset)
53+
println(s"Within Set Sum of Squared Errors = $WSSSE")
54+
55+
// Shows the result.
56+
println("Cluster Centers: ")
6457
model.clusterCenters.foreach(println)
6558
// $example off$
6659

0 commit comments

Comments
 (0)