Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speed up word2vec model loading #2671

Merged
merged 6 commits into from
Nov 18, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 84 additions & 48 deletions gensim/models/utils_any2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# -*- coding: utf-8 -*-
#
# Author: Shiva Manne <s.manne@rare-technologies.com>
# Copyright (C) 2018 RaRe Technologies s.r.o.
# Copyright (C) 2019 RaRe Technologies s.r.o.

"""General functions used for any2vec models.

Expand All @@ -28,7 +28,7 @@
import logging
from gensim import utils

from numpy import zeros, dtype, float32 as REAL, ascontiguousarray, fromstring
from numpy import zeros, dtype, float32 as REAL, ascontiguousarray, frombuffer

from six.moves import range
from six import iteritems, PY2
Expand Down Expand Up @@ -146,8 +146,83 @@ def _save_word2vec_format(fname, vocab, vectors, fvocab=None, binary=False, tota
fout.write(utils.to_utf8("%s %s\n" % (word, ' '.join(repr(val) for val in row))))


# Functions for internal use by _load_word2vec_format function


def _add_word_to_result(result, counts, word, weights, vocab_size):
from gensim.models.keyedvectors import Vocab
word_id = len(result.vocab)
if word in result.vocab:
logger.warning("duplicate word '%s' in word2vec file, ignoring all but first", word)
return
if counts is None:
# most common scenario: no vocab file given. just make up some bogus counts, in descending order
word_count = vocab_size - word_id
elif word in counts:
# use count from the vocab file
word_count = counts[word]
else:
logger.warning("vocabulary file is incomplete: '%s' is missing", word)
word_count = None

result.vocab[word] = Vocab(index=word_id, count=word_count)
result.vectors[word_id] = weights
result.index2word.append(word)


def _add_bytes_to_result(result, counts, chunk, vocab_size, vector_size, datatype, unicode_errors):
start = 0
processed_words = 0
bytes_per_vector = vector_size * dtype(REAL).itemsize
max_words = vocab_size - len(result.vocab)
for _ in range(max_words):
i_space = chunk.find(b' ', start)
i_vector = i_space + 1

if i_space == -1 or (len(chunk) - i_vector) < bytes_per_vector:
break

word = chunk[start:i_space].decode("utf-8", errors=unicode_errors)
# Some binary files are reported to have obsolete new line in the beginning of word, remove it
word = word.lstrip('\n')
vector = frombuffer(chunk, offset=i_vector, count=vector_size, dtype=REAL).astype(datatype)
_add_word_to_result(result, counts, word, vector, vocab_size)
start = i_vector + bytes_per_vector
processed_words += 1

return processed_words, chunk[start:]


def _word2vec_read_binary(fin, result, counts, vocab_size, vector_size, datatype, unicode_errors, binary_chunk_size):
chunk = b''
tot_processed_words = 0

while tot_processed_words < vocab_size:
new_chunk = fin.read(binary_chunk_size)
chunk += new_chunk
processed_words, chunk = _add_bytes_to_result(
result, counts, chunk, vocab_size, vector_size, datatype, unicode_errors)
tot_processed_words += processed_words
if len(new_chunk) < binary_chunk_size:
break
if tot_processed_words != vocab_size:
raise EOFError("unexpected end of input; is count incorrect or file otherwise damaged?")


def _word2vec_read_text(fin, result, counts, vocab_size, vector_size, datatype, unicode_errors, encoding):
for line_no in range(vocab_size):
line = fin.readline()
if line == b'':
raise EOFError("unexpected end of input; is count incorrect or file otherwise damaged?")
parts = utils.to_unicode(line.rstrip(), encoding=encoding, errors=unicode_errors).split(" ")
if len(parts) != vector_size + 1:
raise ValueError("invalid vector on line %s (is this really the text format?)" % line_no)
word, weights = parts[0], [datatype(x) for x in parts[1:]]
_add_word_to_result(result, counts, word, weights, vocab_size)


def _load_word2vec_format(cls, fname, fvocab=None, binary=False, encoding='utf8', unicode_errors='strict',
limit=None, datatype=REAL):
limit=None, datatype=REAL, binary_chunk_size=100 * 1024):
"""Load the input-hidden weight matrix from the original C word2vec-tool format.

Note that the information stored in the file is incomplete (the binary tree is missing),
Expand Down Expand Up @@ -176,14 +251,16 @@ def _load_word2vec_format(cls, fname, fvocab=None, binary=False, encoding='utf8'
datatype : type, optional
(Experimental) Can coerce dimensions to a non-default float type (such as `np.float16`) to save memory.
Such types may result in much slower bulk operations or incompatibility with optimized routines.)
binary_chunk_size : int, optional
Read input file in chunks of this many bytes for performance reasons.

Returns
-------
object
Returns the loaded model as an instance of :class:`cls`.

"""
from gensim.models.keyedvectors import Vocab

counts = None
if fvocab is not None:
logger.info("loading word counts from %s", fvocab)
Expand All @@ -203,52 +280,11 @@ def _load_word2vec_format(cls, fname, fvocab=None, binary=False, encoding='utf8'
result.vector_size = vector_size
result.vectors = zeros((vocab_size, vector_size), dtype=datatype)

def add_word(word, weights):
word_id = len(result.vocab)
if word in result.vocab:
logger.warning("duplicate word '%s' in %s, ignoring all but first", word, fname)
return
if counts is None:
# most common scenario: no vocab file given. just make up some bogus counts, in descending order
result.vocab[word] = Vocab(index=word_id, count=vocab_size - word_id)
elif word in counts:
# use count from the vocab file
result.vocab[word] = Vocab(index=word_id, count=counts[word])
else:
# vocab file given, but word is missing -- set count to None (TODO: or raise?)
logger.warning("vocabulary file is incomplete: '%s' is missing", word)
result.vocab[word] = Vocab(index=word_id, count=None)
result.vectors[word_id] = weights
result.index2word.append(word)

if binary:
mpenkov marked this conversation as resolved.
Show resolved Hide resolved
binary_len = dtype(REAL).itemsize * vector_size
for _ in range(vocab_size):
# mixed text and binary: read text first, then binary
word = []
while True:
ch = fin.read(1) # Python uses I/O buffering internally
if ch == b' ':
break
if ch == b'':
raise EOFError("unexpected end of input; is count incorrect or file otherwise damaged?")
if ch != b'\n': # ignore newlines in front of words (some binary files have)
word.append(ch)
word = utils.to_unicode(b''.join(word), encoding=encoding, errors=unicode_errors)
with utils.ignore_deprecation_warning():
# TODO use frombuffer or something similar
weights = fromstring(fin.read(binary_len), dtype=REAL).astype(datatype)
add_word(word, weights)
_word2vec_read_binary(fin, result, counts,
vocab_size, vector_size, datatype, unicode_errors, binary_chunk_size)
else:
for line_no in range(vocab_size):
line = fin.readline()
if line == b'':
raise EOFError("unexpected end of input; is count incorrect or file otherwise damaged?")
parts = utils.to_unicode(line.rstrip(), encoding=encoding, errors=unicode_errors).split(" ")
if len(parts) != vector_size + 1:
raise ValueError("invalid vector on line %s (is this really the text format?)" % line_no)
word, weights = parts[0], [datatype(x) for x in parts[1:]]
add_word(word, weights)
_word2vec_read_text(fin, result, counts, vocab_size, vector_size, datatype, unicode_errors, encoding)
if result.vectors.shape[0] != len(result.vocab):
logger.info(
"duplicate words detected, shrinking matrix size from %i to %i",
Expand Down
122 changes: 122 additions & 0 deletions gensim/test/test_utils_any2vec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (C) 2017 Radim Rehurek <me@radimrehurek.com>
piskvorky marked this conversation as resolved.
Show resolved Hide resolved
# Licensed under the GNU LGPL v2.1 - http://www.gnu.org/licenses/lgpl.html

"""
Automated tests for checking utils_any2vec functionality.
"""

import logging
import unittest

import numpy as np

import gensim.utils
import gensim.test.utils

import gensim.models.utils_any2vec


logger = logging.getLogger(__name__)


def save_dict_to_word2vec_formated_file(fname, word2vec_dict):

with gensim.utils.open(fname, "bw") as f:

num_words = len(word2vec_dict)
vector_length = len(list(word2vec_dict.values())[0])

header = "%d %d\n" % (num_words, vector_length)
f.write(header.encode(encoding="ascii"))

for word, vector in word2vec_dict.items():
f.write(word.encode())
f.write(' '.encode())
f.write(np.array(vector).astype(np.float32).tobytes())


class LoadWord2VecFormatTest(unittest.TestCase):

def assert_dict_equal_to_model(self, d, m):
self.assertEqual(len(d), len(m.vocab))

for word in d.keys():
self.assertSequenceEqual(list(d[word]), list(m[word]))

def verify_load2vec_binary_result(self, w2v_dict, binary_chunk_size, limit):
tmpfile = gensim.test.utils.get_tmpfile("tmp_w2v")
save_dict_to_word2vec_formated_file(tmpfile, w2v_dict)
w2v_model = \
gensim.models.utils_any2vec._load_word2vec_format(
cls=gensim.models.KeyedVectors,
fname=tmpfile,
binary=True,
limit=limit,
binary_chunk_size=binary_chunk_size)
if limit is None:
limit = len(w2v_dict)

w2v_keys_postprocessed = list(w2v_dict.keys())[:limit]
w2v_dict_postprocessed = {k.lstrip(): w2v_dict[k] for k in w2v_keys_postprocessed}

self.assert_dict_equal_to_model(w2v_dict_postprocessed, w2v_model)

def test_load_word2vec_format_basic(self):
w2v_dict = {"abc": [1, 2, 3],
"cde": [4, 5, 6],
"def": [7, 8, 9]}
self.verify_load2vec_binary_result(w2v_dict, binary_chunk_size=5, limit=None)
self.verify_load2vec_binary_result(w2v_dict, binary_chunk_size=16, limit=None)
self.verify_load2vec_binary_result(w2v_dict, binary_chunk_size=1024, limit=None)

w2v_dict = {"abc": [1, 2, 3],
"cdefg": [4, 5, 6],
"d": [7, 8, 9]}
self.verify_load2vec_binary_result(w2v_dict, binary_chunk_size=5, limit=None)
self.verify_load2vec_binary_result(w2v_dict, binary_chunk_size=16, limit=None)
self.verify_load2vec_binary_result(w2v_dict, binary_chunk_size=1024, limit=None)

def test_load_word2vec_format_limit(self):
w2v_dict = {"abc": [1, 2, 3],
"cde": [4, 5, 6],
"def": [7, 8, 9]}
self.verify_load2vec_binary_result(w2v_dict, binary_chunk_size=5, limit=1)
self.verify_load2vec_binary_result(w2v_dict, binary_chunk_size=16, limit=1)
self.verify_load2vec_binary_result(w2v_dict, binary_chunk_size=1024, limit=1)

w2v_dict = {"abc": [1, 2, 3],
"cde": [4, 5, 6],
"def": [7, 8, 9]}
self.verify_load2vec_binary_result(w2v_dict, binary_chunk_size=5, limit=2)
self.verify_load2vec_binary_result(w2v_dict, binary_chunk_size=16, limit=2)
self.verify_load2vec_binary_result(w2v_dict, binary_chunk_size=1024, limit=2)

w2v_dict = {"abc": [1, 2, 3],
"cdefg": [4, 5, 6],
"d": [7, 8, 9]}

self.verify_load2vec_binary_result(w2v_dict, binary_chunk_size=5, limit=1)
self.verify_load2vec_binary_result(w2v_dict, binary_chunk_size=16, limit=1)
self.verify_load2vec_binary_result(w2v_dict, binary_chunk_size=1024, limit=1)

w2v_dict = {"abc": [1, 2, 3],
"cdefg": [4, 5, 6],
"d": [7, 8, 9]}
self.verify_load2vec_binary_result(w2v_dict, binary_chunk_size=5, limit=2)
self.verify_load2vec_binary_result(w2v_dict, binary_chunk_size=16, limit=2)
self.verify_load2vec_binary_result(w2v_dict, binary_chunk_size=1024, limit=2)

def test_load_word2vec_format_space_stripping(self):
w2v_dict = {"\nabc": [1, 2, 3],
"cdefdg": [4, 5, 6],
"\n\ndef": [7, 8, 9]}
self.verify_load2vec_binary_result(w2v_dict, binary_chunk_size=5, limit=None)
self.verify_load2vec_binary_result(w2v_dict, binary_chunk_size=5, limit=1)


if __name__ == '__main__':
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.DEBUG)
unittest.main()