Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions examples/src/main/python/mllib/gaussian_mixture_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,5 +62,9 @@ def parseVector(line):
for i in range(args.k):
print(("weight = ", model.weights[i], "mu = ", model.gaussians[i].mu,
"sigma = ", model.gaussians[i].sigma.toArray()))
print("\n")
print(("The membership value of each vector to all mixture components (first 100): ",
model.predictSoft(data).take(100)))
print("\n")
print(("Cluster labels (first 100): ", model.predict(data).take(100)))
sc.stop()
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,12 @@ object DenseGaussianMixture {
(clusters.weights(i), clusters.gaussians(i).mu, clusters.gaussians(i).sigma))
}

println("The membership value of each vector to all mixture components (first <= 100):")
val membership = clusters.predictSoft(data)
membership.take(100).foreach { x =>
print(" " + x.mkString(","))
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd add a println after this for spacing.

println()
println("Cluster labels (first <= 100):")
val clusterLabels = clusters.predict(data)
clusterLabels.take(100).foreach { x =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,5 +40,9 @@ private[python] class GaussianMixtureModelWrapper(model: GaussianMixtureModel) {
SerDe.dumps(JavaConverters.seqAsJavaListConverter(modelGaussians).asJava)
}

def predictSoft(point: Vector): Vector = {
Vectors.dense(model.predictSoft(point))
}

def save(sc: SparkContext, path: String): Unit = model.save(sc, path)
}
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class GaussianMixtureModel @Since("1.3.0") (
*/
@Since("1.5.0")
def predict(point: Vector): Int = {
val r = computeSoftAssignments(point.toBreeze.toDenseVector, gaussians, weights, k)
val r = predictSoft(point)
r.indexOf(r.max)
}

Expand Down
35 changes: 22 additions & 13 deletions python/pyspark/mllib/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,16 +202,25 @@ class GaussianMixtureModel(JavaModelWrapper, JavaSaveable, JavaLoader):

>>> clusterdata_1 = sc.parallelize(array([-0.1,-0.05,-0.01,-0.1,
... 0.9,0.8,0.75,0.935,
... -0.83,-0.68,-0.91,-0.76 ]).reshape(6, 2))
... -0.83,-0.68,-0.91,-0.76 ]).reshape(6, 2), 2)
>>> model = GaussianMixture.train(clusterdata_1, 3, convergenceTol=0.0001,
... maxIterations=50, seed=10)
>>> labels = model.predict(clusterdata_1).collect()
>>> labels[0]==labels[1]
False
>>> labels[1]==labels[2]
True
False
>>> labels[4]==labels[5]
True
>>> model.predict([-0.1,-0.05])
0
>>> softPredicted = model.predictSoft([-0.1,-0.05])
>>> abs(softPredicted[0] - 1.0) < 0.001
True
>>> abs(softPredicted[1] - 0.0) < 0.001
True
>>> abs(softPredicted[2] - 0.0) < 0.001
True

>>> path = tempfile.mkdtemp()
>>> model.save(sc, path)
Expand Down Expand Up @@ -277,35 +286,35 @@ def k(self):
@since('1.3.0')
def predict(self, x):
"""
Find the cluster to which the points in 'x' has maximum membership
in this model.
Find the cluster to which the point 'x' or each point in RDD 'x'
has maximum membership in this model.

:param x: RDD of data points.
:return: cluster_labels. RDD of cluster labels.
:param x: vector or RDD of vector represents data points.
:return: cluster label or RDD of cluster labels.
"""
if isinstance(x, RDD):
cluster_labels = self.predictSoft(x).map(lambda z: z.index(max(z)))
return cluster_labels
else:
raise TypeError("x should be represented by an RDD, "
"but got %s." % type(x))
z = self.predictSoft(x)
return z.argmax()

@since('1.3.0')
def predictSoft(self, x):
"""
Find the membership of each point in 'x' to all mixture components.
Find the membership of point 'x' or each point in RDD 'x' to all mixture components.

:param x: RDD of data points.
:return: membership_matrix. RDD of array of double values.
:param x: vector or RDD of vector represents data points.
:return: the membership value to all mixture components for vector 'x'
or each vector in RDD 'x'.
"""
if isinstance(x, RDD):
means, sigmas = zip(*[(g.mu, g.sigma) for g in self.gaussians])
membership_matrix = callMLlibFunc("predictSoftGMM", x.map(_convert_to_vector),
_convert_to_vector(self.weights), means, sigmas)
return membership_matrix.map(lambda x: pyarray.array('d', x))
else:
raise TypeError("x should be represented by an RDD, "
"but got %s." % type(x))
return self.call("predictSoft", _convert_to_vector(x)).toArray()

@classmethod
@since('1.5.0')
Expand Down