Skip to content

Commit

Permalink
Tidy up KeyedVectors.most_similar() API (piskvorky#3000)
Browse files Browse the repository at this point in the history
* Allow supplying a string-key as the negative arg. to most_similar()

* Allow a single vector as a positive or negative arg. to most_similar()

* Update comments

* Accept single arguments when positive and negative are both supplied

* Update most_similar_cosmul to match most_similar

I'm not sure if this fully addresses the
`# TODO: Update to better match & share code with most_similar()`
at line piskvorky#981 or not, so I've left it in.

* minor code cleanup

* add unit tests

* Update CHANGELOG.md

* remove redundant variable declaration

* enforce consistency

* respond to review feedback

* Update keyedvectors.py

Co-authored-by: Michael Penkov <misha.penkov@gmail.com>
Co-authored-by: Michael Penkov <m@penkov.dev>
  • Loading branch information
3 people authored and tbbharaj committed Aug 19, 2021
1 parent 6c4cd0e commit 10d72d7
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 28 deletions.
5 changes: 2 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ Nevertheless, we describe them below.
### Improved parameter edge-case handling in KeyedVectors most_similar and most_similar_cosmul methods

We now handle both ``positive`` and ``negative`` keyword parameters consistently.
These parameters typically specify
They may now be either:

1. A string, in which case the value is reinterpreted as a list of one element (the string value)
Expand All @@ -28,7 +27,7 @@ So you can now simply do:
```python
model.most_similar(positive='war', negative='peace')
```

instead of the slightly more involved

```python
Expand Down Expand Up @@ -73,7 +72,7 @@ Plus a large number of smaller improvements and fixes, as usual.
* [#3091](https://github.com/RaRe-Technologies/gensim/pull/3091): LsiModel: Only log top words that actually exist in the dictionary, by [@kmurphy4](https://github.com/kmurphy4)
* [#2980](https://github.com/RaRe-Technologies/gensim/pull/2980): Added EnsembleLda for stable LDA topics, by [@sezanzeb](https://github.com/sezanzeb)
* [#2978](https://github.com/RaRe-Technologies/gensim/pull/2978): Optimize performance of Author-Topic model, by [@horpto](https://github.com/horpto)

* [#3000](https://github.com/RaRe-Technologies/gensim/pull/3000): Tidy up KeyedVectors.most_similar() API, by [@simonwiles](https://github.com/simonwiles)

### :books: Tutorials and docs

Expand Down
54 changes: 29 additions & 25 deletions gensim/models/keyedvectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,21 @@
logger = logging.getLogger(__name__)


KEY_TYPES = (str, int, np.integer)
_KEY_TYPES = (str, int, np.integer)

_EXTENDED_KEY_TYPES = (str, int, np.integer, np.ndarray)


def _ensure_list(value):
"""Ensure that the specified value is wrapped in a list, for those supported cases
where we also accept a single key or vector."""
if value is None:
return []

if isinstance(value, _KEY_TYPES) or (isinstance(value, ndarray) and len(value.shape) == 1):
return [value]

return value


class KeyedVectors(utils.SaveLoad):
Expand Down Expand Up @@ -377,7 +391,7 @@ def __getitem__(self, key_or_keys):
Vector representation for `key_or_keys` (1D if `key_or_keys` is single key, otherwise - 2D).
"""
if isinstance(key_or_keys, KEY_TYPES):
if isinstance(key_or_keys, _KEY_TYPES):
return self.get_vector(key_or_keys)

return vstack([self.get_vector(key) for key in key_or_keys])
Expand Down Expand Up @@ -491,7 +505,7 @@ def add_vectors(self, keys, weights, extras=None, replace=False):
if True - replace vectors, otherwise - keep old vectors.
"""
if isinstance(keys, KEY_TYPES):
if isinstance(keys, _KEY_TYPES):
keys = [keys]
weights = np.array(weights).reshape(1, -1)
elif isinstance(weights, list):
Expand Down Expand Up @@ -729,10 +743,9 @@ def most_similar(
if isinstance(topn, Integral) and topn < 1:
return []

if positive is None:
positive = []
if negative is None:
negative = []
# allow passing a single string-key or vector for the positive/negative arguments
positive = _ensure_list(positive)
negative = _ensure_list(negative)

self.fill_norms()
clip_end = clip_end or len(self.vectors)
Expand All @@ -741,18 +754,14 @@ def most_similar(
clip_start = 0
clip_end = restrict_vocab

if isinstance(positive, KEY_TYPES) and not negative:
# allow calls like most_similar('dog'), as a shorthand for most_similar(['dog'])
positive = [positive]

# add weights for each key, if not already present; default to 1.0 for positive and -1.0 for negative keys
positive = [
(item, 1.0) if isinstance(item, KEY_TYPES + (ndarray,))
else item for item in positive
(item, 1.0) if isinstance(item, _EXTENDED_KEY_TYPES) else item
for item in positive
]
negative = [
(item, -1.0) if isinstance(item, KEY_TYPES + (ndarray,))
else item for item in negative
(item, -1.0) if isinstance(item, _EXTENDED_KEY_TYPES) else item
for item in negative
]

# compute the weighted average of all keys
Expand Down Expand Up @@ -969,21 +978,16 @@ def most_similar_cosmul(self, positive=None, negative=None, topn=10):
if isinstance(topn, Integral) and topn < 1:
return []

if positive is None:
positive = []
if negative is None:
negative = []
# allow passing a single string-key or vector for the positive/negative arguments
positive = _ensure_list(positive)
negative = _ensure_list(negative)

self.fill_norms()

if isinstance(positive, str) and not negative:
# allow calls like most_similar_cosmul('dog'), as a shorthand for most_similar_cosmul(['dog'])
positive = [positive]

all_words = {
self.get_index(word) for word in positive + negative
if not isinstance(word, ndarray) and word in self.key_to_index
}
}

positive = [
self.get_vector(word, norm=True) if isinstance(word, str) else word
Expand Down Expand Up @@ -1101,7 +1105,7 @@ def distances(self, word_or_vector, other_words=()):
If either `word_or_vector` or any word in `other_words` is absent from vocab.
"""
if isinstance(word_or_vector, KEY_TYPES):
if isinstance(word_or_vector, _KEY_TYPES):
input_vector = self.get_vector(word_or_vector)
else:
input_vector = word_or_vector
Expand Down
39 changes: 39 additions & 0 deletions gensim/test/test_keyedvectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
Automated tests for checking the poincare module from the models package.
"""

import functools
import logging
import unittest

Expand Down Expand Up @@ -39,6 +40,44 @@ def test_most_similar(self):
predicted = [result[0] for result in self.vectors.most_similar('war', topn=5)]
self.assertEqual(expected, predicted)

def test_most_similar_vector(self):
"""Can we pass vectors to most_similar directly?"""
positive = self.vectors.vectors[0:5]
most_similar = self.vectors.most_similar(positive=positive)
assert most_similar is not None

def test_most_similar_parameter_types(self):
"""Are the positive/negative parameter types are getting interpreted correctly?"""
partial = functools.partial(self.vectors.most_similar, topn=5)

position = partial('war', 'peace')
position_list = partial(['war'], ['peace'])
keyword = partial(positive='war', negative='peace')
keyword_list = partial(positive=['war'], negative=['peace'])

#
# The above calls should all yield identical results.
#
assert position == position_list
assert position == keyword
assert position == keyword_list

def test_most_similar_cosmul_parameter_types(self):
"""Are the positive/negative parameter types are getting interpreted correctly?"""
partial = functools.partial(self.vectors.most_similar_cosmul, topn=5)

position = partial('war', 'peace')
position_list = partial(['war'], ['peace'])
keyword = partial(positive='war', negative='peace')
keyword_list = partial(positive=['war'], negative=['peace'])

#
# The above calls should all yield identical results.
#
assert position == position_list
assert position == keyword
assert position == keyword_list

def test_vectors_for_all_list(self):
"""Test vectors_for_all returns expected results with a list of keys."""
words = [
Expand Down

0 comments on commit 10d72d7

Please sign in to comment.