Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -139,21 +139,21 @@ class Word2Vec extends Serializable with Logging {
/** context words from [-window, window] */
private val window = 5

private var trainWordsCount = 0
private var trainWordsCount = 0L
private var vocabSize = 0
@transient private var vocab: Array[VocabWord] = null
@transient private var vocabHash = mutable.HashMap.empty[String, Int]

private def learnVocab(words: RDD[String]): Unit = {
vocab = words.map(w => (w, 1))
.reduceByKey(_ + _)
.filter(_._2 >= minCount)
.map(x => VocabWord(
x._1,
x._2,
new Array[Int](MAX_CODE_LENGTH),
new Array[Int](MAX_CODE_LENGTH),
0))
.filter(_.cn >= minCount)
.collect()
.sortWith((a, b) => a.cn > b.cn)

Expand All @@ -164,7 +164,7 @@ class Word2Vec extends Serializable with Logging {
trainWordsCount += vocab(a).cn
a += 1
}
logInfo("trainWordsCount = " + trainWordsCount)
logInfo(s"vocabSize = $vocabSize, trainWordsCount = $trainWordsCount")
}

private def createExpTable(): Array[Float] = {
Expand Down Expand Up @@ -313,7 +313,7 @@ class Word2Vec extends Serializable with Logging {
val random = new XORShiftRandom(seed ^ ((idx + 1) << 16) ^ ((-k - 1) << 8))
val syn0Modify = new Array[Int](vocabSize)
val syn1Modify = new Array[Int](vocabSize)
val model = iter.foldLeft((syn0Global, syn1Global, 0, 0)) {
val model = iter.foldLeft((syn0Global, syn1Global, 0L, 0L)) {
case ((syn0, syn1, lastWordCount, wordCount), sentence) =>
var lwc = lastWordCount
var wc = wordCount
Expand Down