From 10d72d7cb1825bff1b014178967a7769e46dd374 Mon Sep 17 00:00:00 2001 From: Simon Wiles Date: Fri, 13 Aug 2021 05:00:09 -0700 Subject: [PATCH] Tidy up KeyedVectors.most_similar() API (#3000) * Allow supplying a string-key as the negative arg. to most_similar() * Allow a single vector as a positive or negative arg. to most_similar() * Update comments * Accept single arguments when positive and negative are both supplied * Update most_similar_cosmul to match most_similar I'm not sure if this fully addresses the `# TODO: Update to better match & share code with most_similar()` at line #981 or not, so I've left it in. * minor code cleanup * add unit tests * Update CHANGELOG.md * remove redundant variable declaration * enforce consistency * respond to review feedback * Update keyedvectors.py Co-authored-by: Michael Penkov Co-authored-by: Michael Penkov --- CHANGELOG.md | 5 ++- gensim/models/keyedvectors.py | 54 +++++++++++++++++--------------- gensim/test/test_keyedvectors.py | 39 +++++++++++++++++++++++ 3 files changed, 70 insertions(+), 28 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c23172893e..0bf7acb7c4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,7 +15,6 @@ Nevertheless, we describe them below. ### Improved parameter edge-case handling in KeyedVectors most_similar and most_similar_cosmul methods We now handle both ``positive`` and ``negative`` keyword parameters consistently. -These parameters typically specify They may now be either: 1. A string, in which case the value is reinterpreted as a list of one element (the string value) @@ -28,7 +27,7 @@ So you can now simply do: ```python model.most_similar(positive='war', negative='peace') ``` - + instead of the slightly more involved ```python @@ -73,7 +72,7 @@ Plus a large number of smaller improvements and fixes, as usual. * [#3091](https://github.com/RaRe-Technologies/gensim/pull/3091): LsiModel: Only log top words that actually exist in the dictionary, by [@kmurphy4](https://github.com/kmurphy4) * [#2980](https://github.com/RaRe-Technologies/gensim/pull/2980): Added EnsembleLda for stable LDA topics, by [@sezanzeb](https://github.com/sezanzeb) * [#2978](https://github.com/RaRe-Technologies/gensim/pull/2978): Optimize performance of Author-Topic model, by [@horpto](https://github.com/horpto) - +* [#3000](https://github.com/RaRe-Technologies/gensim/pull/3000): Tidy up KeyedVectors.most_similar() API, by [@simonwiles](https://github.com/simonwiles) ### :books: Tutorials and docs diff --git a/gensim/models/keyedvectors.py b/gensim/models/keyedvectors.py index b4620abb81..b5debb21c1 100644 --- a/gensim/models/keyedvectors.py +++ b/gensim/models/keyedvectors.py @@ -189,7 +189,21 @@ logger = logging.getLogger(__name__) -KEY_TYPES = (str, int, np.integer) +_KEY_TYPES = (str, int, np.integer) + +_EXTENDED_KEY_TYPES = (str, int, np.integer, np.ndarray) + + +def _ensure_list(value): + """Ensure that the specified value is wrapped in a list, for those supported cases + where we also accept a single key or vector.""" + if value is None: + return [] + + if isinstance(value, _KEY_TYPES) or (isinstance(value, ndarray) and len(value.shape) == 1): + return [value] + + return value class KeyedVectors(utils.SaveLoad): @@ -377,7 +391,7 @@ def __getitem__(self, key_or_keys): Vector representation for `key_or_keys` (1D if `key_or_keys` is single key, otherwise - 2D). """ - if isinstance(key_or_keys, KEY_TYPES): + if isinstance(key_or_keys, _KEY_TYPES): return self.get_vector(key_or_keys) return vstack([self.get_vector(key) for key in key_or_keys]) @@ -491,7 +505,7 @@ def add_vectors(self, keys, weights, extras=None, replace=False): if True - replace vectors, otherwise - keep old vectors. """ - if isinstance(keys, KEY_TYPES): + if isinstance(keys, _KEY_TYPES): keys = [keys] weights = np.array(weights).reshape(1, -1) elif isinstance(weights, list): @@ -729,10 +743,9 @@ def most_similar( if isinstance(topn, Integral) and topn < 1: return [] - if positive is None: - positive = [] - if negative is None: - negative = [] + # allow passing a single string-key or vector for the positive/negative arguments + positive = _ensure_list(positive) + negative = _ensure_list(negative) self.fill_norms() clip_end = clip_end or len(self.vectors) @@ -741,18 +754,14 @@ def most_similar( clip_start = 0 clip_end = restrict_vocab - if isinstance(positive, KEY_TYPES) and not negative: - # allow calls like most_similar('dog'), as a shorthand for most_similar(['dog']) - positive = [positive] - # add weights for each key, if not already present; default to 1.0 for positive and -1.0 for negative keys positive = [ - (item, 1.0) if isinstance(item, KEY_TYPES + (ndarray,)) - else item for item in positive + (item, 1.0) if isinstance(item, _EXTENDED_KEY_TYPES) else item + for item in positive ] negative = [ - (item, -1.0) if isinstance(item, KEY_TYPES + (ndarray,)) - else item for item in negative + (item, -1.0) if isinstance(item, _EXTENDED_KEY_TYPES) else item + for item in negative ] # compute the weighted average of all keys @@ -969,21 +978,16 @@ def most_similar_cosmul(self, positive=None, negative=None, topn=10): if isinstance(topn, Integral) and topn < 1: return [] - if positive is None: - positive = [] - if negative is None: - negative = [] + # allow passing a single string-key or vector for the positive/negative arguments + positive = _ensure_list(positive) + negative = _ensure_list(negative) self.fill_norms() - if isinstance(positive, str) and not negative: - # allow calls like most_similar_cosmul('dog'), as a shorthand for most_similar_cosmul(['dog']) - positive = [positive] - all_words = { self.get_index(word) for word in positive + negative if not isinstance(word, ndarray) and word in self.key_to_index - } + } positive = [ self.get_vector(word, norm=True) if isinstance(word, str) else word @@ -1101,7 +1105,7 @@ def distances(self, word_or_vector, other_words=()): If either `word_or_vector` or any word in `other_words` is absent from vocab. """ - if isinstance(word_or_vector, KEY_TYPES): + if isinstance(word_or_vector, _KEY_TYPES): input_vector = self.get_vector(word_or_vector) else: input_vector = word_or_vector diff --git a/gensim/test/test_keyedvectors.py b/gensim/test/test_keyedvectors.py index d16b6bd9df..d5eda547ea 100644 --- a/gensim/test/test_keyedvectors.py +++ b/gensim/test/test_keyedvectors.py @@ -9,6 +9,7 @@ Automated tests for checking the poincare module from the models package. """ +import functools import logging import unittest @@ -39,6 +40,44 @@ def test_most_similar(self): predicted = [result[0] for result in self.vectors.most_similar('war', topn=5)] self.assertEqual(expected, predicted) + def test_most_similar_vector(self): + """Can we pass vectors to most_similar directly?""" + positive = self.vectors.vectors[0:5] + most_similar = self.vectors.most_similar(positive=positive) + assert most_similar is not None + + def test_most_similar_parameter_types(self): + """Are the positive/negative parameter types are getting interpreted correctly?""" + partial = functools.partial(self.vectors.most_similar, topn=5) + + position = partial('war', 'peace') + position_list = partial(['war'], ['peace']) + keyword = partial(positive='war', negative='peace') + keyword_list = partial(positive=['war'], negative=['peace']) + + # + # The above calls should all yield identical results. + # + assert position == position_list + assert position == keyword + assert position == keyword_list + + def test_most_similar_cosmul_parameter_types(self): + """Are the positive/negative parameter types are getting interpreted correctly?""" + partial = functools.partial(self.vectors.most_similar_cosmul, topn=5) + + position = partial('war', 'peace') + position_list = partial(['war'], ['peace']) + keyword = partial(positive='war', negative='peace') + keyword_list = partial(positive=['war'], negative=['peace']) + + # + # The above calls should all yield identical results. + # + assert position == position_list + assert position == keyword + assert position == keyword_list + def test_vectors_for_all_list(self): """Test vectors_for_all returns expected results with a list of keys.""" words = [