From ddba6577a6ece67f77b3757da859c88fa9065c04 Mon Sep 17 00:00:00 2001 From: William Benton Date: Mon, 19 Sep 2016 09:35:57 -0500 Subject: [PATCH 1/4] Use a bounded priority queue to find synonyms in Word2VecModel --- .../org/apache/spark/mllib/feature/Word2Vec.scala | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 42ca9665e5843..8d6ef8f1a69ab 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -35,6 +35,7 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd._ import org.apache.spark.sql.SparkSession +import org.apache.spark.util.BoundedPriorityQueue import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom @@ -555,7 +556,7 @@ class Word2VecModel private[spark] ( num: Int, wordOpt: Option[String]): Array[(String, Double)] = { require(num > 0, "Number of similar words should > 0") - // TODO: optimize top-k + val fVector = vector.toArray.map(_.toFloat) val cosineVec = Array.fill[Float](numWords)(0) val alpha: Float = 1 @@ -580,7 +581,15 @@ class Word2VecModel private[spark] ( ind += 1 } - val scored = wordList.zip(cosVec).toSeq.sortBy(-_._2) + val ord = new Ordering[(String, Double)] { + override def compare(x: (String, Double), y: (String, Double)): Int = x._2.compareTo(y._2) + } + + val pq = new BoundedPriorityQueue(num + 1)(ord) + + wordList.zip(cosVec).foreach(tup => pq += tup) + + val scored = pq.toSeq.sortBy(-_._2) val filtered = wordOpt match { case Some(w) => scored.take(num + 1).filter(tup => w != tup._1) From 93ebb94a7e4daef0c9582397fda9df3198192b71 Mon Sep 17 00:00:00 2001 From: William Benton Date: Mon, 19 Sep 2016 12:32:10 -0500 Subject: [PATCH 2/4] Stylistic cleanups from review --- .../scala/org/apache/spark/mllib/feature/Word2Vec.scala | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 8d6ef8f1a69ab..4d9d3fe0fe65d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -581,13 +581,9 @@ class Word2VecModel private[spark] ( ind += 1 } - val ord = new Ordering[(String, Double)] { - override def compare(x: (String, Double), y: (String, Double)): Int = x._2.compareTo(y._2) - } - - val pq = new BoundedPriorityQueue(num + 1)(ord) + val pq = new BoundedPriorityQueue[(String, Double)](num + 1)(Ordering.by(_._2)) - wordList.zip(cosVec).foreach(tup => pq += tup) + pq ++= wordList.zip(cosVec) val scored = pq.toSeq.sortBy(-_._2) From f7311a22d78b1875446e86aa53ad9f15892df7e2 Mon Sep 17 00:00:00 2001 From: William Benton Date: Mon, 19 Sep 2016 23:05:59 -0500 Subject: [PATCH 3/4] Removed redundant take from Word2VecModel.findSynonyms --- .../main/scala/org/apache/spark/mllib/feature/Word2Vec.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 4d9d3fe0fe65d..bf59e4d13d4cc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -588,7 +588,7 @@ class Word2VecModel private[spark] ( val scored = pq.toSeq.sortBy(-_._2) val filtered = wordOpt match { - case Some(w) => scored.take(num + 1).filter(tup => w != tup._1) + case Some(w) => scored.filter(tup => w != tup._1) case None => scored } From 4b235dc1eaf84c7f24cc6090153bd7b4f78da35d Mon Sep 17 00:00:00 2001 From: William Benton Date: Tue, 20 Sep 2016 10:40:54 -0500 Subject: [PATCH 4/4] Prefer explicit iteration over indices to zip; avoid allocating an array copy --- .../main/scala/org/apache/spark/mllib/feature/Word2Vec.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index bf59e4d13d4cc..2364d43aaa0e2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -583,7 +583,9 @@ class Word2VecModel private[spark] ( val pq = new BoundedPriorityQueue[(String, Double)](num + 1)(Ordering.by(_._2)) - pq ++= wordList.zip(cosVec) + for(i <- cosVec.indices) { + pq += Tuple2(wordList(i), cosVec(i)) + } val scored = pq.toSeq.sortBy(-_._2)