Skip to content

Commit 25cbbe6

Browse files
willbsrowen
authored andcommitted
[SPARK-17548][MLLIB] Word2VecModel.findSynonyms no longer spuriously rejects the best match when invoked with a vector
## What changes were proposed in this pull request? This pull request changes the behavior of `Word2VecModel.findSynonyms` so that it will not spuriously reject the best match when invoked with a vector that does not correspond to a word in the model's vocabulary. Instead of blindly discarding the best match, the changed implementation discards a match that corresponds to the query word (in cases where `findSynonyms` is invoked with a word) or that has an identical angle to the query vector. ## How was this patch tested? I added a test to `Word2VecSuite` to ensure that the word with the most similar vector from a supplied vector would not be spuriously rejected. Author: William Benton <willb@redhat.com> Closes #15105 from willb/fix/findSynonyms.
1 parent f15d41b commit 25cbbe6

File tree

5 files changed

+83
-24
lines changed

5 files changed

+83
-24
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -221,24 +221,26 @@ class Word2VecModel private[ml] (
221221
}
222222

223223
/**
224-
* Find "num" number of words closest in similarity to the given word.
225-
* Returns a dataframe with the words and the cosine similarities between the
226-
* synonyms and the given word.
224+
* Find "num" number of words closest in similarity to the given word, not
225+
* including the word itself. Returns a dataframe with the words and the
226+
* cosine similarities between the synonyms and the given word.
227227
*/
228228
@Since("1.5.0")
229229
def findSynonyms(word: String, num: Int): DataFrame = {
230-
findSynonyms(wordVectors.transform(word), num)
230+
val spark = SparkSession.builder().getOrCreate()
231+
spark.createDataFrame(wordVectors.findSynonyms(word, num)).toDF("word", "similarity")
231232
}
232233

233234
/**
234-
* Find "num" number of words closest to similarity to the given vector representation
235-
* of the word. Returns a dataframe with the words and the cosine similarities between the
236-
* synonyms and the given word vector.
235+
* Find "num" number of words whose vector representation most similar to the supplied vector.
236+
* If the supplied vector is the vector representation of a word in the model's vocabulary,
237+
* that word will be in the results. Returns a dataframe with the words and the cosine
238+
* similarities between the synonyms and the given word vector.
237239
*/
238240
@Since("2.0.0")
239-
def findSynonyms(word: Vector, num: Int): DataFrame = {
241+
def findSynonyms(vec: Vector, num: Int): DataFrame = {
240242
val spark = SparkSession.builder().getOrCreate()
241-
spark.createDataFrame(wordVectors.findSynonyms(word, num)).toDF("word", "similarity")
243+
spark.createDataFrame(wordVectors.findSynonyms(vec, num)).toDF("word", "similarity")
242244
}
243245

244246
/** @group setParam */

mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,18 +43,34 @@ private[python] class Word2VecModelWrapper(model: Word2VecModel) {
4343
rdd.rdd.map(model.transform)
4444
}
4545

46+
/**
47+
* Finds synonyms of a word; do not include the word itself in results.
48+
* @param word a word
49+
* @param num number of synonyms to find
50+
* @return a list consisting of a list of words and a vector of cosine similarities
51+
*/
4652
def findSynonyms(word: String, num: Int): JList[Object] = {
47-
val vec = transform(word)
48-
findSynonyms(vec, num)
53+
prepareResult(model.findSynonyms(word, num))
4954
}
5055

56+
/**
57+
* Finds words similar to the the vector representation of a word without
58+
* filtering results.
59+
* @param vector a vector
60+
* @param num number of synonyms to find
61+
* @return a list consisting of a list of words and a vector of cosine similarities
62+
*/
5163
def findSynonyms(vector: Vector, num: Int): JList[Object] = {
52-
val result = model.findSynonyms(vector, num)
64+
prepareResult(model.findSynonyms(vector, num))
65+
}
66+
67+
private def prepareResult(result: Array[(String, Double)]) = {
5368
val similarity = Vectors.dense(result.map(_._2))
5469
val words = result.map(_._1)
5570
List(words, similarity).map(_.asInstanceOf[Object]).asJava
5671
}
5772

73+
5874
def getVectors: JMap[String, JList[Float]] = {
5975
model.getVectors.map { case (k, v) =>
6076
(k, v.toList.asJava)

mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -518,25 +518,42 @@ class Word2VecModel private[spark] (
518518
}
519519

520520
/**
521-
* Find synonyms of a word
521+
* Find synonyms of a word; do not include the word itself in results.
522522
* @param word a word
523523
* @param num number of synonyms to find
524524
* @return array of (word, cosineSimilarity)
525525
*/
526526
@Since("1.1.0")
527527
def findSynonyms(word: String, num: Int): Array[(String, Double)] = {
528528
val vector = transform(word)
529-
findSynonyms(vector, num)
529+
findSynonyms(vector, num, Some(word))
530530
}
531531

532532
/**
533-
* Find synonyms of the vector representation of a word
533+
* Find synonyms of the vector representation of a word, possibly
534+
* including any words in the model vocabulary whose vector respresentation
535+
* is the supplied vector.
534536
* @param vector vector representation of a word
535537
* @param num number of synonyms to find
536538
* @return array of (word, cosineSimilarity)
537539
*/
538540
@Since("1.1.0")
539541
def findSynonyms(vector: Vector, num: Int): Array[(String, Double)] = {
542+
findSynonyms(vector, num, None)
543+
}
544+
545+
/**
546+
* Find synonyms of the vector representation of a word, rejecting
547+
* words identical to the value of wordOpt, if one is supplied.
548+
* @param vector vector representation of a word
549+
* @param num number of synonyms to find
550+
* @param wordOpt optionally, a word to reject from the results list
551+
* @return array of (word, cosineSimilarity)
552+
*/
553+
private def findSynonyms(
554+
vector: Vector,
555+
num: Int,
556+
wordOpt: Option[String]): Array[(String, Double)] = {
540557
require(num > 0, "Number of similar words should > 0")
541558
// TODO: optimize top-k
542559
val fVector = vector.toArray.map(_.toFloat)
@@ -563,12 +580,14 @@ class Word2VecModel private[spark] (
563580
ind += 1
564581
}
565582

566-
wordList.zip(cosVec)
567-
.toSeq
568-
.sortBy(-_._2)
569-
.take(num + 1)
570-
.tail
571-
.toArray
583+
val scored = wordList.zip(cosVec).toSeq.sortBy(-_._2)
584+
585+
val filtered = wordOpt match {
586+
case Some(w) => scored.take(num + 1).filter(tup => w != tup._1)
587+
case None => scored
588+
}
589+
590+
filtered.take(num).toArray
572591
}
573592

574593
/**

mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.mllib.feature
1919

2020
import org.apache.spark.SparkFunSuite
21+
import org.apache.spark.mllib.linalg.Vectors
2122
import org.apache.spark.mllib.util.MLlibTestSparkContext
2223
import org.apache.spark.util.Utils
2324

@@ -68,6 +69,21 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext {
6869
assert(syms(1)._1 == "japan")
6970
}
7071

72+
test("findSynonyms doesn't reject similar word vectors when called with a vector") {
73+
val num = 2
74+
val word2VecMap = Map(
75+
("china", Array(0.50f, 0.50f, 0.50f, 0.50f)),
76+
("japan", Array(0.40f, 0.50f, 0.50f, 0.50f)),
77+
("taiwan", Array(0.60f, 0.50f, 0.50f, 0.50f)),
78+
("korea", Array(0.45f, 0.60f, 0.60f, 0.60f))
79+
)
80+
val model = new Word2VecModel(word2VecMap)
81+
val syms = model.findSynonyms(Vectors.dense(Array(0.52, 0.5, 0.5, 0.5)), num)
82+
assert(syms.length == num)
83+
assert(syms(0)._1 == "china")
84+
assert(syms(1)._1 == "taiwan")
85+
}
86+
7187
test("model load / save") {
7288

7389
val word2VecMap = Map(

python/pyspark/mllib/feature.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -544,8 +544,7 @@ def load(cls, sc, path):
544544

545545
@ignore_unicode_prefix
546546
class Word2Vec(object):
547-
"""
548-
Word2Vec creates vector representation of words in a text corpus.
547+
"""Word2Vec creates vector representation of words in a text corpus.
549548
The algorithm first constructs a vocabulary from the corpus
550549
and then learns vector representation of words in the vocabulary.
551550
The vector representation can be used as features in
@@ -567,13 +566,19 @@ class Word2Vec(object):
567566
>>> doc = sc.parallelize(localDoc).map(lambda line: line.split(" "))
568567
>>> model = Word2Vec().setVectorSize(10).setSeed(42).fit(doc)
569568
569+
Querying for synonyms of a word will not return that word:
570+
570571
>>> syms = model.findSynonyms("a", 2)
571572
>>> [s[0] for s in syms]
572573
[u'b', u'c']
574+
575+
But querying for synonyms of a vector may return the word whose
576+
representation is that vector:
577+
573578
>>> vec = model.transform("a")
574579
>>> syms = model.findSynonyms(vec, 2)
575580
>>> [s[0] for s in syms]
576-
[u'b', u'c']
581+
[u'a', u'b']
577582
578583
>>> import os, tempfile
579584
>>> path = tempfile.mkdtemp()
@@ -591,6 +596,7 @@ class Word2Vec(object):
591596
... pass
592597
593598
.. versionadded:: 1.2.0
599+
594600
"""
595601
def __init__(self):
596602
"""

0 commit comments

Comments
 (0)