Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ import org.apache.spark.ml.linalg.{Vector, VectorUDT}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util._
import org.apache.spark.mllib.clustering.{BisectingKMeans => MLlibBisectingKMeans, BisectingKMeansModel => MLlibBisectingKMeansModel}
import org.apache.spark.mllib.clustering.{BisectingKMeans => MLlibBisectingKMeans,
BisectingKMeansModel => MLlibBisectingKMeansModel}
import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors}
import org.apache.spark.mllib.linalg.VectorImplicits._
import org.apache.spark.rdd.RDD
Expand All @@ -38,8 +39,8 @@ import org.apache.spark.sql.types.{IntegerType, StructType}
/**
* Common params for BisectingKMeans and BisectingKMeansModel
*/
private[clustering] trait BisectingKMeansParams extends Params
with HasMaxIter with HasFeaturesCol with HasSeed with HasPredictionCol {
private[clustering] trait BisectingKMeansParams extends Params with HasMaxIter
with HasFeaturesCol with HasSeed with HasPredictionCol with HasDistanceMeasure {

/**
* The desired number of leaf clusters. Must be > 1. Default: 4.
Expand Down Expand Up @@ -104,6 +105,10 @@ class BisectingKMeansModel private[ml] (
@Since("2.1.0")
def setPredictionCol(value: String): this.type = set(predictionCol, value)

/** @group expertSetParam */
@Since("2.4.0")
def setDistanceMeasure(value: String): this.type = set(distanceMeasure, value)

@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
Expand Down Expand Up @@ -248,6 +253,10 @@ class BisectingKMeans @Since("2.0.0") (
@Since("2.0.0")
def setMinDivisibleClusterSize(value: Double): this.type = set(minDivisibleClusterSize, value)

/** @group expertSetParam */
@Since("2.4.0")
def setDistanceMeasure(value: String): this.type = set(distanceMeasure, value)

@Since("2.0.0")
override def fit(dataset: Dataset[_]): BisectingKMeansModel = {
transformSchema(dataset.schema, logging = true)
Expand All @@ -263,6 +272,7 @@ class BisectingKMeans @Since("2.0.0") (
.setMaxIterations($(maxIter))
.setMinDivisibleClusterSize($(minDivisibleClusterSize))
.setSeed($(seed))
.setDistanceMeasure($(distanceMeasure))
val parentModel = bkm.run(rdd)
val model = copyValues(new BisectingKMeansModel(uid, parentModel).setParent(this))
val summary = new BisectingKMeansSummary(
Expand Down
11 changes: 1 addition & 10 deletions mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ import org.apache.spark.util.VersionUtils.majorVersion
* Common params for KMeans and KMeansModel
*/
private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFeaturesCol
with HasSeed with HasPredictionCol with HasTol {
with HasSeed with HasPredictionCol with HasTol with HasDistanceMeasure {

/**
* The number of clusters to create (k). Must be > 1. Note that it is possible for fewer than
Expand Down Expand Up @@ -71,15 +71,6 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe
@Since("1.5.0")
def getInitMode: String = $(initMode)

@Since("2.4.0")
final val distanceMeasure = new Param[String](this, "distanceMeasure", "The distance measure. " +
"Supported options: 'euclidean' and 'cosine'.",
(value: String) => MLlibKMeans.validateDistanceMeasure(value))

/** @group expertGetParam */
@Since("2.4.0")
def getDistanceMeasure: String = $(distanceMeasure)

/**
* Param for the number of steps for the k-means|| initialization mode. This is an advanced
* setting -- the default of 2 is almost always enough. Must be > 0. Default: 2.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,11 @@ private[shared] object SharedParamsCodeGen {
"after fitting. If set to true, then all sub-models will be available. Warning: For " +
"large models, collecting all sub-models can cause OOMs on the Spark driver",
Some("false"), isExpertParam = true),
ParamDesc[String]("loss", "the loss function to be optimized", finalFields = false)
ParamDesc[String]("loss", "the loss function to be optimized", finalFields = false),
ParamDesc[String]("distanceMeasure", "The distance measure. Supported options: 'euclidean'" +
" and 'cosine'", Some("org.apache.spark.mllib.clustering.DistanceMeasure.EUCLIDEAN"),
isValid = "(value: String) => " +
"org.apache.spark.mllib.clustering.DistanceMeasure.validateDistanceMeasure(value)")
)

val code = genSharedParams(params)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -504,4 +504,23 @@ trait HasLoss extends Params {
/** @group getParam */
final def getLoss: String = $(loss)
}

/**
* Trait for shared param distanceMeasure (default: org.apache.spark.mllib.clustering.DistanceMeasure.EUCLIDEAN). This trait may be changed or
* removed between minor versions.
*/
@DeveloperApi
trait HasDistanceMeasure extends Params {

/**
* Param for The distance measure. Supported options: 'euclidean' and 'cosine'.
* @group param
*/
final val distanceMeasure: Param[String] = new Param[String](this, "distanceMeasure", "The distance measure. Supported options: 'euclidean' and 'cosine'", (value: String) => org.apache.spark.mllib.clustering.DistanceMeasure.validateDistanceMeasure(value))

setDefault(distanceMeasure, org.apache.spark.mllib.clustering.DistanceMeasure.EUCLIDEAN)

/** @group getParam */
final def getDistanceMeasure: String = $(distanceMeasure)
}
// scalastyle:on
Loading