Skip to content

Commit ba52582

Browse files
committed
Add another overload, fix train args style, remove internal call to deprecated method
1 parent 84fb22f commit ba52582

File tree

1 file changed

+30
-6
lines changed
  • mllib/src/main/scala/org/apache/spark/mllib/clustering

1 file changed

+30
-6
lines changed

mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -420,18 +420,40 @@ object KMeans {
420420
* on system time.
421421
*/
422422
@Since("2.1.0")
423-
def train(data: RDD[Vector],
424-
k: Int,
425-
maxIterations: Int,
426-
initializationMode: String,
427-
seed: Long): KMeansModel = {
423+
def train(
424+
data: RDD[Vector],
425+
k: Int,
426+
maxIterations: Int,
427+
initializationMode: String,
428+
seed: Long): KMeansModel = {
428429
new KMeans().setK(k)
429430
.setMaxIterations(maxIterations)
430431
.setInitializationMode(initializationMode)
431432
.setSeed(seed)
432433
.run(data)
433434
}
434435

436+
/**
437+
* Trains a k-means model using the given set of parameters.
438+
*
439+
* @param data Training points as an `RDD` of `Vector` types.
440+
* @param k Number of clusters to create.
441+
* @param maxIterations Maximum number of iterations allowed.
442+
* @param initializationMode The initialization algorithm. This can either be "random" or
443+
* "k-means||". (default: "k-means||")
444+
*/
445+
@Since("2.1.0")
446+
def train(
447+
data: RDD[Vector],
448+
k: Int,
449+
maxIterations: Int,
450+
initializationMode: String): KMeansModel = {
451+
new KMeans().setK(k)
452+
.setMaxIterations(maxIterations)
453+
.setInitializationMode(initializationMode)
454+
.run(data)
455+
}
456+
435457
/**
436458
* Trains a k-means model using the given set of parameters.
437459
*
@@ -492,7 +514,9 @@ object KMeans {
492514
data: RDD[Vector],
493515
k: Int,
494516
maxIterations: Int): KMeansModel = {
495-
train(data, k, maxIterations, 1, K_MEANS_PARALLEL)
517+
new KMeans().setK(k)
518+
.setMaxIterations(maxIterations)
519+
.run(data)
496520
}
497521

498522
/**

0 commit comments

Comments
 (0)