diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index 278d61d916735..ac85fbc235c93 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -223,12 +223,12 @@ class KMeans private ( // Compute squared norms and cache them. val norms = data.map(Vectors.norm(_, 2.0)) - norms.persist() val zippedData = data.zip(norms).map { case (v, norm) => new VectorWithNorm(v, norm) } + zippedData.persist() val model = runAlgorithm(zippedData, instr) - norms.unpersist() + zippedData.unpersist() // Warn at the end of the run as well, for increased visibility. if (data.getStorageLevel == StorageLevel.NONE) {