From dee62072a26a0d1696bcc1b352c0aa6005990036 Mon Sep 17 00:00:00 2001 From: Gordon Mohr Date: Tue, 17 Dec 2019 18:33:26 -0800 Subject: [PATCH] rm Word2VecVocab class, persistent callbacks (bug #2136) --- gensim/models/base_any2vec.py | 60 +-- gensim/models/deprecated/word2vec.py | 2 +- gensim/models/word2vec.py | 549 +++++++++++++-------------- gensim/models/word2vec_inner.pyx | 6 +- gensim/test/test_word2vec.py | 20 +- 5 files changed, 312 insertions(+), 325 deletions(-) diff --git a/gensim/models/base_any2vec.py b/gensim/models/base_any2vec.py index f0a33ba7ff..62cce8d992 100644 --- a/gensim/models/base_any2vec.py +++ b/gensim/models/base_any2vec.py @@ -94,7 +94,6 @@ def __init__(self, workers=3, vector_size=100, epochs=5, callbacks=(), batch_wor self.total_train_time = 0 self.batch_words = batch_words self.model_trimmed_post_training = False - self.callbacks = callbacks def _get_job_params(self, cur_epoch): """Get job parameters required for each batch.""" @@ -193,6 +192,7 @@ def _worker_loop(self, job_queue, progress_queue): """ thread_private_mem = self._get_thread_working_mem() jobs_processed = 0 + callbacks = progress_queue.callbacks while True: job = job_queue.get() if job is None: @@ -200,12 +200,12 @@ def _worker_loop(self, job_queue, progress_queue): break # no more jobs => quit this worker data_iterable, job_parameters = job - for callback in self.callbacks: + for callback in callbacks: callback.on_batch_begin(self) tally, raw_tally = self._do_train_job(data_iterable, job_parameters, thread_private_mem) - for callback in self.callbacks: + for callback in callbacks: callback.on_batch_end(self) progress_queue.put((len(data_iterable), tally, raw_tally)) # report back progress @@ -366,7 +366,8 @@ def _log_epoch_progress(self, progress_queue=None, job_queue=None, cur_epoch=0, self.total_train_time += elapsed return trained_word_count, raw_word_count, job_tally - def _train_epoch_corpusfile(self, corpus_file, cur_epoch=0, total_examples=None, total_words=None, **kwargs): + def _train_epoch_corpusfile( + self, corpus_file, cur_epoch=0, total_examples=None, total_words=None, callbacks=(), **kwargs): """Train the model for a single epoch. Parameters @@ -430,7 +431,7 @@ def _train_epoch_corpusfile(self, corpus_file, cur_epoch=0, total_examples=None, return trained_word_count, raw_word_count, job_tally def _train_epoch(self, data_iterable, cur_epoch=0, total_examples=None, total_words=None, - queue_factor=2, report_delay=1.0): + queue_factor=2, report_delay=1.0, callbacks=()): """Train the model for a single epoch. Parameters @@ -462,6 +463,7 @@ def _train_epoch(self, data_iterable, cur_epoch=0, total_examples=None, total_wo """ job_queue = Queue(maxsize=queue_factor * self.workers) progress_queue = Queue(maxsize=(queue_factor + 1) * self.workers) + progress_queue.callbacks = callbacks # messy way to pass along for just this session workers = [ threading.Thread( @@ -522,15 +524,13 @@ def train(self, data_iterable=None, corpus_file=None, epochs=None, total_example """ self._set_train_params(**kwargs) - if callbacks: - self.callbacks = callbacks self.epochs = epochs self._check_training_sanity( epochs=epochs, total_examples=total_examples, total_words=total_words, **kwargs) - for callback in self.callbacks: + for callback in callbacks: callback.on_train_begin(self) trained_word_count = 0 @@ -539,22 +539,24 @@ def train(self, data_iterable=None, corpus_file=None, epochs=None, total_example job_tally = 0 for cur_epoch in range(self.epochs): - for callback in self.callbacks: + for callback in callbacks: callback.on_epoch_begin(self) if data_iterable is not None: trained_word_count_epoch, raw_word_count_epoch, job_tally_epoch = self._train_epoch( data_iterable, cur_epoch=cur_epoch, total_examples=total_examples, - total_words=total_words, queue_factor=queue_factor, report_delay=report_delay) + total_words=total_words, queue_factor=queue_factor, report_delay=report_delay, + callbacks=callbacks) else: trained_word_count_epoch, raw_word_count_epoch, job_tally_epoch = self._train_epoch_corpusfile( - corpus_file, cur_epoch=cur_epoch, total_examples=total_examples, total_words=total_words, **kwargs) + corpus_file, cur_epoch=cur_epoch, total_examples=total_examples, total_words=total_words, + callbacks=callbacks, **kwargs) trained_word_count += trained_word_count_epoch raw_word_count += raw_word_count_epoch job_tally += job_tally_epoch - for callback in self.callbacks: + for callback in callbacks: callback.on_epoch_end(self) # Log overall time @@ -564,7 +566,7 @@ def train(self, data_iterable=None, corpus_file=None, epochs=None, total_example self.train_count += 1 # number of times train() has been called self._clear_post_train() - for callback in self.callbacks: + for callback in callbacks: callback.on_train_end(self) return trained_word_count, raw_word_count @@ -730,13 +732,19 @@ def __init__(self, sentences=None, corpus_file=None, workers=3, vector_size=100, self.train( sentences=sentences, corpus_file=corpus_file, total_examples=self.corpus_count, total_words=self.corpus_total_words, epochs=self.epochs, start_alpha=self.alpha, - end_alpha=self.min_alpha, compute_loss=compute_loss) + end_alpha=self.min_alpha, compute_loss=compute_loss, callbacks=callbacks) else: if trim_rule is not None: logger.warning( "The rule, if given, is only used to prune vocabulary during build_vocab() " "and is not stored as part of the model. Model initialized without sentences. " "trim_rule provided, if any, will be ignored.") + if callbacks: + logger.warning( + "Callbacks are no longer retained by the model, so must be provided whenever " + "training is triggered, as in initialization with a corpus or calling `train()`. " + "The callbacks provided in this initialization without triggering train will " + "be ignored.") def _clear_post_train(self): raise NotImplementedError() @@ -797,18 +805,16 @@ def build_vocab(self, sentences=None, corpus_file=None, update=False, progress_p * `min_count` (int) - the minimum count threshold. **kwargs : object - Key word arguments propagated to `self.vocabulary.prepare_vocab` + Key word arguments propagated to `self.prepare_vocab` """ - total_words, corpus_count = self.vocabulary.scan_vocab( + total_words, corpus_count = self.scan_vocab( sentences=sentences, corpus_file=corpus_file, progress_per=progress_per, trim_rule=trim_rule) self.corpus_count = corpus_count self.corpus_total_words = total_words - report_values = self.vocabulary.prepare_vocab( - self.hs, self.negative, self.wv, update=update, keep_raw_vocab=keep_raw_vocab, - trim_rule=trim_rule, **kwargs) + report_values = self.prepare_vocab(update=update, keep_raw_vocab=keep_raw_vocab, trim_rule=trim_rule, **kwargs) report_values['memory'] = self.estimate_memory(vocab_size=report_values['num_retained_words']) - self.trainables.prepare_weights(self.hs, self.negative, self.wv, update=update, vocabulary=self.vocabulary) + self.trainables.prepare_weights(self.hs, self.negative, self.wv, update=update, vocabulary=self) def build_vocab_from_freq(self, word_freq, keep_raw_vocab=False, corpus_count=None, trim_rule=None, update=False): """Build vocabulary from a dictionary of word frequencies. @@ -850,15 +856,13 @@ def build_vocab_from_freq(self, word_freq, keep_raw_vocab=False, corpus_count=No # Since no sentences are provided, this is to control the corpus_count. self.corpus_count = corpus_count or 0 - self.vocabulary.raw_vocab = raw_vocab + self.raw_vocab = raw_vocab # trim by min_count & precalculate downsampling - report_values = self.vocabulary.prepare_vocab( - self.hs, self.negative, self.wv, keep_raw_vocab=keep_raw_vocab, - trim_rule=trim_rule, update=update) + report_values = self.prepare_vocab(keep_raw_vocab=keep_raw_vocab, trim_rule=trim_rule, update=update) report_values['memory'] = self.estimate_memory(vocab_size=report_values['num_retained_words']) self.trainables.prepare_weights( - self.hs, self.negative, self.wv, update=update, vocabulary=self.vocabulary) # build tables & arrays + self.hs, self.negative, self.wv, update=update, vocabulary=self) # build tables & arrays def estimate_memory(self, vocab_size=None, report=None): """Estimate required memory for a model using current settings and provided vocabulary size. @@ -1075,7 +1079,7 @@ def _check_training_sanity(self, epochs=None, total_examples=None, total_words=N "training model with %i workers on %i vocabulary and %i features, " "using sg=%s hs=%s sample=%s negative=%s window=%s", self.workers, len(self.wv.vocab), self.trainables.layer1_size, self.sg, - self.hs, self.vocabulary.sample, self.negative, self.window + self.hs, self.sample, self.negative, self.window ) @classmethod @@ -1112,10 +1116,8 @@ def load(cls, *args, **kwargs): model = super(BaseWordEmbeddingsModel, cls).load(*args, **kwargs) if not hasattr(model, 'ns_exponent'): model.ns_exponent = 0.75 - if not hasattr(model.vocabulary, 'ns_exponent'): - model.vocabulary.ns_exponent = 0.75 if model.negative and hasattr(model.wv, 'index2word'): - model.vocabulary.make_cum_table(model.wv) # rebuild cum_table from vocabulary + model.make_cum_table() # rebuild cum_table from vocabulary ## TODO: ??? if not hasattr(model, 'corpus_count'): model.corpus_count = None if not hasattr(model, 'corpus_total_words'): diff --git a/gensim/models/deprecated/word2vec.py b/gensim/models/deprecated/word2vec.py index d57a902c55..6e17e05dc5 100644 --- a/gensim/models/deprecated/word2vec.py +++ b/gensim/models/deprecated/word2vec.py @@ -203,7 +203,7 @@ def load_old_word2vec(*args, **kwargs): # set vocabulary attributes new_model.wv.vocab = old_model.wv.vocab new_model.wv.index2word = old_model.wv.index2word - new_model.vocabulary.cum_table = old_model.__dict__.get('cum_table', None) + new_model.cum_table = old_model.__dict__.get('cum_table', None) new_model.train_count = old_model.__dict__.get('train_count', None) new_model.corpus_count = old_model.__dict__.get('corpus_count', None) diff --git a/gensim/models/word2vec.py b/gensim/models/word2vec.py index b6a6c8c2d6..bdb71c71ce 100755 --- a/gensim/models/word2vec.py +++ b/gensim/models/word2vec.py @@ -348,19 +348,12 @@ def __init__(self, sentences=None, corpus_file=None, size=100, alpha=0.025, wind >>> sentences = [["cat", "say", "meow"], ["dog", "say", "woof"]] >>> model = Word2Vec(sentences, min_count=1) - Some important attributes are the following: - Attributes ---------- wv : :class:`~gensim.models.keyedvectors.KeyedVectors` This object essentially contains the mapping between words and embeddings. After training, it can be used directly to query those embeddings in various ways. See the module level docstring for examples. - vocabulary : :class:`~gensim.models.word2vec.Word2VecVocab` - This object represents the vocabulary (sometimes called Dictionary in gensim) of the model. - Besides keeping track of all unique words, this object provides extra functionality, such as - constructing a huffman tree (frequent words are closer to the root), or discarding extremely rare words. - trainables : :class:`~gensim.models.word2vec.Word2VecTrainables` This object represents the inner shallow neural network used to train the embeddings. The semantics of the network differ slightly in the two available training modes (CBOW or SG) but you can think of it @@ -370,21 +363,273 @@ def __init__(self, sentences=None, corpus_file=None, size=100, alpha=0.025, wind """ self.max_final_vocab = max_final_vocab - - self.callbacks = callbacks - self.load = call_on_class_only + self.max_vocab_size = max_vocab_size + self.min_count = min_count + self.sample = sample + self.sorted_vocab = sorted_vocab + self.null_word = null_word + self.cum_table = None # for negative sampling + self.raw_vocab = None self.wv = KeyedVectors(size) - self.vocabulary = Word2VecVocab( - max_vocab_size=max_vocab_size, min_count=min_count, sample=sample, sorted_vocab=bool(sorted_vocab), - null_word=null_word, max_final_vocab=max_final_vocab, ns_exponent=ns_exponent) + self.trainables = Word2VecTrainables(seed=seed, vector_size=size, hashfxn=hashfxn) + self.load = call_on_class_only + super(Word2Vec, self).__init__( sentences=sentences, corpus_file=corpus_file, workers=workers, vector_size=size, epochs=iter, callbacks=callbacks, batch_words=batch_words, trim_rule=trim_rule, sg=sg, alpha=alpha, window=window, seed=seed, hs=hs, negative=negative, cbow_mean=cbow_mean, min_alpha=min_alpha, compute_loss=compute_loss) + def _scan_vocab(self, sentences, progress_per, trim_rule): + sentence_no = -1 + total_words = 0 + min_reduce = 1 + vocab = defaultdict(int) + checked_string_types = 0 + for sentence_no, sentence in enumerate(sentences): + if not checked_string_types: + if isinstance(sentence, string_types): + logger.warning( + "Each 'sentences' item should be a list of words (usually unicode strings). " + "First item here is instead plain %s.", + type(sentence) + ) + checked_string_types += 1 + if sentence_no % progress_per == 0: + logger.info( + "PROGRESS: at sentence #%i, processed %i words, keeping %i word types", + sentence_no, total_words, len(vocab) + ) + for word in sentence: + vocab[word] += 1 + total_words += len(sentence) + + if self.max_vocab_size and len(vocab) > self.max_vocab_size: + utils.prune_vocab(vocab, min_reduce, trim_rule=trim_rule) + min_reduce += 1 + + corpus_count = sentence_no + 1 + self.raw_vocab = vocab + return total_words, corpus_count + + def scan_vocab(self, sentences=None, corpus_file=None, progress_per=10000, workers=None, trim_rule=None): + logger.info("collecting all words and their counts") + if corpus_file: + sentences = LineSentence(corpus_file) + + total_words, corpus_count = self._scan_vocab(sentences, progress_per, trim_rule) + + logger.info( + "collected %i word types from a corpus of %i raw words and %i sentences", + len(self.raw_vocab), total_words, corpus_count + ) + + return total_words, corpus_count + + def sort_vocab(self): + """Sort the vocabulary so the most frequent words have the lowest indexes.""" + if len(self.wv.vectors): + raise RuntimeError("cannot sort vocabulary after model weights already initialized.") + self.wv.index2key.sort(key=lambda word: self.wv.vocab[word].count, reverse=True) + for i, word in enumerate(self.wv.index2key): + self.wv.vocab[word].index = i + + def prepare_vocab( + self, update=False, keep_raw_vocab=False, trim_rule=None, + min_count=None, sample=None, dry_run=False): + """Apply vocabulary settings for `min_count` (discarding less-frequent words) + and `sample` (controlling the downsampling of more-frequent words). + + Calling with `dry_run=True` will only simulate the provided settings and + report the size of the retained vocabulary, effective corpus length, and + estimated memory requirements. Results are both printed via logging and + returned as a dict. + + Delete the raw vocabulary after the scaling is done to free up RAM, + unless `keep_raw_vocab` is set. + + """ + min_count = min_count or self.min_count + sample = sample or self.sample + drop_total = drop_unique = 0 + + # set effective_min_count to min_count in case max_final_vocab isn't set + self.effective_min_count = min_count + + # if max_final_vocab is specified instead of min_count + # pick a min_count which satisfies max_final_vocab as well as possible + if self.max_final_vocab is not None: + sorted_vocab = sorted(self.raw_vocab.keys(), key=lambda word: self.raw_vocab[word], reverse=True) + calc_min_count = 1 + + if self.max_final_vocab < len(sorted_vocab): + calc_min_count = self.raw_vocab[sorted_vocab[self.max_final_vocab]] + 1 + + self.effective_min_count = max(calc_min_count, min_count) + logger.info( + "max_final_vocab=%d and min_count=%d resulted in calc_min_count=%d, effective_min_count=%d", + self.max_final_vocab, min_count, calc_min_count, self.effective_min_count + ) + + if not update: + logger.info("Loading a fresh vocabulary") + retain_total, retain_words = 0, [] + # Discard words less-frequent than min_count + if not dry_run: + self.wv.index2key = [] + # make stored settings match these applied settings + self.min_count = min_count + self.sample = sample + self.wv.vocab = {} + + for word, v in iteritems(self.raw_vocab): + if keep_vocab_item(word, v, self.effective_min_count, trim_rule=trim_rule): + retain_words.append(word) + retain_total += v + if not dry_run: + self.wv.vocab[word] = W2VVocab(count=v, index=len(self.wv.index2key)) + self.wv.index2key.append(word) + else: + drop_unique += 1 + drop_total += v + original_unique_total = len(retain_words) + drop_unique + retain_unique_pct = len(retain_words) * 100 / max(original_unique_total, 1) + logger.info( + "effective_min_count=%d retains %i unique words (%i%% of original %i, drops %i)", + self.effective_min_count, len(retain_words), retain_unique_pct, original_unique_total, drop_unique + ) + original_total = retain_total + drop_total + retain_pct = retain_total * 100 / max(original_total, 1) + logger.info( + "effective_min_count=%d leaves %i word corpus (%i%% of original %i, drops %i)", + self.effective_min_count, retain_total, retain_pct, original_total, drop_total + ) + else: + logger.info("Updating model with new vocabulary") + new_total = pre_exist_total = 0 + new_words = pre_exist_words = [] + for word, v in iteritems(self.raw_vocab): + if keep_vocab_item(word, v, self.effective_min_count, trim_rule=trim_rule): + if word in self.wv.vocab: + pre_exist_words.append(word) + pre_exist_total += v + if not dry_run: + self.wv.vocab[word].count += v + else: + new_words.append(word) + new_total += v + if not dry_run: + self.wv.vocab[word] = W2VVocab(count=v, index=len(self.wv.index2key)) + self.wv.index2key.append(word) + else: + drop_unique += 1 + drop_total += v + original_unique_total = len(pre_exist_words) + len(new_words) + drop_unique + pre_exist_unique_pct = len(pre_exist_words) * 100 / max(original_unique_total, 1) + new_unique_pct = len(new_words) * 100 / max(original_unique_total, 1) + logger.info( + "New added %i unique words (%i%% of original %i) " + "and increased the count of %i pre-existing words (%i%% of original %i)", + len(new_words), new_unique_pct, original_unique_total, len(pre_exist_words), + pre_exist_unique_pct, original_unique_total + ) + retain_words = new_words + pre_exist_words + retain_total = new_total + pre_exist_total + + # Precalculate each vocabulary item's threshold for sampling + if not sample: + # no words downsampled + threshold_count = retain_total + elif sample < 1.0: + # traditional meaning: set parameter as proportion of total + threshold_count = sample * retain_total + else: + # new shorthand: sample >= 1 means downsample all words with higher count than sample + threshold_count = int(sample * (3 + sqrt(5)) / 2) + + downsample_total, downsample_unique = 0, 0 + for w in retain_words: + v = self.raw_vocab[w] + word_probability = (sqrt(v / threshold_count) + 1) * (threshold_count / v) + if word_probability < 1.0: + downsample_unique += 1 + downsample_total += word_probability * v + else: + word_probability = 1.0 + downsample_total += v + if not dry_run: + self.wv.vocab[w].sample_int = int(round(word_probability * 2**32)) + + if not dry_run and not keep_raw_vocab: + logger.info("deleting the raw counts dictionary of %i items", len(self.raw_vocab)) + self.raw_vocab = defaultdict(int) + + logger.info("sample=%g downsamples %i most-common words", sample, downsample_unique) + logger.info( + "downsampling leaves estimated %i word corpus (%.1f%% of prior %i)", + downsample_total, downsample_total * 100.0 / max(retain_total, 1), retain_total + ) + + # return from each step: words-affected, resulting-corpus-size, extra memory estimates + report_values = { + 'drop_unique': drop_unique, 'retain_total': retain_total, 'downsample_unique': downsample_unique, + 'downsample_total': int(downsample_total), 'num_retained_words': len(retain_words) + } + + if self.null_word: + # create null pseudo-word for padding when using concatenative L1 (run-of-words) + # this word is only ever input – never predicted – so count, huffman-point, etc doesn't matter + self.add_null_word() + + if self.sorted_vocab and not update: + self.sort_vocab() + if self.hs: + # add info about each word's Huffman encoding + self.create_binary_tree() + if self.negative: + # build the table for drawing random words (for negative sampling) + self.make_cum_table() + + return report_values + + def add_null_word(self): + word, v = '\0', W2VVocab(count=1, sample_int=0) + v.index = len(self.wv.vocab) + self.wv.index2key.append(word) + self.wv.vocab[word] = v + + def create_binary_tree(self): + """Create a `binary Huffman tree `_ using stored vocabulary + word counts. Frequent words will have shorter binary codes. + Called internally from :meth:`~gensim.models.word2vec.Word2VecVocab.build_vocab`. + + """ + _assign_binary_codes(self.wv.vocab) + + def make_cum_table(self, domain=2**31 - 1): + """Create a cumulative-distribution table using stored vocabulary word counts for + drawing random words in the negative-sampling training routines. + + To draw a word index, choose a random integer up to the maximum value in the table (cum_table[-1]), + then finding that integer's sorted insertion point (as if by `bisect_left` or `ndarray.searchsorted()`). + That insertion point is the drawn index, coming up in proportion equal to the increment at that slot. + + """ + vocab_size = len(self.wv.index2key) + self.cum_table = zeros(vocab_size, dtype=uint32) + # compute sum of all power (Z in paper) + train_words_pow = 0.0 + for word_index in range(vocab_size): + train_words_pow += self.wv.vocab[self.wv.index2key[word_index]].count**self.ns_exponent + cumulative = 0.0 + for word_index in range(vocab_size): + cumulative += self.wv.vocab[self.wv.index2key[word_index]].count**self.ns_exponent + self.cum_table[word_index] = round(cumulative / train_words_pow * domain) + if len(self.cum_table) > 0: + assert self.cum_table[-1] == domain + def _do_train_epoch(self, corpus_file, thread_id, offset, cython_vocab, thread_private_mem, cur_epoch, total_examples=None, total_words=None, **kwargs): work, neu1 = thread_private_mem @@ -548,7 +793,7 @@ def score(self, sentences, total_sentences=int(1e6), chunksize=100, queue_factor "scoring sentences with %i workers on %i vocabulary and %i features, " "using sg=%s hs=%s sample=%s and negative=%s", self.workers, len(self.wv.vocab), self.trainables.layer1_size, self.sg, self.hs, - self.vocabulary.sample, self.negative + self.sample, self.negative ) if not self.wv.vocab: @@ -771,7 +1016,7 @@ def reset_from(self, other_model): """ self.wv.vocab = other_model.wv.vocab self.wv.index2key = other_model.wv.index2key - self.vocabulary.cum_table = other_model.vocabulary.cum_table + self.cum_table = other_model.cum_table self.corpus_count = other_model.corpus_count self.trainables.reset_weights(self.hs, self.negative, self.wv) @@ -837,12 +1082,13 @@ def load(cls, *args, **kwargs): """ try: model = super(Word2Vec, cls).load(*args, **kwargs) - - # for backward compatibility for `max_final_vocab` feature + # for backward compatibility if not hasattr(model, 'max_final_vocab'): model.max_final_vocab = None - model.vocabulary.max_final_vocab = None - + if hasattr(model, 'vocabulary'): # re-integrate state that had been moved + for a in ('max_vocab_size', 'min_count', 'sample', 'sorted_vocab', 'null_word', 'raw_vocab'): + setattr(model, a, getattr(model.vocabulary, a)) + del model.vocabulary return model except AttributeError: logger.info('Model saved using code from earlier Gensim Version. Re-loading old model in a compatible way.') @@ -1077,269 +1323,8 @@ def __lt__(self, other): class Word2VecVocab(utils.SaveLoad): - def __init__( - self, max_vocab_size=None, min_count=5, sample=1e-3, sorted_vocab=True, null_word=0, - max_final_vocab=None, ns_exponent=0.75): - """Vocabulary used by :class:`~gensim.models.word2vec.Word2Vec`.""" - self.max_vocab_size = max_vocab_size - self.min_count = min_count - self.sample = sample - self.sorted_vocab = sorted_vocab - self.null_word = null_word - self.cum_table = None # for negative sampling - self.raw_vocab = None - self.max_final_vocab = max_final_vocab - self.ns_exponent = ns_exponent - - def _scan_vocab(self, sentences, progress_per, trim_rule): - sentence_no = -1 - total_words = 0 - min_reduce = 1 - vocab = defaultdict(int) - checked_string_types = 0 - for sentence_no, sentence in enumerate(sentences): - if not checked_string_types: - if isinstance(sentence, string_types): - logger.warning( - "Each 'sentences' item should be a list of words (usually unicode strings). " - "First item here is instead plain %s.", - type(sentence) - ) - checked_string_types += 1 - if sentence_no % progress_per == 0: - logger.info( - "PROGRESS: at sentence #%i, processed %i words, keeping %i word types", - sentence_no, total_words, len(vocab) - ) - for word in sentence: - vocab[word] += 1 - total_words += len(sentence) - - if self.max_vocab_size and len(vocab) > self.max_vocab_size: - utils.prune_vocab(vocab, min_reduce, trim_rule=trim_rule) - min_reduce += 1 - - corpus_count = sentence_no + 1 - self.raw_vocab = vocab - return total_words, corpus_count - - def scan_vocab(self, sentences=None, corpus_file=None, progress_per=10000, workers=None, trim_rule=None): - logger.info("collecting all words and their counts") - if corpus_file: - sentences = LineSentence(corpus_file) - - total_words, corpus_count = self._scan_vocab(sentences, progress_per, trim_rule) - - logger.info( - "collected %i word types from a corpus of %i raw words and %i sentences", - len(self.raw_vocab), total_words, corpus_count - ) - - return total_words, corpus_count - - def sort_vocab(self, wv): - """Sort the vocabulary so the most frequent words have the lowest indexes.""" - if len(wv.vectors): - raise RuntimeError("cannot sort vocabulary after model weights already initialized.") - wv.index2key.sort(key=lambda word: wv.vocab[word].count, reverse=True) - for i, word in enumerate(wv.index2key): - wv.vocab[word].index = i - - def prepare_vocab( - self, hs, negative, wv, update=False, keep_raw_vocab=False, trim_rule=None, - min_count=None, sample=None, dry_run=False): - """Apply vocabulary settings for `min_count` (discarding less-frequent words) - and `sample` (controlling the downsampling of more-frequent words). - - Calling with `dry_run=True` will only simulate the provided settings and - report the size of the retained vocabulary, effective corpus length, and - estimated memory requirements. Results are both printed via logging and - returned as a dict. - - Delete the raw vocabulary after the scaling is done to free up RAM, - unless `keep_raw_vocab` is set. - - """ - min_count = min_count or self.min_count - sample = sample or self.sample - drop_total = drop_unique = 0 - - # set effective_min_count to min_count in case max_final_vocab isn't set - self.effective_min_count = min_count - - # if max_final_vocab is specified instead of min_count - # pick a min_count which satisfies max_final_vocab as well as possible - if self.max_final_vocab is not None: - sorted_vocab = sorted(self.raw_vocab.keys(), key=lambda word: self.raw_vocab[word], reverse=True) - calc_min_count = 1 - - if self.max_final_vocab < len(sorted_vocab): - calc_min_count = self.raw_vocab[sorted_vocab[self.max_final_vocab]] + 1 - - self.effective_min_count = max(calc_min_count, min_count) - logger.info( - "max_final_vocab=%d and min_count=%d resulted in calc_min_count=%d, effective_min_count=%d", - self.max_final_vocab, min_count, calc_min_count, self.effective_min_count - ) - - if not update: - logger.info("Loading a fresh vocabulary") - retain_total, retain_words = 0, [] - # Discard words less-frequent than min_count - if not dry_run: - wv.index2key = [] - # make stored settings match these applied settings - self.min_count = min_count - self.sample = sample - wv.vocab = {} - - for word, v in iteritems(self.raw_vocab): - if keep_vocab_item(word, v, self.effective_min_count, trim_rule=trim_rule): - retain_words.append(word) - retain_total += v - if not dry_run: - wv.vocab[word] = W2VVocab(count=v, index=len(wv.index2key)) - wv.index2key.append(word) - else: - drop_unique += 1 - drop_total += v - original_unique_total = len(retain_words) + drop_unique - retain_unique_pct = len(retain_words) * 100 / max(original_unique_total, 1) - logger.info( - "effective_min_count=%d retains %i unique words (%i%% of original %i, drops %i)", - self.effective_min_count, len(retain_words), retain_unique_pct, original_unique_total, drop_unique - ) - original_total = retain_total + drop_total - retain_pct = retain_total * 100 / max(original_total, 1) - logger.info( - "effective_min_count=%d leaves %i word corpus (%i%% of original %i, drops %i)", - self.effective_min_count, retain_total, retain_pct, original_total, drop_total - ) - else: - logger.info("Updating model with new vocabulary") - new_total = pre_exist_total = 0 - new_words = pre_exist_words = [] - for word, v in iteritems(self.raw_vocab): - if keep_vocab_item(word, v, self.effective_min_count, trim_rule=trim_rule): - if word in wv.vocab: - pre_exist_words.append(word) - pre_exist_total += v - if not dry_run: - wv.vocab[word].count += v - else: - new_words.append(word) - new_total += v - if not dry_run: - wv.vocab[word] = W2VVocab(count=v, index=len(wv.index2key)) - wv.index2key.append(word) - else: - drop_unique += 1 - drop_total += v - original_unique_total = len(pre_exist_words) + len(new_words) + drop_unique - pre_exist_unique_pct = len(pre_exist_words) * 100 / max(original_unique_total, 1) - new_unique_pct = len(new_words) * 100 / max(original_unique_total, 1) - logger.info( - "New added %i unique words (%i%% of original %i) " - "and increased the count of %i pre-existing words (%i%% of original %i)", - len(new_words), new_unique_pct, original_unique_total, len(pre_exist_words), - pre_exist_unique_pct, original_unique_total - ) - retain_words = new_words + pre_exist_words - retain_total = new_total + pre_exist_total - - # Precalculate each vocabulary item's threshold for sampling - if not sample: - # no words downsampled - threshold_count = retain_total - elif sample < 1.0: - # traditional meaning: set parameter as proportion of total - threshold_count = sample * retain_total - else: - # new shorthand: sample >= 1 means downsample all words with higher count than sample - threshold_count = int(sample * (3 + sqrt(5)) / 2) - - downsample_total, downsample_unique = 0, 0 - for w in retain_words: - v = self.raw_vocab[w] - word_probability = (sqrt(v / threshold_count) + 1) * (threshold_count / v) - if word_probability < 1.0: - downsample_unique += 1 - downsample_total += word_probability * v - else: - word_probability = 1.0 - downsample_total += v - if not dry_run: - wv.vocab[w].sample_int = int(round(word_probability * 2**32)) - - if not dry_run and not keep_raw_vocab: - logger.info("deleting the raw counts dictionary of %i items", len(self.raw_vocab)) - self.raw_vocab = defaultdict(int) - - logger.info("sample=%g downsamples %i most-common words", sample, downsample_unique) - logger.info( - "downsampling leaves estimated %i word corpus (%.1f%% of prior %i)", - downsample_total, downsample_total * 100.0 / max(retain_total, 1), retain_total - ) - - # return from each step: words-affected, resulting-corpus-size, extra memory estimates - report_values = { - 'drop_unique': drop_unique, 'retain_total': retain_total, 'downsample_unique': downsample_unique, - 'downsample_total': int(downsample_total), 'num_retained_words': len(retain_words) - } - - if self.null_word: - # create null pseudo-word for padding when using concatenative L1 (run-of-words) - # this word is only ever input – never predicted – so count, huffman-point, etc doesn't matter - self.add_null_word(wv) - - if self.sorted_vocab and not update: - self.sort_vocab(wv) - if hs: - # add info about each word's Huffman encoding - self.create_binary_tree(wv) - if negative: - # build the table for drawing random words (for negative sampling) - self.make_cum_table(wv) - - return report_values - - def add_null_word(self, wv): - word, v = '\0', W2VVocab(count=1, sample_int=0) - v.index = len(wv.vocab) - wv.index2key.append(word) - wv.vocab[word] = v - - def create_binary_tree(self, wv): - """Create a `binary Huffman tree `_ using stored vocabulary - word counts. Frequent words will have shorter binary codes. - Called internally from :meth:`~gensim.models.word2vec.Word2VecVocab.build_vocab`. - - """ - _assign_binary_codes(wv.vocab) - - def make_cum_table(self, wv, domain=2**31 - 1): - """Create a cumulative-distribution table using stored vocabulary word counts for - drawing random words in the negative-sampling training routines. - - To draw a word index, choose a random integer up to the maximum value in the table (cum_table[-1]), - then finding that integer's sorted insertion point (as if by `bisect_left` or `ndarray.searchsorted()`). - That insertion point is the drawn index, coming up in proportion equal to the increment at that slot. - - Called internally from :meth:`~gensim.models.word2vec.Word2VecVocab.build_vocab`. - - """ - vocab_size = len(wv.index2key) - self.cum_table = zeros(vocab_size, dtype=uint32) - # compute sum of all power (Z in paper) - train_words_pow = 0.0 - for word_index in range(vocab_size): - train_words_pow += wv.vocab[wv.index2key[word_index]].count**self.ns_exponent - cumulative = 0.0 - for word_index in range(vocab_size): - cumulative += wv.vocab[wv.index2key[word_index]].count**self.ns_exponent - self.cum_table[word_index] = round(cumulative / train_words_pow * domain) - if len(self.cum_table) > 0: - assert self.cum_table[-1] == domain + """Obsolete class retained for now as load-compatibility state capture""" + pass class Heapitem(namedtuple('Heapitem', 'count, index, left, right')): diff --git a/gensim/models/word2vec_inner.pyx b/gensim/models/word2vec_inner.pyx index 0576773bd5..776d4b2308 100755 --- a/gensim/models/word2vec_inner.pyx +++ b/gensim/models/word2vec_inner.pyx @@ -467,7 +467,7 @@ cdef unsigned long long w2v_fast_sentence_cbow_neg( cdef init_w2v_config(Word2VecConfig *c, model, alpha, compute_loss, _work, _neu1=None): c[0].hs = model.hs c[0].negative = model.negative - c[0].sample = (model.vocabulary.sample != 0) + c[0].sample = (model.sample != 0) c[0].cbow_mean = model.cbow_mean c[0].window = model.window c[0].workers = model.workers @@ -485,8 +485,8 @@ cdef init_w2v_config(Word2VecConfig *c, model, alpha, compute_loss, _work, _neu1 if c[0].negative: c[0].syn1neg = (np.PyArray_DATA(model.trainables.syn1neg)) - c[0].cum_table = (np.PyArray_DATA(model.vocabulary.cum_table)) - c[0].cum_table_len = len(model.vocabulary.cum_table) + c[0].cum_table = (np.PyArray_DATA(model.cum_table)) + c[0].cum_table_len = len(model.cum_table) if c[0].negative or c[0].sample: c[0].next_random = (2**24) * model.random.randint(0, 2**24) + model.random.randint(0, 2**24) diff --git a/gensim/test/test_word2vec.py b/gensim/test/test_word2vec.py index b610047a84..c2e0900f99 100644 --- a/gensim/test/test_word2vec.py +++ b/gensim/test/test_word2vec.py @@ -143,29 +143,29 @@ def testPruneVocab(self): def testTotalWordCount(self): model = word2vec.Word2Vec(size=10, min_count=0, seed=42) - total_words = model.vocabulary.scan_vocab(sentences)[0] + total_words = model.scan_vocab(sentences)[0] self.assertEqual(total_words, 29) def testMaxFinalVocab(self): # Test for less restricting effect of max_final_vocab # max_final_vocab is specified but has no effect model = word2vec.Word2Vec(size=10, max_final_vocab=4, min_count=4, sample=0) - model.vocabulary.scan_vocab(sentences) - reported_values = model.vocabulary.prepare_vocab(wv=model.wv, hs=0, negative=0) + model.scan_vocab(sentences) + reported_values = model.prepare_vocab() self.assertEqual(reported_values['drop_unique'], 11) self.assertEqual(reported_values['retain_total'], 4) self.assertEqual(reported_values['num_retained_words'], 1) - self.assertEqual(model.vocabulary.effective_min_count, 4) + self.assertEqual(model.effective_min_count, 4) # Test for more restricting effect of max_final_vocab # results in setting a min_count more restricting than specified min_count model = word2vec.Word2Vec(size=10, max_final_vocab=4, min_count=2, sample=0) - model.vocabulary.scan_vocab(sentences) - reported_values = model.vocabulary.prepare_vocab(wv=model.wv, hs=0, negative=0) + model.scan_vocab(sentences) + reported_values = model.prepare_vocab() self.assertEqual(reported_values['drop_unique'], 8) self.assertEqual(reported_values['retain_total'], 13) self.assertEqual(reported_values['num_retained_words'], 4) - self.assertEqual(model.vocabulary.effective_min_count, 3) + self.assertEqual(model.effective_min_count, 3) def testOnlineLearning(self): """Test that the algorithm is able to add new words to the @@ -873,7 +873,7 @@ def testLoadOldModel(self): self.assertTrue(len(model.wv.index2word) == 12) self.assertTrue(model.trainables.syn1neg.shape == (len(model.wv.vocab), model.wv.vector_size)) self.assertTrue(model.trainables.vectors_lockf.shape == (12,)) - self.assertTrue(model.vocabulary.cum_table.shape == (12,)) + self.assertTrue(model.cum_table.shape == (12,)) self.onlineSanity(model, trained_model=True) @@ -888,7 +888,7 @@ def testLoadOldModelSeparates(self): self.assertTrue(len(model.wv.index2word) == 12) self.assertTrue(model.trainables.syn1neg.shape == (len(model.wv.vocab), model.wv.vector_size)) self.assertTrue(model.trainables.vectors_lockf.shape == (12,)) - self.assertTrue(model.vocabulary.cum_table.shape == (12,)) + self.assertTrue(model.cum_table.shape == (12,)) self.onlineSanity(model, trained_model=True) @@ -934,7 +934,7 @@ def test_load_old_models_3_x(self): model_file = 'word2vec_3.3' model = word2vec.Word2Vec.load(datapath(model_file)) self.assertEqual(model.max_final_vocab, None) - self.assertEqual(model.vocabulary.max_final_vocab, None) + self.assertEqual(model.max_final_vocab, None) old_versions = [ '3.0.0', '3.1.0', '3.2.0', '3.3.0', '3.4.0'