Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,21 @@ import org.apache.spark.ml.util._
import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.mllib.clustering.{BisectingKMeans => MLlibBisectingKMeans,
BisectingKMeansModel => MLlibBisectingKMeansModel}
import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors}
import org.apache.spark.mllib.linalg.VectorImplicits._
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.types.{IntegerType, StructType}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType}
import org.apache.spark.storage.StorageLevel


/**
* Common params for BisectingKMeans and BisectingKMeansModel
*/
private[clustering] trait BisectingKMeansParams extends Params with HasMaxIter
with HasFeaturesCol with HasSeed with HasPredictionCol with HasDistanceMeasure {
with HasFeaturesCol with HasSeed with HasPredictionCol with HasDistanceMeasure
with HasWeightCol {

/**
* The desired number of leaf clusters. Must be > 1. Default: 4.
Expand Down Expand Up @@ -261,31 +264,50 @@ class BisectingKMeans @Since("2.0.0") (
@Since("2.4.0")
def setDistanceMeasure(value: String): this.type = set(distanceMeasure, value)

/**
* Sets the value of param [[weightCol]].
* If this is not set or empty, we treat all instance weights as 1.0.
* Default is not set, so all instances have weight one.
*
* @group setParam
*/
@Since("3.0.0")
def setWeightCol(value: String): this.type = set(weightCol, value)

@Since("2.0.0")
override def fit(dataset: Dataset[_]): BisectingKMeansModel = instrumented { instr =>
transformSchema(dataset.schema, logging = true)

val handlePersistence = dataset.storageLevel == StorageLevel.NONE
val rdd = DatasetUtils.columnToOldVector(dataset, getFeaturesCol)
val w = if (isDefined(weightCol) && $(weightCol).nonEmpty) {
col($(weightCol)).cast(DoubleType)
} else {
lit(1.0)
}

val instances: RDD[(OldVector, Double)] = dataset
.select(DatasetUtils.columnToVector(dataset, getFeaturesCol), w).rdd.map {
case Row(point: Vector, weight: Double) => (OldVectors.fromML(point), weight)
}
if (handlePersistence) {
rdd.persist(StorageLevel.MEMORY_AND_DISK)
instances.persist(StorageLevel.MEMORY_AND_DISK)
}

instr.logPipelineStage(this)
instr.logDataset(dataset)
instr.logParams(this, featuresCol, predictionCol, k, maxIter, seed,
minDivisibleClusterSize, distanceMeasure)
minDivisibleClusterSize, distanceMeasure, weightCol)

val bkm = new MLlibBisectingKMeans()
.setK($(k))
.setMaxIterations($(maxIter))
.setMinDivisibleClusterSize($(minDivisibleClusterSize))
.setSeed($(seed))
.setDistanceMeasure($(distanceMeasure))
val parentModel = bkm.run(rdd, Some(instr))
val parentModel = bkm.runWithWeight(instances, Some(instr))
val model = copyValues(new BisectingKMeansModel(uid, parentModel).setParent(this))
if (handlePersistence) {
rdd.unpersist()
instances.unpersist()
}

val summary = new BisectingKMeansSummary(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.api.java.JavaRDD
import org.apache.spark.internal.Logging
import org.apache.spark.ml.util.Instrumentation
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.linalg.BLAS.axpy
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
Expand Down Expand Up @@ -152,24 +153,34 @@ class BisectingKMeans private (
this
}


private[spark] def run(
input: RDD[Vector],
instr: Option[Instrumentation]): BisectingKMeansModel = {
if (input.getStorageLevel == StorageLevel.NONE) {
logWarning(s"The input RDD ${input.id} is not directly cached, which may hurt performance if"
+ " its parent RDDs are also not cached.")
val instances: RDD[(Vector, Double)] = input.map {
case (point) => (point, 1.0)
}
val d = input.map(_.size).first()
runWithWeight(instances, None)
}

private[spark] def runWithWeight(
input: RDD[(Vector, Double)],
instr: Option[Instrumentation]): BisectingKMeansModel = {
val d = input.map(_._1.size).first
logInfo(s"Feature dimension: $d.")

val dMeasure: DistanceMeasure = DistanceMeasure.decodeFromString(this.distanceMeasure)
// Compute and cache vector norms for fast distance computation.
val norms = input.map(v => Vectors.norm(v, 2.0)).persist(StorageLevel.MEMORY_AND_DISK)
val vectors = input.zip(norms).map { case (x, norm) => new VectorWithNorm(x, norm) }
val norms = input.map(d => Vectors.norm(d._1, 2.0))
val vectors = input.zip(norms).map {
case ((x, weight), norm) => new VectorWithNorm(x, norm, weight)
}
if (input.getStorageLevel == StorageLevel.NONE) {
vectors.persist(StorageLevel.MEMORY_AND_DISK)
}
var assignments = vectors.map(v => (ROOT_INDEX, v))
var activeClusters = summarize(d, assignments, dMeasure)
instr.foreach(_.logNumExamples(activeClusters.values.map(_.size).sum))
instr.foreach(_.logSumOfWeights(activeClusters.values.map(_.weightSum).sum))
val rootSummary = activeClusters(ROOT_INDEX)
val n = rootSummary.size
logInfo(s"Number of points: $n.")
Expand Down Expand Up @@ -239,7 +250,7 @@ class BisectingKMeans private (
if (indices != null) {
indices.unpersist()
}
norms.unpersist()
vectors.unpersist()
val clusters = activeClusters ++ inactiveClusters
val root = buildTree(clusters, dMeasure)
val totalCost = root.leafNodes.map(_.cost).sum
Expand Down Expand Up @@ -312,31 +323,35 @@ private object BisectingKMeans extends Serializable {
private class ClusterSummaryAggregator(val d: Int, val distanceMeasure: DistanceMeasure)
extends Serializable {
private var n: Long = 0L
private var weightSum: Double = 0.0
private val sum: Vector = Vectors.zeros(d)
private var sumSq: Double = 0.0

/** Adds a point. */
def add(v: VectorWithNorm): this.type = {
n += 1L
weightSum += v.weight
// TODO: use a numerically stable approach to estimate cost
sumSq += v.norm * v.norm
sumSq += v.norm * v.norm * v.weight
distanceMeasure.updateClusterSum(v, sum)
this
}

/** Merges another aggregator. */
def merge(other: ClusterSummaryAggregator): this.type = {
n += other.n
weightSum += other.weightSum
sumSq += other.sumSq
distanceMeasure.updateClusterSum(new VectorWithNorm(other.sum), sum)
axpy(1.0, other.sum, sum)
this
}

/** Returns the summary. */
def summary: ClusterSummary = {
val center = distanceMeasure.centroid(sum.copy, n)
val cost = distanceMeasure.clusterCost(center, new VectorWithNorm(sum), n, sumSq)
ClusterSummary(n, center, cost)
val center = distanceMeasure.centroid(sum.copy, weightSum)
val cost = distanceMeasure.clusterCost(center, new VectorWithNorm(sum), weightSum,
sumSq)
ClusterSummary(n, weightSum, center, cost)
}
}

Expand Down Expand Up @@ -437,10 +452,15 @@ private object BisectingKMeans extends Serializable {
* Summary of a cluster.
*
* @param size the number of points within this cluster
* @param weightSum the weightSum within this cluster
* @param center the center of the points within this cluster
* @param cost the sum of squared distances to the center
*/
private case class ClusterSummary(size: Long, center: VectorWithNorm, cost: Double)
private case class ClusterSummary(
size: Long,
weightSum: Double,
center: VectorWithNorm,
cost: Double)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,28 +76,16 @@ private[spark] abstract class DistanceMeasure extends Serializable {
def clusterCost(
centroid: VectorWithNorm,
pointsSum: VectorWithNorm,
numberOfPoints: Long,
weightSum: Double,
pointsSquaredNorm: Double): Double

/**
* Updates the value of `sum` adding the `point` vector.
* @param point a `VectorWithNorm` to be added to `sum` of a cluster
* @param sum the `sum` for a cluster to be updated
*/
def updateClusterSum(point: VectorWithNorm, sum: Vector, weight: Double = 1.0): Unit = {
axpy(weight, point.vector, sum)
}

/**
* Returns a centroid for a cluster given its `sum` vector and its `count` of points.
*
* @param sum the `sum` for a cluster
* @param count the number of points in the cluster
* @return the centroid of the cluster
*/
def centroid(sum: Vector, count: Long): VectorWithNorm = {
scal(1.0 / count, sum)
new VectorWithNorm(sum)
def updateClusterSum(point: VectorWithNorm, sum: Vector): Unit = {
axpy(point.weight, point.vector, sum)
}

/**
Expand Down Expand Up @@ -217,9 +205,9 @@ private[spark] class EuclideanDistanceMeasure extends DistanceMeasure {
override def clusterCost(
centroid: VectorWithNorm,
pointsSum: VectorWithNorm,
numberOfPoints: Long,
weightSum: Double,
pointsSquaredNorm: Double): Double = {
math.max(pointsSquaredNorm - numberOfPoints * centroid.norm * centroid.norm, 0.0)
math.max(pointsSquaredNorm - weightSum * centroid.norm * centroid.norm, 0.0)
}

/**
Expand Down Expand Up @@ -261,20 +249,20 @@ private[spark] class CosineDistanceMeasure extends DistanceMeasure {
* @param point a `VectorWithNorm` to be added to `sum` of a cluster
* @param sum the `sum` for a cluster to be updated
*/
override def updateClusterSum(point: VectorWithNorm, sum: Vector, weight: Double = 1.0): Unit = {
override def updateClusterSum(point: VectorWithNorm, sum: Vector): Unit = {
assert(point.norm > 0, "Cosine distance is not defined for zero-length vectors.")
axpy(weight / point.norm, point.vector, sum)
axpy(point.weight / point.norm, point.vector, sum)
}

/**
* Returns a centroid for a cluster given its `sum` vector and its `count` of points.
*
* @param sum the `sum` for a cluster
* @param count the number of points in the cluster
* @param weightSum the sum of weight in the cluster
* @return the centroid of the cluster
*/
override def centroid(sum: Vector, count: Long): VectorWithNorm = {
scal(1.0 / count, sum)
override def centroid(sum: Vector, weightSum: Double): VectorWithNorm = {
scal(1.0 / weightSum, sum)
val norm = Vectors.norm(sum, 2)
scal(1.0 / norm, sum)
new VectorWithNorm(sum, 1)
Expand All @@ -286,10 +274,10 @@ private[spark] class CosineDistanceMeasure extends DistanceMeasure {
override def clusterCost(
centroid: VectorWithNorm,
pointsSum: VectorWithNorm,
numberOfPoints: Long,
weightSum: Double,
pointsSquaredNorm: Double): Double = {
val costVector = pointsSum.vector.copy
math.max(numberOfPoints - dot(centroid.vector, costVector) / centroid.norm, 0.0)
math.max(weightSum - dot(centroid.vector, costVector) / centroid.norm, 0.0)
}

/**
Expand Down
28 changes: 14 additions & 14 deletions mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ class KMeans private (
}

val zippedData = data.zip(norms).map { case ((v, w), norm) =>
(new VectorWithNorm(v, norm), w)
new VectorWithNorm(v, norm, w)
}

if (data.getStorageLevel == StorageLevel.NONE) {
Expand All @@ -241,7 +241,7 @@ class KMeans private (
* Implementation of K-Means algorithm.
*/
private def runAlgorithmWithWeight(
data: RDD[(VectorWithNorm, Double)],
data: RDD[VectorWithNorm],
instr: Option[Instrumentation]): KMeansModel = {

val sc = data.sparkContext
Expand All @@ -250,16 +250,14 @@ class KMeans private (

val distanceMeasureInstance = DistanceMeasure.decodeFromString(this.distanceMeasure)

val dataVectorWithNorm = data.map(d => d._1)

val centers = initialModel match {
case Some(kMeansCenters) =>
kMeansCenters.clusterCenters.map(new VectorWithNorm(_))
case None =>
if (initializationMode == KMeans.RANDOM) {
initRandom(dataVectorWithNorm)
initRandom(data)
} else {
initKMeansParallel(dataVectorWithNorm, distanceMeasureInstance)
initKMeansParallel(data, distanceMeasureInstance)
}
}
val initTimeInSeconds = (System.nanoTime() - initStartTime) / 1e9
Expand All @@ -279,7 +277,7 @@ class KMeans private (
val bcCenters = sc.broadcast(centers)

// Find the new centers
val collected = data.mapPartitions { pointsAndWeights =>
val collected = data.mapPartitions { points =>
val thisCenters = bcCenters.value
val dims = thisCenters.head.vector.size

Expand All @@ -290,11 +288,11 @@ class KMeans private (
// sample1 * weight1/clusterWeightSum + sample2 * weight2/clusterWeightSum + ...
val clusterWeightSum = Array.ofDim[Double](thisCenters.length)

pointsAndWeights.foreach { case (point, weight) =>
points.foreach { point =>
val (bestCenter, cost) = distanceMeasureInstance.findClosest(thisCenters, point)
costAccum.add(cost * weight)
distanceMeasureInstance.updateClusterSum(point, sums(bestCenter), weight)
clusterWeightSum(bestCenter) += weight
costAccum.add(cost * point.weight)
distanceMeasureInstance.updateClusterSum(point, sums(bestCenter))
clusterWeightSum(bestCenter) += point.weight
}

clusterWeightSum.indices.filter(clusterWeightSum(_) > 0)
Expand Down Expand Up @@ -511,13 +509,15 @@ object KMeans {
/**
* A vector with its norm for fast distance computation.
*/
private[clustering] class VectorWithNorm(val vector: Vector, val norm: Double)
extends Serializable {
private[clustering] class VectorWithNorm(
val vector: Vector,
val norm: Double,
val weight: Double = 1.0) extends Serializable {

def this(vector: Vector) = this(vector, Vectors.norm(vector, 2.0))

def this(array: Array[Double]) = this(Vectors.dense(array))

/** Converts the vector to a dense vector. */
def toDense: VectorWithNorm = new VectorWithNorm(Vectors.dense(vector.toArray), norm)
def toDense: VectorWithNorm = new VectorWithNorm(Vectors.dense(vector.toArray), norm, weight)
}
Loading