Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Adding unsupervised FastText to Gensim #1525

Merged
merged 35 commits into from
Sep 19, 2017
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
a815c84
added initial code for CBOW
chinmayapancholi13 Aug 8, 2017
102c14a
updated unit tests for fasttext
chinmayapancholi13 Aug 13, 2017
4c449df
corrected use of matrix and precomputed ngrams for vocab words
chinmayapancholi13 Aug 13, 2017
f49df54
added EOS token in 'LineSentence' class
chinmayapancholi13 Aug 15, 2017
1fcb8fa
added skipgram training code
chinmayapancholi13 Aug 16, 2017
82fda3c
updated unit tests for fasttext
chinmayapancholi13 Aug 16, 2017
cd59034
seeded 'np.random' with 'self.seed'
chinmayapancholi13 Aug 16, 2017
353f7a8
added test for persistence
chinmayapancholi13 Aug 17, 2017
569a026
updated seeding numpy obj
chinmayapancholi13 Aug 17, 2017
c228b8d
updated (unclean) fasttext code for review
chinmayapancholi13 Aug 23, 2017
29c627f
updated fasttext tutorial notebook
chinmayapancholi13 Aug 24, 2017
acbfdf2
added 'save' and 'load_fasttext_format' functions
chinmayapancholi13 Aug 24, 2017
cb7a2ad
updated unit tests for fasttext
chinmayapancholi13 Aug 24, 2017
5a18297
cleaned main fasttext code
chinmayapancholi13 Aug 25, 2017
4b98722
updated unittests
chinmayapancholi13 Aug 25, 2017
cf1f3e0
removed EOS token from LineSentence
chinmayapancholi13 Aug 25, 2017
d986242
fixed flake8 errors
chinmayapancholi13 Aug 25, 2017
bce17ff
[WIP] added online learning
chinmayapancholi13 Aug 25, 2017
cb84001
added tests for online learning
chinmayapancholi13 Aug 25, 2017
fbe8bdc
Merge branch 'develop' of https://github.com/RaRe-Technologies/gensim…
chinmayapancholi13 Aug 25, 2017
58c673a
flake8 fixes
chinmayapancholi13 Aug 25, 2017
893ef76
refactored code to remove redundancy
chinmayapancholi13 Aug 27, 2017
e12f6c0
reusing 'word_vec' from 'FastTextKeyedVectors'
chinmayapancholi13 Aug 27, 2017
39d14bd
flake8 fixes
chinmayapancholi13 Aug 27, 2017
d3ec5a8
split 'syn0_all' into 'syn0_vocab' and 'syn0_ngrams'
chinmayapancholi13 Aug 29, 2017
0854622
removed 'init_wv' param from Word2Vec
chinmayapancholi13 Aug 29, 2017
904882a
updated unittests
chinmayapancholi13 Aug 30, 2017
a9e7d03
flake8 errors fixed
chinmayapancholi13 Aug 30, 2017
ec58512
fixed oov word_vec
chinmayapancholi13 Sep 7, 2017
2ed7d31
removed merge conflicts
chinmayapancholi13 Sep 7, 2017
daace4a
updated test_training unittest
chinmayapancholi13 Sep 7, 2017
58c531b
Merge branch 'develop' into fasttext_gensim
menshikh-iv Sep 18, 2017
3ffa103
Fix broken merge
menshikh-iv Sep 18, 2017
2b0583b
useless change (need to re-run Appveyour)
menshikh-iv Sep 19, 2017
55d731a
Add skipIf for Appveyor x32 (avoid memory error)
menshikh-iv Sep 19, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
244 changes: 244 additions & 0 deletions gensim/models/fasttext.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import logging

from types import GeneratorType
from copy import deepcopy
from six import string_types
import numpy as np
from numpy import dot, zeros, ones, vstack, outer, random, sum as np_sum, empty, float32 as REAL
from scipy.special import expit

from gensim.utils import call_on_class_only
from gensim.models.word2vec import Word2Vec
from gensim.models.wrappers.fasttext import FastTextKeyedVectors
from gensim.models.wrappers.fasttext import FastText as Ft_Wrapper

logger = logging.getLogger(__name__)

MAX_WORDS_IN_BATCH = 10000


def train_batch_cbow(model, sentences, alpha, work=None, neu1=None):
result = 0
for sentence in sentences:
word_vocabs = [model.wv.vocab[w] for w in sentence if w in model.wv.vocab and
model.wv.vocab[w].sample_int > model.random.rand() * 2**32]
for pos, word in enumerate(word_vocabs):
reduced_window = model.random.randint(model.window) # `b` in the original word2vec code
start = max(0, pos - model.window + reduced_window)
window_pos = enumerate(word_vocabs[start:(pos + model.window + 1 - reduced_window)], start)
word2_indices = [word2.index for pos2, word2 in window_pos if (word2 is not None and pos2 != pos)]

word2_subwords = []

for indices in word2_indices:
word2_subwords += ['<' + model.wv.index2word[indices] + '>']
word2_subwords += Ft_Wrapper.compute_ngrams(model.wv.index2word[indices], model.min_n, model.max_n)
Copy link
Contributor

@jayantj jayantj Aug 11, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This works for now, but ideally we'd like a cleaner solution to this later on. In general, I think the FastText wrapper (to load .bin files) and the FastText training code implemented here shares a lot of common ground (both conceptually and code-wise). Once we have the correctness of the models verified, we'd be looking to refactor it somehow (maybe just inheriting from the wrapper? Completely removing train functionality from the wrapper and replacing it with native train functionality?) Any thoughts on this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. The current implementation shares code with Gensim's Fasttext wrapper so inheriting from the wrapper seems to be good way for avoiding this redundancy.
I think it would also be helpful to refactor the current Word2Vec implementation since apart from using ngrams-vectors rather than word-vectors at the time of backpropagation in fasttext, the logic and code between the two models overlap significantly. Having one common parent class and the two models as the children could be a useful way to tackle this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that refactoring to avoid redundancy would be good. I'm not sure a common parent class is the way to go though, since most of the redundant code is in methods train_batch_cbow and train_batch_skipgram, which are both independently defined functions, and not methods of the Word2Vec class.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apart from these training functions, there is overlap between the two models in some other tasks as well. For instance, in our fasttext implementation, we are first constructing the vocabulary in the same way as is done in Word2Vec (i.e. calling scan_vocab, scale_vocab and finalize_vocab functions) and then we are handling all the "fasttext-specific" things (like constructing the dictionary of ngrams and precomputing & storing ngrams for each word in the vocab). These "fasttext-specific" things can be handled at a prior stage (e.g. within scale_vocab or finalize_vocab functions) and this would also help us optimize things e.g. by avoiding iterating over the vocabulary a few times unnecessarily.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The super method is useful in such situations - where the parent class implementation of the method needs to be run along with whatever code is specific to the child class.

word2_subwords = list(set(word2_subwords))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought we changed this to no longer be a set.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's correct. I have pushed those changes now.


subwords_indices = []
for subword in word2_subwords:
subwords_indices.append(model.wv.ngrams[subword])

l1 = np_sum(model.wv.syn0_all[subwords_indices], axis=0) # 1 x vector_size
if subwords_indices and model.cbow_mean:
l1 /= len(subwords_indices)

train_cbow_pair(model, word, subwords_indices, l1, alpha) # train on the sliding window for target word
result += len(word_vocabs)
return result

def train_cbow_pair(model, word, input_subword_indices, l1, alpha, learn_vectors=True, learn_hidden=True):
neu1e = zeros(l1.shape)

if model.hs:
l2a = model.syn1[word.point] # 2d matrix, codelen x layer1_size
fa = expit(dot(l1, l2a.T)) # propagate hidden -> output
ga = (1. - word.code - fa) * alpha # vector of error gradients multiplied by the learning rate
if learn_hidden:
model.syn1[word.point] += outer(ga, l1) # learn hidden -> output
neu1e += dot(ga, l2a) # save error

if model.negative:
# use this word (label = 1) + `negative` other random words not from this sentence (label = 0)
word_indices = [word.index] # through word index get all subwords indices (need to make the changes in code)
while len(word_indices) < model.negative + 1:
w = model.cum_table.searchsorted(model.random.randint(model.cum_table[-1]))
if w != word.index:
word_indices.append(w)
l2b = model.syn1neg[word_indices] # 2d matrix, k+1 x layer1_size
fb = expit(dot(l1, l2b.T)) # propagate hidden -> output
gb = (model.neg_labels - fb) * alpha # vector of error gradients multiplied by the learning rate
if learn_hidden:
model.syn1neg[word_indices] += outer(gb, l1) # learn hidden -> output
neu1e += dot(gb, l2b) # save error

if learn_vectors:
# learn input -> hidden, here for all words in the window separately
if not model.cbow_mean and input_subword_indices:
neu1e /= len(input_subword_indices)
for i in input_subword_indices:
model.wv.syn0_all[i] += neu1e * model.syn0_all_lockf[i]

return neu1e


class FastText(Word2Vec):
def __init__(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should definitely be using super here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

self, sentences=None, sg=0, hs=0, size=100, alpha=0.025, window=5, min_count=5,
max_vocab_size=None, word_ngrams=1, loss='ns', sample=1e-3, seed=1, workers=3, min_alpha=0.0001,
negative=5, cbow_mean=1, hashfxn=hash, iter=5, null_word=0, min_n=3, max_n=6, sorted_vocab=1, bucket=2000000,
trim_rule=None, batch_words=MAX_WORDS_IN_BATCH):

self.load = call_on_class_only

self.initialize_ngram_vectors()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's just use initialize_word_vectors and get rid of the init_wv argument.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


self.sg = int(sg)
self.cum_table = None # for negative sampling
self.vector_size = int(size)
self.layer1_size = int(size)
if size % 4 != 0:
logger.warning("consider setting layer size to a multiple of 4 for greater performance")
self.alpha = float(alpha)
self.min_alpha_yet_reached = float(alpha) # To warn user if alpha increases
self.window = int(window)
self.max_vocab_size = max_vocab_size
self.seed = seed
self.random = random.RandomState(seed)
self.min_count = min_count
self.sample = sample
self.workers = int(workers)
self.min_alpha = float(min_alpha)
self.hs = hs
self.negative = negative
self.cbow_mean = int(cbow_mean)
self.hashfxn = hashfxn
self.iter = iter
self.null_word = null_word
self.train_count = 0
self.total_train_time = 0
self.sorted_vocab = sorted_vocab
self.batch_words = batch_words
self.model_trimmed_post_training = False

self.bucket = bucket
self.loss = loss # should we keep this? -> we already have `hs`, `negative` -> although we don't have a mode for only `softmax`
self.word_ngrams = word_ngrams
self.min_n = min_n
self.max_n = max_n
if self.word_ngrams <= 1 and self.max_n == 0:
self.bucket = 0

self.wv.min_n = min_n
self.wv.max_n = max_n

if sentences is not None:
if isinstance(sentences, GeneratorType):
raise TypeError("You can't pass a generator as the sentences argument. Try an iterator.")
self.build_vocab(sentences, trim_rule=trim_rule)
self.train(sentences, total_examples=self.corpus_count, epochs=self.iter,
start_alpha=self.alpha, end_alpha=self.min_alpha)
else:
if trim_rule is not None:
logger.warning("The rule, if given, is only used to prune vocabulary during build_vocab() and is not stored as part of the model. ")
logger.warning("Model initialized without sentences. trim_rule provided, if any, will be ignored.")

def train(self, sentences, total_examples=None, total_words=None,
epochs=None, start_alpha=None, end_alpha=None,
word_count=0, queue_factor=2, report_delay=1.0):
self.neg_labels = []
if self.negative > 0:
# precompute negative labels optimization for pure-python training
self.neg_labels = zeros(self.negative + 1)
self.neg_labels[0] = 1.

Word2Vec.train(self, sentences, total_examples=self.corpus_count, epochs=self.iter,
start_alpha=self.alpha, end_alpha=self.min_alpha)
self.get_vocab_word_vecs()

def initialize_ngram_vectors(self):
self.wv = FastTextKeyedVectors()

def __getitem__(self, word):
return self.word_vec(word)

def get_vocab_word_vecs(self):
for w, v in self.wv.vocab.items():
word_vec = np.zeros(self.wv.syn0_all.shape[1])
ngrams = ['<' + w + '>']
ngrams += Ft_Wrapper.compute_ngrams(w, self.min_n, self.max_n)
ngrams = list(set(ngrams))
ngram_weights = self.wv.syn0_all
for ngram in ngrams:
word_vec += ngram_weights[self.wv.ngrams[ngram]]
word_vec /= len(ngrams)

self.wv.syn0[v.index] = word_vec

def word_vec(self, word, use_norm=False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason for not reusing word_vec from FastTextKeyedVectors here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I get this error if I try to access an OOV word:

print model["use"] # only "user" is available in the vocabulary

     86         else:
     87             word_vec = np.zeros(self.syn0_ngrams.shape[1])
---> 88             ngrams = compute_ngrams(word, self.min_n, self.max_n)
     89             ngrams = [ng for ng in ngrams if ng in self.ngrams]
     90             if use_norm:

AttributeError: 'FastTextKeyedVectors' object has no attribute 'min_n'

Looks like FastTextKeyedVectors class has no min_n attribute?

Copy link

@luthfianto luthfianto Sep 7, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's the minimal example to reproduce my issue https://gist.github.com/rilut/31f41d5cf3da075d43cf7e4f2c993b76

Thanks 🙏

Copy link

@luthfianto luthfianto Sep 7, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw, since self.wv = FastTextKeyedVectors() can we just do self.wv.word_vec(word, word_norm)? (just asking)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rilut Seems like I missed setting min_n and max_n in FastTextKeyedVectors after I had refactored the code b/w the wrapper and main models.FastText class.
Anyway, thanks a lot for pointing out this! I have made the necessary changes for this now and also added a unittest to ensure that this error doesn't go unnoticed in the future.
And yes, FastTextKeyedVectors.word_vec(self.wv, word, use_norm=use_norm) and self.wv.word_vec(word, word_norm) are equivalent. I preferred the former form at the time since it made the borrowed usage of the function more explicit.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@chinmayapancholi13 oh ok that's good. Thanks for your fix and hard work! I really appreciated it.

if word in self.wv.vocab:
if use_norm:
return self.wv.syn0norm[self.wv.vocab[word].index]
else:
return self.wv.syn0[self.wv.vocab[word].index]
else:
logger.info("out of vocab")
word_vec = np.zeros(self.wv.syn0_all.shape[1])
ngrams = Ft_Wrapper.compute_ngrams(word, self.min_n, self.max_n)
ngrams = [ng for ng in ngrams if ng in self.wv.ngrams]
if use_norm:
ngram_weights = self.wv.syn0_all_norm
else:
ngram_weights = self.wv.syn0_all
for ngram in ngrams:
word_vec += ngram_weights[self.wv.ngrams[ngram]]
if word_vec.any():
return word_vec / len(ngrams)
else: # No ngrams of the word are present in self.ngrams
raise KeyError('all ngrams for word %s absent from model' % word)

def build_vocab(self, sentences, keep_raw_vocab=False, trim_rule=None, progress_per=10000, update=False):
self.scan_vocab(sentences, progress_per=progress_per, trim_rule=trim_rule) # initial survey
self.scale_vocab(keep_raw_vocab=keep_raw_vocab, trim_rule=trim_rule, update=update) # trim by min_count & precalculate downsampling
self.finalize_vocab(update=update) # build tables & arrays
# super(build_vocab, self, sentences, keep_raw_vocab=False, trim_rule=None, progress_per=10000, update=False)
self.init_ngrams()

def reset_ngram_weights(self):
for ngram in self.wv.ngrams:
self.wv.syn0_all[self.wv.ngrams[ngram]] = self.seeded_vector(ngram + str(self.seed))

def init_ngrams(self):
self.wv.ngrams = {}
self.wv.syn0_all = empty((self.bucket + len(self.wv.vocab), self.vector_size), dtype=REAL)
self.syn0_all_lockf = ones((self.bucket + len(self.wv.vocab), self.vector_size), dtype=REAL)

all_ngrams = []
for w, v in self.wv.vocab.items():
all_ngrams += ['<' + w + '>']
all_ngrams += Ft_Wrapper.compute_ngrams(w, self.min_n, self.max_n)
all_ngrams = list(set(all_ngrams))
self.num_ngram_vectors = len(all_ngrams)
logger.info("Total number of ngrams in the vocab is %d", self.num_ngram_vectors)

ngram_indices = []
for i, ngram in enumerate(all_ngrams):
ngram_hash = Ft_Wrapper.ft_hash(ngram)
ngram_indices.append(len(self.wv.vocab) + ngram_hash % self.bucket)
self.wv.ngrams[ngram] = i

self.wv.syn0_all = self.wv.syn0_all.take(ngram_indices, axis=0)
self.reset_ngram_weights()

def _do_train_job(self, sentences, alpha, inits):
work, neu1 = inits
tally = 0
# if self.sg:
# tally += train_batch_sg(self, sentences, alpha, work)
# else:
tally += train_batch_cbow(self, sentences, alpha, work, neu1)

return tally, self._raw_word_count(sentences)
111 changes: 111 additions & 0 deletions gensim/test/test_fasttext.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import logging
import unittest
import os

import numpy as np

from gensim import utils
from gensim.models.fasttext import FastText as FT_gensim
from gensim.models.wrappers.fasttext import FastText as FT_wrapper

module_path = os.path.dirname(__file__) # needed because sample data files are located in the same folder
datapath = lambda fname: os.path.join(module_path, 'test_data', fname)

class LeeCorpus(object):
def __iter__(self):
with open(datapath('lee_background.cor')) as f:
for line in f:
yield utils.simple_preprocess(line)

list_corpus = list(LeeCorpus())

sentences = [
['human', 'interface', 'computer'],
['survey', 'user', 'computer', 'system', 'response', 'time'],
['eps', 'user', 'interface', 'system'],
['system', 'human', 'system', 'eps'],
['user', 'response', 'time'],
['trees'],
['graph', 'trees'],
['graph', 'minors', 'trees'],
['graph', 'minors', 'survey']
]


class TestFastTextModel(unittest.TestCase):

def models_equal(self, model, model2):
self.assertEqual(len(model.wv.vocab), len(model2.wv.vocab))
self.assertTrue(np.allclose(model.wv.syn0, model2.wv.syn0))
self.assertTrue(np.allclose(model.wv.syn0_all, model2.wv.syn0_all))
if model.hs:
self.assertTrue(np.allclose(model.syn1, model2.syn1))
if model.negative:
self.assertTrue(np.allclose(model.syn1neg, model2.syn1neg))
most_common_word = max(model.wv.vocab.items(), key=lambda item: item[1].count)[0]
self.assertTrue(np.allclose(model[most_common_word], model2[most_common_word]))

def testTraining(self):
model = FT_gensim(size=2, min_count=1, hs=1, negative=0)
model.build_vocab(sentences)

self.assertTrue(model.wv.syn0_all.shape == (len(model.wv.ngrams), 2))
self.assertTrue(model.syn1.shape == (len(model.wv.vocab), 2))

model.train(sentences, total_examples=model.corpus_count, epochs=model.iter)
sims = model.most_similar('graph', topn=10)

# test querying for "most similar" by vector
graph_vector = model.wv.syn0norm[model.wv.vocab['graph'].index]
sims2 = model.most_similar(positive=[graph_vector], topn=11)
sims2 = [(w, sim) for w, sim in sims2 if w != 'graph'] # ignore 'graph' itself
self.assertEqual(sims, sims2)

# build vocab and train in one step; must be the same as above
model2 = FT_gensim(sentences, size=2, min_count=1, hs=1, negative=0)
self.models_equal(model, model2)

def test_against_fasttext_wrapper(self, model_gensim, model_wrapper):
sims_gensim = model_gensim.most_similar('war', topn=50)
sims_wrapper = model_wrapper.most_similar('war', topn=50)
self.assertEqual(sims_gensim, sims_wrapper)

def test_cbow_hs(self):
model_wrapper = FT_wrapper.train(ft_path='/home/chinmaya/GSOC/Gensim/fastText/fasttext',
corpus_file=datapath('lee_background.cor'), output_file='/home/chinmaya/GSOC/Gensim/fasttext_out1', model='cbow', size=50,
alpha=0.05, window=8, min_count=5, word_ngrams=1, loss='hs', sample=1e-3, negative=0, iter=5, min_n=3, max_n=6, sorted_vocab=1, threads=12)

model_gensim = FT_gensim(size=50, sg=0, cbow_mean=1, alpha=0.05, window=8, hs=1, negative=0,
min_count=5, iter=5, batch_words=1000, word_ngrams=1, sample=1e-3, min_n=3, max_n=6,
sorted_vocab=1, workers=12, min_alpha=0.0001)

model_gensim.build_vocab(list_corpus)
orig0 = np.copy(model_gensim.wv.syn0[0])
model_gensim.train(list_corpus, total_examples=model_gensim.corpus_count, epochs=model_gensim.iter)
self.assertFalse((orig0 == model_gensim.wv.syn0[1]).all()) # vector should vary after training

self.test_against_fasttext_wrapper(model_gensim, model_wrapper)

def test_cbow_neg(self):
model_wrapper = FT_wrapper.train(ft_path='/home/chinmaya/GSOC/Gensim/fastText/fasttext',
corpus_file=datapath('lee_background.cor'), output_file='/home/chinmaya/GSOC/Gensim/fasttext_out1', model='cbow', size=50,
alpha=0.05, window=8, min_count=5, word_ngrams=1, loss='ns', sample=1e-3, negative=15, iter=5, min_n=3, max_n=6, sorted_vocab=1, threads=12)

model_gensim = FT_gensim(size=50, sg=0, cbow_mean=1, alpha=0.05, window=8, hs=0, negative=15,
min_count=5, iter=5, batch_words=1000, word_ngrams=1, sample=1e-3, min_n=3, max_n=6,
sorted_vocab=1, workers=12, min_alpha=0.0001)

model_gensim.build_vocab(list_corpus)
orig0 = np.copy(model_gensim.wv.syn0[0])
model_gensim.train(list_corpus, total_examples=model_gensim.corpus_count, epochs=model_gensim.iter)
self.assertFalse((orig0 == model_gensim.wv.syn0[1]).all()) # vector should vary after training

self.test_against_fasttext_wrapper(model_gensim, model_wrapper)


if __name__ == '__main__':
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.DEBUG)
unittest.main()