Skip to content

Commit

Permalink
Fix KeyedVectors.add matrix type (#2761)
Browse files Browse the repository at this point in the history
* add type test

* cast internal state to passed type

* ekv -> kv

* parametrize datatype & cast embeddings passed to `add` to KV datatype

* set f32 as default type

Co-authored-by: Ivan Menshikh <imenshikh@embedika.ru>
Co-authored-by: Michael Penkov <m@penkov.dev>
  • Loading branch information
3 people authored Mar 21, 2020
1 parent 493e52f commit 30ca5b3
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 11 deletions.
6 changes: 3 additions & 3 deletions gensim/models/keyedvectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def __str__(self):
class BaseKeyedVectors(utils.SaveLoad):
"""Abstract base class / interface for various types of word vectors."""
def __init__(self, vector_size):
self.vectors = zeros((0, vector_size))
self.vectors = zeros((0, vector_size), dtype=REAL)
self.vocab = {}
self.vector_size = vector_size
self.index2entity = []
Expand Down Expand Up @@ -308,7 +308,7 @@ def add(self, entities, weights, replace=False):
self.index2entity.append(entity)

# add vectors for new entities
self.vectors = vstack((self.vectors, weights[~in_vocab_mask]))
self.vectors = vstack((self.vectors, weights[~in_vocab_mask].astype(self.vectors.dtype)))

# change vectors for in_vocab entities if `replace` flag is specified
if replace:
Expand Down Expand Up @@ -2113,7 +2113,7 @@ def word_vec(self, word, use_norm=False):
elif self.bucket == 0:
raise KeyError('cannot calculate vector for OOV word without ngrams')
else:
word_vec = np.zeros(self.vectors_ngrams.shape[1], dtype=np.float32)
word_vec = np.zeros(self.vectors_ngrams.shape[1], dtype=REAL)
ngram_hashes = ft_ngram_hashes(word, self.min_n, self.max_n, self.bucket, self.compatible_hash)
if len(ngram_hashes) == 0:
#
Expand Down
25 changes: 17 additions & 8 deletions gensim/test/test_keyedvectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
import numpy as np

from gensim.corpora import Dictionary
from gensim.models.keyedvectors import KeyedVectors as EuclideanKeyedVectors, WordEmbeddingSimilarityIndex, \
FastTextKeyedVectors
from gensim.models.keyedvectors import KeyedVectors, WordEmbeddingSimilarityIndex, \
FastTextKeyedVectors, REAL
from gensim.test.utils import datapath

import gensim.models.keyedvectors
Expand All @@ -27,7 +27,7 @@

class TestWordEmbeddingSimilarityIndex(unittest.TestCase):
def setUp(self):
self.vectors = EuclideanKeyedVectors.load_word2vec_format(
self.vectors = KeyedVectors.load_word2vec_format(
datapath('euclidean_vectors.bin'), binary=True, datatype=np.float64)

def test_most_similar(self):
Expand Down Expand Up @@ -70,9 +70,9 @@ def test_most_similar(self):
self.assertTrue(np.allclose(first_similarities**2.0, second_similarities))


class TestEuclideanKeyedVectors(unittest.TestCase):
class TestKeyedVectors(unittest.TestCase):
def setUp(self):
self.vectors = EuclideanKeyedVectors.load_word2vec_format(
self.vectors = KeyedVectors.load_word2vec_format(
datapath('euclidean_vectors.bin'), binary=True, datatype=np.float64)

def test_similarity_matrix(self):
Expand Down Expand Up @@ -227,7 +227,7 @@ def test_add_single(self):
self.assertTrue(np.allclose(self.vectors[ent], vector))

# Test `add` on empty kv.
kv = EuclideanKeyedVectors(self.vectors.vector_size)
kv = KeyedVectors(self.vectors.vector_size)
for ent, vector in zip(entities, vectors):
kv.add(ent, vector)

Expand All @@ -248,13 +248,22 @@ def test_add_multiple(self):
self.assertTrue(np.allclose(self.vectors[ent], vector))

# Test `add` on empty kv.
kv = EuclideanKeyedVectors(self.vectors.vector_size)
kv = KeyedVectors(self.vectors.vector_size)
kv[entities] = vectors
self.assertEqual(len(kv.vocab), len(entities))

for ent, vector in zip(entities, vectors):
self.assertTrue(np.allclose(kv[ent], vector))

def test_add_type(self):
kv = KeyedVectors(2)
assert kv.vectors.dtype == REAL

words, vectors = ["a"], np.array([1., 1.], dtype=np.float64).reshape(1, -1)
kv.add(words, vectors)

assert kv.vectors.dtype == REAL

def test_set_item(self):
"""Test that __setitem__ works correctly."""
vocab_size = len(self.vectors.vocab)
Expand Down Expand Up @@ -287,7 +296,7 @@ def test_set_item(self):
self.assertTrue(np.allclose(self.vectors[ent], vector))

def test_ft_kv_backward_compat_w_360(self):
kv = EuclideanKeyedVectors.load(datapath("ft_kv_3.6.0.model.gz"))
kv = KeyedVectors.load(datapath("ft_kv_3.6.0.model.gz"))
ft_kv = FastTextKeyedVectors.load(datapath("ft_kv_3.6.0.model.gz"))

expected = ['trees', 'survey', 'system', 'graph', 'interface']
Expand Down

0 comments on commit 30ca5b3

Please sign in to comment.