diff --git a/gensim/models/keyedvectors.py b/gensim/models/keyedvectors.py index 70bc76b641..40cbdc787a 100644 --- a/gensim/models/keyedvectors.py +++ b/gensim/models/keyedvectors.py @@ -215,7 +215,7 @@ def __str__(self): class BaseKeyedVectors(utils.SaveLoad): """Abstract base class / interface for various types of word vectors.""" def __init__(self, vector_size): - self.vectors = zeros((0, vector_size)) + self.vectors = zeros((0, vector_size), dtype=REAL) self.vocab = {} self.vector_size = vector_size self.index2entity = [] @@ -308,7 +308,7 @@ def add(self, entities, weights, replace=False): self.index2entity.append(entity) # add vectors for new entities - self.vectors = vstack((self.vectors, weights[~in_vocab_mask])) + self.vectors = vstack((self.vectors, weights[~in_vocab_mask].astype(self.vectors.dtype))) # change vectors for in_vocab entities if `replace` flag is specified if replace: @@ -2113,7 +2113,7 @@ def word_vec(self, word, use_norm=False): elif self.bucket == 0: raise KeyError('cannot calculate vector for OOV word without ngrams') else: - word_vec = np.zeros(self.vectors_ngrams.shape[1], dtype=np.float32) + word_vec = np.zeros(self.vectors_ngrams.shape[1], dtype=REAL) ngram_hashes = ft_ngram_hashes(word, self.min_n, self.max_n, self.bucket, self.compatible_hash) if len(ngram_hashes) == 0: # diff --git a/gensim/test/test_keyedvectors.py b/gensim/test/test_keyedvectors.py index 3005dee6da..3b5ae532ad 100644 --- a/gensim/test/test_keyedvectors.py +++ b/gensim/test/test_keyedvectors.py @@ -15,8 +15,8 @@ import numpy as np from gensim.corpora import Dictionary -from gensim.models.keyedvectors import KeyedVectors as EuclideanKeyedVectors, WordEmbeddingSimilarityIndex, \ - FastTextKeyedVectors +from gensim.models.keyedvectors import KeyedVectors, WordEmbeddingSimilarityIndex, \ + FastTextKeyedVectors, REAL from gensim.test.utils import datapath import gensim.models.keyedvectors @@ -27,7 +27,7 @@ class TestWordEmbeddingSimilarityIndex(unittest.TestCase): def setUp(self): - self.vectors = EuclideanKeyedVectors.load_word2vec_format( + self.vectors = KeyedVectors.load_word2vec_format( datapath('euclidean_vectors.bin'), binary=True, datatype=np.float64) def test_most_similar(self): @@ -70,9 +70,9 @@ def test_most_similar(self): self.assertTrue(np.allclose(first_similarities**2.0, second_similarities)) -class TestEuclideanKeyedVectors(unittest.TestCase): +class TestKeyedVectors(unittest.TestCase): def setUp(self): - self.vectors = EuclideanKeyedVectors.load_word2vec_format( + self.vectors = KeyedVectors.load_word2vec_format( datapath('euclidean_vectors.bin'), binary=True, datatype=np.float64) def test_similarity_matrix(self): @@ -227,7 +227,7 @@ def test_add_single(self): self.assertTrue(np.allclose(self.vectors[ent], vector)) # Test `add` on empty kv. - kv = EuclideanKeyedVectors(self.vectors.vector_size) + kv = KeyedVectors(self.vectors.vector_size) for ent, vector in zip(entities, vectors): kv.add(ent, vector) @@ -248,13 +248,22 @@ def test_add_multiple(self): self.assertTrue(np.allclose(self.vectors[ent], vector)) # Test `add` on empty kv. - kv = EuclideanKeyedVectors(self.vectors.vector_size) + kv = KeyedVectors(self.vectors.vector_size) kv[entities] = vectors self.assertEqual(len(kv.vocab), len(entities)) for ent, vector in zip(entities, vectors): self.assertTrue(np.allclose(kv[ent], vector)) + def test_add_type(self): + kv = KeyedVectors(2) + assert kv.vectors.dtype == REAL + + words, vectors = ["a"], np.array([1., 1.], dtype=np.float64).reshape(1, -1) + kv.add(words, vectors) + + assert kv.vectors.dtype == REAL + def test_set_item(self): """Test that __setitem__ works correctly.""" vocab_size = len(self.vectors.vocab) @@ -287,7 +296,7 @@ def test_set_item(self): self.assertTrue(np.allclose(self.vectors[ent], vector)) def test_ft_kv_backward_compat_w_360(self): - kv = EuclideanKeyedVectors.load(datapath("ft_kv_3.6.0.model.gz")) + kv = KeyedVectors.load(datapath("ft_kv_3.6.0.model.gz")) ft_kv = FastTextKeyedVectors.load(datapath("ft_kv_3.6.0.model.gz")) expected = ['trees', 'survey', 'system', 'graph', 'interface']