Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -221,24 +221,26 @@ class Word2VecModel private[ml] (
}

/**
* Find "num" number of words closest in similarity to the given word.
* Returns a dataframe with the words and the cosine similarities between the
* synonyms and the given word.
* Find "num" number of words closest in similarity to the given word, not
* including the word itself. Returns a dataframe with the words and the
* cosine similarities between the synonyms and the given word.
*/
@Since("1.5.0")
def findSynonyms(word: String, num: Int): DataFrame = {
findSynonyms(wordVectors.transform(word), num)
val spark = SparkSession.builder().getOrCreate()
spark.createDataFrame(wordVectors.findSynonyms(word, num)).toDF("word", "similarity")
}

/**
* Find "num" number of words closest to similarity to the given vector representation
* of the word. Returns a dataframe with the words and the cosine similarities between the
* synonyms and the given word vector.
* Find "num" number of words whose vector representation most similar to the supplied vector.
* If the supplied vector is the vector representation of a word in the model's vocabulary,
* that word will be in the results. Returns a dataframe with the words and the cosine
* similarities between the synonyms and the given word vector.
*/
@Since("2.0.0")
def findSynonyms(word: Vector, num: Int): DataFrame = {
def findSynonyms(vec: Vector, num: Int): DataFrame = {
val spark = SparkSession.builder().getOrCreate()
spark.createDataFrame(wordVectors.findSynonyms(word, num)).toDF("word", "similarity")
spark.createDataFrame(wordVectors.findSynonyms(vec, num)).toDF("word", "similarity")
}

/** @group setParam */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,18 +43,34 @@ private[python] class Word2VecModelWrapper(model: Word2VecModel) {
rdd.rdd.map(model.transform)
}

/**
* Finds synonyms of a word; do not include the word itself in results.
* @param word a word
* @param num number of synonyms to find
* @return a list consisting of a list of words and a vector of cosine similarities
*/
def findSynonyms(word: String, num: Int): JList[Object] = {
val vec = transform(word)
findSynonyms(vec, num)
prepareResult(model.findSynonyms(word, num))
}

/**
* Finds words similar to the the vector representation of a word without
* filtering results.
* @param vector a vector
* @param num number of synonyms to find
* @return a list consisting of a list of words and a vector of cosine similarities
*/
def findSynonyms(vector: Vector, num: Int): JList[Object] = {
val result = model.findSynonyms(vector, num)
prepareResult(model.findSynonyms(vector, num))
}

private def prepareResult(result: Array[(String, Double)]) = {
val similarity = Vectors.dense(result.map(_._2))
val words = result.map(_._1)
List(words, similarity).map(_.asInstanceOf[Object]).asJava
}


def getVectors: JMap[String, JList[Float]] = {
model.getVectors.map { case (k, v) =>
(k, v.toList.asJava)
Expand Down
37 changes: 28 additions & 9 deletions mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -518,25 +518,42 @@ class Word2VecModel private[spark] (
}

/**
* Find synonyms of a word
* Find synonyms of a word; do not include the word itself in results.
* @param word a word
* @param num number of synonyms to find
* @return array of (word, cosineSimilarity)
*/
@Since("1.1.0")
def findSynonyms(word: String, num: Int): Array[(String, Double)] = {
val vector = transform(word)
findSynonyms(vector, num)
findSynonyms(vector, num, Some(word))
}

/**
* Find synonyms of the vector representation of a word
* Find synonyms of the vector representation of a word, possibly
* including any words in the model vocabulary whose vector respresentation
* is the supplied vector.
* @param vector vector representation of a word
* @param num number of synonyms to find
* @return array of (word, cosineSimilarity)
*/
@Since("1.1.0")
def findSynonyms(vector: Vector, num: Int): Array[(String, Double)] = {
findSynonyms(vector, num, None)
}

/**
* Find synonyms of the vector representation of a word, rejecting
* words identical to the value of wordOpt, if one is supplied.
* @param vector vector representation of a word
* @param num number of synonyms to find
* @param wordOpt optionally, a word to reject from the results list
* @return array of (word, cosineSimilarity)
*/
private def findSynonyms(
vector: Vector,
num: Int,
wordOpt: Option[String]): Array[(String, Double)] = {
require(num > 0, "Number of similar words should > 0")
// TODO: optimize top-k
val fVector = vector.toArray.map(_.toFloat)
Expand All @@ -563,12 +580,14 @@ class Word2VecModel private[spark] (
ind += 1
}

wordList.zip(cosVec)
.toSeq
.sortBy(-_._2)
.take(num + 1)
.tail
.toArray
val scored = wordList.zip(cosVec).toSeq.sortBy(-_._2)

val filtered = wordOpt match {
case Some(w) => scored.take(num + 1).filter(tup => w != tup._1)
case None => scored
}

filtered.take(num).toArray
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.mllib.feature

import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.util.Utils

Expand Down Expand Up @@ -68,6 +69,21 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(syms(1)._1 == "japan")
}

test("findSynonyms doesn't reject similar word vectors when called with a vector") {
val num = 2
val word2VecMap = Map(
("china", Array(0.50f, 0.50f, 0.50f, 0.50f)),
("japan", Array(0.40f, 0.50f, 0.50f, 0.50f)),
("taiwan", Array(0.60f, 0.50f, 0.50f, 0.50f)),
("korea", Array(0.45f, 0.60f, 0.60f, 0.60f))
)
val model = new Word2VecModel(word2VecMap)
val syms = model.findSynonyms(Vectors.dense(Array(0.52, 0.5, 0.5, 0.5)), num)
assert(syms.length == num)
assert(syms(0)._1 == "china")
assert(syms(1)._1 == "taiwan")
}

test("model load / save") {

val word2VecMap = Map(
Expand Down
12 changes: 9 additions & 3 deletions python/pyspark/mllib/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,8 +544,7 @@ def load(cls, sc, path):

@ignore_unicode_prefix
class Word2Vec(object):
"""
Word2Vec creates vector representation of words in a text corpus.
"""Word2Vec creates vector representation of words in a text corpus.
The algorithm first constructs a vocabulary from the corpus
and then learns vector representation of words in the vocabulary.
The vector representation can be used as features in
Expand All @@ -567,13 +566,19 @@ class Word2Vec(object):
>>> doc = sc.parallelize(localDoc).map(lambda line: line.split(" "))
>>> model = Word2Vec().setVectorSize(10).setSeed(42).fit(doc)

Querying for synonyms of a word will not return that word:

>>> syms = model.findSynonyms("a", 2)
>>> [s[0] for s in syms]
[u'b', u'c']

But querying for synonyms of a vector may return the word whose
representation is that vector:

>>> vec = model.transform("a")
>>> syms = model.findSynonyms(vec, 2)
>>> [s[0] for s in syms]
[u'b', u'c']
[u'a', u'b']

>>> import os, tempfile
>>> path = tempfile.mkdtemp()
Expand All @@ -591,6 +596,7 @@ class Word2Vec(object):
... pass

.. versionadded:: 1.2.0

"""
def __init__(self):
"""
Expand Down