From ad2b2911b5a6ebb7b43b4981bf5ff4424425a292 Mon Sep 17 00:00:00 2001 From: Kento NOZAWA Date: Thu, 21 Sep 2017 22:05:45 +0900 Subject: [PATCH] Cache instance variable --- .../unsupervised/AbstractWord2vecModel.java | 2 +- .../hivemall/unsupervised/SkipGramModel.java | 36 +++++++----- .../unsupervised/Word2vecFeatureUDTF.java | 57 ++++++++++--------- docs/gitbook/unsupervised/word2vec.md | 2 +- 4 files changed, 54 insertions(+), 43 deletions(-) diff --git a/core/src/main/java/hivemall/unsupervised/AbstractWord2vecModel.java b/core/src/main/java/hivemall/unsupervised/AbstractWord2vecModel.java index 04a9951c2..a4c6e33ae 100644 --- a/core/src/main/java/hivemall/unsupervised/AbstractWord2vecModel.java +++ b/core/src/main/java/hivemall/unsupervised/AbstractWord2vecModel.java @@ -73,7 +73,7 @@ private static Int2FloatOpenHashTable initSigmoidTable(final int maxSigmoid, return sigmoidTable; } - protected void initWordWeights(final int wordId){ + protected void initWordWeights(final int wordId) { for (int i = 0; i < dim; i++) { inputWeights.put(wordId * dim + i, ((float) _rnd.nextDouble() - 0.5f) / dim); } diff --git a/core/src/main/java/hivemall/unsupervised/SkipGramModel.java b/core/src/main/java/hivemall/unsupervised/SkipGramModel.java index 109fd878e..b0a385296 100644 --- a/core/src/main/java/hivemall/unsupervised/SkipGramModel.java +++ b/core/src/main/java/hivemall/unsupervised/SkipGramModel.java @@ -27,38 +27,44 @@ protected SkipGramModel(final int dim, final float startingLR, final long numTra super(dim, startingLR, numTrainWords); } - protected void onlineTrain(final int inWord, final int posWord, - @Nonnull final int[] negWords) { + protected void onlineTrain(final int inWord, final int posWord, @Nonnull final int[] negWords) { + + final int vecDim = dim; updateLearningRate(); - if (!inputWeights.containsKey(inWord * dim)) { + if (!inputWeights.containsKey(inWord * vecDim)) { initWordWeights(inWord); - } + } - float[] gradVec = new float[dim]; + float[] gradVec = new float[vecDim]; // positive words float gradient = grad(1.f, inWord, posWord) * lr; - for (int i = 0; i < dim; i++) { - gradVec[i] += gradient * contextWeights.get(posWord * dim + i); - contextWeights.put(posWord * dim + i, gradient * inputWeights.get(inWord * dim + i) - + contextWeights.get(posWord * dim + i)); + for (int i = 0; i < vecDim; i++) { + gradVec[i] += gradient * contextWeights.get(posWord * vecDim + i); + contextWeights.put( + posWord * vecDim + i, + gradient * inputWeights.get(inWord * vecDim + i) + + contextWeights.get(posWord * vecDim + i)); } // negative words for (int negWord : negWords) { gradient = grad(0.f, inWord, negWord) * lr; - for (int i = 0; i < dim; i++) { - gradVec[i] += gradient * contextWeights.get(negWord * dim + i); - contextWeights.put(negWord * dim + i, gradient * inputWeights.get(inWord * dim + i) - + contextWeights.get(negWord * dim + i)); + for (int i = 0; i < vecDim; i++) { + gradVec[i] += gradient * contextWeights.get(negWord * vecDim + i); + contextWeights.put( + negWord * vecDim + i, + gradient * inputWeights.get(inWord * vecDim + i) + + contextWeights.get(negWord * vecDim + i)); } } // update inWord vector - for (int i = 0; i < dim; i++) { - inputWeights.put(inWord * dim + i, gradVec[i] + inputWeights.get(inWord * dim + i)); + for (int i = 0; i < vecDim; i++) { + inputWeights.put(inWord * vecDim + i, + gradVec[i] + inputWeights.get(inWord * vecDim + i)); } wordCount++; diff --git a/core/src/main/java/hivemall/unsupervised/Word2vecFeatureUDTF.java b/core/src/main/java/hivemall/unsupervised/Word2vecFeatureUDTF.java index 5cc0e155c..b827230a2 100644 --- a/core/src/main/java/hivemall/unsupervised/Word2vecFeatureUDTF.java +++ b/core/src/main/java/hivemall/unsupervised/Word2vecFeatureUDTF.java @@ -46,7 +46,7 @@ import javax.annotation.Nonnull; public class Word2vecFeatureUDTF extends UDTFWithOptions { - private PRNG _rnd; + private PRNG rnd; // skip-gram with negative sampling parameters private int win; @@ -101,7 +101,7 @@ public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgu fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableStringObjectInspector)); this.previousNegativeSamplerId = -1; - this._rnd = RandomNumberGeneratorFactory.createPRNG(1001); + this.rnd = RandomNumberGeneratorFactory.createPRNG(1001); return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs); } @@ -111,9 +111,9 @@ protected Options getOptions() { Options opts = new Options(); opts.addOption("win", "window", true, "Range for context word [default: 5]"); opts.addOption("neg", "negative", true, - "The number of negative sampled words per word [default: 5]"); + "The number of negative sampled words per word [default: 5]"); opts.addOption("iter", "iteration", true, - "The number of skip-gram per word. It is equivalent to the epoch of word2vec [default: 5]"); + "The number of skip-gram per word. It is equivalent to the epoch of word2vec [default: 5]"); return opts; } @@ -164,10 +164,17 @@ public void process(Object[] args) throws HiveException { } private void forwardSample(@Nonnull final List doc) throws HiveException { + final int numNegative = neg; + final PRNG _rnd = rnd; + final PrimitiveObjectInspector _wordOI = wordOI; + final Int2FloatOpenHashTable S_ = S; + final String[] aliasIndex2Word_ = aliasIndex2Word; + final String[] aliasIndex2OtherWord_ = aliasIndex2OtherWord; + final Text inWord = new Text(); final Text posWord = new Text(); - final Text[] negWords = new Text[neg]; - for (int i = 0; i < neg; i++) { + final Text[] negWords = new Text[numNegative]; + for (int i = 0; i < numNegative; i++) { negWords[i] = new Text(); } @@ -180,10 +187,10 @@ private void forwardSample(@Nonnull final List doc) throws HiveException { int docLength = doc.size(); for (int inputWordPosition = 0; inputWordPosition < docLength; inputWordPosition++) { String inputWord = PrimitiveObjectInspectorUtils.getString(doc.get(inputWordPosition), - wordOI); + _wordOI); inWord.set(inputWord); - for(int i = 0; i < iter; i++){ + for (int i = 0; i < iter; i++) { int windowSize = _rnd.nextInt(win) + 1; for (int contextPosition = inputWordPosition - windowSize; contextPosition < inputWordPosition @@ -196,12 +203,25 @@ private void forwardSample(@Nonnull final List doc) throws HiveException { continue; String contextWord = PrimitiveObjectInspectorUtils.getString( - doc.get(contextPosition), wordOI); + doc.get(contextPosition), _wordOI); posWord.set(contextWord); - for (int d = 0; d < neg; d++) { - negWords[d].set(negativeSample(contextWord)); + // negative sampling + for (int d = 0; d < numNegative; d++) { + String sample; + do { + int k = _rnd.nextInt(S_.size()); + + if (S_.get(k) > _rnd.nextDouble()) { + sample = aliasIndex2Word_[k]; + } else { + sample = aliasIndex2OtherWord_[k]; + } + } while (sample.equals(contextWord)); + + negWords[d].set(sample); } + forward(forwardObjs); } } @@ -230,21 +250,6 @@ private void parseNegativeTable(Object listObj) { this.aliasIndex2OtherWord = aliasIndex2OtherWord; } - - private String negativeSample(final String excludeWord) { - String sample; - do { - int k = _rnd.nextInt(S.size()); - - if (S.get(k) > _rnd.nextDouble()) { - sample = aliasIndex2Word[k]; - } else { - sample = aliasIndex2OtherWord[k]; - } - } while (sample.equals(excludeWord)); - return sample; - } - @Override public void close() throws HiveException {} } diff --git a/docs/gitbook/unsupervised/word2vec.md b/docs/gitbook/unsupervised/word2vec.md index 02285e2e8..8cfb8129b 100644 --- a/docs/gitbook/unsupervised/word2vec.md +++ b/docs/gitbook/unsupervised/word2vec.md @@ -200,7 +200,7 @@ group by k drop table skipgram_features; create table skipgram_features as select - skipgram(k, negative_table, words, "-win 5 -neg 15") + skipgram(k, negative_table, words, "-win 5 -neg 15 -iter 2") from( select r.k,