@@ -24,6 +24,7 @@ import org.apache.spark.rdd.RDD
2424import org .apache .spark .mllib .linalg .{Matrices , Vector , Vectors , DenseVector , DenseMatrix , BLAS }
2525import org .apache .spark .mllib .stat .impl .MultivariateGaussian
2626import org .apache .spark .mllib .util .MLUtils
27+ import org .apache .spark .util .Utils
2728
2829/**
2930 * This class performs expectation maximization for multivariate Gaussian
@@ -45,10 +46,11 @@ import org.apache.spark.mllib.util.MLUtils
4546class GaussianMixtureEM private (
4647 private var k : Int ,
4748 private var convergenceTol : Double ,
48- private var maxIterations : Int ) extends Serializable {
49+ private var maxIterations : Int ,
50+ private var seed : Long ) extends Serializable {
4951
5052 /** A default instance, 2 Gaussians, 100 iterations, 0.01 log-likelihood threshold */
51- def this () = this (2 , 0.01 , 100 )
53+ def this () = this (2 , 0.01 , 100 , Utils .random.nextLong() )
5254
5355 // number of samples per cluster to use when initializing Gaussians
5456 private val nSamples = 5
@@ -100,11 +102,21 @@ class GaussianMixtureEM private (
100102 this
101103 }
102104
103- /** Return the largest change in log-likelihood at which convergence is
104- * considered to have occurred.
105+ /**
106+ * Return the largest change in log-likelihood at which convergence is
107+ * considered to have occurred.
105108 */
106109 def getConvergenceTol : Double = convergenceTol
107-
110+
111+ /** Set the random seed */
112+ def setSeed (seed : Long ): this .type = {
113+ this .seed = seed
114+ this
115+ }
116+
117+ /** Return the random seed */
118+ def getSeed : Long = seed
119+
108120 /** Perform expectation maximization */
109121 def run (data : RDD [Vector ]): GaussianMixtureModel = {
110122 val sc = data.sparkContext
@@ -113,7 +125,7 @@ class GaussianMixtureEM private (
113125 val breezeData = data.map(u => u.toBreeze.toDenseVector).cache()
114126
115127 // Get length of the input vectors
116- val d = breezeData.first.length
128+ val d = breezeData.first() .length
117129
118130 // Determine initial weights and corresponding Gaussians.
119131 // If the user supplied an initial GMM, we use those values, otherwise
@@ -126,7 +138,7 @@ class GaussianMixtureEM private (
126138 })
127139
128140 case None => {
129- val samples = breezeData.takeSample(true , k * nSamples, scala.util. Random .nextInt )
141+ val samples = breezeData.takeSample(withReplacement = true , k * nSamples, seed )
130142 (Array .fill(k)(1.0 / k), Array .tabulate(k) { i =>
131143 val slice = samples.view(i * nSamples, (i + 1 ) * nSamples)
132144 new MultivariateGaussian (vectorMean(slice), initCovariance(slice))
0 commit comments