Skip to content
This repository has been archived by the owner on Sep 20, 2022. It is now read-only.

[WIP][HIVEMALL-118] word2vec #116

Open
wants to merge 44 commits into
base: master
Choose a base branch
from
Open

[WIP][HIVEMALL-118] word2vec #116

wants to merge 44 commits into from

Conversation

nzw0301
Copy link
Member

@nzw0301 nzw0301 commented Sep 21, 2017

What changes were proposed in this pull request?

Add new algorithm: skip-gram with negative sampling (a.k.a word2vec)

What type of PR is it?

Feature

What is the Jira issue?

https://issues.apache.org/jira/browse/HIVEMALL-118

How was this patch tested?

manual tests on EMR

To train word2vec, I used wikipedia dataset, preprocessed by this perl script.

I evaluated word vector by https://github.com/kudkudak/word-embeddings-benchmarks .

from six import iteritems
from web.datasets.similarity import fetch_MEN, fetch_WS353, fetch_SimLex999, fetch_RW, fetch_RG65, fetch_MTurk
from web.datasets.analogy import fetch_msr_analogy, fetch_google_analogy, fetch_semeval_2012_2, fetch_wordrep
import gensim
from gensim.models.word2vec import Word2Vec, LineSentence

from web.embeddings import load_embedding
from web.evaluate import evaluate_similarity, evaluate_analogy

sim_tasks = {
    "MEN      ": fetch_MEN(),
    "WS353    ": fetch_WS353(),
    "SIMLEX999": fetch_SimLex999(),
    "RW       ": fetch_RW(),
    "RG       ": fetch_RG65(),
    "MTurk    ": fetch_MTurk()
}
analogy_tasks = {
    "google": fetch_google_analogy(),
    "msr   ": fetch_msr_analogy()
}

docs = LineSentence('PATH/TO/PREPROCESSED_DATA')
model = Word2Vec(docs, size=100, window=5, min_count=15, workers=8, negative=15, hs=0, sg=1, iter=1)
model.wv.save_word2vec_format('./gensim_sg.txt')

gensim = load_embedding('./gensim_sg.txt', 'word2vec')

for name, data in iteritems(sim_tasks):
    print("Spearman correlation of scores on {} {}".format(name, evaluate_similarity(gensim, data.X, data.y)))
Spearman correlation of scores on MEN       0.6483416401993833
Spearman correlation of scores on WS353     0.6169418277184877
Spearman correlation of scores on SIMLEX999 0.3070155939988943
Spearman correlation of scores on RW        0.28548732030155277
Spearman correlation of scores on RG        0.6762247194247315
Spearman correlation of scores on MTurk     0.6471504497920156

CBoW model of hivemall

Spearman correlation of scores on MEN       0.6247965194705783
Spearman correlation of scores on WS353     0.6225747519511903
Spearman correlation of scores on SIMLEX999 0.2985588069793148
Spearman correlation of scores on RW        0.27686018664704454
Spearman correlation of scores on RG        0.6528832630934683
Spearman correlation of scores on MTurk     0.6307218892934624

Skip-gram of hivemall

Spearman correlation of scores on MEN       0.6202415358400425
Spearman correlation of scores on WS353     0.6235303875587551
Spearman correlation of scores on SIMLEX999 0.2983910352562464
Spearman correlation of scores on RW        0.2939926699533969
Spearman correlation of scores on RG        0.6782791172666216
Spearman correlation of scores on MTurk     0.6344295663665642

CBoW of hivemall when the number of reducer for training is 4

Spearman correlation of scores on MEN       0.5392243963768429
Spearman correlation of scores on WS353     0.546436543545682
Spearman correlation of scores on SIMLEX999 0.2529742987287988
Spearman correlation of scores on RW        0.28611116074856136
Spearman correlation of scores on RG        0.5040049854449996
Spearman correlation of scores on MTurk     0.5953150581437371

How to use this feature?

please see word2vec.md

Checklist

  • Did you apply source code formatter, i.e., mvn formatter:format, for your commit?

@coveralls
Copy link

coveralls commented Sep 21, 2017

Coverage Status

Coverage decreased (-0.7%) to 40.18% when pulling 7014e85 on nzw0301:skipgram into c2b9578 on apache:master.

@coveralls
Copy link

coveralls commented Sep 21, 2017

Coverage Status

Coverage decreased (-0.7%) to 40.185% when pulling 7a5fd54 on nzw0301:skipgram into c2b9578 on apache:master.

@coveralls
Copy link

coveralls commented Sep 21, 2017

Coverage Status

Coverage decreased (-0.7%) to 40.179% when pulling c224912 on nzw0301:skipgram into c2b9578 on apache:master.

@coveralls
Copy link

coveralls commented Sep 21, 2017

Coverage Status

Coverage decreased (-0.8%) to 40.165% when pulling 8319861 on nzw0301:skipgram into c2b9578 on apache:master.

@coveralls
Copy link

coveralls commented Sep 21, 2017

Coverage Status

Coverage decreased (-0.8%) to 40.156% when pulling ad2b291 on nzw0301:skipgram into c2b9578 on apache:master.

@coveralls
Copy link

coveralls commented Sep 21, 2017

Coverage Status

Coverage decreased (-0.8%) to 40.149% when pulling 39d1123 on nzw0301:skipgram into c2b9578 on apache:master.

@coveralls
Copy link

coveralls commented Sep 22, 2017

Coverage Status

Coverage decreased (-0.8%) to 40.138% when pulling e507561 on nzw0301:skipgram into c2b9578 on apache:master.

* specific language governing permissions and limitations
* under the License.
*/
package hivemall.unsupervised;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please move package from hivemall.unsupervised to hivemall.embedding.

import javax.annotation.Nonnegative;
import javax.annotation.Nonnull;

public abstract class AbstractWord2vecModel {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please rename Word2vec to Word2Vec as seen in spark.

@coveralls
Copy link

coveralls commented Sep 22, 2017

Coverage Status

Coverage decreased (-0.8%) to 40.137% when pulling bf5d927 on nzw0301:skipgram into c2b9578 on apache:master.

@coveralls
Copy link

coveralls commented Sep 22, 2017

Coverage Status

Coverage decreased (-0.8%) to 40.144% when pulling a3ccaa8 on nzw0301:skipgram into c2b9578 on apache:master.

@coveralls
Copy link

coveralls commented Sep 22, 2017

Coverage Status

Coverage decreased (-0.8%) to 40.147% when pulling c7cba82 on nzw0301:skipgram into c2b9578 on apache:master.

@coveralls
Copy link

coveralls commented Sep 22, 2017

Coverage Status

Coverage decreased (-0.8%) to 40.077% when pulling bbdb561 on nzw0301:skipgram into c2b9578 on apache:master.

@coveralls
Copy link

coveralls commented Sep 25, 2017

Coverage Status

Coverage decreased (-0.9%) to 40.057% when pulling 7a2f4db on nzw0301:skipgram into c2b9578 on apache:master.

@coveralls
Copy link

coveralls commented Sep 25, 2017

Coverage Status

Coverage decreased (-0.9%) to 40.029% when pulling e094552 on nzw0301:skipgram into c2b9578 on apache:master.

@coveralls
Copy link

coveralls commented Sep 25, 2017

Coverage Status

Coverage decreased (-0.9%) to 40.025% when pulling aede5ec on nzw0301:skipgram into c2b9578 on apache:master.

}
}

protected static float sigmoid(final float v, final int MAX_SIGMOID,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to use Constants for argument: final int MAX_SIGMOID, final int SIGMOID_TABLE_SIZE


public abstract class AbstractWord2VecModel {
// cached sigmoid function parameters
protected final int MAX_SIGMOID = 6;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Constants should be static final.

@coveralls
Copy link

coveralls commented Sep 25, 2017

Coverage Status

Coverage decreased (-0.9%) to 40.028% when pulling 4abdb8f on nzw0301:skipgram into c2b9578 on apache:master.

@coveralls
Copy link

coveralls commented Sep 25, 2017

Coverage Status

Coverage decreased (-0.9%) to 39.993% when pulling 4abdb8f on nzw0301:skipgram into c2b9578 on apache:master.

@coveralls
Copy link

coveralls commented Sep 26, 2017

Coverage Status

Coverage decreased (-0.9%) to 40.05% when pulling 8a42adf on nzw0301:skipgram into c2b9578 on apache:master.

@coveralls
Copy link

coveralls commented Sep 26, 2017

Coverage Status

Coverage decreased (-0.9%) to 40.05% when pulling d1b4270 on nzw0301:skipgram into c2b9578 on apache:master.

@coveralls
Copy link

coveralls commented Sep 26, 2017

Coverage Status

Coverage decreased (-0.9%) to 40.049% when pulling f19d732 on nzw0301:skipgram into c2b9578 on apache:master.

@coveralls
Copy link

coveralls commented Sep 26, 2017

Coverage Status

Coverage decreased (-0.8%) to 40.149% when pulling 2b66e5e on nzw0301:skipgram into c2b9578 on apache:master.

@coveralls
Copy link

coveralls commented Sep 26, 2017

Coverage Status

Coverage decreased (-0.8%) to 40.108% when pulling f0abd4f on nzw0301:skipgram into c2b9578 on apache:master.

@coveralls
Copy link

coveralls commented Sep 26, 2017

Coverage Status

Coverage decreased (-0.9%) to 40.068% when pulling c340038 on nzw0301:skipgram into c2b9578 on apache:master.

@coveralls
Copy link

coveralls commented Sep 27, 2017

Coverage Status

Coverage decreased (-0.9%) to 40.065% when pulling af5b5be on nzw0301:skipgram into c2b9578 on apache:master.

@coveralls
Copy link

coveralls commented Sep 27, 2017

Coverage Status

Coverage decreased (-0.9%) to 40.065% when pulling d12ba32 on nzw0301:skipgram into c2b9578 on apache:master.

@coveralls
Copy link

coveralls commented Sep 27, 2017

Coverage Status

Coverage decreased (-0.9%) to 40.069% when pulling 2415589 on nzw0301:skipgram into c2b9578 on apache:master.

@coveralls
Copy link

coveralls commented Sep 27, 2017

Coverage Status

Coverage decreased (-0.9%) to 40.056% when pulling da564b8 on nzw0301:skipgram into c2b9578 on apache:master.

protected static final int SIGMOID_TABLE_SIZE = 1000;
protected float[] sigmoidTable;


Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove unnecessary blank line


@Nonnegative
protected int dim;
protected int win;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Nonnegative for each variable (win, neg, iter).

protected Int2FloatOpenHashTable S;
protected int[] aliasWordId;

protected AbstractWord2VecModel(final int dim, final int win, final int neg, final int iter,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add @Nonnegative for each constructor argument and caller methods.

}
}

protected static float sigmoid(final float v, final float[] sigmoidTable) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Nonnull for sigmoidTable

}

protected void updateLearningRate() {
// TODO: valid lr?
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this TODO comment and blank lines.

import java.util.List;

public final class CBoWModel extends AbstractWord2VecModel {
protected CBoWModel(final int dim, final int win, final int neg, final int iter,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a blank line before constructor.


updateLearningRate();

int docLength = doc.length;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

final int docLength

opts.addOption("win", "window", true, "Context window size [default: 5]");
opts.addOption("neg", "negative", true,
"The number of negative sampled words per word [default: 5]");
opts.addOption("iter", "iteration", true, "The number of iterations [default: 5]");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consistent naming "iters", "iterations" as seen in SLIM.

opts.addOption("model", "modelName", true,
"The model name of word2vec: skipgram or cbow [default: skipgram]");
opts.addOption(
"lr",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consistent naming eta0, learningRate for the initial learning rate.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see.
Does longOpt remain learningRate or remove this field?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remain learningRate for longOpt and use eta0 for initialLearningRate.

}

modelName = cl.getOptionValue("model", modelName);
if (!(modelName.equals("skipgram") || modelName.equals("cbow"))) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"skipgram".equals(modelName) is null safe.

@myui
Copy link
Member

myui commented Sep 28, 2017

What type of PR is it? => Improvement should be Feature.

@nzw0301
Copy link
Member Author

nzw0301 commented Sep 28, 2017

@myui I resolved conflicts.

@coveralls
Copy link

Coverage Status

Coverage decreased (-0.6%) to 40.505% when pulling 8696f5f on nzw0301:skipgram into 1e42387 on apache:master.

@coveralls
Copy link

coveralls commented Sep 28, 2017

Coverage Status

Coverage decreased (-0.6%) to 40.508% when pulling 0b163fa on nzw0301:skipgram into 1e42387 on apache:master.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants