Skip to content

Commit 7bd2564

Browse files
hhbyyhjkbradley
authored andcommitted
[SPARK-12685][MLLIB][BACKPORT TO 1.4] word2vec trainWordsCount gets overflow
jira: https://issues.apache.org/jira/browse/SPARK-12685 master PR: #10627 the log of word2vec reports trainWordsCount = -785727483 during computation over a large dataset. Update the priority as it will affect the computation process. alpha = learningRate * (1 - numPartitions * wordCount.toDouble / (trainWordsCount + 1)) Author: Yuhao Yang <hhbyyh@gmail.com> Closes #10721 from hhbyyh/branch-1.4.
1 parent 0832530 commit 7bd2564

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -139,21 +139,21 @@ class Word2Vec extends Serializable with Logging {
139139
/** context words from [-window, window] */
140140
private val window = 5
141141

142-
private var trainWordsCount = 0
142+
private var trainWordsCount = 0L
143143
private var vocabSize = 0
144144
@transient private var vocab: Array[VocabWord] = null
145145
@transient private var vocabHash = mutable.HashMap.empty[String, Int]
146146

147147
private def learnVocab(words: RDD[String]): Unit = {
148148
vocab = words.map(w => (w, 1))
149149
.reduceByKey(_ + _)
150+
.filter(_._2 >= minCount)
150151
.map(x => VocabWord(
151152
x._1,
152153
x._2,
153154
new Array[Int](MAX_CODE_LENGTH),
154155
new Array[Int](MAX_CODE_LENGTH),
155156
0))
156-
.filter(_.cn >= minCount)
157157
.collect()
158158
.sortWith((a, b) => a.cn > b.cn)
159159

@@ -164,7 +164,7 @@ class Word2Vec extends Serializable with Logging {
164164
trainWordsCount += vocab(a).cn
165165
a += 1
166166
}
167-
logInfo("trainWordsCount = " + trainWordsCount)
167+
logInfo(s"vocabSize = $vocabSize, trainWordsCount = $trainWordsCount")
168168
}
169169

170170
private def createExpTable(): Array[Float] = {
@@ -313,7 +313,7 @@ class Word2Vec extends Serializable with Logging {
313313
val random = new XORShiftRandom(seed ^ ((idx + 1) << 16) ^ ((-k - 1) << 8))
314314
val syn0Modify = new Array[Int](vocabSize)
315315
val syn1Modify = new Array[Int](vocabSize)
316-
val model = iter.foldLeft((syn0Global, syn1Global, 0, 0)) {
316+
val model = iter.foldLeft((syn0Global, syn1Global, 0L, 0L)) {
317317
case ((syn0, syn1, lastWordCount, wordCount), sentence) =>
318318
var lwc = lastWordCount
319319
var wc = wordCount

0 commit comments

Comments
 (0)