-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Adding unsupervised FastText to Gensim #1525
Changes from 1 commit
a815c84
102c14a
4c449df
f49df54
1fcb8fa
82fda3c
cd59034
353f7a8
569a026
c228b8d
29c627f
acbfdf2
cb7a2ad
5a18297
4b98722
cf1f3e0
d986242
bce17ff
cb84001
fbe8bdc
58c673a
893ef76
e12f6c0
39d14bd
d3ec5a8
0854622
904882a
a9e7d03
ec58512
2ed7d31
daace4a
58c531b
3ffa103
2b0583b
55d731a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,244 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
|
||
import logging | ||
|
||
from types import GeneratorType | ||
from copy import deepcopy | ||
from six import string_types | ||
import numpy as np | ||
from numpy import dot, zeros, ones, vstack, outer, random, sum as np_sum, empty, float32 as REAL | ||
from scipy.special import expit | ||
|
||
from gensim.utils import call_on_class_only | ||
from gensim.models.word2vec import Word2Vec | ||
from gensim.models.wrappers.fasttext import FastTextKeyedVectors | ||
from gensim.models.wrappers.fasttext import FastText as Ft_Wrapper | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
MAX_WORDS_IN_BATCH = 10000 | ||
|
||
|
||
def train_batch_cbow(model, sentences, alpha, work=None, neu1=None): | ||
result = 0 | ||
for sentence in sentences: | ||
word_vocabs = [model.wv.vocab[w] for w in sentence if w in model.wv.vocab and | ||
model.wv.vocab[w].sample_int > model.random.rand() * 2**32] | ||
for pos, word in enumerate(word_vocabs): | ||
reduced_window = model.random.randint(model.window) # `b` in the original word2vec code | ||
start = max(0, pos - model.window + reduced_window) | ||
window_pos = enumerate(word_vocabs[start:(pos + model.window + 1 - reduced_window)], start) | ||
word2_indices = [word2.index for pos2, word2 in window_pos if (word2 is not None and pos2 != pos)] | ||
|
||
word2_subwords = [] | ||
|
||
for indices in word2_indices: | ||
word2_subwords += ['<' + model.wv.index2word[indices] + '>'] | ||
word2_subwords += Ft_Wrapper.compute_ngrams(model.wv.index2word[indices], model.min_n, model.max_n) | ||
word2_subwords = list(set(word2_subwords)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I thought we changed this to no longer be a set. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's correct. I have pushed those changes now. |
||
|
||
subwords_indices = [] | ||
for subword in word2_subwords: | ||
subwords_indices.append(model.wv.ngrams[subword]) | ||
|
||
l1 = np_sum(model.wv.syn0_all[subwords_indices], axis=0) # 1 x vector_size | ||
if subwords_indices and model.cbow_mean: | ||
l1 /= len(subwords_indices) | ||
|
||
train_cbow_pair(model, word, subwords_indices, l1, alpha) # train on the sliding window for target word | ||
result += len(word_vocabs) | ||
return result | ||
|
||
def train_cbow_pair(model, word, input_subword_indices, l1, alpha, learn_vectors=True, learn_hidden=True): | ||
neu1e = zeros(l1.shape) | ||
|
||
if model.hs: | ||
l2a = model.syn1[word.point] # 2d matrix, codelen x layer1_size | ||
fa = expit(dot(l1, l2a.T)) # propagate hidden -> output | ||
ga = (1. - word.code - fa) * alpha # vector of error gradients multiplied by the learning rate | ||
if learn_hidden: | ||
model.syn1[word.point] += outer(ga, l1) # learn hidden -> output | ||
neu1e += dot(ga, l2a) # save error | ||
|
||
if model.negative: | ||
# use this word (label = 1) + `negative` other random words not from this sentence (label = 0) | ||
word_indices = [word.index] # through word index get all subwords indices (need to make the changes in code) | ||
while len(word_indices) < model.negative + 1: | ||
w = model.cum_table.searchsorted(model.random.randint(model.cum_table[-1])) | ||
if w != word.index: | ||
word_indices.append(w) | ||
l2b = model.syn1neg[word_indices] # 2d matrix, k+1 x layer1_size | ||
fb = expit(dot(l1, l2b.T)) # propagate hidden -> output | ||
gb = (model.neg_labels - fb) * alpha # vector of error gradients multiplied by the learning rate | ||
if learn_hidden: | ||
model.syn1neg[word_indices] += outer(gb, l1) # learn hidden -> output | ||
neu1e += dot(gb, l2b) # save error | ||
|
||
if learn_vectors: | ||
# learn input -> hidden, here for all words in the window separately | ||
if not model.cbow_mean and input_subword_indices: | ||
neu1e /= len(input_subword_indices) | ||
for i in input_subword_indices: | ||
model.wv.syn0_all[i] += neu1e * model.syn0_all_lockf[i] | ||
|
||
return neu1e | ||
|
||
|
||
class FastText(Word2Vec): | ||
def __init__( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should definitely be using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
self, sentences=None, sg=0, hs=0, size=100, alpha=0.025, window=5, min_count=5, | ||
max_vocab_size=None, word_ngrams=1, loss='ns', sample=1e-3, seed=1, workers=3, min_alpha=0.0001, | ||
negative=5, cbow_mean=1, hashfxn=hash, iter=5, null_word=0, min_n=3, max_n=6, sorted_vocab=1, bucket=2000000, | ||
trim_rule=None, batch_words=MAX_WORDS_IN_BATCH): | ||
|
||
self.load = call_on_class_only | ||
|
||
self.initialize_ngram_vectors() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's just use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
|
||
self.sg = int(sg) | ||
self.cum_table = None # for negative sampling | ||
self.vector_size = int(size) | ||
self.layer1_size = int(size) | ||
if size % 4 != 0: | ||
logger.warning("consider setting layer size to a multiple of 4 for greater performance") | ||
self.alpha = float(alpha) | ||
self.min_alpha_yet_reached = float(alpha) # To warn user if alpha increases | ||
self.window = int(window) | ||
self.max_vocab_size = max_vocab_size | ||
self.seed = seed | ||
self.random = random.RandomState(seed) | ||
self.min_count = min_count | ||
self.sample = sample | ||
self.workers = int(workers) | ||
self.min_alpha = float(min_alpha) | ||
self.hs = hs | ||
self.negative = negative | ||
self.cbow_mean = int(cbow_mean) | ||
self.hashfxn = hashfxn | ||
self.iter = iter | ||
self.null_word = null_word | ||
self.train_count = 0 | ||
self.total_train_time = 0 | ||
self.sorted_vocab = sorted_vocab | ||
self.batch_words = batch_words | ||
self.model_trimmed_post_training = False | ||
|
||
self.bucket = bucket | ||
self.loss = loss # should we keep this? -> we already have `hs`, `negative` -> although we don't have a mode for only `softmax` | ||
self.word_ngrams = word_ngrams | ||
self.min_n = min_n | ||
self.max_n = max_n | ||
if self.word_ngrams <= 1 and self.max_n == 0: | ||
self.bucket = 0 | ||
|
||
self.wv.min_n = min_n | ||
self.wv.max_n = max_n | ||
|
||
if sentences is not None: | ||
if isinstance(sentences, GeneratorType): | ||
raise TypeError("You can't pass a generator as the sentences argument. Try an iterator.") | ||
self.build_vocab(sentences, trim_rule=trim_rule) | ||
self.train(sentences, total_examples=self.corpus_count, epochs=self.iter, | ||
start_alpha=self.alpha, end_alpha=self.min_alpha) | ||
else: | ||
if trim_rule is not None: | ||
logger.warning("The rule, if given, is only used to prune vocabulary during build_vocab() and is not stored as part of the model. ") | ||
logger.warning("Model initialized without sentences. trim_rule provided, if any, will be ignored.") | ||
|
||
def train(self, sentences, total_examples=None, total_words=None, | ||
epochs=None, start_alpha=None, end_alpha=None, | ||
word_count=0, queue_factor=2, report_delay=1.0): | ||
self.neg_labels = [] | ||
if self.negative > 0: | ||
# precompute negative labels optimization for pure-python training | ||
self.neg_labels = zeros(self.negative + 1) | ||
self.neg_labels[0] = 1. | ||
|
||
Word2Vec.train(self, sentences, total_examples=self.corpus_count, epochs=self.iter, | ||
start_alpha=self.alpha, end_alpha=self.min_alpha) | ||
self.get_vocab_word_vecs() | ||
|
||
def initialize_ngram_vectors(self): | ||
self.wv = FastTextKeyedVectors() | ||
|
||
def __getitem__(self, word): | ||
return self.word_vec(word) | ||
|
||
def get_vocab_word_vecs(self): | ||
for w, v in self.wv.vocab.items(): | ||
word_vec = np.zeros(self.wv.syn0_all.shape[1]) | ||
ngrams = ['<' + w + '>'] | ||
ngrams += Ft_Wrapper.compute_ngrams(w, self.min_n, self.max_n) | ||
ngrams = list(set(ngrams)) | ||
ngram_weights = self.wv.syn0_all | ||
for ngram in ngrams: | ||
word_vec += ngram_weights[self.wv.ngrams[ngram]] | ||
word_vec /= len(ngrams) | ||
|
||
self.wv.syn0[v.index] = word_vec | ||
|
||
def word_vec(self, word, use_norm=False): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any reason for not reusing There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I get this error if I try to access an OOV word:
Looks like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here's the minimal example to reproduce my issue https://gist.github.com/rilut/31f41d5cf3da075d43cf7e4f2c993b76 Thanks 🙏 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Btw, since There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @rilut Seems like I missed setting There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @chinmayapancholi13 oh ok that's good. Thanks for your fix and hard work! I really appreciated it. |
||
if word in self.wv.vocab: | ||
if use_norm: | ||
return self.wv.syn0norm[self.wv.vocab[word].index] | ||
else: | ||
return self.wv.syn0[self.wv.vocab[word].index] | ||
else: | ||
logger.info("out of vocab") | ||
word_vec = np.zeros(self.wv.syn0_all.shape[1]) | ||
ngrams = Ft_Wrapper.compute_ngrams(word, self.min_n, self.max_n) | ||
ngrams = [ng for ng in ngrams if ng in self.wv.ngrams] | ||
if use_norm: | ||
ngram_weights = self.wv.syn0_all_norm | ||
else: | ||
ngram_weights = self.wv.syn0_all | ||
for ngram in ngrams: | ||
word_vec += ngram_weights[self.wv.ngrams[ngram]] | ||
if word_vec.any(): | ||
return word_vec / len(ngrams) | ||
else: # No ngrams of the word are present in self.ngrams | ||
raise KeyError('all ngrams for word %s absent from model' % word) | ||
|
||
def build_vocab(self, sentences, keep_raw_vocab=False, trim_rule=None, progress_per=10000, update=False): | ||
self.scan_vocab(sentences, progress_per=progress_per, trim_rule=trim_rule) # initial survey | ||
self.scale_vocab(keep_raw_vocab=keep_raw_vocab, trim_rule=trim_rule, update=update) # trim by min_count & precalculate downsampling | ||
self.finalize_vocab(update=update) # build tables & arrays | ||
# super(build_vocab, self, sentences, keep_raw_vocab=False, trim_rule=None, progress_per=10000, update=False) | ||
self.init_ngrams() | ||
|
||
def reset_ngram_weights(self): | ||
for ngram in self.wv.ngrams: | ||
self.wv.syn0_all[self.wv.ngrams[ngram]] = self.seeded_vector(ngram + str(self.seed)) | ||
|
||
def init_ngrams(self): | ||
self.wv.ngrams = {} | ||
self.wv.syn0_all = empty((self.bucket + len(self.wv.vocab), self.vector_size), dtype=REAL) | ||
self.syn0_all_lockf = ones((self.bucket + len(self.wv.vocab), self.vector_size), dtype=REAL) | ||
|
||
all_ngrams = [] | ||
for w, v in self.wv.vocab.items(): | ||
all_ngrams += ['<' + w + '>'] | ||
all_ngrams += Ft_Wrapper.compute_ngrams(w, self.min_n, self.max_n) | ||
all_ngrams = list(set(all_ngrams)) | ||
self.num_ngram_vectors = len(all_ngrams) | ||
logger.info("Total number of ngrams in the vocab is %d", self.num_ngram_vectors) | ||
|
||
ngram_indices = [] | ||
for i, ngram in enumerate(all_ngrams): | ||
ngram_hash = Ft_Wrapper.ft_hash(ngram) | ||
ngram_indices.append(len(self.wv.vocab) + ngram_hash % self.bucket) | ||
self.wv.ngrams[ngram] = i | ||
|
||
self.wv.syn0_all = self.wv.syn0_all.take(ngram_indices, axis=0) | ||
self.reset_ngram_weights() | ||
|
||
def _do_train_job(self, sentences, alpha, inits): | ||
work, neu1 = inits | ||
tally = 0 | ||
# if self.sg: | ||
# tally += train_batch_sg(self, sentences, alpha, work) | ||
# else: | ||
tally += train_batch_cbow(self, sentences, alpha, work, neu1) | ||
|
||
return tally, self._raw_word_count(sentences) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
|
||
import logging | ||
import unittest | ||
import os | ||
|
||
import numpy as np | ||
|
||
from gensim import utils | ||
from gensim.models.fasttext import FastText as FT_gensim | ||
from gensim.models.wrappers.fasttext import FastText as FT_wrapper | ||
|
||
module_path = os.path.dirname(__file__) # needed because sample data files are located in the same folder | ||
datapath = lambda fname: os.path.join(module_path, 'test_data', fname) | ||
|
||
class LeeCorpus(object): | ||
def __iter__(self): | ||
with open(datapath('lee_background.cor')) as f: | ||
for line in f: | ||
yield utils.simple_preprocess(line) | ||
|
||
list_corpus = list(LeeCorpus()) | ||
|
||
sentences = [ | ||
['human', 'interface', 'computer'], | ||
['survey', 'user', 'computer', 'system', 'response', 'time'], | ||
['eps', 'user', 'interface', 'system'], | ||
['system', 'human', 'system', 'eps'], | ||
['user', 'response', 'time'], | ||
['trees'], | ||
['graph', 'trees'], | ||
['graph', 'minors', 'trees'], | ||
['graph', 'minors', 'survey'] | ||
] | ||
|
||
|
||
class TestFastTextModel(unittest.TestCase): | ||
|
||
def models_equal(self, model, model2): | ||
self.assertEqual(len(model.wv.vocab), len(model2.wv.vocab)) | ||
self.assertTrue(np.allclose(model.wv.syn0, model2.wv.syn0)) | ||
self.assertTrue(np.allclose(model.wv.syn0_all, model2.wv.syn0_all)) | ||
if model.hs: | ||
self.assertTrue(np.allclose(model.syn1, model2.syn1)) | ||
if model.negative: | ||
self.assertTrue(np.allclose(model.syn1neg, model2.syn1neg)) | ||
most_common_word = max(model.wv.vocab.items(), key=lambda item: item[1].count)[0] | ||
self.assertTrue(np.allclose(model[most_common_word], model2[most_common_word])) | ||
|
||
def testTraining(self): | ||
model = FT_gensim(size=2, min_count=1, hs=1, negative=0) | ||
model.build_vocab(sentences) | ||
|
||
self.assertTrue(model.wv.syn0_all.shape == (len(model.wv.ngrams), 2)) | ||
self.assertTrue(model.syn1.shape == (len(model.wv.vocab), 2)) | ||
|
||
model.train(sentences, total_examples=model.corpus_count, epochs=model.iter) | ||
sims = model.most_similar('graph', topn=10) | ||
|
||
# test querying for "most similar" by vector | ||
graph_vector = model.wv.syn0norm[model.wv.vocab['graph'].index] | ||
sims2 = model.most_similar(positive=[graph_vector], topn=11) | ||
sims2 = [(w, sim) for w, sim in sims2 if w != 'graph'] # ignore 'graph' itself | ||
self.assertEqual(sims, sims2) | ||
|
||
# build vocab and train in one step; must be the same as above | ||
model2 = FT_gensim(sentences, size=2, min_count=1, hs=1, negative=0) | ||
self.models_equal(model, model2) | ||
|
||
def test_against_fasttext_wrapper(self, model_gensim, model_wrapper): | ||
sims_gensim = model_gensim.most_similar('war', topn=50) | ||
sims_wrapper = model_wrapper.most_similar('war', topn=50) | ||
self.assertEqual(sims_gensim, sims_wrapper) | ||
|
||
def test_cbow_hs(self): | ||
model_wrapper = FT_wrapper.train(ft_path='/home/chinmaya/GSOC/Gensim/fastText/fasttext', | ||
corpus_file=datapath('lee_background.cor'), output_file='/home/chinmaya/GSOC/Gensim/fasttext_out1', model='cbow', size=50, | ||
alpha=0.05, window=8, min_count=5, word_ngrams=1, loss='hs', sample=1e-3, negative=0, iter=5, min_n=3, max_n=6, sorted_vocab=1, threads=12) | ||
|
||
model_gensim = FT_gensim(size=50, sg=0, cbow_mean=1, alpha=0.05, window=8, hs=1, negative=0, | ||
min_count=5, iter=5, batch_words=1000, word_ngrams=1, sample=1e-3, min_n=3, max_n=6, | ||
sorted_vocab=1, workers=12, min_alpha=0.0001) | ||
|
||
model_gensim.build_vocab(list_corpus) | ||
orig0 = np.copy(model_gensim.wv.syn0[0]) | ||
model_gensim.train(list_corpus, total_examples=model_gensim.corpus_count, epochs=model_gensim.iter) | ||
self.assertFalse((orig0 == model_gensim.wv.syn0[1]).all()) # vector should vary after training | ||
|
||
self.test_against_fasttext_wrapper(model_gensim, model_wrapper) | ||
|
||
def test_cbow_neg(self): | ||
model_wrapper = FT_wrapper.train(ft_path='/home/chinmaya/GSOC/Gensim/fastText/fasttext', | ||
corpus_file=datapath('lee_background.cor'), output_file='/home/chinmaya/GSOC/Gensim/fasttext_out1', model='cbow', size=50, | ||
alpha=0.05, window=8, min_count=5, word_ngrams=1, loss='ns', sample=1e-3, negative=15, iter=5, min_n=3, max_n=6, sorted_vocab=1, threads=12) | ||
|
||
model_gensim = FT_gensim(size=50, sg=0, cbow_mean=1, alpha=0.05, window=8, hs=0, negative=15, | ||
min_count=5, iter=5, batch_words=1000, word_ngrams=1, sample=1e-3, min_n=3, max_n=6, | ||
sorted_vocab=1, workers=12, min_alpha=0.0001) | ||
|
||
model_gensim.build_vocab(list_corpus) | ||
orig0 = np.copy(model_gensim.wv.syn0[0]) | ||
model_gensim.train(list_corpus, total_examples=model_gensim.corpus_count, epochs=model_gensim.iter) | ||
self.assertFalse((orig0 == model_gensim.wv.syn0[1]).all()) # vector should vary after training | ||
|
||
self.test_against_fasttext_wrapper(model_gensim, model_wrapper) | ||
|
||
|
||
if __name__ == '__main__': | ||
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.DEBUG) | ||
unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This works for now, but ideally we'd like a cleaner solution to this later on. In general, I think the FastText wrapper (to load
.bin
files) and the FastText training code implemented here shares a lot of common ground (both conceptually and code-wise). Once we have the correctness of the models verified, we'd be looking to refactor it somehow (maybe just inheriting from the wrapper? Completely removing train functionality from the wrapper and replacing it with native train functionality?) Any thoughts on this?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed. The current implementation shares code with Gensim's Fasttext wrapper so inheriting from the wrapper seems to be good way for avoiding this redundancy.
I think it would also be helpful to refactor the current
Word2Vec
implementation since apart from using ngrams-vectors rather than word-vectors at the time of backpropagation in fasttext, the logic and code between the two models overlap significantly. Having one common parent class and the two models as the children could be a useful way to tackle this.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that refactoring to avoid redundancy would be good. I'm not sure a common parent class is the way to go though, since most of the redundant code is in methods
train_batch_cbow
andtrain_batch_skipgram
, which are both independently defined functions, and not methods of theWord2Vec
class.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Apart from these training functions, there is overlap between the two models in some other tasks as well. For instance, in our
fasttext
implementation, we are first constructing the vocabulary in the same way as is done inWord2Vec
(i.e. callingscan_vocab
,scale_vocab
andfinalize_vocab
functions) and then we are handling all the "fasttext-specific" things (like constructing the dictionary of ngrams and precomputing & storing ngrams for each word in the vocab). These "fasttext-specific" things can be handled at a prior stage (e.g. withinscale_vocab
orfinalize_vocab
functions) and this would also help us optimize things e.g. by avoiding iterating over the vocabulary a few times unnecessarily.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
super
method is useful in such situations - where the parent class implementation of the method needs to be run along with whatever code is specific to the child class.