Skip to content

Commit

Permalink
refactor init/update vectors/vectors_vocab; bulk randomization
Browse files Browse the repository at this point in the history
  • Loading branch information
gojomo committed Sep 14, 2020
1 parent 9ccb97e commit cd2e326
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 211 deletions.
13 changes: 7 additions & 6 deletions gensim/models/doc2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,9 @@ def __init__(self, documents=None, corpus_file=None, vector_size=100, dm_mean=No

self.vector_size = vector_size
self.dv = dv or KeyedVectors(self.vector_size, mapfile_path=dv_mapfile)
# EXPERIMENTAL lockf feature; create minimal no-op lockf arrays (1 element of 1.0)
# advanced users should directly resize/adjust as desired after any vocab growth
self.dv.vectors_lockf = np.ones(1, dtype=REAL) # 0.0 values suppress word-backprop-updates; 1.0 allows

super(Doc2Vec, self).__init__(
sentences=corpus_iterable,
Expand Down Expand Up @@ -329,11 +332,9 @@ def _clear_post_train(self):
self.wv.norms = None
self.dv.norms = None

def reset_weights(self):
super(Doc2Vec, self).reset_weights()
self.dv.resize_vectors()
self.dv.randomly_initialize_vectors()
self.dv.vectors_lockf = np.ones(1, dtype=REAL) # 0.0 values suppress word-backprop-updates; 1.0 allows
def init_weights(self):
super(Doc2Vec, self).init_weights()
self.dv.resize_vectors(seed=self.seed)

def reset_from(self, other_model):
"""Copy shareable data structures from another (possibly pre-trained) model.
Expand All @@ -358,7 +359,7 @@ def reset_from(self, other_model):
self.dv.key_to_index = other_model.dv.key_to_index
self.dv.index_to_key = other_model.dv.index_to_key
self.dv.expandos = other_model.dv.expandos
self.reset_weights()
self.init_weights()

def _do_train_epoch(self, corpus_file, thread_id, offset, cython_vocab, thread_private_mem, cur_epoch,
total_examples=None, total_words=None, offsets=None, start_doctags=None, **kwargs):
Expand Down
181 changes: 20 additions & 161 deletions gensim/models/fasttext.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@

import gensim.models._fasttext_bin
from gensim.models.word2vec import Word2Vec
from gensim.models.keyedvectors import KeyedVectors
from gensim.models.keyedvectors import KeyedVectors, prep_vectors
from gensim import utils
from gensim.utils import deprecated
try:
Expand Down Expand Up @@ -455,7 +455,10 @@ def __init__(self, sentences=None, corpus_file=None, sg=0, hs=0, vector_size=100
bucket = 0

self.wv = FastTextKeyedVectors(vector_size, min_n, max_n, bucket)
self.wv.bucket = bucket
# EXPERIMENTAL lockf feature; create minimal no-op lockf arrays (1 element of 1.0)
# advanced users should directly resize/adjust as desired after any vocab growth
self.wv.vectors_vocab_lockf = ones(1, dtype=REAL)
self.wv.vectors_ngrams_lockf = ones(1, dtype=REAL)

super(FastText, self).__init__(
sentences=sentences, corpus_file=corpus_file, workers=workers, vector_size=vector_size, epochs=epochs,
Expand All @@ -465,29 +468,6 @@ def __init__(self, sentences=None, corpus_file=None, sg=0, hs=0, vector_size=100
null_word=null_word, ns_exponent=ns_exponent, hashfxn=hashfxn,
seed=seed, hs=hs, negative=negative, cbow_mean=cbow_mean, min_alpha=min_alpha)

def prepare_weights(self, update=False):
"""In addition to superclass allocations, compute ngrams of all words present in vocabulary.
Parameters
----------
update : bool
If True, the new vocab words and their new ngrams word vectors are initialized
with random uniform distribution and updated/added to the existing vocab word and ngram vectors.
"""
super(FastText, self).prepare_weights(update=update)
if not update:
self.wv.init_ngrams_weights(self.seed)
# EXPERIMENTAL lockf feature; create minimal no-op lockf arrays (1 element of 1.0)
# advanced users should directly resize/adjust as necessary
self.wv.vectors_vocab_lockf = ones(1, dtype=REAL)
self.wv.vectors_ngrams_lockf = ones(1, dtype=REAL)
else:
self.wv.update_ngrams_weights(self.seed, self.old_vocab_len)
# EXPERIMENTAL lockf feature; create minimal no-op lockf arrays (1 element of 1.0)
# advanced users should directly resize/adjust as necessary
self.wv.vectors_vocab_lockf = ones(1, dtype=REAL)
self.wv.vectors_ngrams_lockf = ones(1, dtype=REAL)

def _init_post_load(self, hidden_output):
num_vectors = len(self.wv.vectors)
vocab_size = len(self.wv)
Expand All @@ -508,85 +488,6 @@ def _init_post_load(self, hidden_output):

self.layer1_size = vector_size

def build_vocab(self, corpus_iterable=None, corpus_file=None, update=False, progress_per=10000,
keep_raw_vocab=False, trim_rule=None, **kwargs):
"""Build vocabulary from a sequence of sentences (can be a once-only generator stream).
Each sentence must be a list of unicode strings.
Parameters
----------
corpus_iterable : iterable of list of str, optional
Can be simply a list of lists of tokens, but for larger corpora,
consider an iterable that streams the sentences directly from disk/network.
See :class:`~gensim.models.word2vec.BrownCorpus`, :class:`~gensim.models.word2vec.Text8Corpus`
or :class:`~gensim.models.word2vec.LineSentence` in :mod:`~gensim.models.word2vec` module for such examples.
corpus_file : str, optional
Path to a corpus file in :class:`~gensim.models.word2vec.LineSentence` format.
You may use this argument instead of `sentences` to get performance boost. Only one of `sentences` or
`corpus_file` arguments need to be passed (not both of them).
update : bool
If true, the new words in `sentences` will be added to model's vocab.
progress_per : int
Indicates how many words to process before showing/updating the progress.
keep_raw_vocab : bool
If not true, delete the raw vocabulary after the scaling is done and free up RAM.
trim_rule : function, optional
Vocabulary trimming rule, specifies whether certain words should remain in the vocabulary,
be trimmed away, or handled using the default (discard if word count < min_count).
Can be None (min_count will be used, look to :func:`~gensim.utils.keep_vocab_item`),
or a callable that accepts parameters (word, count, min_count) and returns either
:attr:`gensim.utils.RULE_DISCARD`, :attr:`gensim.utils.RULE_KEEP` or :attr:`gensim.utils.RULE_DEFAULT`.
The rule, if given, is only used to prune vocabulary during
:meth:`~gensim.models.fasttext.FastText.build_vocab` and is not stored as part of the model.
The input parameters are of the following types:
* `word` (str) - the word we are examining
* `count` (int) - the word's frequency count in the corpus
* `min_count` (int) - the minimum count threshold.
**kwargs
Additional key word parameters passed to
:meth:`~gensim.models.word2vec.Word2Vec.build_vocab`.
Examples
--------
Train a model and update vocab for online training:
.. sourcecode:: pycon
>>> from gensim.models import FastText
>>> sentences_1 = [["cat", "say", "meow"], ["dog", "say", "woof"]]
>>> sentences_2 = [["dude", "say", "wazzup!"]]
>>>
>>> model = FastText(min_count=1)
>>> model.build_vocab(sentences_1)
>>> model.train(sentences_1, total_examples=model.corpus_count, epochs=model.epochs)
>>>
>>> model.build_vocab(sentences_2, update=True)
>>> model.train(sentences_2, total_examples=model.corpus_count, epochs=model.epochs)
"""
if not update:
self.wv.init_ngrams_weights(self.seed)
elif not len(self.wv):
raise RuntimeError(
"You cannot do an online vocabulary-update of a model which has no prior vocabulary. "
"First build the vocabulary of your model with a corpus "
"by calling the gensim.models.fasttext.FastText.build_vocab method "
"before doing an online update."
)
else:
self.old_vocab_len = len(self.wv)

retval = super(FastText, self).build_vocab(
corpus_iterable=corpus_iterable, corpus_file=corpus_file, update=update, progress_per=progress_per,
keep_raw_vocab=keep_raw_vocab, trim_rule=trim_rule, **kwargs)

if update:
self.wv.update_ngrams_weights(self.seed, self.old_vocab_len)

return retval

def _clear_post_train(self):
"""Clear the model's internal structures after training has finished to free up RAM."""
self.wv.adjust_vectors() # ensure composite-word vecs reflect latest training
Expand Down Expand Up @@ -1146,7 +1047,7 @@ def save_facebook_model(model, path, encoding="utf-8", lr_update_rate=100, word_


class FastTextKeyedVectors(KeyedVectors):
def __init__(self, vector_size, min_n, max_n, bucket):
def __init__(self, vector_size, min_n, max_n, bucket, count=0, dtype=REAL):
"""Vectors and vocab for :class:`~gensim.models.fasttext.FastText`.
Implements significant parts of the FastText algorithm. For example,
Expand Down Expand Up @@ -1192,12 +1093,12 @@ def __init__(self, vector_size, min_n, max_n, bucket):
"""
super(FastTextKeyedVectors, self).__init__(vector_size=vector_size)
self.vectors_vocab = None # fka syn0_vocab
self.vectors_ngrams = None # fka syn0_ngrams
self.buckets_word = None
self.min_n = min_n
self.max_n = max_n
self.bucket = bucket # count of buckets, fka num_ngram_vectors
self.buckets_word = None # precalculated cache of buckets for each word's ngrams
self.vectors_vocab = np.zeros((count, vector_size), dtype=dtype) # fka (formerly known as) syn0_vocab
self.vectors_ngrams = None # must be initialized later
self.compatible_hash = True

@classmethod
Expand Down Expand Up @@ -1311,63 +1212,21 @@ def get_vector(self, word, norm=False):
else:
return word_vec

def init_ngrams_weights(self, seed):
"""Initialize the vocabulary and ngrams weights prior to training.
Creates the weight matrices and initializes them with uniform random values.
Parameters
----------
seed : float
The seed for the PRNG.
Note
----
Call this **after** the vocabulary has been fully initialized.
"""
self.recalc_char_ngram_buckets()
def resize_vectors(self, seed=0):
"""Make underlying vectors match `index_to_key size; random-initialize any new rows.
rand_obj = np.random.default_rng(seed=seed) # use new instance of numpy's recommended generator/algorithm
Unlike in superclass, the 'vectors_vocab' array is of primary importance, with
'vectors' derived from it. And, the ngrams_vectors may need allocation."""

lo, hi = -1.0 / self.vector_size, 1.0 / self.vector_size
vocab_shape = (len(self), self.vector_size)
vocab_shape = (len(self.index_to_key), self.vector_size)
self.vectors_vocab = prep_vectors(vocab_shape, prior_vectors=self.vectors_vocab, seed=seed)
ngrams_shape = (self.bucket, self.vector_size)
self.vectors_vocab = rand_obj.uniform(lo, hi, vocab_shape).astype(REAL)

#
# We could have initialized vectors_ngrams at construction time, but we
# do it here for two reasons:
#
# 1. The constructor does not have access to the random seed
# 2. We want to use the same rand_obj to fill vectors_vocab _and_
# vectors_ngrams, and vectors_vocab cannot happen at construction
# time because the vocab is not initialized at that stage.
#
self.vectors_ngrams = rand_obj.uniform(lo, hi, ngrams_shape).astype(REAL)

def update_ngrams_weights(self, seed, old_vocab_len):
"""Update the vocabulary weights for training continuation.
Parameters
----------
seed : float
The seed for the PRNG.
old_vocab_length : int
The length of the vocabulary prior to its update.
Note
----
Call this **after** the vocabulary has been updated.
"""
self.recalc_char_ngram_buckets()

rand_obj = np.random
rand_obj.seed(seed)
self.vectors_ngrams = prep_vectors(ngrams_shape, prior_vectors=self.vectors_ngrams, seed=seed + 1)

new_vocab = len(self) - old_vocab_len
self.vectors_vocab = _pad_random(self.vectors_vocab, new_vocab, rand_obj)
self.allocate_vecattrs()
self.norms = None
self.recalc_char_ngram_buckets() # ensure new words have precalc buckets
self.adjust_vectors() # ensure `vectors` filled as well (though may be nonsense pre-training)

def init_post_load(self, fb_vectors):
"""Perform initialization after loading a native Facebook model.
Expand Down
49 changes: 23 additions & 26 deletions gensim/models/keyedvectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,34 +286,16 @@ def get_vecattr(self, key, attr):
index = self.get_index(key)
return self.expandos[attr][index]

def resize_vectors(self):
"""Make underlying vectors match index_to_key size."""
target_count = len(self.index_to_key)
prev_count = len(self.vectors)
if prev_count == target_count:
return ()
prev_vectors = self.vectors
if hasattr(self, 'mapfile_path') and self.mapfile_path:
self.vectors = np.memmap(self.mapfile_path, shape=(target_count, self.vector_size), mode='w+', dtype=REAL)
else:
self.vectors = np.zeros((target_count, self.vector_size), dtype=REAL)
self.vectors[0: min(prev_count, target_count), ] = prev_vectors[0: min(prev_count, target_count), ]
self.allocate_vecattrs()
self.norms = None
return range(prev_count, target_count)
def resize_vectors(self, seed=0):
"""Make underlying vectors match index_to_key size; random-initialize any new rows."""

def randomly_initialize_vectors(self, indexes=None, seed=0):
"""Initialize vectors with low-magnitude random vectors, as is typical for pre-trained
Word2Vec and related models.
target_shape = (len(self.index_to_key), self.vector_size)
self.vectors = prep_vectors(target_shape, prior_vectors=self.vectors, seed=seed)
# TODO: support memmap?
# if hasattr(self, 'mapfile_path') and self.mapfile_path:
# self.vectors = np.memmap(self.mapfile_path, shape=(target_count, self.vector_size), mode='w+', dtype=REAL)

"""
if indexes is None:
indexes = range(0, len(self.vectors))
for i in indexes:
self.vectors[i] = pseudorandom_weak_vector(
self.vectors.shape[1],
seed_string=str(self.index_to_key[i]) + str(seed),
)
self.allocate_vecattrs()
self.norms = None

def __len__(self):
Expand Down Expand Up @@ -1829,3 +1811,18 @@ def pseudorandom_weak_vector(size, seed_string=None, hashfxn=hash):
else:
once = utils.default_prng
return (once.random(size).astype(REAL) - 0.5) / size


def prep_vectors(target_shape, prior_vectors=None, seed=0, dtype=REAL):
"""Return a numpy array of the given shape. Reuse prior_vectors values instance or values
to extent possible. Initialize new values randomly if requested."""
if prior_vectors is None:
prior_vectors = np.zeros((0, 0))
if prior_vectors.shape == target_shape:
return prior_vectors
target_count, vector_size = target_shape
rng = np.random.default_rng(seed=seed) # use new instance of numpy's recommended generator/algorithm
new_vectors = rng.uniform(-1.0, 1.0, target_shape).astype(dtype)
new_vectors /= vector_size
new_vectors[0:prior_vectors.shape[0], 0:prior_vectors.shape[1]] = prior_vectors
return new_vectors
Loading

0 comments on commit cd2e326

Please sign in to comment.