Skip to content

Commit d91cbe9

Browse files
committed
use datafile
1 parent 6e268b9 commit d91cbe9

File tree

4 files changed

+46
-91
lines changed

4 files changed

+46
-91
lines changed

docs/ml-clustering.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,11 @@ Refer to the [Java API docs](api/java/org/apache/spark/ml/clustering/KMeans.html
7979
{% include_example java/org/apache/spark/examples/ml/JavaKMeansExample.java %}
8080
</div>
8181

82+
<div data-lang="python" markdown="1">
83+
Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.clustering.KMeans) for more details.
84+
85+
{% include_example python/ml/kmeans_example.py %}
86+
</div>
8287
</div>
8388

8489

examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java

Lines changed: 13 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -17,77 +17,45 @@
1717

1818
package org.apache.spark.examples.ml;
1919

20-
import java.util.regex.Pattern;
21-
22-
import org.apache.spark.api.java.JavaRDD;
23-
import org.apache.spark.api.java.function.Function;
24-
import org.apache.spark.sql.Dataset;
25-
import org.apache.spark.sql.SparkSession;
26-
import org.apache.spark.sql.catalyst.expressions.GenericRow;
2720
// $example on$
2821
import org.apache.spark.ml.clustering.KMeansModel;
2922
import org.apache.spark.ml.clustering.KMeans;
3023
import org.apache.spark.mllib.linalg.Vector;
31-
import org.apache.spark.mllib.linalg.VectorUDT;
32-
import org.apache.spark.mllib.linalg.Vectors;
24+
import org.apache.spark.sql.Dataset;
3325
import org.apache.spark.sql.Row;
34-
import org.apache.spark.sql.types.Metadata;
35-
import org.apache.spark.sql.types.StructField;
36-
import org.apache.spark.sql.types.StructType;
3726
// $example off$
27+
import org.apache.spark.sql.SparkSession;
3828

3929

4030
/**
4131
* An example demonstrating a k-means clustering.
4232
* Run with
4333
* <pre>
44-
* bin/run-example ml.JavaKMeansExample <file> <k>
34+
* bin/run-example ml.JavaKMeansExample
4535
* </pre>
4636
*/
4737
public class JavaKMeansExample {
4838

49-
private static class ParsePoint implements Function<String, Row> {
50-
private static final Pattern separator = Pattern.compile(" ");
51-
52-
@Override
53-
public Row call(String line) {
54-
String[] tok = separator.split(line);
55-
double[] point = new double[tok.length];
56-
for (int i = 0; i < tok.length; ++i) {
57-
point[i] = Double.parseDouble(tok[i]);
58-
}
59-
Vector[] points = {Vectors.dense(point)};
60-
return new GenericRow(points);
61-
}
62-
}
63-
6439
public static void main(String[] args) {
65-
if (args.length != 2) {
66-
System.err.println("Usage: ml.JavaKMeansExample <file> <k>");
67-
System.exit(1);
68-
}
69-
String inputFile = args[0];
70-
int k = Integer.parseInt(args[1]);
71-
72-
// Parses the arguments
40+
// Create a SparkSession.
7341
SparkSession spark = SparkSession
7442
.builder()
7543
.appName("JavaKMeansExample")
7644
.getOrCreate();
7745

7846
// $example on$
79-
// Loads data
80-
JavaRDD<Row> points = spark.read().text(inputFile).javaRDD().map(new ParsePoint());
81-
StructField[] fields = {new StructField("features", new VectorUDT(), false, Metadata.empty())};
82-
StructType schema = new StructType(fields);
83-
Dataset<Row> dataset = spark.createDataFrame(points, schema);
47+
// Loads data.
48+
Dataset<Row> dataset = spark.read().format("libsvm").load("data/mllib/sample_kmeans_data.txt");
8449

85-
// Trains a k-means model
86-
KMeans kmeans = new KMeans()
87-
.setK(k);
50+
// Trains a k-means model.
51+
KMeans kmeans = new KMeans().setK(2).setSeed(1L);
8852
KMeansModel model = kmeans.fit(dataset);
8953

90-
// Shows the result
54+
// Evaluate clustering by computing Within Set Sum of Squared Errors.
55+
double WSSSE = model.computeCost(dataset);
56+
System.out.println("Within Set Sum of Squared Errors = " + WSSSE);
57+
58+
// Shows the result.
9159
Vector[] centers = model.clusterCenters();
9260
System.out.println("Cluster Centers: ");
9361
for (Vector center: centers) {

examples/src/main/python/ml/kmeans_example.py

Lines changed: 17 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -17,55 +17,43 @@
1717

1818
from __future__ import print_function
1919

20-
import sys
21-
22-
import numpy as np
20+
# $example on$
2321
from pyspark.ml.clustering import KMeans, KMeansModel
24-
from pyspark.mllib.linalg import VectorUDT, _convert_to_vector
22+
# $example off$
23+
2524
from pyspark.sql import SparkSession
26-
from pyspark.sql.types import Row, StructField, StructType
2725

2826
"""
2927
A simple example demonstrating a k-means clustering.
3028
Run with:
31-
bin/spark-submit examples/src/main/python/ml/kmeans_example.py <input> <k>
32-
33-
This example requires NumPy (http://www.numpy.org/).
29+
bin/spark-submit examples/src/main/python/ml/kmeans_example.py
3430
"""
3531

3632

37-
def parseVector(row):
38-
array = np.array([float(x) for x in row.value.split(' ')])
39-
return _convert_to_vector(array)
40-
41-
4233
if __name__ == "__main__":
4334

44-
FEATURES_COL = "features"
45-
46-
if len(sys.argv) != 3:
47-
print("Usage: kmeans_example.py <file> <k>", file=sys.stderr)
48-
exit(-1)
49-
path = sys.argv[1]
50-
k = sys.argv[2]
51-
5235
spark = SparkSession\
5336
.builder\
5437
.appName("PythonKMeansExample")\
5538
.getOrCreate()
5639

57-
lines = spark.read.text(path).rdd
58-
data = lines.map(parseVector)
59-
row_rdd = data.map(lambda x: Row(x))
60-
schema = StructType([StructField(FEATURES_COL, VectorUDT(), False)])
61-
df = spark.createDataFrame(row_rdd, schema)
40+
# $example on$
41+
# Loads data.
42+
dataset = spark.read.format("libsvm").load("data/mllib/sample_kmeans_data.txt")
6243

63-
kmeans = KMeans().setK(2).setSeed(1).setFeaturesCol(FEATURES_COL)
64-
model = kmeans.fit(df)
65-
centers = model.clusterCenters()
44+
# Trains a k-means model.
45+
kmeans = KMeans().setK(2).setSeed(1)
46+
model = kmeans.fit(dataset)
6647

48+
# Evaluate clustering by computing Within Set Sum of Squared Errors.
49+
wssse = model.computeCost(dataset)
50+
print("Within Set Sum of Squared Errors = " + str(wssse))
51+
52+
# Shows the result.
53+
centers = model.clusterCenters()
6754
print("Cluster Centers: ")
6855
for center in centers:
6956
print(center)
57+
# $example off$
7058

7159
spark.stop()

examples/src/main/scala/org/apache/spark/examples/ml/KMeansExample.scala

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -35,32 +35,26 @@ import org.apache.spark.sql.{DataFrame, SparkSession}
3535
object KMeansExample {
3636

3737
def main(args: Array[String]): Unit = {
38-
// Creates a Spark context and a SQL context
38+
// Creates a SparkSession.
3939
val spark = SparkSession
4040
.builder
4141
.appName(s"${this.getClass.getSimpleName}")
4242
.getOrCreate()
4343

4444
// $example on$
45-
// Crates a DataFrame
46-
val dataset: DataFrame = spark.createDataFrame(Seq(
47-
(1, Vectors.dense(0.0, 0.0, 0.0)),
48-
(2, Vectors.dense(0.1, 0.1, 0.1)),
49-
(3, Vectors.dense(0.2, 0.2, 0.2)),
50-
(4, Vectors.dense(9.0, 9.0, 9.0)),
51-
(5, Vectors.dense(9.1, 9.1, 9.1)),
52-
(6, Vectors.dense(9.2, 9.2, 9.2))
53-
)).toDF("id", "features")
45+
// Loads data.
46+
val dataset = spark.read.format("libsvm").load("data/mllib/sample_kmeans_data.txt")
5447

55-
// Trains a k-means model
56-
val kmeans = new KMeans()
57-
.setK(2)
58-
.setFeaturesCol("features")
59-
.setPredictionCol("prediction")
48+
// Trains a k-means model.
49+
val kmeans = new KMeans().setK(2).setSeed(1L)
6050
val model = kmeans.fit(dataset)
6151

62-
// Shows the result
63-
println("Final Centers: ")
52+
// Evaluate clustering by computing Within Set Sum of Squared Errors.
53+
val WSSSE = model.computeCost(dataset)
54+
println(s"Within Set Sum of Squared Errors = ${WSSSE}")
55+
56+
// Shows the result.
57+
println("Cluster Centers: ")
6458
model.clusterCenters.foreach(println)
6559
// $example off$
6660

0 commit comments

Comments
 (0)