Skip to content
This repository has been archived by the owner on Sep 20, 2022. It is now read-only.

Commit

Permalink
Cache instance variable
Browse files Browse the repository at this point in the history
  • Loading branch information
nzw0301 committed Sep 21, 2017
1 parent 8319861 commit ad2b291
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ private static Int2FloatOpenHashTable initSigmoidTable(final int maxSigmoid,
return sigmoidTable;
}

protected void initWordWeights(final int wordId){
protected void initWordWeights(final int wordId) {
for (int i = 0; i < dim; i++) {
inputWeights.put(wordId * dim + i, ((float) _rnd.nextDouble() - 0.5f) / dim);
}
Expand Down
36 changes: 21 additions & 15 deletions core/src/main/java/hivemall/unsupervised/SkipGramModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,38 +27,44 @@ protected SkipGramModel(final int dim, final float startingLR, final long numTra
super(dim, startingLR, numTrainWords);
}

protected void onlineTrain(final int inWord, final int posWord,
@Nonnull final int[] negWords) {
protected void onlineTrain(final int inWord, final int posWord, @Nonnull final int[] negWords) {

final int vecDim = dim;

updateLearningRate();

if (!inputWeights.containsKey(inWord * dim)) {
if (!inputWeights.containsKey(inWord * vecDim)) {
initWordWeights(inWord);
}
}

float[] gradVec = new float[dim];
float[] gradVec = new float[vecDim];

// positive words
float gradient = grad(1.f, inWord, posWord) * lr;
for (int i = 0; i < dim; i++) {
gradVec[i] += gradient * contextWeights.get(posWord * dim + i);
contextWeights.put(posWord * dim + i, gradient * inputWeights.get(inWord * dim + i)
+ contextWeights.get(posWord * dim + i));
for (int i = 0; i < vecDim; i++) {
gradVec[i] += gradient * contextWeights.get(posWord * vecDim + i);
contextWeights.put(
posWord * vecDim + i,
gradient * inputWeights.get(inWord * vecDim + i)
+ contextWeights.get(posWord * vecDim + i));
}

// negative words
for (int negWord : negWords) {
gradient = grad(0.f, inWord, negWord) * lr;
for (int i = 0; i < dim; i++) {
gradVec[i] += gradient * contextWeights.get(negWord * dim + i);
contextWeights.put(negWord * dim + i, gradient * inputWeights.get(inWord * dim + i)
+ contextWeights.get(negWord * dim + i));
for (int i = 0; i < vecDim; i++) {
gradVec[i] += gradient * contextWeights.get(negWord * vecDim + i);
contextWeights.put(
negWord * vecDim + i,
gradient * inputWeights.get(inWord * vecDim + i)
+ contextWeights.get(negWord * vecDim + i));
}
}

// update inWord vector
for (int i = 0; i < dim; i++) {
inputWeights.put(inWord * dim + i, gradVec[i] + inputWeights.get(inWord * dim + i));
for (int i = 0; i < vecDim; i++) {
inputWeights.put(inWord * vecDim + i,
gradVec[i] + inputWeights.get(inWord * vecDim + i));
}

wordCount++;
Expand Down
57 changes: 31 additions & 26 deletions core/src/main/java/hivemall/unsupervised/Word2vecFeatureUDTF.java
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
import javax.annotation.Nonnull;

public class Word2vecFeatureUDTF extends UDTFWithOptions {
private PRNG _rnd;
private PRNG rnd;

// skip-gram with negative sampling parameters
private int win;
Expand Down Expand Up @@ -101,7 +101,7 @@ public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgu
fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(PrimitiveObjectInspectorFactory.writableStringObjectInspector));

this.previousNegativeSamplerId = -1;
this._rnd = RandomNumberGeneratorFactory.createPRNG(1001);
this.rnd = RandomNumberGeneratorFactory.createPRNG(1001);

return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
}
Expand All @@ -111,9 +111,9 @@ protected Options getOptions() {
Options opts = new Options();
opts.addOption("win", "window", true, "Range for context word [default: 5]");
opts.addOption("neg", "negative", true,
"The number of negative sampled words per word [default: 5]");
"The number of negative sampled words per word [default: 5]");
opts.addOption("iter", "iteration", true,
"The number of skip-gram per word. It is equivalent to the epoch of word2vec [default: 5]");
"The number of skip-gram per word. It is equivalent to the epoch of word2vec [default: 5]");
return opts;
}

Expand Down Expand Up @@ -164,10 +164,17 @@ public void process(Object[] args) throws HiveException {
}

private void forwardSample(@Nonnull final List<?> doc) throws HiveException {
final int numNegative = neg;
final PRNG _rnd = rnd;
final PrimitiveObjectInspector _wordOI = wordOI;
final Int2FloatOpenHashTable S_ = S;
final String[] aliasIndex2Word_ = aliasIndex2Word;
final String[] aliasIndex2OtherWord_ = aliasIndex2OtherWord;

final Text inWord = new Text();
final Text posWord = new Text();
final Text[] negWords = new Text[neg];
for (int i = 0; i < neg; i++) {
final Text[] negWords = new Text[numNegative];
for (int i = 0; i < numNegative; i++) {
negWords[i] = new Text();
}

Expand All @@ -180,10 +187,10 @@ private void forwardSample(@Nonnull final List<?> doc) throws HiveException {
int docLength = doc.size();
for (int inputWordPosition = 0; inputWordPosition < docLength; inputWordPosition++) {
String inputWord = PrimitiveObjectInspectorUtils.getString(doc.get(inputWordPosition),
wordOI);
_wordOI);
inWord.set(inputWord);

for(int i = 0; i < iter; i++){
for (int i = 0; i < iter; i++) {
int windowSize = _rnd.nextInt(win) + 1;

for (int contextPosition = inputWordPosition - windowSize; contextPosition < inputWordPosition
Expand All @@ -196,12 +203,25 @@ private void forwardSample(@Nonnull final List<?> doc) throws HiveException {
continue;

String contextWord = PrimitiveObjectInspectorUtils.getString(
doc.get(contextPosition), wordOI);
doc.get(contextPosition), _wordOI);
posWord.set(contextWord);

for (int d = 0; d < neg; d++) {
negWords[d].set(negativeSample(contextWord));
// negative sampling
for (int d = 0; d < numNegative; d++) {
String sample;
do {
int k = _rnd.nextInt(S_.size());

if (S_.get(k) > _rnd.nextDouble()) {
sample = aliasIndex2Word_[k];
} else {
sample = aliasIndex2OtherWord_[k];
}
} while (sample.equals(contextWord));

negWords[d].set(sample);
}

forward(forwardObjs);
}
}
Expand Down Expand Up @@ -230,21 +250,6 @@ private void parseNegativeTable(Object listObj) {
this.aliasIndex2OtherWord = aliasIndex2OtherWord;
}


private String negativeSample(final String excludeWord) {
String sample;
do {
int k = _rnd.nextInt(S.size());

if (S.get(k) > _rnd.nextDouble()) {
sample = aliasIndex2Word[k];
} else {
sample = aliasIndex2OtherWord[k];
}
} while (sample.equals(excludeWord));
return sample;
}

@Override
public void close() throws HiveException {}
}
2 changes: 1 addition & 1 deletion docs/gitbook/unsupervised/word2vec.md
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ group by k
drop table skipgram_features;
create table skipgram_features as
select
skipgram(k, negative_table, words, "-win 5 -neg 15")
skipgram(k, negative_table, words, "-win 5 -neg 15 -iter 2")
from(
select
r.k,
Expand Down

0 comments on commit ad2b291

Please sign in to comment.