Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tidy up KeyedVectors.most_similar() API #3000

Merged
merged 14 commits into from
Aug 13, 2021
16 changes: 12 additions & 4 deletions gensim/models/keyedvectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,10 +742,14 @@ def most_similar(
clip_start = 0
clip_end = restrict_vocab

if isinstance(positive, KEY_TYPES) and not negative:
# allow calls like most_similar('dog'), as a shorthand for most_similar(['dog'])
if isinstance(positive, KEY_TYPES + (ndarray,)):
# allow passing a single string-key or vector for the positive argument
piskvorky marked this conversation as resolved.
Show resolved Hide resolved
positive = [positive]

if isinstance(negative, KEY_TYPES + (ndarray,)):
# allow passing a single string-key or vector for the negative argument
negative = [negative]

# add weights for each key, if not already present; default to 1.0 for positive and -1.0 for negative keys
positive = [
(item, 1.0) if isinstance(item, KEY_TYPES + (ndarray,))
Expand Down Expand Up @@ -985,10 +989,14 @@ def most_similar_cosmul(self, positive=None, negative=None, topn=10):

self.fill_norms()

if isinstance(positive, str) and not negative:
# allow calls like most_similar_cosmul('dog'), as a shorthand for most_similar_cosmul(['dog'])
if isinstance(positive, KEY_TYPES + (ndarray,)):
# allow passing a single string-key or vector for the positive argument
positive = [positive]

if isinstance(negative, KEY_TYPES + (ndarray,)):
mpenkov marked this conversation as resolved.
Show resolved Hide resolved
# allow passing a single string-key or vector for the negative argument
negative = [negative]

all_words = {
self.get_index(word) for word in positive + negative
if not isinstance(word, ndarray) and word in self.key_to_index
Expand Down