Skip to content

Commit f0df4fd

Browse files
committed
Added seed parameter to GMM. Updated test suite to use seed to prevent flakiness
1 parent e9ca16e commit f0df4fd

File tree

2 files changed

+27
-13
lines changed

2 files changed

+27
-13
lines changed

mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import org.apache.spark.rdd.RDD
2424
import org.apache.spark.mllib.linalg.{Matrices, Vector, Vectors, DenseVector, DenseMatrix, BLAS}
2525
import org.apache.spark.mllib.stat.impl.MultivariateGaussian
2626
import org.apache.spark.mllib.util.MLUtils
27+
import org.apache.spark.util.Utils
2728

2829
/**
2930
* This class performs expectation maximization for multivariate Gaussian
@@ -45,10 +46,11 @@ import org.apache.spark.mllib.util.MLUtils
4546
class GaussianMixtureEM private (
4647
private var k: Int,
4748
private var convergenceTol: Double,
48-
private var maxIterations: Int) extends Serializable {
49+
private var maxIterations: Int,
50+
private var seed: Long) extends Serializable {
4951

5052
/** A default instance, 2 Gaussians, 100 iterations, 0.01 log-likelihood threshold */
51-
def this() = this(2, 0.01, 100)
53+
def this() = this(2, 0.01, 100, Utils.random.nextLong())
5254

5355
// number of samples per cluster to use when initializing Gaussians
5456
private val nSamples = 5
@@ -100,11 +102,21 @@ class GaussianMixtureEM private (
100102
this
101103
}
102104

103-
/** Return the largest change in log-likelihood at which convergence is
104-
* considered to have occurred.
105+
/**
106+
* Return the largest change in log-likelihood at which convergence is
107+
* considered to have occurred.
105108
*/
106109
def getConvergenceTol: Double = convergenceTol
107-
110+
111+
/** Set the random seed */
112+
def setSeed(seed: Long): this.type = {
113+
this.seed = seed
114+
this
115+
}
116+
117+
/** Return the random seed */
118+
def getSeed: Long = seed
119+
108120
/** Perform expectation maximization */
109121
def run(data: RDD[Vector]): GaussianMixtureModel = {
110122
val sc = data.sparkContext
@@ -113,7 +125,7 @@ class GaussianMixtureEM private (
113125
val breezeData = data.map(u => u.toBreeze.toDenseVector).cache()
114126

115127
// Get length of the input vectors
116-
val d = breezeData.first.length
128+
val d = breezeData.first().length
117129

118130
// Determine initial weights and corresponding Gaussians.
119131
// If the user supplied an initial GMM, we use those values, otherwise
@@ -126,7 +138,7 @@ class GaussianMixtureEM private (
126138
})
127139

128140
case None => {
129-
val samples = breezeData.takeSample(true, k * nSamples, scala.util.Random.nextInt)
141+
val samples = breezeData.takeSample(withReplacement = true, k * nSamples, seed)
130142
(Array.fill(k)(1.0 / k), Array.tabulate(k) { i =>
131143
val slice = samples.view(i * nSamples, (i + 1) * nSamples)
132144
new MultivariateGaussian(vectorMean(slice), initCovariance(slice))

mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,14 @@ class GMMExpectationMaximizationSuite extends FunSuite with MLlibTestSparkContex
3535
val Ew = 1.0
3636
val Emu = Vectors.dense(5.0, 10.0)
3737
val Esigma = Matrices.dense(2, 2, Array(2.0 / 3.0, -2.0 / 3.0, -2.0 / 3.0, 2.0 / 3.0))
38-
39-
val gmm = new GaussianMixtureEM().setK(1).run(data)
40-
41-
assert(gmm.weight(0) ~== Ew absTol 1E-5)
42-
assert(gmm.mu(0) ~== Emu absTol 1E-5)
43-
assert(gmm.sigma(0) ~== Esigma absTol 1E-5)
38+
39+
val seeds = Array(314589, 29032897, 50181, 494821, 4660)
40+
seeds.foreach { seed =>
41+
val gmm = new GaussianMixtureEM().setK(1).setSeed(seed).run(data)
42+
assert(gmm.weight(0) ~== Ew absTol 1E-5)
43+
assert(gmm.mu(0) ~== Emu absTol 1E-5)
44+
assert(gmm.sigma(0) ~== Esigma absTol 1E-5)
45+
}
4446
}
4547

4648
test("two clusters") {

0 commit comments

Comments
 (0)