From f5d8e44e827d2682c7631e70653b44cc3f580113 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Sat, 9 Oct 2021 11:22:16 -0700 Subject: [PATCH 1/9] Return numpy arrays from recommend methods (#482) Change recommend/rank_items/similar_items/similar_users to return a numpy array of ids and a numpy array of scores, rather than return a python list of (itemid, score) tuples. This opens up options for batch processing in the future, and can be trivially converted by users into the previous format (liked ```zip(*model.recommend(...))```. --- examples/lastfm.py | 4 +- examples/movielens.py | 2 +- implicit/datasets/reddit.py | 4 -- implicit/gpu/_cuda.pyx | 5 ++- implicit/gpu/matrix_factorization_base.py | 40 ++++++++++++------- implicit/nearest_neighbours.py | 37 ++++++++++-------- implicit/recommender_base.pyx | 47 ++++++++++++----------- tests/als_test.py | 13 +++++-- tests/recommender_base_test.py | 43 ++++++++++----------- 9 files changed, 108 insertions(+), 87 deletions(-) diff --git a/examples/lastfm.py b/examples/lastfm.py index eb898123..83fc057a 100644 --- a/examples/lastfm.py +++ b/examples/lastfm.py @@ -104,7 +104,7 @@ def calculate_similar_artists(output_filename, model_name="als"): with codecs.open(output_filename, "w", "utf8") as o: for artistid in to_generate: artist = artists[artistid] - for other, score in model.similar_items(artistid, 11): + for other, score in zip(*model.similar_items(artistid, 11)): o.write("%s\t%s\t%s\n" % (artist, artists[other], score)) progress.update(1) @@ -143,7 +143,7 @@ def calculate_recommendations(output_filename, model_name="als"): with tqdm.tqdm(total=len(users)) as progress: with codecs.open(output_filename, "w", "utf8") as o: for userid, username in enumerate(users): - for artistid, score in model.recommend(userid, user_plays): + for artistid, score in zip(*model.recommend(userid, user_plays)): o.write("%s\t%s\t%s\n" % (username, artists[artistid], score)) progress.update(1) logging.debug("generated recommendations in %0.2fs", time.time() - start) diff --git a/examples/movielens.py b/examples/movielens.py index 7ecb66c4..c87e4af8 100644 --- a/examples/movielens.py +++ b/examples/movielens.py @@ -91,7 +91,7 @@ def calculate_similar_movies(output_filename, model_name="als", min_rating=4.0, # no ratings > 4 meaning we've filtered out all data for it. if ratings.indptr[movieid] != ratings.indptr[movieid + 1]: title = titles[movieid] - for other, score in model.similar_items(movieid, 11): + for other, score in zip(*model.similar_items(movieid, 11)): o.write("%s\t%s\t%s\n" % (title, titles[other], score)) progress.update(1) diff --git a/implicit/datasets/reddit.py b/implicit/datasets/reddit.py index 3b29ff4e..297026c3 100644 --- a/implicit/datasets/reddit.py +++ b/implicit/datasets/reddit.py @@ -73,10 +73,6 @@ def _hfd5_from_dataframe(data, outputfilename): (data["item"].cat.codes.copy(), data["user"].cat.codes.copy()), ) ).tocsr() - print(repr(ratings)) - print(repr(ratings.indices)) - print(repr(ratings.indptr)) - with h5py.File(outputfilename, "w") as f: g = f.create_group("item_user_ratings") g.create_dataset("data", data=ratings.data) diff --git a/implicit/gpu/_cuda.pyx b/implicit/gpu/_cuda.pyx index e57893a4..4fb7df5b 100644 --- a/implicit/gpu/_cuda.pyx +++ b/implicit/gpu/_cuda.pyx @@ -117,8 +117,11 @@ cdef class Matrix(object): except Exception: raise ValueError(f"don't know how to handle __getitem__ on {idx}") + if len(idx.shape) == 0: + idx = idx.reshape([1]) + if len(idx.shape) != 1: - raise ValueError(f"don't know how to handle __getitem__ on {idx}") + raise ValueError(f"don't know how to handle __getitem__ on {idx} - shape={idx.shape}") if ((idx < 0) | (idx >= self.c_matrix.rows)).any(): raise ValueError(f"row id out of range for selecting items from matrix") diff --git a/implicit/gpu/matrix_factorization_base.py b/implicit/gpu/matrix_factorization_base.py index e46d97e6..67ba5f9a 100644 --- a/implicit/gpu/matrix_factorization_base.py +++ b/implicit/gpu/matrix_factorization_base.py @@ -1,4 +1,3 @@ -import itertools import time import numpy as np @@ -48,11 +47,16 @@ def recommend( count = N + len(liked) # calculate the top N items, removing the users own liked items from the results - # TODO: own like filtering (direct in topk class + # TODO: own like filtering (direct in topk class) ids, scores = self._knn.topk(self.item_factors, self.user_factors[userid], count) - return list( - itertools.islice((rec for rec in zip(ids[0], scores[0]) if rec[0] not in liked), N) - ) + + # TODO: handle batch mode + ids, scores = ids[0], scores[0] + + if liked: + mask = np.in1d(ids, list(liked), invert=True) + ids, scores = ids[mask][:N], scores[mask][:N] + return ids, scores recommend.__doc__ = RecommenderBase.recommend.__doc__ @@ -69,9 +73,9 @@ def rank_items(self, userid, user_items, selected_items, recalculate_user=False) # once we have item_factors here, this should work ids, scores = self._knn.topk(item_factors, user, len(selected_items)) + ids, scores = ids[0], scores[0] ids = np.array(selected_items)[ids] - - return list(zip(ids[0], scores[0])) + return ids, scores rank_items.__doc__ = RecommenderBase.rank_items.__doc__ @@ -91,10 +95,15 @@ def item_norms(self): def similar_users(self, userid, N=10): ids, scores = self._knn.topk( - self.user_factors, self.user_factors[int(userid)], N, self.user_norms + self.user_factors, self.user_factors[userid], N, self.user_norms ) - scores /= self._user_norms_host[userid] - return list(zip(ids[0], scores[0])) + ids, scores = ids[0], scores[0] + + user_norms = self._user_norms_host[userid] + if not np.isscalar(user_norms): + user_norms = user_norms.reshape((len(user_norms), 1)) + scores /= user_norms + return ids, scores similar_users.__doc__ = RecommenderBase.similar_users.__doc__ @@ -102,10 +111,15 @@ def similar_items(self, itemid, N=10, react_users=None, recalculate_item=False): if recalculate_item: raise NotImplementedError("recalculate_item isn't support on GPU yet") ids, scores = self._knn.topk( - self.item_factors, self.item_factors[int(itemid)], N, self.item_norms + self.item_factors, self.item_factors[itemid], N, self.item_norms ) - scores /= self._item_norms_host[itemid] - return list(zip(ids[0], scores[0])) + ids, scores = ids[0], scores[0] + + item_norms = self._item_norms_host[itemid] + if not np.isscalar(item_norms): + item_norms = item_norms.reshape((len(item_norms), 1)) + scores /= item_norms + return ids, scores similar_items.__doc__ = RecommenderBase.similar_items.__doc__ diff --git a/implicit/nearest_neighbours.py b/implicit/nearest_neighbours.py index 4c22812f..2946ecfc 100644 --- a/implicit/nearest_neighbours.py +++ b/implicit/nearest_neighbours.py @@ -1,12 +1,9 @@ -import itertools - import numpy from numpy import bincount, log, log1p, sqrt from scipy.sparse import coo_matrix, csr_matrix from ._nearest_neighbours import NearestNeighboursScorer, all_pairs_knn from .recommender_base import RecommenderBase -from .utils import nonzeros class ItemItemRecommender(RecommenderBase): @@ -54,7 +51,7 @@ def recommend( if filter_items: items += len(filter_items) - indices, data = self.scorer.recommend( + ids, scores = self.scorer.recommend( userid, user_items.indptr, user_items.indices, @@ -62,13 +59,12 @@ def recommend( K=items, remove_own_likes=filter_already_liked_items, ) - best = sorted(zip(indices, data), key=lambda x: -x[1]) - if not filter_items: - return best + if filter_items: + mask = numpy.in1d(ids, filter_items, invert=True) + ids, scores = ids[mask][:N], scores[mask][:N] - liked = set(filter_items) - return list(itertools.islice((rec for rec in best if rec[0] not in liked), N)) + return ids, scores def rank_items(self, userid, user_items, selected_items, recalculate_user=False): """Rank given items for a user and returns sorted item list""" @@ -76,19 +72,23 @@ def rank_items(self, userid, user_items, selected_items, recalculate_user=False) if max(selected_items) >= user_items.shape[1] or min(selected_items) < 0: raise IndexError("Some of selected itemids are not in the model") + selected_items = numpy.array(selected_items) + # calculate the relevance scores liked_vector = user_items.getrow(userid) recommendations = liked_vector.dot(self.similarity) # remove items that are not in the selected_items - best = sorted(zip(recommendations.indices, recommendations.data), key=lambda x: -x[1]) - ret = [rec for rec in best if rec[0] in selected_items] + ids, scores = recommendations.indices, recommendations.data + mask = numpy.in1d(ids, selected_items) + ids, scores = ids[mask], scores[mask] # returned items should be equal to input selected items - for itemid in selected_items: - if itemid not in recommendations.indices: - ret.append((itemid, -1.0)) - return ret + missing = selected_items[numpy.in1d(selected_items, ids, invert=True)] + if missing.size: + ids = numpy.append(ids, missing) + scores = numpy.append(scores, numpy.full(missing.size, -numpy.finfo(scores.dtype).max)) + return ids, scores def similar_users(self, userid, N=10): raise NotImplementedError("Not implemented Yet") @@ -96,9 +96,12 @@ def similar_users(self, userid, N=10): def similar_items(self, itemid, N=10): """Returns a list of the most similar other items""" if itemid >= self.similarity.shape[0]: - return [] + return numpy.array([]), numpy.array([]) - return sorted(list(nonzeros(self.similarity, itemid)), key=lambda x: -x[1])[:N] + ids = self.similarity[itemid].indices + scores = self.similarity[itemid].data + best = numpy.argsort(scores)[::-1][:N] + return ids[best], scores[best] def __getstate__(self): state = self.__dict__.copy() diff --git a/implicit/recommender_base.pyx b/implicit/recommender_base.pyx index 82830f4e..5a40dfd0 100644 --- a/implicit/recommender_base.pyx +++ b/implicit/recommender_base.pyx @@ -71,8 +71,8 @@ class RecommenderBase(object): Returns ------- - list - List of (itemid, score) tuples + tuple + Tuple of (itemids, scores) arrays """ pass @@ -96,8 +96,8 @@ class RecommenderBase(object): Returns ------- - list - List of (itemid, score) tuples. it only contains items that appears in + tuple + Tuple of (itemids, scores) arrays. it only contains items that appears in input parameter selected_items """ pass @@ -116,8 +116,8 @@ class RecommenderBase(object): Returns ------- - list - List of (userid, score) tuples + tuple + Tuple of (itemids, scores) arrays """ pass @@ -142,8 +142,8 @@ class RecommenderBase(object): Returns ------- - list - List of (itemid, score) tuples + tuple + Tuple of (itemids, scores) arrays """ pass @@ -170,22 +170,22 @@ class MatrixFactorizationBase(RecommenderBase): N=10, filter_already_liked_items=True, filter_items=None, recalculate_user=False): user = self._user_factor(userid, user_items, recalculate_user) - liked = set() - if filter_already_liked_items: - liked.update(user_items[userid].indices) - if filter_items: - liked.update(filter_items) - # calculate the top N items, removing the users own liked items from the results scores = self.item_factors.dot(user) - count = N + len(liked) - if count < len(scores): - ids = np.argpartition(scores, -count)[-count:] - best = sorted(zip(ids, scores[ids]), key=lambda x: -x[1]) + # filter out liked items + if filter_already_liked_items: + scores[user_items[userid].indices] = -np.finfo(scores.dtype).max + if filter_items: + scores[filter_items] = -np.finfo(scores.dtype).max + + if N < len(scores): + ids = np.argpartition(scores, -N)[-N:] else: - best = sorted(enumerate(scores), key=lambda x: -x[1]) - return list(itertools.islice((rec for rec in best if rec[0] not in liked), N)) + ids = np.arange(len(scores)) + + ids = ids[np.argsort(scores[ids])[::-1]] + return ids, scores[ids] @cython.boundscheck(False) @cython.wraparound(False) @@ -310,7 +310,9 @@ class MatrixFactorizationBase(RecommenderBase): scores = item_factors.dot(user) # return sorted results - return sorted(zip(selected_items, scores), key=lambda x: -x[1]) + selected_items = np.array(selected_items) + best = np.argsort(scores)[::-1] + return selected_items[best], scores[best] recommend.__doc__ = RecommenderBase.recommend.__doc__ @@ -357,7 +359,8 @@ class MatrixFactorizationBase(RecommenderBase): def _get_similarity_score(self, factor, norm, factors, norms, N): scores = factors.dot(factor) / (norm * norms) best = np.argpartition(scores, -N)[-N:] - return sorted(zip(best, scores[best]), key=lambda x: -x[1]) + ids = best[np.argsort(scores[best])[::-1]] + return ids, scores[ids] @property def user_norms(self): diff --git a/tests/als_test.py b/tests/als_test.py index cffd3d0f..73a75a86 100644 --- a/tests/als_test.py +++ b/tests/als_test.py @@ -171,14 +171,19 @@ def test_explain(): userid = 0 # Assert recommendation is the the same if we recompute user vectors - recs = model.recommend(userid, item_users, N=10) - recalculated_recs = model.recommend(userid, item_users, N=10, recalculate_user=True) - for (item1, score1), (item2, score2) in zip(recs, recalculated_recs): + # TODO: this doesn't quite work with N=10 (because we returns items that should have been + # filtered with large negative score?) also seems like the dtype is different between + # recalculate and not + ids, scores = model.recommend(userid, item_users, N=5) + recalculated_ids, recalculated_scores = model.recommend( + userid, item_users, N=5, recalculate_user=True + ) + for item1, score1, item2, score2 in zip(ids, scores, recalculated_ids, recalculated_scores): assert item1 == item2 assert pytest.approx(score1, abs=1e-4) == score2 # Assert explanation makes sense - top_rec, score = recalculated_recs[0] + top_rec, score = recalculated_ids[0], recalculated_scores[0] score_explained, contributions, W = model.explain(userid, item_users, itemid=top_rec) scores = [s for _, s in contributions] items = [i for i, _ in contributions] diff --git a/tests/recommender_base_test.py b/tests/recommender_base_test.py index 77c185f4..c9ed9505 100644 --- a/tests/recommender_base_test.py +++ b/tests/recommender_base_test.py @@ -37,24 +37,24 @@ def test_recommend(self): model.fit(item_users, show_progress=False) for userid in range(50): - recs = model.recommend(userid, user_items, N=1) - self.assertEqual(len(recs), 1) + ids, scores = model.recommend(userid, user_items, N=1) + self.assertEqual(len(ids), 1) # the top item recommended should be the same as the userid: # its the one withheld item for the user that is liked by # all the other similar users - self.assertEqual(recs[0][0], userid) + self.assertEqual(ids[0], userid) # try asking for more items than possible, # should return only the available items # https://github.com/benfred/implicit/issues/22 - recs = model.recommend(0, user_items, N=10000) - self.assertTrue(len(recs)) + ids, scores = model.recommend(0, user_items, N=10000) + self.assertTrue(len(ids)) # filter recommended items using an additional filter list # https://github.com/benfred/implicit/issues/26 - recs = model.recommend(0, user_items, N=1, filter_items=[0]) - self.assertTrue(0 not in dict(recs)) + ids, scores = model.recommend(0, user_items, N=1, filter_items=[0]) + self.assertTrue(0 not in set(ids)) def test_recalculate_user(self): item_users = get_checker_board(50) @@ -64,21 +64,21 @@ def test_recalculate_user(self): model.fit(item_users, show_progress=False) for userid in range(item_users.shape[1]): - recs = model.recommend(userid, user_items, N=1) - self.assertEqual(len(recs), 1) + ids, scores = model.recommend(userid, user_items, N=1) + self.assertEqual(len(ids), 1) user_vector = user_items[userid] # we should get the same item if we recalculate_user try: - recs_from_liked = model.recommend( + ids_from_liked, scores_from_liked = model.recommend( userid=0, user_items=user_vector, N=1, recalculate_user=True ) - self.assertEqual(recs[0][0], recs_from_liked[0][0]) + self.assertEqual(ids[0], ids_from_liked[0]) # TODO: if we set regularization for the model to be sufficiently high, the # scores from recalculate_user are slightly different. Investigate # (could be difference between CG and cholesky optimizers?) - self.assertAlmostEqual(recs[0][1], recs_from_liked[0][1], places=4) + self.assertAlmostEqual(scores[0], scores_from_liked[0], places=4) except NotImplementedError: # some models don't support recalculating user on the fly, and thats ok pass @@ -98,23 +98,22 @@ def test_evaluation(self): self.assertEqual(p, 1) def test_similar_users(self): - model = self._get_model() # calculating similar users in nearest-neighbours is not implemented yet if isinstance(model, ItemItemRecommender): return model.fit(get_checker_board(50), show_progress=False) for userid in range(50): - recs = model.similar_users(userid, N=10) - for r, _ in recs: + ids, _ = model.similar_users(userid, N=10) + for r in ids: self.assertEqual(r % 2, userid % 2) def test_similar_items(self): model = self._get_model() model.fit(get_checker_board(256), show_progress=False) for itemid in range(50): - recs = model.similar_items(itemid, N=10) - for r, _ in recs: + ids, _ = model.similar_items(itemid, N=10) + for r in ids: self.assertEqual(r % 2, itemid % 2) def test_zero_length_row(self): @@ -133,8 +132,8 @@ def test_zero_length_row(self): # item 42 has no users, shouldn't be similar to anything for itemid in range(40): - recs = model.similar_items(itemid, 10) - self.assertTrue(42 not in [r for r, _ in recs]) + ids, _ = model.similar_items(itemid, 10) + self.assertTrue(42 not in ids) def test_dtype(self): # models should be able to accept input of either float32 or float64 @@ -154,17 +153,15 @@ def test_rank_items(self): for userid in range(50): selected_items = np.random.randint(50, size=10).tolist() - ranked_list = model.rank_items(userid, user_items, selected_items) - ordered_items = [itemid for (itemid, score) in ranked_list] + ids, scores = model.rank_items(userid, user_items, selected_items) # ranked list should have same items - self.assertEqual(set(ordered_items), set(selected_items)) + self.assertEqual(set(ids), set(selected_items)) wrong_neg_items = [-1, -3, -5] wrong_pos_items = [51, 300, 200] # rank_items should raise IndexError if selected items contains wrong itemids - with self.assertRaises(IndexError): wrong_item_list = selected_items + wrong_neg_items model.rank_items(userid, user_items, wrong_item_list) From 19a55bea4d6d765ae73179d705dc23714664b883 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Sat, 16 Oct 2021 10:05:04 -0700 Subject: [PATCH 2/9] Change model.fit to take a user_items sparse matrix (#484) Previously, model.fit took an item_users sparse matrix, while model.recommend took a user_items sparse matrix. This was a source of confusion, so change the model.fit method to be consistent with model.recommend --- cuda_setup.py | 2 -- examples/lastfm.py | 7 ++++--- examples/movielens.py | 4 +++- implicit/_nearest_neighbours.pyx | 6 +++--- implicit/approximate_als.py | 12 ++++++------ implicit/cpu/als.py | 28 ++++++++++++++-------------- implicit/cpu/bpr.pyx | 19 ++++++++----------- implicit/gpu/als.py | 28 ++++++++++++++-------------- implicit/lmf.pyx | 21 +++++++++++---------- implicit/recommender_base.pyx | 6 +++--- tests/als_test.py | 14 +++++++------- tests/knn_test.py | 4 ++-- 12 files changed, 75 insertions(+), 76 deletions(-) diff --git a/cuda_setup.py b/cuda_setup.py index b12d98d5..443a510b 100644 --- a/cuda_setup.py +++ b/cuda_setup.py @@ -61,8 +61,6 @@ def locate_cuda(): post_args = [ "-arch=sm_60", - "-gencode=arch=compute_50,code=sm_50", - "-gencode=arch=compute_52,code=sm_52", "-gencode=arch=compute_60,code=sm_60", "-gencode=arch=compute_61,code=sm_61", "-gencode=arch=compute_70,code=sm_70", diff --git a/examples/lastfm.py b/examples/lastfm.py index 83fc057a..f35917a2 100644 --- a/examples/lastfm.py +++ b/examples/lastfm.py @@ -85,10 +85,11 @@ def calculate_similar_artists(output_filename, model_name="als"): # this is actually disturbingly expensive: plays = plays.tocsr() + user_plays = plays.T.tocsr() logging.debug("training model %s", model_name) start = time.time() - model.fit(plays) + model.fit(user_plays) logging.debug("trained model '%s' in %0.2fs", model_name, time.time() - start) # write out similar artists by popularity @@ -131,15 +132,15 @@ def calculate_recommendations(output_filename, model_name="als"): # this is actually disturbingly expensive: plays = plays.tocsr() + user_plays = plays.T.tocsr() logging.debug("training model %s", model_name) start = time.time() - model.fit(plays) + model.fit(user_plays) logging.debug("trained model '%s' in %0.2fs", model_name, time.time() - start) # generate recommendations for each user and write out to a file start = time.time() - user_plays = plays.T.tocsr() with tqdm.tqdm(total=len(users)) as progress: with codecs.open(output_filename, "w", "utf8") as o: for userid, username in enumerate(users): diff --git a/examples/movielens.py b/examples/movielens.py index c87e4af8..e12f71fd 100644 --- a/examples/movielens.py +++ b/examples/movielens.py @@ -73,10 +73,12 @@ def calculate_similar_movies(output_filename, model_name="als", min_rating=4.0, else: raise NotImplementedError("TODO: model %s" % model_name) + user_ratings = ratings.T.tocsr() + # train the model log.debug("training model %s", model_name) start = time.time() - model.fit(ratings) + model.fit(user_ratings) log.debug("trained model '%s' in %s", model_name, time.time() - start) log.debug("calculating top movies") diff --git a/implicit/_nearest_neighbours.pyx b/implicit/_nearest_neighbours.pyx index 6bfbb7d3..46d83a93 100644 --- a/implicit/_nearest_neighbours.pyx +++ b/implicit/_nearest_neighbours.pyx @@ -97,11 +97,11 @@ cdef class NearestNeighboursScorer(object): @cython.boundscheck(False) -def all_pairs_knn(items, unsigned int K=100, int num_threads=0, show_progress=True): +def all_pairs_knn(users, unsigned int K=100, int num_threads=0, show_progress=True): """ Returns the top K nearest neighbours for each row in the matrix. """ - items = items.tocsr() - users = items.T.tocsr() + users = users.tocsr() + items = users.T.tocsr() cdef int item_count = items.shape[0] cdef int i, u, index1, index2, j diff --git a/implicit/approximate_als.py b/implicit/approximate_als.py index 5395c67c..eba15807 100644 --- a/implicit/approximate_als.py +++ b/implicit/approximate_als.py @@ -97,14 +97,14 @@ def __init__( *args, random_state=random_state, **kwargs ) - def fit(self, Ciu, show_progress=True): + def fit(self, Cui, show_progress=True): # nmslib can be a little chatty when first imported, disable some of # the logging logging.getLogger("nmslib").setLevel(logging.WARNING) import nmslib # train the model - super(NMSLibAlternatingLeastSquares, self).fit(Ciu, show_progress) + super(NMSLibAlternatingLeastSquares, self).fit(Cui, show_progress) # create index for similar_items if self.approximate_similar_items: @@ -238,12 +238,12 @@ def __init__( self.n_trees = n_trees self.search_k = search_k - def fit(self, Ciu, show_progress=True): + def fit(self, Cui, show_progress=True): # delay loading the annoy library in case its not installed here import annoy # train the model - super(AnnoyAlternatingLeastSquares, self).fit(Ciu, show_progress) + super(AnnoyAlternatingLeastSquares, self).fit(Cui, show_progress) # build up an Annoy Index with all the item_factors (for calculating # similar items) @@ -377,11 +377,11 @@ def __init__( *args, random_state=random_state, **kwargs ) - def fit(self, Ciu, show_progress=True): + def fit(self, Cui, show_progress=True): import faiss # train the model - super(FaissAlternatingLeastSquares, self).fit(Ciu, show_progress) + super(FaissAlternatingLeastSquares, self).fit(Cui, show_progress) self.quantizer = faiss.IndexFlat(self.factors) diff --git a/implicit/cpu/als.py b/implicit/cpu/als.py index d40b46f0..887130c5 100644 --- a/implicit/cpu/als.py +++ b/implicit/cpu/als.py @@ -93,27 +93,27 @@ def __init__( check_blas_config() - def fit(self, item_users, show_progress=True): - """Factorizes the item_users matrix. + def fit(self, user_items, show_progress=True): + """Factorizes the user_items matrix. After calling this method, the members 'user_factors' and 'item_factors' will be initialized with a latent factor model of the input data. - The item_users matrix does double duty here. It defines which items are liked by which - users (P_iu in the original paper), as well as how much confidence we have that the user - liked the item (C_iu). + The user_items matrix does double duty here. It defines which items are liked by which + users (P_ui in the original paper), as well as how much confidence we have that the user + liked the item (C_ui). The negative items are implicitly defined: This code assumes that positive items in the - item_users matrix means that the user liked the item. The negatives are left unset in this + user_items matrix means that the user liked the item. The negatives are left unset in this sparse matrix: the library will assume that means Piu = 0 and Ciu = 1 for all these items. Negative items can also be passed with a higher confidence value by passing a negative value, indicating that the user disliked the item. Parameters ---------- - item_users: csr_matrix + user_items: csr_matrix Matrix of confidences for the liked items. This matrix should be a csr_matrix where - the rows of the matrix are the item, the columns are the users that liked that item, + the rows of the matrix are the users, the columns are the items liked that user, and the value is the confidence that the user liked the item. show_progress : bool, optional Whether to show a progress bar during fitting @@ -121,18 +121,18 @@ def fit(self, item_users, show_progress=True): # initialize the random state random_state = check_random_state(self.random_state) - Ciu = item_users - if not isinstance(Ciu, scipy.sparse.csr_matrix): + Cui = user_items + if not isinstance(Cui, scipy.sparse.csr_matrix): s = time.time() log.debug("Converting input to CSR format") - Ciu = Ciu.tocsr() + Cui = Cui.tocsr() log.debug("Converted input to CSR in %.3fs", time.time() - s) - if Ciu.dtype != np.float32: - Ciu = Ciu.astype(np.float32) + if Cui.dtype != np.float32: + Cui = Cui.astype(np.float32) s = time.time() - Cui = Ciu.T.tocsr() + Ciu = Cui.T.tocsr() log.debug("Calculated transpose in %.3fs", time.time() - s) items, users = Ciu.shape diff --git a/implicit/cpu/bpr.pyx b/implicit/cpu/bpr.pyx index ac9049b1..09f95dee 100644 --- a/implicit/cpu/bpr.pyx +++ b/implicit/cpu/bpr.pyx @@ -121,14 +121,14 @@ class BayesianPersonalizedRanking(MatrixFactorizationBase): @cython.cdivision(True) @cython.boundscheck(False) - def fit(self, item_users, show_progress=True): - """ Factorizes the item_users matrix + def fit(self, user_items, show_progress=True): + """ Factorizes the user_items matrix Parameters ---------- - item_users: coo_matrix - Matrix of confidences for the liked items. This matrix should be a coo_matrix where - the rows of the matrix are the item, and the columns are the users that liked that item. + user_items: csr_matrix + Matrix of confidences for the liked items. This matrix should be a csr_matrix where + the rows of the matrix are the user, and the columns are the items liked by that user. BPR ignores the weight value of the matrix right now - it treats non zero entries as a binary signal that the user liked the item. show_progress : bool, optional @@ -137,15 +137,12 @@ class BayesianPersonalizedRanking(MatrixFactorizationBase): rs = check_random_state(self.random_state) # for now, all we handle is float 32 values - if item_users.dtype != np.float32: - item_users = item_users.astype(np.float32) + if user_items.dtype != np.float32: + user_items = user_items.astype(np.float32) - items, users = item_users.shape + users, items = user_items.shape # We need efficient user lookup for case of removing own likes - # TODO: might make more sense to just changes inputs to be users by items instead - # but that would be a major breaking API change - user_items = item_users.T.tocsr() if not user_items.has_sorted_indices: user_items.sort_indices() diff --git a/implicit/gpu/als.py b/implicit/gpu/als.py index 3b999138..0f380ead 100644 --- a/implicit/gpu/als.py +++ b/implicit/gpu/als.py @@ -67,27 +67,27 @@ def __init__( self.random_state = random_state self.cg_steps = 3 - def fit(self, item_users, show_progress=True): - """Factorizes the item_users matrix. + def fit(self, user_items, show_progress=True): + """Factorizes the user_items matrix. After calling this method, the members 'user_factors' and 'item_factors' will be initialized with a latent factor model of the input data. - The item_users matrix does double duty here. It defines which items are liked by which - users (P_iu in the original paper), as well as how much confidence we have that the user - liked the item (C_iu). + The user_items matrix does double duty here. It defines which items are liked by which + users (P_ui in the original paper), as well as how much confidence we have that the user + liked the item (C_ui. The negative items are implicitly defined: This code assumes that positive items in the - item_users matrix means that the user liked the item. The negatives are left unset in this + user_items matrix means that the user liked the item. The negatives are left unset in this sparse matrix: the library will assume that means Piu = 0 and Ciu = 1 for all these items. Negative items can also be passed with a higher confidence value by passing a negative value, indicating that the user disliked the item. Parameters ---------- - item_users: csr_matrix + user_items: csr_matrix Matrix of confidences for the liked items. This matrix should be a csr_matrix where - the rows of the matrix are the item, the columns are the users that liked that item, + the rows of the matrix are the user, the columns are the items liked by that user, and the value is the confidence that the user liked the item. show_progress : bool, optional Whether to show a progress bar during fitting @@ -96,18 +96,18 @@ def fit(self, item_users, show_progress=True): random_state = check_random_state(self.random_state) # TODO: allow passing in cupy arrays on gpu - Ciu = item_users - if not isinstance(Ciu, scipy.sparse.csr_matrix): + Cui = user_items + if not isinstance(Cui, scipy.sparse.csr_matrix): s = time.time() log.debug("Converting input to CSR format") - Ciu = Ciu.tocsr() + Cui = Cui.tocsr() log.debug("Converted input to CSR in %.3fs", time.time() - s) - if Ciu.dtype != np.float32: - Ciu = Ciu.astype(np.float32) + if Cui.dtype != np.float32: + Cui = Cui.astype(np.float32) s = time.time() - Cui = Ciu.T.tocsr() + Ciu = Cui.T.tocsr() log.debug("Calculated transpose in %.3fs", time.time() - s) items, users = Ciu.shape diff --git a/implicit/lmf.pyx b/implicit/lmf.pyx index 01c614ba..974eb6b4 100644 --- a/implicit/lmf.pyx +++ b/implicit/lmf.pyx @@ -117,14 +117,15 @@ class LogisticMatrixFactorization(MatrixFactorizationBase): @cython.cdivision(True) @cython.boundscheck(False) - def fit(self, item_users, show_progress=True): - """ Factorizes the item_users matrix + def fit(self, user_items, show_progress=True): + """ Factorizes the user_items matrix Parameters ---------- - item_users: coo_matrix - Matrix of confidences for the liked items. This matrix should be a coo_matrix where - the rows of the matrix are the item, and the columns are the users that liked that item. + user_items: csr_matrix + Matrix of confidences for the liked items. This matrix should be a csr_matrix where + the rows of the matrix are the user, and the columns are the items that are liked by + the user. BPR ignores the weight value of the matrix right now - it treats non zero entries as a binary signal that the user liked the item. show_progress : bool, optional @@ -133,13 +134,13 @@ class LogisticMatrixFactorization(MatrixFactorizationBase): rs = check_random_state(self.random_state) # for now, all we handle is float 32 values - if item_users.dtype != np.float32: - item_users = item_users.astype(np.float32) + if user_items.dtype != np.float32: + user_items = user_items.astype(np.float32) - items, users = item_users.shape + users, items = user_items.shape - item_users = item_users.tocsr() - user_items = item_users.T.tocsr() + user_items = user_items.tocsr() + item_users = user_items.T.tocsr() if not item_users.has_sorted_indices: item_users.sort_indices() diff --git a/implicit/recommender_base.pyx b/implicit/recommender_base.pyx index 5a40dfd0..8cc9ad6b 100644 --- a/implicit/recommender_base.pyx +++ b/implicit/recommender_base.pyx @@ -28,14 +28,14 @@ class RecommenderBase(object): __metaclass__ = ABCMeta @abstractmethod - def fit(self, item_users): + def fit(self, user_items): """ Trains the model on a sparse matrix of item/user/weight Parameters ---------- - item_user : csr_matrix - A matrix of shape (number_of_items, number_of_users). The nonzero + user_items : csr_matrix + A matrix of shape (number_of_users, number_of_items). The nonzero entries in this matrix are the items that are liked by each user. The values are how confident you are that the item is liked by the user. """ diff --git a/tests/als_test.py b/tests/als_test.py index 73a75a86..b97d17e4 100644 --- a/tests/als_test.py +++ b/tests/als_test.py @@ -126,7 +126,7 @@ def test_factorize(use_native, use_gpu, use_cg, dtype): random_state=42, ) model.fit(user_items, show_progress=False) - rows, cols = model.item_factors, model.user_factors + rows, cols = model.user_factors, model.item_factors if use_gpu: rows, cols = rows.to_numpy(), cols.to_numpy() @@ -154,8 +154,8 @@ def test_explain(): ], dtype=np.float64, ) - user_items = counts * 2 - item_users = user_items.T + item_users = counts * 2 + user_items = item_users.T.tocsr() model = AlternatingLeastSquares( factors=4, @@ -174,9 +174,9 @@ def test_explain(): # TODO: this doesn't quite work with N=10 (because we returns items that should have been # filtered with large negative score?) also seems like the dtype is different between # recalculate and not - ids, scores = model.recommend(userid, item_users, N=5) + ids, scores = model.recommend(userid, user_items, N=3) recalculated_ids, recalculated_scores = model.recommend( - userid, item_users, N=5, recalculate_user=True + userid, user_items, N=3, recalculate_user=True ) for item1, score1, item2, score2 in zip(ids, scores, recalculated_ids, recalculated_scores): assert item1 == item2 @@ -184,7 +184,7 @@ def test_explain(): # Assert explanation makes sense top_rec, score = recalculated_ids[0], recalculated_scores[0] - score_explained, contributions, W = model.explain(userid, item_users, itemid=top_rec) + score_explained, contributions, W = model.explain(userid, user_items, itemid=top_rec) scores = [s for _, s in contributions] items = [i for i, _ in contributions] assert pytest.approx(score, abs=1e-4) == score_explained @@ -196,7 +196,7 @@ def test_explain(): # Assert explanation with precomputed user weights is correct top_score_explained, top_contributions, W = model.explain( - userid, item_users, itemid=top_rec, user_weights=W, N=2 + userid, user_items, itemid=top_rec, user_weights=W, N=2 ) top_scores = [s for _, s in top_contributions] top_items = [i for i, _ in top_contributions] diff --git a/tests/knn_test.py b/tests/knn_test.py index b5dc32bf..b0e71160 100644 --- a/tests/knn_test.py +++ b/tests/knn_test.py @@ -42,11 +42,11 @@ def test_all_pairs_knn(): counts = implicit.nearest_neighbours.tfidf_weight(counts).tocsr() # compute all neighbours using matrix dot product - all_neighbours = counts.dot(counts.T).tocsr() + all_neighbours = counts.T.dot(counts).tocsr() K = 3 knn = implicit.nearest_neighbours.all_pairs_knn(counts, K, show_progress=False).tocsr() - for rowid in range(counts.shape[0]): + for rowid in range(counts.shape[1]): # make sure values match for colid, data in zip(knn[rowid].indices, knn[rowid].data): pytest.approx(all_neighbours[rowid, colid]) == data From 6f7b80b75e1eafece564dae4b5104fa95e15fabc Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Thu, 25 Nov 2021 16:44:57 -0800 Subject: [PATCH 3/9] Unify recommend/rank_items/recommend_all apis (#493) Add batch mode operations for recommend, add an item list parameter and deprecate the rank_items and recommend_all methods. --- examples/lastfm.py | 28 +- implicit/__init__.py | 2 +- implicit/_nearest_neighbours.pyx | 2 +- implicit/cpu/als.py | 2 +- implicit/cpu/bpr.h | 8 +- implicit/cpu/bpr.pyx | 2 +- implicit/cpu/matrix_factorization_base.py | 213 ++++++++++++ implicit/cpu/select.h | 43 +++ implicit/cpu/topk.pyx | 68 ++++ implicit/evaluation.pyx | 4 +- implicit/gpu/_cuda.pyx | 22 +- implicit/gpu/bpr.py | 21 +- implicit/gpu/knn.cu | 28 +- implicit/gpu/knn.h | 4 +- implicit/gpu/knn.pxd | 7 +- implicit/gpu/matrix_factorization_base.py | 86 ++--- implicit/gpu/random.cu | 2 +- implicit/lmf.pyx | 2 +- implicit/nearest_neighbours.py | 111 ++++--- implicit/recommender_base.py | 136 ++++++++ implicit/recommender_base.pyx | 385 ---------------------- implicit/topnc.cpp | 30 -- implicit/topnc.h | 7 - setup.cfg | 1 + setup.py | 14 +- tests/als_test.py | 42 +-- tests/evaluation_test.py | 159 +++++---- tests/recommender_base_test.py | 135 +++++++- 28 files changed, 885 insertions(+), 679 deletions(-) create mode 100644 implicit/cpu/matrix_factorization_base.py create mode 100644 implicit/cpu/select.h create mode 100644 implicit/cpu/topk.pyx create mode 100644 implicit/recommender_base.py delete mode 100644 implicit/recommender_base.pyx delete mode 100644 implicit/topnc.cpp delete mode 100644 implicit/topnc.h diff --git a/examples/lastfm.py b/examples/lastfm.py index f35917a2..f8b167bd 100644 --- a/examples/lastfm.py +++ b/examples/lastfm.py @@ -103,11 +103,15 @@ def calculate_similar_artists(output_filename, model_name="als"): logging.debug("writing similar items") with tqdm.tqdm(total=len(to_generate)) as progress: with codecs.open(output_filename, "w", "utf8") as o: - for artistid in to_generate: - artist = artists[artistid] - for other, score in zip(*model.similar_items(artistid, 11)): - o.write("%s\t%s\t%s\n" % (artist, artists[other], score)) - progress.update(1) + batch_size = 1000 + for startidx in range(0, len(to_generate), batch_size): + batch = to_generate[startidx : startidx + batch_size] + ids, scores = model.similar_items(batch, 11) + for i, artistid in enumerate(batch): + artist = artists[artistid] + for other, score in zip(ids[i], scores[i]): + o.write("%s\t%s\t%s\n" % (artist, artists[other], score)) + progress.update(batch_size) logging.debug("generated similar artists in %0.2fs", time.time() - start) @@ -143,10 +147,16 @@ def calculate_recommendations(output_filename, model_name="als"): start = time.time() with tqdm.tqdm(total=len(users)) as progress: with codecs.open(output_filename, "w", "utf8") as o: - for userid, username in enumerate(users): - for artistid, score in zip(*model.recommend(userid, user_plays)): - o.write("%s\t%s\t%s\n" % (username, artists[artistid], score)) - progress.update(1) + batch_size = 1000 + to_generate = np.arange(len(users)) + for startidx in range(0, len(to_generate), batch_size): + batch = to_generate[startidx : startidx + batch_size] + ids, scores = model.recommend(batch, user_plays, filter_already_liked_items=True) + for i, userid in enumerate(batch): + username = users[userid] + for other, score in zip(ids[i], scores[i]): + o.write("%s\t%s\t%s\n" % (username, artists[other], score)) + progress.update(batch_size) logging.debug("generated recommendations in %0.2fs", time.time() - start) diff --git a/implicit/__init__.py b/implicit/__init__.py index 47a16062..419ffbc1 100644 --- a/implicit/__init__.py +++ b/implicit/__init__.py @@ -2,4 +2,4 @@ __version__ = "0.4.8" -__all__ = [als, approximate_als, bpr, nearest_neighbours, lmf, __version__] +__all__ = ["als", "approximate_als", "bpr", "nearest_neighbours", "lmf", "__version__"] diff --git a/implicit/_nearest_neighbours.pyx b/implicit/_nearest_neighbours.pyx index 46d83a93..a897e130 100644 --- a/implicit/_nearest_neighbours.pyx +++ b/implicit/_nearest_neighbours.pyx @@ -30,7 +30,7 @@ cdef extern from "nearest_neighbours.h" namespace "implicit" nogil: cdef class NearestNeighboursScorer(object): """ Class to return the top K items from multipying a users likes - by a precomputed similarity vector. """ + by a precomputed sparse similarity matrix. """ cdef SparseMatrixMultiplier[int, double] * neighbours cdef int[:] similarity_indptr diff --git a/implicit/cpu/als.py b/implicit/cpu/als.py index 887130c5..c2d15479 100644 --- a/implicit/cpu/als.py +++ b/implicit/cpu/als.py @@ -9,9 +9,9 @@ import scipy.sparse from tqdm.auto import tqdm -from ..recommender_base import MatrixFactorizationBase from ..utils import check_blas_config, check_random_state, nonzeros from . import _als +from .matrix_factorization_base import MatrixFactorizationBase log = logging.getLogger("implicit") diff --git a/implicit/cpu/bpr.h b/implicit/cpu/bpr.h index 34fee868..d8f45c19 100644 --- a/implicit/cpu/bpr.h +++ b/implicit/cpu/bpr.h @@ -1,7 +1,7 @@ -// Copyright 2018 Ben Frederickson +// Copyright 2018-2021 Ben Frederickson -#ifndef IMPLICIT_BPR_H_ -#define IMPLICIT_BPR_H_ +#ifndef IMPLICIT_CPU_BPR_H_ +#define IMPLICIT_CPU_BPR_H_ // We need to get the thread number to figure out which RNG to use, // but this will fail on OSX etc if we have no openmp enabled compiler. @@ -18,4 +18,4 @@ inline int get_thread_num() { return omp_get_thread_num(); } inline int get_thread_num() { return 0; } #endif } // namespace implicit -#endif // IMPLICIT_BPR_H_ +#endif // IMPLICIT_CPU_BPR_H_ diff --git a/implicit/cpu/bpr.pyx b/implicit/cpu/bpr.pyx index 09f95dee..fee32204 100644 --- a/implicit/cpu/bpr.pyx +++ b/implicit/cpu/bpr.pyx @@ -20,8 +20,8 @@ import scipy.sparse from libcpp.vector cimport vector -from ..recommender_base import MatrixFactorizationBase from ..utils import check_random_state +from .matrix_factorization_base import MatrixFactorizationBase log = logging.getLogger("implicit") diff --git a/implicit/cpu/matrix_factorization_base.py b/implicit/cpu/matrix_factorization_base.py new file mode 100644 index 00000000..d905bc2e --- /dev/null +++ b/implicit/cpu/matrix_factorization_base.py @@ -0,0 +1,213 @@ +""" Base class for recommendation algorithms in this package """ +import warnings + +import numpy as np +from scipy.sparse import lil_matrix + +from ..recommender_base import ModelFitError, RecommenderBase +from .topk import topk + + +class MatrixFactorizationBase(RecommenderBase): + """MatrixFactorizationBase contains common functionality for recommendation models. + + Attributes + ---------- + item_factors : ndarray + Array of latent factors for each item in the training set + user_factors : ndarray + Array of latent factors for each user in the training set + """ + + def __init__(self): + # learned parameters + self.item_factors = None + self.user_factors = None + + # cache of user, item norms (useful for calculating similar items) + self._user_norms, self._item_norms = None, None + + def recommend( + self, + userid, + user_items, + N=10, + filter_already_liked_items=True, + filter_items=None, + recalculate_user=False, + items=None, + ): + user = self._user_factor(userid, user_items, recalculate_user) + + item_factors = self.item_factors + + # if we have an item list to restrict down to, we need to filter the item_factors + # and filter_query_items + if items is not None: + if filter_items: + raise ValueError("Can't set both items and filter_items in recommend call") + + items = np.array(items) + items.sort() + item_factors = item_factors[items] + + # check selected items are in the model + if items.max() >= self.item_factors.shape[0] or items.min() < 0: + raise IndexError("Some itemids are not in the model") + + # get a CSR matrix of items to filter per-user + filter_query_items = None + if filter_already_liked_items: + filter_query_items = user_items[userid] + + # if we've been given a list of explicit itemids to rank, we need to filter down + if items is not None: + filter_query_items = _filter_items_from_sparse_matrix(items, filter_query_items) + + ids, scores = topk( + item_factors, + user, + N, + filter_query_items=filter_query_items, + filter_items=filter_items, + ) + + if np.isscalar(userid): + ids, scores = ids[0], scores[0] + + # if we've been given an explicit items list, remap the ids + if items is not None: + ids = items[ids] + + return ids, scores + + def recommend_all( + self, + user_items, + N=10, + recalculate_user=False, + filter_already_liked_items=True, + filter_items=None, + users_items_offset=0, + ): + warnings.warn( + "recommend_all is deprecated. Use recommend with an array of userids instead", + DeprecationWarning, + ) + + userids = np.arange(user_items.shape[0]) + users_items_offset + if users_items_offset: + adjusted = lil_matrix( + (user_items.shape[0] + users_items_offset, user_items.shape[1]), + dtype=user_items.dtype, + ) + adjusted[users_items_offset:] = user_items + user_items = adjusted.tocsr() + + ids, _ = self.recommend( + userids, + user_items, + N=N, + filter_already_liked_items=filter_already_liked_items, + filter_items=filter_items, + recalculate_user=recalculate_user, + ) + return ids + + recommend.__doc__ = RecommenderBase.recommend.__doc__ + + def _user_factor(self, userid, user_items, recalculate_user=False): + if recalculate_user: + if np.isscalar(userid): + return self.recalculate_user(userid, user_items) + else: + return np.stack([self.recalculate_user(i, user_items) for i in userid]) + + return self.user_factors[userid] + + def _item_factor(self, itemid, react_users, recalculate_item=False): + if recalculate_item: + if np.isscalar(itemid): + return self.recalculate_item(itemid, react_users) + else: + return np.stack([self.recalculate_item(i, react_users) for i in itemid]) + + return self.item_factors[itemid] + + def recalculate_user(self, userid, user_items): + raise NotImplementedError("recalculate_user is not supported with this model") + + def recalculate_item(self, itemid, react_users): + raise NotImplementedError("recalculate_item is not supported with this model") + + def similar_users(self, userid, N=10): + factor = self.user_factors[userid] + factors = self.user_factors + norms = self.user_norms + norm = norms[userid] + return self._get_similarity_score(factor, norm, factors, norms, N) + + similar_users.__doc__ = RecommenderBase.similar_users.__doc__ + + def similar_items(self, itemid, N=10, react_users=None, recalculate_item=False): + factor = self._item_factor(itemid, react_users, recalculate_item) + factors = self.item_factors + norms = self.item_norms + if recalculate_item: + if np.isscalar(itemid): + norm = np.linalg.norm(factor) + norm = norm if norm != 0 else 1e-10 + else: + norm = np.linalg.norm(factor, axis=1) + norm[norm == 0] = 1e-10 + else: + norm = norms[itemid] + + return self._get_similarity_score(factor, norm, factors, norms, N) + + similar_items.__doc__ = RecommenderBase.similar_items.__doc__ + + def _get_similarity_score(self, factor, norm, factors, norms, N): + ids, scores = topk(factors, factor, N, item_norms=norms) + if np.isscalar(norm): + ids, scores = ids[0], scores[0] + scores /= norm + else: + scores /= norm[:, None] + return ids, scores + + @property + def user_norms(self): + if self._user_norms is None: + self._user_norms = np.linalg.norm(self.user_factors, axis=-1) + # don't divide by zero in similar_items, replace with small value + self._user_norms[self._user_norms == 0] = 1e-10 + return self._user_norms + + @property + def item_norms(self): + if self._item_norms is None: + self._item_norms = np.linalg.norm(self.item_factors, axis=-1) + # don't divide by zero in similar_items, replace with small value + self._item_norms[self._item_norms == 0] = 1e-10 + return self._item_norms + + def _check_fit_errors(self): + is_nan = np.any(np.isnan(self.user_factors), axis=None) + is_nan |= np.any(np.isnan(self.item_factors), axis=None) + if is_nan: + raise ModelFitError("NaN encountered in factors") + + +def _filter_items_from_sparse_matrix(items, query_items): + """Remaps all the ids in query_items down to match the position + in the items filter. Requires items to be sorted""" + filter_query_items = query_items.tocoo() + + positions = np.searchsorted(items, filter_query_items.col) + positions = np.clip(positions, 0, len(items) - 1) + + filter_query_items.col = positions + filter_query_items.data[items[positions] != filter_query_items.col] = 0 + filter_query_items.eliminate_zeros() + return filter_query_items.tocsr() diff --git a/implicit/cpu/select.h b/implicit/cpu/select.h new file mode 100644 index 00000000..0aee4495 --- /dev/null +++ b/implicit/cpu/select.h @@ -0,0 +1,43 @@ +// Copyright 2021 Ben Frederickson +#ifndef IMPLICIT_CPU_SELECT_H_ +#define IMPLICIT_CPU_SELECT_H_ +#include +#include +#include +#include +#include + +namespace implicit { + +template +inline void select(const T * batch, int rows, int cols, int k, + int * ids, T * distances) { + std::vector> results; + std::greater > heap_order; + + + for (int row = 0; row < rows; ++row) { + results.clear(); + for (int col = 0; col < cols; ++col) { + T score = batch[row * cols + col]; + + if ((results.size() < k) || (score > results[0].first)) { + if (results.size() >= k) { + std::pop_heap(results.begin(), results.end(), heap_order); + results.pop_back(); + } + results.push_back(std::make_pair(score, col)); + std::push_heap(results.begin(), results.end(), heap_order); + } + } + + std::sort_heap(results.begin(), results.end(), heap_order); + + for (size_t i = 0; i < results.size(); ++i) { + ids[row * k + i] = results[i].second; + distances[row * k + i] = results[i].first; + } + } +} +} // namespace implicit +#endif // IMPLICIT_CPU_SELECT_H_ diff --git a/implicit/cpu/topk.pyx b/implicit/cpu/topk.pyx new file mode 100644 index 00000000..7b0442ff --- /dev/null +++ b/implicit/cpu/topk.pyx @@ -0,0 +1,68 @@ +import cython +import numpy as np + +from cython cimport floating, integral + +from cython.parallel import parallel, prange + + +cdef extern from "select.h" namespace "implicit" nogil: + cdef void select[T](const T * batch, + int batch_rows, int batch_columns, int k, + int * ids, T * distances) nogil except * + + +def topk(items, query, int k, item_norms=None, filter_query_items=None, filter_items=None, int num_threads=0): + if len(query.shape) == 1: + query = query.reshape((1, len(query))) + + cdef int query_rows = query.shape[0] + indices = np.zeros((query_rows, k), dtype="int32") + distances = np.zeros((query_rows, k), dtype=query.dtype) + + # TODO: figure out appropiate batch size from available memory + cdef int batch_size = 100 # TODO + + cdef int batches = (query_rows / batch_size) + if query_rows % batch_size: + batches += 1 + + # if we're only running one batch, don't create a threadpool + if batches == 1: + _topk_batch(items, query, k, 0, query_rows, indices, distances, item_norms=item_norms, filter_query_items=filter_query_items, filter_items=filter_items) + return indices, distances + + cdef int startidx, endidx, batch + + for batch in prange(batches, schedule="guided", num_threads=num_threads, nogil=True): + startidx = batch * batch_size + endidx = min(startidx + batch_size, query_rows) + with gil: + _topk_batch(items, query, k, startidx, endidx, indices, distances, item_norms=item_norms, filter_query_items=filter_query_items, filter_items=filter_items) + + return indices, distances + +def _topk_batch(items, query, int k, int startidx, int endidx, int[:, :] indices, floating[:, :] distances, + item_norms=None, filter_query_items=None, filter_items=None): + batch_distances = query[startidx: endidx].dot(items.T) + if item_norms is not None: + batch_distances = batch_distances / item_norms + + neginf = -np.finfo(batch_distances.dtype).max + if filter_query_items is not None: + for i, idx in enumerate(range(startidx, endidx)): + batch_distances[i, filter_query_items[idx].indices] = neginf + if filter_items is not None: + batch_distances[:, filter_items] = neginf + + cdef floating * c_distances = &distances[0, 0] + cdef int * c_indices = &indices[0, 0] + + cdef floating[:, :] batch_view = batch_distances + cdef floating * c_batch = &batch_view[0, 0] + cdef int rows = batch_view.shape[0] + cdef int cols = batch_view.shape[1] + + with nogil: + select(c_batch, rows, cols, k, c_indices + startidx * k, c_distances + startidx * k) + diff --git a/implicit/evaluation.pyx b/implicit/evaluation.pyx index fbc59202..bf870b4d 100644 --- a/implicit/evaluation.pyx +++ b/implicit/evaluation.pyx @@ -436,9 +436,9 @@ def ranking_metrics_at_k(model, train_user_items, test_user_items, int K=10, memset(ids, -1, sizeof(int) * K) with gil: - recs = model.recommend(u, train_user_items, N=K) + recs, _ = model.recommend(u, train_user_items, N=K) for i in range(len(recs)): - ids[i] = recs[i][0] + ids[i] = recs[i] progress.update(1) # mostly we're going to be blocked on the gil here, diff --git a/implicit/gpu/_cuda.pyx b/implicit/gpu/_cuda.pyx index 4fb7df5b..da227975 100644 --- a/implicit/gpu/_cuda.pyx +++ b/implicit/gpu/_cuda.pyx @@ -48,23 +48,33 @@ cdef class KnnQuery(object): def __dealloc__(self): del self.c_knn - def topk(self, Matrix items, Matrix m, int k, Matrix item_norms=None): + def topk(self, Matrix items, Matrix m, int k, Matrix item_norms=None, + COOMatrix query_filter=None, IntVector item_filter=None): cdef CppMatrix * queries = m.c_matrix + cdef CppCOOMatrix * c_query_filter = NULL + cdef CppVector[int] * c_item_filter = NULL cdef int rows = queries.rows cdef int[:, :] x cdef float[:, :] y cdef float * c_item_norms = NULL - if item_norms: + if item_norms is not None: c_item_norms = item_norms.c_matrix.data + if query_filter is not None: + c_query_filter = query_filter.c_matrix + + if item_filter is not None: + c_item_filter = item_filter.c_vector + + indices = np.zeros((rows, k), dtype="int32") distances = np.zeros((rows, k), dtype="float32") x = indices y = distances self.c_knn.topk(dereference(items.c_matrix), dereference(queries), k, - &x[0, 0], &y[0, 0], c_item_norms) + &x[0, 0], &y[0, 0], c_item_norms, c_query_filter, c_item_filter) return indices, distances @@ -115,16 +125,16 @@ cdef class Matrix(object): try: idx = np.array(idx).astype("int32") except Exception: - raise ValueError(f"don't know how to handle __getitem__ on {idx}") + raise IndexError(f"don't know how to handle __getitem__ on {idx}") if len(idx.shape) == 0: idx = idx.reshape([1]) if len(idx.shape) != 1: - raise ValueError(f"don't know how to handle __getitem__ on {idx} - shape={idx.shape}") + raise IndexError(f"don't know how to handle __getitem__ on {idx} - shape={idx.shape}") if ((idx < 0) | (idx >= self.c_matrix.rows)).any(): - raise ValueError(f"row id out of range for selecting items from matrix") + raise IndexError(f"row id out of range for selecting items from matrix") ids = IntVector(idx) ret.c_matrix = new CppMatrix(dereference(self.c_matrix), dereference(ids.c_vector)) diff --git a/implicit/gpu/bpr.py b/implicit/gpu/bpr.py index 7eafcbd0..5e6a7322 100644 --- a/implicit/gpu/bpr.py +++ b/implicit/gpu/bpr.py @@ -65,14 +65,14 @@ def __init__( self.verify_negative_samples = verify_negative_samples self.random_state = random_state - def fit(self, item_users, show_progress=True): - """Factorizes the item_users matrix + def fit(self, user_items, show_progress=True): + """Factorizes the user_items matrix Parameters ---------- - item_users: coo_matrix - Matrix of confidences for the liked items. This matrix should be a coo_matrix where - the rows of the matrix are the item, and the columns are the users that liked that item. + user_items: csr_matrix + Matrix of confidences for the liked items. This matrix should be a csr_matrix where + the rows of the matrix are the user, and the columns are the items liked by that user. BPR ignores the weight value of the matrix right now - it treats non zero entries as a binary signal that the user liked the item. show_progress : bool, optional @@ -81,15 +81,12 @@ def fit(self, item_users, show_progress=True): rs = check_random_state(self.random_state) # for now, all we handle is float 32 values - if item_users.dtype != np.float32: - item_users = item_users.astype(np.float32) + if user_items.dtype != np.float32: + user_items = user_items.astype(np.float32) - items, users = item_users.shape + users, items = user_items.shape # We need efficient user lookup for case of removing own likes - # TODO: might make more sense to just changes inputs to be users by items instead - # but that would be a major breaking API change - user_items = item_users.T.tocsr() if not user_items.has_sorted_indices: user_items.sort_indices() @@ -131,7 +128,7 @@ def fit(self, item_users, show_progress=True): log.debug("Running %i BPR training epochs", self.iterations) with tqdm(total=self.iterations, disable=not show_progress) as progress: - for epoch in range(self.iterations): + for _epoch in range(self.iterations): correct, skipped = implicit.gpu.bpr_update( userids, itemids, diff --git a/implicit/gpu/knn.cu b/implicit/gpu/knn.cu index 1212d103..e48fa2d3 100644 --- a/implicit/gpu/knn.cu +++ b/implicit/gpu/knn.cu @@ -89,7 +89,9 @@ KnnQuery::KnnQuery(size_t temp_memory) const static int MAX_SELECT_K = 128; void KnnQuery::topk(const Matrix & items, const Matrix & query, int k, - int * indices, float * distances, float * item_norms) { + int * indices, float * distances, float * item_norms, + const COOMatrix * query_filter, + Vector * item_filter) { if (query.cols != items.cols) { throw std::invalid_argument("Must have same number of columns in each matrix for topk"); } @@ -156,6 +158,30 @@ void KnnQuery::topk(const Matrix & items, const Matrix & query, int k, }); } + if (item_filter != NULL) { + auto count = thrust::make_counting_iterator(0); + float * data = temp_distances.data; + int * items = item_filter->data; + thrust::for_each(count, count + item_filter->size, + [=] __device__(int i) { + data[items[i]] = -FLT_MAX; + }); + } + + if (query_filter != NULL) { + auto count = thrust::make_counting_iterator(0); + int * row = query_filter->row; + int * col = query_filter->col; + float * data = temp_distances.data; + int items = temp_distances.cols; + thrust::for_each(count, count + query_filter->nonzeros, + [=] __device__(int i) { + if ((row[i] >= start) && (row[i] < end)) { + data[(row[i] -start) * items + col[i]] = -FLT_MAX; + } + }); + } + argpartition(temp_distances, k, indices + start * k, distances + start * k); // TODO: callback per batch (show progress etc) diff --git a/implicit/gpu/knn.h b/implicit/gpu/knn.h index 31526d09..d4e499d2 100644 --- a/implicit/gpu/knn.h +++ b/implicit/gpu/knn.h @@ -17,7 +17,9 @@ class KnnQuery { void topk(const Matrix & items, const Matrix & query, int k, int * indices, float * distances, - float * item_norms = NULL); + float * item_norms = NULL, + const COOMatrix * query_filter = NULL, + Vector * item_filter = NULL); void argpartition(const Matrix & items, int k, int * indices, float * distances); void argsort(const Matrix & items, int * indices, float * distances); diff --git a/implicit/gpu/knn.pxd b/implicit/gpu/knn.pxd index 3cd949b2..a53b7334 100644 --- a/implicit/gpu/knn.pxd +++ b/implicit/gpu/knn.pxd @@ -1,4 +1,4 @@ -from .matrix cimport Matrix +from .matrix cimport COOMatrix, Matrix, Vector cdef extern from "knn.h" namespace "implicit::gpu" nogil: @@ -6,4 +6,7 @@ cdef extern from "knn.h" namespace "implicit::gpu" nogil: KnnQuery(size_t max_temp_memory) except + void topk(const Matrix & items, const Matrix & query, int k, - int * indices, float * distances, float * item_norms) except + + int * indices, float * distances, + float * item_norms, + COOMatrix * query_filter, + Vector[int] * item_filter) except + diff --git a/implicit/gpu/matrix_factorization_base.py b/implicit/gpu/matrix_factorization_base.py index 67ba5f9a..f895425c 100644 --- a/implicit/gpu/matrix_factorization_base.py +++ b/implicit/gpu/matrix_factorization_base.py @@ -4,6 +4,7 @@ import implicit.gpu +from ..cpu.matrix_factorization_base import _filter_items_from_sparse_matrix from ..recommender_base import RecommenderBase @@ -35,49 +36,58 @@ def recommend( filter_already_liked_items=True, filter_items=None, recalculate_user=False, + items=None, ): if recalculate_user: raise NotImplementedError("recalculate_user isn't support on GPU yet") - liked = set() - if filter_already_liked_items: - liked.update(user_items[userid].indices) - if filter_items: - liked.update(filter_items) - count = N + len(liked) + item_factors = self.item_factors + if items is not None: + if filter_items: + raise ValueError("Can't set both items and filter_items in recommend call") - # calculate the top N items, removing the users own liked items from the results - # TODO: own like filtering (direct in topk class) - ids, scores = self._knn.topk(self.item_factors, self.user_factors[userid], count) + items = np.array(items) + items.sort() + item_factors = item_factors[items] - # TODO: handle batch mode - ids, scores = ids[0], scores[0] + # check selected items are in the model + if items.max() >= self.item_factors.shape[0] or items.min() < 0: + raise IndexError("Some itemids are not in the model") - if liked: - mask = np.in1d(ids, list(liked), invert=True) - ids, scores = ids[mask][:N], scores[mask][:N] - return ids, scores + if filter_items: + filter_items = implicit.gpu.IntVector(np.array(filter_items, dtype="int32")) - recommend.__doc__ = RecommenderBase.recommend.__doc__ + query_filter = None + if filter_already_liked_items: + query_filter = user_items[userid] - def rank_items(self, userid, user_items, selected_items, recalculate_user=False): - if recalculate_user: - raise NotImplementedError("recalculate_user isn't support on GPU yet") + # if we've been given a list of explicit itemids to rank, we need to filter down + if items is not None: + query_filter = _filter_items_from_sparse_matrix(items, query_filter) - # check selected items are in the model - if max(selected_items) >= self.item_factors.shape[0] or min(selected_items) < 0: - raise IndexError("Some of selected itemids are not in the model") + if query_filter.nnz: + query_filter = implicit.gpu.COOMatrix(query_filter.tocoo()) + else: + query_filter = None - item_factors = self.item_factors[selected_items] - user = self.user_factors[userid] + # calculate the top N items, removing the users own liked items from the results + ids, scores = self._knn.topk( + item_factors, + self.user_factors[userid], + N, + query_filter=query_filter, + item_filter=filter_items, + ) + + if np.isscalar(userid): + ids, scores = ids[0], scores[0] + + if items is not None: + ids = items[ids] - # once we have item_factors here, this should work - ids, scores = self._knn.topk(item_factors, user, len(selected_items)) - ids, scores = ids[0], scores[0] - ids = np.array(selected_items)[ids] return ids, scores - rank_items.__doc__ = RecommenderBase.rank_items.__doc__ + recommend.__doc__ = RecommenderBase.recommend.__doc__ @property def user_norms(self): @@ -97,12 +107,13 @@ def similar_users(self, userid, N=10): ids, scores = self._knn.topk( self.user_factors, self.user_factors[userid], N, self.user_norms ) - ids, scores = ids[0], scores[0] user_norms = self._user_norms_host[userid] - if not np.isscalar(user_norms): - user_norms = user_norms.reshape((len(user_norms), 1)) - scores /= user_norms + if np.isscalar(userid): + ids, scores = ids[0], scores[0] + scores /= user_norms + else: + scores /= user_norms[:, None] return ids, scores similar_users.__doc__ = RecommenderBase.similar_users.__doc__ @@ -113,12 +124,13 @@ def similar_items(self, itemid, N=10, react_users=None, recalculate_item=False): ids, scores = self._knn.topk( self.item_factors, self.item_factors[itemid], N, self.item_norms ) - ids, scores = ids[0], scores[0] item_norms = self._item_norms_host[itemid] - if not np.isscalar(item_norms): - item_norms = item_norms.reshape((len(item_norms), 1)) - scores /= item_norms + if np.isscalar(itemid): + ids, scores = ids[0], scores[0] + scores /= item_norms + else: + scores /= item_norms[:, None] return ids, scores similar_items.__doc__ = RecommenderBase.similar_items.__doc__ diff --git a/implicit/gpu/random.cu b/implicit/gpu/random.cu index 1cdef862..83cf7750 100644 --- a/implicit/gpu/random.cu +++ b/implicit/gpu/random.cu @@ -21,7 +21,7 @@ Matrix RandomState::uniform(int rows, int cols, float low, float high) { if ((low != 0.0) || (high != 1.0)) { auto start = thrust::device_pointer_cast(ret.data); - thrust::transform(start, start + rows*cols, start, + thrust::transform(start, start + rows*cols, start, thrust::placeholders::_1 = thrust::placeholders::_1 * (high - low) + low); } diff --git a/implicit/lmf.pyx b/implicit/lmf.pyx index 974eb6b4..51022243 100644 --- a/implicit/lmf.pyx +++ b/implicit/lmf.pyx @@ -22,7 +22,7 @@ import scipy.sparse from libcpp.vector cimport vector -from .recommender_base import MatrixFactorizationBase +from .cpu.matrix_factorization_base import MatrixFactorizationBase from .utils import check_random_state log = logging.getLogger("implicit") diff --git a/implicit/nearest_neighbours.py b/implicit/nearest_neighbours.py index 2946ecfc..1a9e8b7d 100644 --- a/implicit/nearest_neighbours.py +++ b/implicit/nearest_neighbours.py @@ -1,4 +1,4 @@ -import numpy +import numpy as np from numpy import bincount, log, log1p, sqrt from scipy.sparse import coo_matrix, csr_matrix @@ -41,66 +41,78 @@ def recommend( filter_already_liked_items=True, filter_items=None, recalculate_user=False, + items=None, ): """returns the best N recommendations for a user given its id""" + if not np.isscalar(userid): + return _batch( + self.recommend, + userid, + user_items=user_items, + N=N, + filter_already_liked_items=filter_already_liked_items, + filter_items=filter_items, + recalculate_user=recalculate_user, + items=items, + ) + if userid >= user_items.shape[0]: raise ValueError("userid is out of bounds of the user_items matrix") - # recalculate_user is ignored because this is not a model based algorithm - items = N - if filter_items: - items += len(filter_items) + if filter_items and items: + raise ValueError("Can't specify both filter_items and items") + + if filter_items is not None: + N += len(filter_items) + elif items is not None: + items = np.array(items) + N = self.similarity.shape[0] + # check if items contains itemids that are not in the model(user_items) + if items.max() >= N or items.min() < 0: + raise IndexError("Some of selected itemids are not in the model") ids, scores = self.scorer.recommend( userid, user_items.indptr, user_items.indices, user_items.data, - K=items, + K=N, remove_own_likes=filter_already_liked_items, ) - if filter_items: - mask = numpy.in1d(ids, filter_items, invert=True) + if filter_items is not None: + mask = np.in1d(ids, filter_items, invert=True) ids, scores = ids[mask][:N], scores[mask][:N] - return ids, scores - - def rank_items(self, userid, user_items, selected_items, recalculate_user=False): - """Rank given items for a user and returns sorted item list""" - # check if selected_items contains itemids that are not in the model(user_items) - if max(selected_items) >= user_items.shape[1] or min(selected_items) < 0: - raise IndexError("Some of selected itemids are not in the model") + elif items is not None: + mask = np.in1d(ids, items) + ids, scores = ids[mask], scores[mask] - selected_items = numpy.array(selected_items) + # returned items should be equal to input selected items + missing = items[np.in1d(items, ids, invert=True)] + if missing.size: + ids = np.append(ids, missing) + scores = np.append(scores, np.full(missing.size, -np.finfo(scores.dtype).max)) - # calculate the relevance scores - liked_vector = user_items.getrow(userid) - recommendations = liked_vector.dot(self.similarity) - - # remove items that are not in the selected_items - ids, scores = recommendations.indices, recommendations.data - mask = numpy.in1d(ids, selected_items) - ids, scores = ids[mask], scores[mask] - - # returned items should be equal to input selected items - missing = selected_items[numpy.in1d(selected_items, ids, invert=True)] - if missing.size: - ids = numpy.append(ids, missing) - scores = numpy.append(scores, numpy.full(missing.size, -numpy.finfo(scores.dtype).max)) return ids, scores def similar_users(self, userid, N=10): raise NotImplementedError("Not implemented Yet") - def similar_items(self, itemid, N=10): + def similar_items(self, itemid, N=10, react_users=None, recalculate_item=False): """Returns a list of the most similar other items""" + if recalculate_item: + raise NotImplementedError("Recalculate_item isn't implemented") + + if not np.isscalar(itemid): + return _batch(self.similar_items, itemid, N=N) + if itemid >= self.similarity.shape[0]: - return numpy.array([]), numpy.array([]) + return np.array([]), np.array([]) ids = self.similarity[itemid].indices scores = self.similarity[itemid].data - best = numpy.argsort(scores)[::-1][:N] + best = np.argsort(scores)[::-1][:N] return ids[best], scores[best] def __getstate__(self): @@ -118,9 +130,7 @@ def __setstate__(self, state): def save(self, filename): m = self.similarity - numpy.savez( - filename, data=m.data, indptr=m.indptr, indices=m.indices, shape=m.shape, K=self.K - ) + np.savez(filename, data=m.data, indptr=m.indptr, indices=m.indices, shape=m.shape, K=self.K) @classmethod def load(cls, filename): @@ -128,7 +138,7 @@ def load(cls, filename): if not filename.endswith(".npz"): filename = filename + ".npz" - m = numpy.load(filename) + m = np.load(filename) similarity = csr_matrix((m["data"], m["indices"], m["indptr"]), shape=m["shape"]) ret = cls() @@ -182,7 +192,7 @@ def tfidf_weight(X): def normalize(X): """equivalent to scipy.preprocessing.normalize on sparse matrices - , but lets avoid another depedency just for a small utility function""" + , but lets avoid another dependency just for a small utility function""" X = coo_matrix(X) X.data = X.data / sqrt(bincount(X.row, X.data ** 2))[X.row] return X @@ -197,10 +207,33 @@ def bm25_weight(X, K1=100, B=0.8): idf = log(N) - log1p(bincount(X.col)) # calculate length_norm per document (artist) - row_sums = numpy.ravel(X.sum(axis=1)) + row_sums = np.ravel(X.sum(axis=1)) average_length = row_sums.mean() length_norm = (1.0 - B) + B * row_sums / average_length # weight matrix rows by bm25 X.data = X.data * (K1 + 1.0) / (K1 * length_norm[X.row] + X.data) * idf[X.col] return X + + +def _batch(func, ids, *args, N=10, **kwargs): + # we're running in batch mode, just loop over each item and call the scalar version of the + # function + output_ids = np.zeros((len(ids), N), dtype=np.int32) + output_scores = np.zeros((len(ids), N), dtype=np.float32) + + for i, idx in enumerate(ids): + batch_ids, batch_scores = func(idx, *args, N=N, **kwargs) + + # pad out to N items if we're returned fewer + missing_items = len(batch_ids) - N + if missing_items: + batch_ids = np.append(batch_ids, np.full(missing_items, -1)) + batch_scores = np.append( + batch_scores, np.full(missing_items, -np.finfo(np.float32).max) + ) + + output_ids[i] = batch_ids[:N] + output_scores[i] = batch_scores[:N] + + return output_ids, output_scores diff --git a/implicit/recommender_base.py b/implicit/recommender_base.py new file mode 100644 index 00000000..06b3ef82 --- /dev/null +++ b/implicit/recommender_base.py @@ -0,0 +1,136 @@ +""" Base class for recommendation algorithms in this package """ +import warnings +from abc import ABCMeta, abstractmethod + + +class ModelFitError(Exception): + pass + + +class RecommenderBase(object): + """Defines the interface that all recommendations models here expose""" + + __metaclass__ = ABCMeta + + @abstractmethod + def fit(self, user_items): + """ + Trains the model on a sparse matrix of item/user/weight + + Parameters + ---------- + user_items : csr_matrix + A matrix of shape (number_of_users, number_of_items). The nonzero + entries in this matrix are the items that are liked by each user. + The values are how confident you are that the item is liked by the user. + """ + + @abstractmethod + def recommend( + self, + userid, + user_items, + N=10, + filter_already_liked_items=True, + filter_items=None, + recalculate_user=False, + items=None, + ): + """ + Recommends items for a user + + Calculates the N best recommendations for a user, and returns a list of itemids, score. + + Parameters + ---------- + userid : Union[int, array_like] + The userid or array of userids to calculate recommendations for + user_items : csr_matrix + A sparse matrix of shape (number_users, number_items). This lets us look + up the liked items and their weights for the user. This is used to filter out + items that have already been liked from the output, and to also potentially + calculate the best items for this user. + N : int, optional + The number of results to return + filter_already_liked_items: bool, optional + When true, don't return items present in the training set that were rated + by the specificed user. + filter_items : sequence of ints, optional + List of extra item ids to filter out from the output + recalculate_user : bool, optional + When true, don't rely on stored user state and instead recalculate from the + passed in user_items + items: array_like, optional + Array of extra item ids. When set this will only rank the items in this array instead + of ranking every item the model was fit for. This parameter cannot be used with + filter_items + + Returns + ------- + tuple + Tuple of (itemids, scores) arrays. When calculating for a single user these array will + be 1-dimensional with N items. When passed an array of userids, these will be + 2-dimensional arrays with a row for each user. + """ + + @abstractmethod + def similar_users(self, userid, N=10): + """ + Calculates a list of similar users + + Parameters + ---------- + userid : Union[int, array_like] + The userid or an array of userids to retrieve similar users for + N : int, optional + The number of similar users to return + + Returns + ------- + tuple + Tuple of (itemids, scores) arrays + """ + + @abstractmethod + def similar_items(self, itemid, N=10, react_users=None, recalculate_item=False): + + """ + Calculates a list of similar items + + Parameters + ---------- + itemid : Union[int, array_like] + The item id or an array of item ids to retrieve similar items for + N : int, optional + The number of similar items to return + react_users : csr_matrix, optional + A sparse matrix of shape (number_items, number_users). This lets us look + up the reacted users and their weights for the item. + recalculate_item : bool, optional + When true, don't rely on stored item state and instead recalculate from the + passed in react_users + + Returns + ------- + tuple + Tuple of (itemids, scores) arrays + """ + + def rank_items(self, userid, user_items, selected_items, recalculate_user=False): + """ + Rank given items for a user and returns sorted item list. + + Deprecated. Use recommend with the 'items' parameter instead + """ + warnings.warn( + "rank_items is deprecated. Use recommend with the 'items' parameter instead", + DeprecationWarning, + stacklevel=2, + ) + return self.recommend( + userid, + user_items, + recalculate_user=recalculate_user, + items=selected_items, + filter_already_liked_items=False, + ) diff --git a/implicit/recommender_base.pyx b/implicit/recommender_base.pyx deleted file mode 100644 index 8cc9ad6b..00000000 --- a/implicit/recommender_base.pyx +++ /dev/null @@ -1,385 +0,0 @@ -""" Base class for recommendation algorithms in this package """ -# distutils: language = c++ -# cython: language_level=3 - -import itertools -import multiprocessing -from abc import ABCMeta, abstractmethod -from math import ceil - -import cython -import numpy as np -from cython.parallel import prange -from scipy.sparse import csr_matrix -from tqdm.auto import tqdm - - -# Define wrapper for C++ sorting function -cdef extern from "topnc.h": - cdef void fargsort_c(float A[], int n_row, int m_row, int m_cols, int ktop, int B[]) nogil - - -class ModelFitError(Exception): - pass - - -class RecommenderBase(object): - """ Defines the interface that all recommendations models here expose """ - __metaclass__ = ABCMeta - - @abstractmethod - def fit(self, user_items): - """ - Trains the model on a sparse matrix of item/user/weight - - Parameters - ---------- - user_items : csr_matrix - A matrix of shape (number_of_users, number_of_items). The nonzero - entries in this matrix are the items that are liked by each user. - The values are how confident you are that the item is liked by the user. - """ - pass - - @abstractmethod - def recommend(self, userid, user_items, - N=10, filter_already_liked_items=True, filter_items=None, recalculate_user=False): - """ - Recommends items for a user - - Calculates the N best recommendations for a user, and returns a list of itemids, score. - - Parameters - ---------- - userid : int - The userid to calculate recommendations for - user_items : csr_matrix - A sparse matrix of shape (number_users, number_items). This lets us look - up the liked items and their weights for the user. This is used to filter out - items that have already been liked from the output, and to also potentially - calculate the best items for this user. - N : int, optional - The number of results to return - filter_already_liked_items: bool, optional - When true, don't return items present in the training set that were rated - by the specificed user. - filter_items : sequence of ints, optional - List of extra item ids to filter out from the output - recalculate_user : bool, optional - When true, don't rely on stored user state and instead recalculate from the - passed in user_items - - Returns - ------- - tuple - Tuple of (itemids, scores) arrays - """ - pass - - @abstractmethod - def rank_items(self, userid, user_items, selected_items, recalculate_user=False): - """ - Rank given items for a user and returns sorted item list. - - Parameters - ---------- - userid : int - The userid to calculate recommendations for - user_items : csr_matrix - A sparse matrix of shape (number_users, number_items). This lets us - (optionally) recalculate user factors (see `recalculate_user` parameter) as - required - selected_items : List of itemids - recalculate_user : bool, optional - When true, don't rely on stored user state and instead recalculate from the - passed in user_items - - Returns - ------- - tuple - Tuple of (itemids, scores) arrays. it only contains items that appears in - input parameter selected_items - """ - pass - - @abstractmethod - def similar_users(self, userid, N=10): - """ - Calculates a list of similar users - - Parameters - ---------- - userid : int - The row id of the user to retrieve similar users for - N : int, optional - The number of similar users to return - - Returns - ------- - tuple - Tuple of (itemids, scores) arrays - """ - pass - - @abstractmethod - def similar_items(self, itemid, N=10, react_users=None, recalculate_item=False): - - """ - Calculates a list of similar items - - Parameters - ---------- - itemid : int - The row id of the item to retrieve similar items for - N : int, optional - The number of similar items to return - react_users : csr_matrix, optional - A sparse matrix of shape (number_items, number_users). This lets us look - up the reacted users and their weights for the item. - recalculate_item : bool, optional - When true, don't rely on stored item state and instead recalculate from the - passed in react_users - - Returns - ------- - tuple - Tuple of (itemids, scores) arrays - """ - pass - - -class MatrixFactorizationBase(RecommenderBase): - """ MatrixFactorizationBase contains common functionality for recommendation models. - - Attributes - ---------- - item_factors : ndarray - Array of latent factors for each item in the training set - user_factors : ndarray - Array of latent factors for each user in the training set - """ - def __init__(self): - # learned parameters - self.item_factors = None - self.user_factors = None - - # cache of user, item norms (useful for calculating similar items) - self._user_norms, self._item_norms = None, None - - def recommend(self, userid, user_items, - N=10, filter_already_liked_items=True, filter_items=None, recalculate_user=False): - user = self._user_factor(userid, user_items, recalculate_user) - - # calculate the top N items, removing the users own liked items from the results - scores = self.item_factors.dot(user) - - # filter out liked items - if filter_already_liked_items: - scores[user_items[userid].indices] = -np.finfo(scores.dtype).max - if filter_items: - scores[filter_items] = -np.finfo(scores.dtype).max - - if N < len(scores): - ids = np.argpartition(scores, -N)[-N:] - else: - ids = np.arange(len(scores)) - - ids = ids[np.argsort(scores[ids])[::-1]] - return ids, scores[ids] - - @cython.boundscheck(False) - @cython.wraparound(False) - @cython.nonecheck(False) - def recommend_all(self, user_items, int N=10, - recalculate_user=False, filter_already_liked_items=True, filter_items=None, - int num_threads=0, show_progress=True, int batch_size=0, - int users_items_offset=0): - """ - Recommends items for all users - - Calculates the N best recommendations for all users, and returns numpy ndarray of - shape (number_users, N) with item's ids in reversed probability order - - Parameters - ---------- - self : implicit.als.AlternatingLeastSquares - The fitted recommendation model - user_items : csr_matrix - A sparse matrix of shape (number_users, number_items). This lets us look - up the liked items and their weights for the user. This is used to filter out - items that have already been liked from the output, and to also potentially - calculate the best items for this user. - N : int, optional - The number of results to return - recalculate_user : bool, optional - When true, don't rely on stored user state and instead recalculate from the - passed in user_items - filter_already_liked_items : bool, optional - This is used to filter out items that have already been liked from the user_items - filter_items: list, optional - List of item id's to exclude from recommendations for all users - num_threads : int, optional - The number of threads to use for sorting scores in parallel by users. Default is - number of cores on machine - show_progress : bool, optional - Whether to show a progress bar - batch_size : int, optional - To optimise memory usage while matrix multiplication, users are separated into groups - and scored iteratively. By default batch_size == num_threads * 100 - users_items_offset : int, optional - Allow to pass a slice of user_items matrix to split calculations - - Returns - ------- - numpy ndarray - Array of (number_users, N) with item's ids in descending probability order - """ - - # Check N possibility - if filter_already_liked_items: - max_row_n = user_items.getnnz(axis=1).max() - if max_row_n > user_items.shape[1] - N: - raise ValueError(f"filter_already_liked_items:\ - cannot filter {max_row_n} and recommend {N} items\ - out of {user_items.shape[1]} available.") - if filter_items: - filter_items = list(set(filter_items)) - if len(filter_items) > user_items.shape[1] - N: - raise ValueError(f"filter_items:\ - cannot filter {len(filter_items)} and recommend {N} items\ - out of {user_items.shape[1]} available.") - - if num_threads==0: - num_threads=multiprocessing.cpu_count() - - if not isinstance(user_items, csr_matrix): - user_items = user_items.tocsr() - - factors_items = self.item_factors.T - - cdef: - int users_c = user_items.shape[0], items_c = user_items.shape[1] - int batch = num_threads * 100 if batch_size==0 else batch_size - int u_b, u_low, u_high, u_len, u - A = np.zeros((batch, items_c), dtype=np.float32) - cdef: - int users_c_b = ceil(users_c / float(batch)) - float[:, ::1] A_mv = A - float * A_mv_p = &A_mv[0, 0] - int[:, ::1] B_mv = np.zeros((users_c, N), dtype=np.intc) - int * B_mv_p = &B_mv[0, 0] - - progress = tqdm(total=users_c, disable=not show_progress) - # Separate all users in batches - for u_b in range(users_c_b): - u_low = u_b * batch - u_high = min([(u_b + 1) * batch, users_c]) - u_len = u_high - u_low - # Prepare array with scores for batch of users - users_factors = np.vstack([ - self._user_factor(u+users_items_offset, user_items, recalculate_user) - for u - in range(u_low, u_high, 1) - ]).astype(np.float32) - users_factors.dot(factors_items, out=A[:u_len]) - # Precalculate min if needed later - if filter_already_liked_items or filter_items: - A_min = np.amin(A) - # Filter out items from user_items if needed - if filter_already_liked_items: - A[user_items[u_low:u_high].nonzero()] = A_min - 1 - # Filter out constant items - if filter_items: - A[:, filter_items] = A_min - 1 - # Sort array of scores in parallel - for u in prange(u_len, nogil=True, num_threads=num_threads, schedule='dynamic'): - fargsort_c(A_mv_p, u, batch * u_b + u, items_c, N, B_mv_p) - progress.update(u_len) - progress.close() - return np.asarray(B_mv) - - def rank_items(self, userid, user_items, selected_items, recalculate_user=False): - user = self._user_factor(userid, user_items, recalculate_user) - - # check selected items are in the model - if max(selected_items) >= self.item_factors.shape[0] or min(selected_items) < 0: - raise IndexError("Some of selected itemids are not in the model") - - item_factors = self.item_factors[selected_items] - # calculate relevance scores of given items w.r.t the user - scores = item_factors.dot(user) - - # return sorted results - selected_items = np.array(selected_items) - best = np.argsort(scores)[::-1] - return selected_items[best], scores[best] - - recommend.__doc__ = RecommenderBase.recommend.__doc__ - - def _user_factor(self, userid, user_items, recalculate_user=False): - if recalculate_user: - return self.recalculate_user(userid, user_items) - else: - return self.user_factors[userid] - - def _item_factor(self, itemid, react_users, recalculate_item=False): - if recalculate_item: - return self.recalculate_item(itemid, react_users) - else: - return self.item_factors[itemid] - - def recalculate_user(self, userid, user_items): - raise NotImplementedError("recalculate_user is not supported with this model") - - def recalculate_item(self, itemid, react_users): - raise NotImplementedError("recalculate_item is not supported with this model") - - def similar_users(self, userid, N=10): - factor = self.user_factors[userid] - factors = self.user_factors - norms = self.user_norms - norm = norms[userid] - return self._get_similarity_score(factor, norm, factors, norms, N) - - similar_users.__doc__ = RecommenderBase.similar_users.__doc__ - - def similar_items(self, itemid, N=10, react_users=None, recalculate_item=False): - factor = self._item_factor(itemid, react_users, recalculate_item) - factors = self.item_factors - norms = self.item_norms - if recalculate_item: - norm = np.linalg.norm(factor) - norm = norm if norm != 0 else 1e-10 - else: - norm = norms[itemid] - return self._get_similarity_score(factor, norm, factors, norms, N) - - similar_items.__doc__ = RecommenderBase.similar_items.__doc__ - - def _get_similarity_score(self, factor, norm, factors, norms, N): - scores = factors.dot(factor) / (norm * norms) - best = np.argpartition(scores, -N)[-N:] - ids = best[np.argsort(scores[best])[::-1]] - return ids, scores[ids] - - @property - def user_norms(self): - if self._user_norms is None: - self._user_norms = np.linalg.norm(self.user_factors, axis=-1) - # don't divide by zero in similar_items, replace with small value - self._user_norms[self._user_norms == 0] = 1e-10 - return self._user_norms - - @property - def item_norms(self): - if self._item_norms is None: - self._item_norms = np.linalg.norm(self.item_factors, axis=-1) - # don't divide by zero in similar_items, replace with small value - self._item_norms[self._item_norms == 0] = 1e-10 - return self._item_norms - - def _check_fit_errors(self): - is_nan = np.any(np.isnan(self.user_factors), axis=None) - is_nan |= np.any(np.isnan(self.item_factors), axis=None) - if is_nan: - raise ModelFitError('NaN encountered in factors') diff --git a/implicit/topnc.cpp b/implicit/topnc.cpp deleted file mode 100644 index b540a59e..00000000 --- a/implicit/topnc.cpp +++ /dev/null @@ -1,30 +0,0 @@ -#include -#include -#include -#include - -#include "topnc.h" - -namespace { -struct target { - int index; - float value; -}; - -bool targets_compare(target t_i, target t_j) { return (t_i.value > t_j.value); } -} - -void fargsort_c(float A[], int n_row, int m_row, int m_cols, int ktop, int B[]) { - std::vector targets; - for ( int j = 0; j < m_cols; j++ ) { - target c; - c.index = j; - c.value = A[(n_row*m_cols) + j]; - targets.push_back(c); - } - std::partial_sort(targets.begin(), targets.begin() + ktop, targets.end(), targets_compare); - std::sort(targets.begin(), targets.begin() + ktop, targets_compare); - for (int j = 0; j < ktop; j++) { - B[(m_row*ktop) + j] = targets[j].index; - } -} diff --git a/implicit/topnc.h b/implicit/topnc.h deleted file mode 100644 index 51270cd4..00000000 --- a/implicit/topnc.h +++ /dev/null @@ -1,7 +0,0 @@ -/* "Copyright [2019] " [legal/copyright] */ -#ifndef IMPLICIT_TOPNC_H_ -#define IMPLICIT_TOPNC_H_ - -extern void fargsort_c(float A[], int n_row, int m_row, int m_cols, int ktop, int B[]); - -#endif // IMPLICIT_TOPNC_H_ diff --git a/setup.cfg b/setup.cfg index 68ea8638..64b9d377 100644 --- a/setup.cfg +++ b/setup.cfg @@ -9,6 +9,7 @@ description-file = README.md [flake8] max-line-length = 100 exclude = build,.eggs,.tox +ignore = E203 [isort] multi_line_output = 3 diff --git a/setup.py b/setup.py index 55c60632..c1346d3a 100644 --- a/setup.py +++ b/setup.py @@ -63,21 +63,9 @@ def define_extensions(): extra_compile_args=compile_args, extra_link_args=link_args, ) - for name in ["_als", "bpr"] + for name in ["_als", "bpr", "topk"] ] ) - modules.append( - Extension( - "implicit." + "recommender_base", - [ - os.path.join("implicit", "recommender_base" + src_ext), - os.path.join("implicit", "topnc.cpp"), - ], - language="c++", - extra_compile_args=compile_args, - extra_link_args=link_args, - ) - ) if CUDA: conda_prefix = os.getenv("CONDA_PREFIX") diff --git a/tests/als_test.py b/tests/als_test.py index b97d17e4..0b74de3e 100644 --- a/tests/als_test.py +++ b/tests/als_test.py @@ -9,12 +9,12 @@ from implicit.als import AlternatingLeastSquares from implicit.gpu import HAS_CUDA -from .recommender_base_test import RecommenderBaseTestMixin, get_checker_board +from .recommender_base_test import RecommenderBaseTestMixin class ALSTest(unittest.TestCase, RecommenderBaseTestMixin): def _get_model(self): - return AlternatingLeastSquares(factors=3, regularization=0, use_gpu=False, random_state=23) + return AlternatingLeastSquares(factors=32, regularization=0, use_gpu=False, random_state=23) if HAS_CUDA: @@ -152,7 +152,7 @@ def test_explain(): [0, 1, 0, 0, 0, 1], [0, 0, 2, 0, 1, 1], ], - dtype=np.float64, + dtype=np.float32, ) item_users = counts * 2 user_items = item_users.T.tocsr() @@ -205,39 +205,3 @@ def test_explain(): assert pytest.approx(score, abs=1e-4) == top_score_explained assert scores[:2] == top_scores assert items[:2] == top_items - - -def test_recommend_all(): - item_users = get_checker_board(50) - user_items = item_users.T.tocsr() - - model = AlternatingLeastSquares(factors=3, regularization=0, use_gpu=False, random_state=23) - model.fit(item_users, show_progress=False) - - recs = model.recommend_all(user_items, N=1, show_progress=False) - for userid in range(50): - assert len(recs[userid]) == 1 - - # the top item recommended should be the same as the userid: - # its the one withheld item for the user that is liked by - # all the other similar users - assert recs[userid][0] == userid - - offset = 2 - recs = model.recommend_all( - user_items[[2, 3, 4]], N=1, show_progress=False, users_items_offset=offset - ) - - for userid in range(2, 5): - assert len(recs[userid - offset]) == 1 - assert recs[userid - offset][0] == userid - - # try asking for more items than possible - with pytest.raises(ValueError): - model.recommend_all(user_items, N=10000, show_progress=False) - with pytest.raises(ValueError): - model.recommend_all(user_items, filter_items=list(range(10000)), show_progress=False) - - # filter recommended items using an additional filter list - recs = model.recommend_all(user_items, N=1, filter_items=[0], show_progress=False) - assert 0 not in recs diff --git a/tests/evaluation_test.py b/tests/evaluation_test.py index ded2c4a4..b21ef939 100644 --- a/tests/evaluation_test.py +++ b/tests/evaluation_test.py @@ -1,95 +1,118 @@ from __future__ import print_function -import unittest - import numpy as np +import pytest from scipy.sparse import csr_matrix, random -from implicit.evaluation import leave_k_out_split, train_test_split +import implicit +from implicit.datasets.movielens import get_movielens +from implicit.evaluation import leave_k_out_split, precision_at_k, train_test_split + + +def _get_sample_matrix(): + return csr_matrix((np.random.random((10, 10)) > 0.5).astype(np.float64)) + + +def _get_matrix(): + mat = random(100, 100, density=0.5, format="csr", dtype=np.float32) + return mat.tocoo() + + +def test_train_test_split(): + seed = np.random.randint(1000) + mat = _get_sample_matrix() + train, test = train_test_split(mat, 0.8, seed) + train2, test2 = train_test_split(mat, 0.8, seed) + assert np.all(train.todense() == train2.todense()) + + +def test_leave_k_out_returns_correct_shape(): + """ + Test that the output matrices are of the same shape as the input matrix. + """ + + mat = _get_matrix() + train, test = leave_k_out_split(mat, K=1) + assert train.shape == mat.shape + assert test.shape == mat.shape + + +def test_leave_k_out_outputs_produce_input(): + """ + Test that the sum of the output matrices is equal to the input matrix (i.e. + that summing the output matrices produces the input matrix). + """ + mat = _get_matrix() + train, test = leave_k_out_split(mat, K=1) + assert ((train + test) - mat).nnz == 0 -class EvaluationTest(unittest.TestCase): - @staticmethod - def _get_sample_matrix(): - return csr_matrix((np.random.random((10, 10)) > 0.5).astype(np.float64)) - @staticmethod - def _get_matrix(): - mat = random(100, 100, density=0.5, format="csr", dtype=np.float32) - return mat.tocoo() +def test_leave_k_split_is_reservable(): + """ + Test that the sum of the train and test set equals the input. + """ - def test_train_test_split(self): - seed = np.random.randint(1000) - mat = self._get_sample_matrix() - train, test = train_test_split(mat, 0.8, seed) - train2, test2 = train_test_split(mat, 0.8, seed) - self.assertTrue(np.all(train.todense() == train2.todense())) + mat = _get_matrix() + train, test = leave_k_out_split(mat, K=1) - def test_leave_k_out_returns_correct_shape(self): - """ - Test that the output matrices are of the same shape as the input matrix. - """ + # check all matrices are positive, non-zero + assert mat.sum() > 0 + assert test.sum() > 0 + assert train.sum() > 0 - mat = self._get_matrix() - train, test = leave_k_out_split(mat, K=1) - self.assertTrue(train.shape == mat.shape) - self.assertTrue(test.shape == mat.shape) + # check sum of train + test = input + assert ((train + test) - mat).nnz == 0 - def test_leave_k_out_outputs_produce_input(self): - """ - Test that the sum of the output matrices is equal to the input matrix (i.e. - that summing the output matrices produces the input matrix). - """ - mat = self._get_matrix() - train, test = leave_k_out_split(mat, K=1) - self.assertTrue(((train + test) - mat).nnz == 0) +def test_leave_k_out_gets_correct_train_only_shape(): + """Test that the correct number of users appear *only* in the train set.""" - def test_leave_k_split_is_reservable(self): - """ - Test that the sum of the train and test set equals the input. - """ + mat = _get_matrix() + train, test = leave_k_out_split(mat, K=1, train_only_size=0.8) + train_only = ~np.isin(np.unique(train.tocoo().row), test.tocoo().row) - mat = self._get_matrix() - train, test = leave_k_out_split(mat, K=1) + assert train_only.sum() == int(train.shape[0] * 0.8) - # check all matrices are positive, non-zero - self.assertTrue(mat.sum() > 0) - self.assertTrue(test.sum() > 0) - self.assertTrue(train.sum() > 0) - # check sum of train + test = input - self.assertTrue(((train + test) - mat).nnz == 0) +def test_leave_k_out_raises_error_for_k_less_than_zero(): + """ + Test that an error is raised when K < 0. + """ + with pytest.raises(ValueError): + leave_k_out_split(None, K=0) - def test_leave_k_out_gets_correct_train_only_shape(self): - """Test that the correct number of users appear *only* in the train set.""" - mat = self._get_matrix() - train, test = leave_k_out_split(mat, K=1, train_only_size=0.8) - train_only = ~np.isin(np.unique(train.tocoo().row), test.tocoo().row) - self.assertTrue(train_only.sum() == int(train.shape[0] * 0.8)) +def test_leave_k_out_raises_error_for_invalid_train_only_size_lower_bound(): + """ + Test that an error is raised when train_only_size < 0. + """ + with pytest.raises(ValueError): + leave_k_out_split(None, K=1, train_only_size=-1.0) - def test_leave_k_out_raises_error_for_k_less_than_zero(self): - """ - Test that an error is raised when K < 0. - """ - self.assertRaises(ValueError, leave_k_out_split, None, K=0) +def test_leave_k_out_raises_error_for_invalid_train_only_size_upper_bound(): + """ + Test that an error is raised when train_only_size >= 1. + """ + with pytest.raises(ValueError): + leave_k_out_split(None, K=1, train_only_size=1.0) - def test_leave_k_out_raises_error_for_invalid_train_only_size_lower_bound(self): - """ - Test that an error is raised when train_only_size < 0. - """ - self.assertRaises(ValueError, leave_k_out_split, None, K=1, train_only_size=-1.0) +def test_evaluate_movielens_100k(): + _, ratings = get_movielens(variant="100k") - def test_leave_k_out_raises_error_for_invalid_train_only_size_upper_bound(self): - """ - Test that an error is raised when train_only_size >= 1. - """ + # remove things < min_rating, and convert to implicit dataset + # by considering ratings as a binary preference only + min_rating = 3.0 + ratings.data[ratings.data < min_rating] = 0 + ratings.eliminate_zeros() + ratings.data = np.ones(len(ratings.data)) - self.assertRaises(ValueError, leave_k_out_split, None, K=1, train_only_size=1.0) + user_ratings = ratings.T.tocsr() + train, test = train_test_split(user_ratings) + model = implicit.als.AlternatingLeastSquares() + model.fit(train) -if __name__ == "__main__": - unittest.main() + assert precision_at_k(model, train, test) > 0.2 diff --git a/tests/recommender_base_test.py b/tests/recommender_base_test.py index c9ed9505..8c431d6f 100644 --- a/tests/recommender_base_test.py +++ b/tests/recommender_base_test.py @@ -3,6 +3,7 @@ from __future__ import print_function import pickle +import random import numpy as np from scipy.sparse import csr_matrix @@ -56,6 +57,32 @@ def test_recommend(self): ids, scores = model.recommend(0, user_items, N=1, filter_items=[0]) self.assertTrue(0 not in set(ids)) + def test_recommend_batch(self): + user_items = get_checker_board(50) + + model = self._get_model() + model.fit(user_items, show_progress=False) + + ids, _ = model.recommend(np.arange(50), user_items, N=1) + for userid in range(50): + assert len(ids[userid]) == 1 + + # the top item recommended should be the same as the userid: + # its the one withheld item for the user that is liked by + # all the other similar users + assert ids[userid][0] == userid + + userids = np.array([2, 3, 4]) + ids, _ = model.recommend(userids, user_items, N=1) + + for i, userid in enumerate(userids): + assert ids[i][0] == userid + + # filter recommended items using an additional filter list + ids, _ = model.recommend(userids, user_items, N=1, filter_items=[0]) + for i, _ in enumerate(userids): + assert 0 not in ids[i] + def test_recalculate_user(self): item_users = get_checker_board(50) user_items = item_users.T.tocsr() @@ -63,25 +90,33 @@ def test_recalculate_user(self): model = self._get_model() model.fit(item_users, show_progress=False) + try: + batch_ids, batch_scores = model.recommend( + np.arange(50), user_items, N=1, recalculate_user=True + ) + except NotImplementedError: + # some models don't support recalculating user on the fly, and thats ok + return + for userid in range(item_users.shape[1]): ids, scores = model.recommend(userid, user_items, N=1) self.assertEqual(len(ids), 1) user_vector = user_items[userid] # we should get the same item if we recalculate_user - try: - ids_from_liked, scores_from_liked = model.recommend( - userid=0, user_items=user_vector, N=1, recalculate_user=True - ) - self.assertEqual(ids[0], ids_from_liked[0]) - - # TODO: if we set regularization for the model to be sufficiently high, the - # scores from recalculate_user are slightly different. Investigate - # (could be difference between CG and cholesky optimizers?) - self.assertAlmostEqual(scores[0], scores_from_liked[0], places=4) - except NotImplementedError: - # some models don't support recalculating user on the fly, and thats ok - pass + ids_from_liked, scores_from_liked = model.recommend( + userid=0, user_items=user_vector, N=1, recalculate_user=True + ) + self.assertEqual(ids[0], ids_from_liked[0]) + + # TODO: if we set regularization for the model to be sufficiently high, the + # scores from recalculate_user are slightly different. Investigate + # (could be difference between CG and cholesky optimizers?) + self.assertAlmostEqual(scores[0], scores_from_liked[0], places=4) + + # we should also get the same from the batch recommend call already done + self.assertEqual(batch_ids[userid][0], ids_from_liked[0]) + self.assertAlmostEqual(batch_scores[userid][0], scores_from_liked[0], places=4) def test_evaluation(self): item_users = get_checker_board(50) @@ -90,7 +125,7 @@ def test_evaluation(self): model = self._get_model() model.fit(item_users, show_progress=False) - # we've withheld the diagnoal for testing, and have verified that in test_recommend + # we've withheld the diagonal for testing, and have verified that in test_recommend # it is returned for each user. So p@1 should be 1.0 p = precision_at_k( model, user_items.tocsr(), csr_matrix(np.eye(50)), K=1, show_progress=False @@ -108,6 +143,26 @@ def test_similar_users(self): for r in ids: self.assertEqual(r % 2, userid % 2) + def test_similar_users_batch(self): + model = self._get_model() + # calculating similar users in nearest-neighbours is not implemented yet + if isinstance(model, ItemItemRecommender): + return + model.fit(get_checker_board(256), show_progress=False) + userids = np.arange(50) + ids, scores = model.similar_users(userids, N=10) + + self.assertEqual(ids.shape, (50, 10)) + + for userid in userids: + # first user returned should be itself, and score should be ~1.0 + self.assertEqual(ids[userid][0], userid) + self.assertAlmostEqual(scores[userid][0], 1.0, places=4) + + # the rest of the users should be even or odd depending on the userid + for r in ids[userid]: + self.assertEqual(r % 2, userid % 2) + def test_similar_items(self): model = self._get_model() model.fit(get_checker_board(256), show_progress=False) @@ -116,6 +171,33 @@ def test_similar_items(self): for r in ids: self.assertEqual(r % 2, itemid % 2) + def test_similar_items_batch(self): + model = self._get_model() + user_items = get_checker_board(256) + model.fit(user_items, show_progress=False) + itemids = np.arange(50) + + def check_results(ids): + self.assertEqual(ids.shape, (50, 10)) + for itemid in itemids: + # first item returned should be itself + self.assertEqual(ids[itemid][0], itemid) + + # the rest of the items should be even or odd depending on the itemid + for r in ids[itemid]: + self.assertEqual(r % 2, itemid % 2) + + ids, _ = model.similar_items(itemids, N=10) + check_results(ids) + try: + ids, _ = model.similar_items( + itemids, N=10, recalculate_item=True, react_users=user_items.T.tocsr() + ) + check_results(ids) + except NotImplementedError: + # some models don't support recalculating user on the fly, and thats ok + pass + def test_zero_length_row(self): # get a matrix where a row/column is 0 item_users = get_checker_board(50).todense() @@ -152,8 +234,11 @@ def test_rank_items(self): model.fit(item_users, show_progress=False) for userid in range(50): - selected_items = np.random.randint(50, size=10).tolist() - ids, scores = model.rank_items(userid, user_items, selected_items) + selected_items = random.sample(range(50), 10) + + ids, scores = model.recommend( + userid, user_items, items=selected_items, filter_already_liked_items=False + ) # ranked list should have same items self.assertEqual(set(ids), set(selected_items)) @@ -164,10 +249,24 @@ def test_rank_items(self): # rank_items should raise IndexError if selected items contains wrong itemids with self.assertRaises(IndexError): wrong_item_list = selected_items + wrong_neg_items - model.rank_items(userid, user_items, wrong_item_list) + model.recommend(userid, user_items, items=wrong_item_list) with self.assertRaises(IndexError): wrong_item_list = selected_items + wrong_pos_items - model.rank_items(userid, user_items, wrong_item_list) + model.recommend(userid, user_items, items=wrong_item_list) + + def test_rank_items_batch(self): + item_users = get_checker_board(50) + user_items = item_users.T.tocsr() + + model = self._get_model() + model.fit(item_users, show_progress=False) + + selected_items = np.arange(10) * 3 + ids, scores = model.recommend(np.arange(50), user_items, items=selected_items) + + for userid in range(50): + current_ids = ids[userid] + self.assertEqual(set(current_ids), set(selected_items)) def test_pickle(self): item_users = get_checker_board(50) From eea4d263263ca8e768e0b873a1a75c1361ca97cf Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Thu, 25 Nov 2021 23:01:57 -0800 Subject: [PATCH 4/9] Add pylint and codespell to linters (#495) --- .github/workflows/build.yml | 14 ++++--- .pre-commit-config.yaml | 22 +++++++++++ .pylintrc | 46 +++++++++++++++++++++++ README.md | 2 +- benchmarks/README.md | 2 +- implicit/als.py | 23 ++++++------ implicit/approximate_als.py | 43 +++++++++++---------- implicit/bpr.py | 21 +++++------ implicit/cpu/als.py | 10 ++--- implicit/cpu/bpr.pyx | 2 +- implicit/cpu/matrix_factorization_base.py | 8 +--- implicit/cpu/topk.pyx | 2 +- implicit/datasets/lastfm.py | 2 +- implicit/datasets/million_song_dataset.py | 2 +- implicit/datasets/movielens.py | 6 +-- implicit/datasets/reddit.py | 2 +- implicit/gpu/als.py | 2 +- implicit/gpu/bpr.py | 8 ++-- implicit/gpu/matrix_factorization_base.py | 2 + implicit/gpu/utils.cuh | 2 +- implicit/lmf.pyx | 2 +- implicit/nearest_neighbours.py | 2 +- implicit/recommender_base.py | 5 +-- implicit/utils.py | 2 +- setup.cfg | 5 +++ tests/recommender_base_test.py | 14 +++---- 26 files changed, 160 insertions(+), 91 deletions(-) create mode 100644 .pre-commit-config.yaml create mode 100644 .pylintrc diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 12321e4b..caaf2876 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -3,11 +3,7 @@ name: Build -on: - push: - branches: [ main ] - pull_request: - branches: [ main ] +on: [push, pull_request] jobs: build: @@ -27,7 +23,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install flake8 isort cpplint black pytest + pip install flake8 isort cpplint black pytest codespell h5py pylint pip install -r requirements.txt - name: Lint with flake8 run: | @@ -42,6 +38,12 @@ jobs: - name: Lint with isort run: | isort -c . + - name: Lint with codespell + run: | + codespell + - name: Lint with pylint + run: | + pylint implicit - name: Build run: | python setup.py develop diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..ccb0b34e --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,22 @@ +repos: + - repo: https://github.com/timothycrosley/isort + rev: 5.10.1 + hooks: + - id: isort + additional_dependencies: [toml] + - repo: https://github.com/python/black + rev: 21.11b1 + hooks: + - id: black + - repo: https://github.com/pycqa/flake8 + rev: 4.0.1 + hooks: + - id: flake8 + - repo: https://github.com/pycqa/pylint + rev: v2.12.1 + hooks: + - id: pylint + - repo: https://github.com/codespell-project/codespell + rev: v2.1.0 + hooks: + - id: codespell diff --git a/.pylintrc b/.pylintrc new file mode 100644 index 00000000..9a140043 --- /dev/null +++ b/.pylintrc @@ -0,0 +1,46 @@ +[MASTER] + +extension-pkg-whitelist=implicit.cpu._als,implicit._nearest_neighbours,implicit.gpu._cuda,implicit.cpu.bpr,implicit.cpu.topk,numpy.random.mtrand + +[MESSAGES CONTROL] +disable=fixme, + missing-function-docstring, + missing-module-docstring, + missing-class-docstring, + wrong-import-order, + wrong-import-position, + ungrouped-imports, + line-too-long, + superfluous-parens, + trailing-whitespace, + invalid-name, + import-error, + no-self-use, + + # disable code-complexity check + too-many-function-args, + too-many-instance-attributes, + too-many-locals, + too-many-branches, + too-many-nested-blocks, + too-many-statements, + too-many-arguments, + too-many-return-statements, + too-many-lines, + too-few-public-methods, + + # TODO: fix underlying errors for these + import-outside-toplevel, + not-callable, + unused-argument, + abstract-method, + arguments-differ, + no-member, + no-name-in-module, + arguments-renamed, + import-self, + +[SIMILARITIES] +min-similarity-lines=16 +ignore-docstrings=yes +ignore-imports=yes diff --git a/README.md b/README.md index a98a4600..172c7dc1 100644 --- a/README.md +++ b/README.md @@ -93,7 +93,7 @@ which can be installed with homebrew: ```brew install gcc```. Running on Windows 3.5+. GPU Support requires at least version 11 of the [NVidia CUDA Toolkit](https://developer.nvidia.com/cuda-downloads). The build will use the ```nvcc``` compiler -that is found on the path, but this can be overriden by setting the CUDAHOME enviroment variable +that is found on the path, but this can be overridden by setting the CUDAHOME environment variable to point to your cuda installation. This library has been tested with Python 3.6, 3.7, 3.8 and 3.9 on Ubuntu, OSX and Windows. diff --git a/benchmarks/README.md b/benchmarks/README.md index e69394fb..346e0c81 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -66,4 +66,4 @@ Note that this dataset was filtered down for all versions to reviews that were p stars), to simulate a truly implicit dataset. Implicit on the CPU seems to suffer a bit here relative to the other options. It seems like there might -be a single threaded bottleneck at some point thats worth examining later. +be a single threaded bottleneck at some point that's worth examining later. diff --git a/implicit/als.py b/implicit/als.py index ced62558..dbabb93a 100644 --- a/implicit/als.py +++ b/implicit/als.py @@ -61,15 +61,14 @@ def AlternatingLeastSquares( calculate_training_loss=calculate_training_loss, random_state=random_state, ) - else: - return implicit.cpu.als.AlternatingLeastSquares( - factors, - regularization, - dtype, - use_native, - use_cg, - iterations, - calculate_training_loss, - num_threads, - random_state, - ) + return implicit.cpu.als.AlternatingLeastSquares( + factors, + regularization, + dtype, + use_native, + use_cg, + iterations, + calculate_training_loss, + num_threads, + random_state, + ) diff --git a/implicit/approximate_als.py b/implicit/approximate_als.py index eba15807..e18c207e 100644 --- a/implicit/approximate_als.py +++ b/implicit/approximate_als.py @@ -69,13 +69,13 @@ class NMSLibAlternatingLeastSquares(AlternatingLeastSquares): def __init__( self, + *args, approximate_similar_items=True, approximate_recommend=True, method="hnsw", index_params=None, query_params=None, random_state=None, - *args, **kwargs ): if index_params is None: @@ -93,9 +93,9 @@ def __init__( self.index_params = index_params self.query_params = query_params - super(NMSLibAlternatingLeastSquares, self).__init__( - *args, random_state=random_state, **kwargs - ) + self.max_norm = None + + super().__init__(*args, random_state=random_state, **kwargs) def fit(self, Cui, show_progress=True): # nmslib can be a little chatty when first imported, disable some of @@ -104,7 +104,7 @@ def fit(self, Cui, show_progress=True): import nmslib # train the model - super(NMSLibAlternatingLeastSquares, self).fit(Cui, show_progress) + super().fit(Cui, show_progress) # create index for similar_items if self.approximate_similar_items: @@ -137,7 +137,7 @@ def fit(self, Cui, show_progress=True): def similar_items(self, itemid, N=10): if not self.approximate_similar_items: - return super(NMSLibAlternatingLeastSquares, self).similar_items(itemid, N) + return super().similar_items(itemid, N) neighbours, distances = self.similar_items_index.knnQuery(self.item_factors[itemid], N) return zip(neighbours, 1.0 - distances) @@ -152,7 +152,7 @@ def recommend( recalculate_user=False, ): if not self.approximate_recommend: - return super(NMSLibAlternatingLeastSquares, self).recommend( + return super().recommend( userid, user_items, N=N, @@ -216,21 +216,20 @@ class AnnoyAlternatingLeastSquares(AlternatingLeastSquares): def __init__( self, + *args, approximate_similar_items=True, approximate_recommend=True, n_trees=50, search_k=-1, random_state=None, - *args, **kwargs ): - super(AnnoyAlternatingLeastSquares, self).__init__( - *args, random_state=random_state, **kwargs - ) + super().__init__(*args, random_state=random_state, **kwargs) self.similar_items_index = None self.recommend_index = None + self.max_norm = None self.approximate_similar_items = approximate_similar_items self.approximate_recommend = approximate_recommend @@ -243,7 +242,7 @@ def fit(self, Cui, show_progress=True): import annoy # train the model - super(AnnoyAlternatingLeastSquares, self).fit(Cui, show_progress) + super().fit(Cui, show_progress) # build up an Annoy Index with all the item_factors (for calculating # similar items) @@ -267,7 +266,7 @@ def fit(self, Cui, show_progress=True): def similar_items(self, itemid, N=10): if not self.approximate_similar_items: - return super(AnnoyAlternatingLeastSquares, self).similar_items(itemid, N) + return super().similar_items(itemid, N) neighbours, dist = self.similar_items_index.get_nns_by_item( itemid, N, search_k=self.search_k, include_distances=True @@ -285,7 +284,7 @@ def recommend( recalculate_user=False, ): if not self.approximate_recommend: - return super(AnnoyAlternatingLeastSquares, self).recommend( + return super().recommend( userid, user_items, N=N, @@ -353,18 +352,20 @@ class FaissAlternatingLeastSquares(AlternatingLeastSquares): def __init__( self, + *args, approximate_similar_items=True, approximate_recommend=True, nlist=400, nprobe=20, use_gpu=implicit.gpu.HAS_CUDA, random_state=None, - *args, **kwargs ): self.similar_items_index = None self.recommend_index = None + self.quantizer = None + self.gpu_resources = None self.approximate_similar_items = approximate_similar_items self.approximate_recommend = approximate_recommend @@ -373,15 +374,13 @@ def __init__( self.nlist = nlist self.nprobe = nprobe self.use_gpu = use_gpu - super(FaissAlternatingLeastSquares, self).__init__( - *args, random_state=random_state, **kwargs - ) + super().__init__(*args, random_state=random_state, **kwargs) def fit(self, Cui, show_progress=True): import faiss # train the model - super(FaissAlternatingLeastSquares, self).fit(Cui, show_progress) + super().fit(Cui, show_progress) self.quantizer = faiss.IndexFlat(self.factors) @@ -433,7 +432,7 @@ def fit(self, Cui, show_progress=True): def similar_items(self, itemid, N=10): if not self.approximate_similar_items or (self.use_gpu and N >= 1024): - return super(FaissAlternatingLeastSquares, self).similar_items(itemid, N) + return super().similar_items(itemid, N) factors = self.item_factors[itemid] factors /= numpy.linalg.norm(factors) @@ -452,7 +451,7 @@ def recommend( recalculate_user=False, ): if not self.approximate_recommend: - return super(FaissAlternatingLeastSquares, self).recommend( + return super().recommend( userid, user_items, N=N, @@ -474,7 +473,7 @@ def recommend( # the GPU variant of faiss doesn't support returning more than 1024 results. # fall back to the exact match when this happens if self.use_gpu and count >= 1024: - return super(FaissAlternatingLeastSquares, self).recommend( + return super().recommend( userid, user_items, N=N, diff --git a/implicit/bpr.py b/implicit/bpr.py index 80cc258c..342bee2f 100644 --- a/implicit/bpr.py +++ b/implicit/bpr.py @@ -60,14 +60,13 @@ def BayesianPersonalizedRanking( verify_negative_samples=verify_negative_samples, random_state=random_state, ) - else: - return implicit.cpu.bpr.BayesianPersonalizedRanking( - factors, - learning_rate, - regularization, - dtype=dtype, - num_threads=num_threads, - iterations=iterations, - verify_negative_samples=verify_negative_samples, - random_state=random_state, - ) + return implicit.cpu.bpr.BayesianPersonalizedRanking( + factors, + learning_rate, + regularization, + dtype=dtype, + num_threads=num_threads, + iterations=iterations, + verify_negative_samples=verify_negative_samples, + random_state=random_state, + ) diff --git a/implicit/cpu/als.py b/implicit/cpu/als.py index c2d15479..11bbbe64 100644 --- a/implicit/cpu/als.py +++ b/implicit/cpu/als.py @@ -69,7 +69,7 @@ def __init__( random_state=None, ): - super(AlternatingLeastSquares, self).__init__() + super().__init__() # parameters on how to factorize self.factors = factors @@ -260,15 +260,15 @@ def explain(self, userid, user_items, itemid, user_weights=None, N=10): total_score = 0.0 h = [] h_len = 0 - for itemid, confidence in nonzeros(user_items, userid): + for other_itemid, confidence in nonzeros(user_items, userid): if confidence < 0: continue - factor = self.item_factors[itemid] + factor = self.item_factors[other_itemid] # s_u^ij = (y_i^t W^u) y_j score = weighted_item.dot(factor) * confidence total_score += score - contribution = (score, itemid) + contribution = (score, other_itemid) if h_len < N: heapq.heappush(h, contribution) h_len += 1 @@ -384,7 +384,7 @@ def least_squares_cg(Cui, X, Y, regularization, num_threads=0, cg_steps=3): if rsold < 1e-20: continue - for it in range(cg_steps): + for _ in range(cg_steps): # calculate Ap = YtCuYp - without actually calculating YtCuY Ap = YtY.dot(p) for i, confidence in nonzeros(Cui, u): diff --git a/implicit/cpu/bpr.pyx b/implicit/cpu/bpr.pyx index fee32204..285abe9f 100644 --- a/implicit/cpu/bpr.pyx +++ b/implicit/cpu/bpr.pyx @@ -176,7 +176,7 @@ class BayesianPersonalizedRanking(MatrixFactorizationBase): # we accept num_threads = 0 as indicating to create as many threads as we have cores, # but in that case we need the number of cores, since we need to initialize RNG state per - # thread. Get the appropiate value back from openmp + # thread. Get the appropriate value back from openmp cdef int num_threads = self.num_threads if not num_threads: num_threads = multiprocessing.cpu_count() diff --git a/implicit/cpu/matrix_factorization_base.py b/implicit/cpu/matrix_factorization_base.py index d905bc2e..85724d93 100644 --- a/implicit/cpu/matrix_factorization_base.py +++ b/implicit/cpu/matrix_factorization_base.py @@ -120,18 +120,14 @@ def _user_factor(self, userid, user_items, recalculate_user=False): if recalculate_user: if np.isscalar(userid): return self.recalculate_user(userid, user_items) - else: - return np.stack([self.recalculate_user(i, user_items) for i in userid]) - + return np.stack([self.recalculate_user(i, user_items) for i in userid]) return self.user_factors[userid] def _item_factor(self, itemid, react_users, recalculate_item=False): if recalculate_item: if np.isscalar(itemid): return self.recalculate_item(itemid, react_users) - else: - return np.stack([self.recalculate_item(i, react_users) for i in itemid]) - + return np.stack([self.recalculate_item(i, react_users) for i in itemid]) return self.item_factors[itemid] def recalculate_user(self, userid, user_items): diff --git a/implicit/cpu/topk.pyx b/implicit/cpu/topk.pyx index 7b0442ff..98725792 100644 --- a/implicit/cpu/topk.pyx +++ b/implicit/cpu/topk.pyx @@ -20,7 +20,7 @@ def topk(items, query, int k, item_norms=None, filter_query_items=None, filter_i indices = np.zeros((query_rows, k), dtype="int32") distances = np.zeros((query_rows, k), dtype=query.dtype) - # TODO: figure out appropiate batch size from available memory + # TODO: figure out appropriate batch size from available memory cdef int batch_size = 100 # TODO cdef int batches = (query_rows / batch_size) diff --git a/implicit/datasets/lastfm.py b/implicit/datasets/lastfm.py index c9cc567f..4ad08a3a 100644 --- a/implicit/datasets/lastfm.py +++ b/implicit/datasets/lastfm.py @@ -36,7 +36,7 @@ def generate_dataset(filename, outputfilename): http://ocelma.net/MusicRecommendationDataset/lastfm-360K.html You shouldn't have to run this yourself, and can instead just download the - output using the 'get_lastfm' funciton./ + output using the 'get_lastfm' function./ Note there are some invalid entries in this dataset, running this function will clean it up so pandas can read it: diff --git a/implicit/datasets/million_song_dataset.py b/implicit/datasets/million_song_dataset.py index aecfdf9f..e0dc0493 100644 --- a/implicit/datasets/million_song_dataset.py +++ b/implicit/datasets/million_song_dataset.py @@ -54,7 +54,7 @@ def generate_dataset( https://labrosa.ee.columbia.edu/millionsong/pages/getting-dataset You shouldn't have to run this yourself, and can instead just download the - output using the 'get_msd_taste_profile' funciton + output using the 'get_msd_taste_profile' function """ data = _read_triplets_dataframe(triplets_filename) track_info = _join_summary_file(data, summary_filename) diff --git a/implicit/datasets/movielens.py b/implicit/datasets/movielens.py index a5cff616..26688659 100644 --- a/implicit/datasets/movielens.py +++ b/implicit/datasets/movielens.py @@ -30,7 +30,7 @@ def get_movielens(variant="20m"): A sparse matrix where the row is the movieId, the column is the userId and the value is the rating. """ - filename = "movielens_%s.hdf5" % variant + filename = f"movielens_{variant}.hdf5" path = os.path.join(_download.LOCAL_CACHE_DIR, filename) if not os.path.isfile(path): @@ -50,9 +50,9 @@ def generate_dataset(path, variant="20m", outputpath="."): https://grouplens.org/datasets/movielens/20m/ You shouldn't have to run this yourself, and can instead just download the - output using the 'get_movielens' funciton./ + output using the 'get_movielens' function./ """ - filename = os.path.join(outputpath, "movielens_%s.hdf5" % variant) + filename = os.path.join(outputpath, f"movielens_{variant}.hdf5") if variant == "20m": ratings, movies = _read_dataframes_20M(path) diff --git a/implicit/datasets/reddit.py b/implicit/datasets/reddit.py index 297026c3..f3483c0a 100644 --- a/implicit/datasets/reddit.py +++ b/implicit/datasets/reddit.py @@ -40,7 +40,7 @@ def generate_dataset(filename, outputfilename): https://www.reddit.com/r/redditdev/comments/dtg4j/want_to_help_reddit_build_a_recommender_a_public/ You shouldn't have to run this yourself, and can instead just download the - output using the 'get_reddit' funciton. + output using the 'get_reddit' function. """ data = _read_dataframe(filename) _hfd5_from_dataframe(data, outputfilename) diff --git a/implicit/gpu/als.py b/implicit/gpu/als.py index 0f380ead..b11df189 100644 --- a/implicit/gpu/als.py +++ b/implicit/gpu/als.py @@ -54,7 +54,7 @@ def __init__( if not implicit.gpu.HAS_CUDA: raise ValueError("No CUDA extension has been built, can't train on GPU.") - super(AlternatingLeastSquares, self).__init__() + super().__init__() # parameters on how to factorize self.factors = factors diff --git a/implicit/gpu/bpr.py b/implicit/gpu/bpr.py index 5e6a7322..9ab97451 100644 --- a/implicit/gpu/bpr.py +++ b/implicit/gpu/bpr.py @@ -54,7 +54,7 @@ def __init__( verify_negative_samples=True, random_state=None, ): - super(BayesianPersonalizedRanking, self).__init__() + super().__init__() if not implicit.gpu.HAS_CUDA: raise ValueError("No CUDA extension has been built, can't train on GPU.") @@ -142,10 +142,10 @@ def fit(self, user_items, show_progress=True): ) progress.update(1) total = len(user_items.data) - if total != 0 and total != skipped: + if total and total != skipped: progress.set_postfix( { - "train_auc": "%.2f%%" % (100.0 * correct / (total - skipped)), - "skipped": "%.2f%%" % (100.0 * skipped / total), + "train_auc": f"{100.0 * correct / (total - skipped):0.2f}%", + "skipped": f"{100.0 * skipped / total:0.2f}%", } ) diff --git a/implicit/gpu/matrix_factorization_base.py b/implicit/gpu/matrix_factorization_base.py index f895425c..500c2cd4 100644 --- a/implicit/gpu/matrix_factorization_base.py +++ b/implicit/gpu/matrix_factorization_base.py @@ -26,6 +26,8 @@ def __init__(self): self.user_factors = None self._item_norms = None self._user_norms = None + self._user_norms_host = None + self._item_norms_host = None self._knn = implicit.gpu.KnnQuery() def recommend( diff --git a/implicit/gpu/utils.cuh b/implicit/gpu/utils.cuh index e71d64f8..5adadad3 100644 --- a/implicit/gpu/utils.cuh +++ b/implicit/gpu/utils.cuh @@ -86,7 +86,7 @@ float dot(float a, float b, float * shared) { float val = a * b ; val = warp_reduce_sum(val); - // write out the partial reduction to shared memory if appropiate + // write out the partial reduction to shared memory if appropriate if (lane == 0) { shared[warp] = val; } diff --git a/implicit/lmf.pyx b/implicit/lmf.pyx index 51022243..279eceec 100644 --- a/implicit/lmf.pyx +++ b/implicit/lmf.pyx @@ -113,7 +113,7 @@ class LogisticMatrixFactorization(MatrixFactorizationBase): # TODO: Add GPU training if self.use_gpu: - raise NotImplementedError("GPU version of LMF is not implemeneted yet!") + raise NotImplementedError("GPU version of LMF is not implemented yet!") @cython.cdivision(True) @cython.boundscheck(False) diff --git a/implicit/nearest_neighbours.py b/implicit/nearest_neighbours.py index 1a9e8b7d..482c6ccb 100644 --- a/implicit/nearest_neighbours.py +++ b/implicit/nearest_neighbours.py @@ -168,7 +168,7 @@ class BM25Recommender(ItemItemRecommender): """An Item-Item Recommender on BM25 distance between items""" def __init__(self, K=20, K1=1.2, B=0.75, num_threads=0): - super(BM25Recommender, self).__init__(K, num_threads) + super().__init__(K, num_threads) self.K1 = K1 self.B = B diff --git a/implicit/recommender_base.py b/implicit/recommender_base.py index 06b3ef82..c9868f88 100644 --- a/implicit/recommender_base.py +++ b/implicit/recommender_base.py @@ -7,7 +7,7 @@ class ModelFitError(Exception): pass -class RecommenderBase(object): +class RecommenderBase: """Defines the interface that all recommendations models here expose""" __metaclass__ = ABCMeta @@ -54,7 +54,7 @@ def recommend( The number of results to return filter_already_liked_items: bool, optional When true, don't return items present in the training set that were rated - by the specificed user. + by the specified user. filter_items : sequence of ints, optional List of extra item ids to filter out from the output recalculate_user : bool, optional @@ -93,7 +93,6 @@ def similar_users(self, userid, N=10): @abstractmethod def similar_items(self, itemid, N=10, react_users=None, recalculate_item=False): - """ Calculates a list of similar items diff --git a/implicit/utils.py b/implicit/utils.py index 8dedc94f..5c584fb4 100644 --- a/implicit/utils.py +++ b/implicit/utils.py @@ -17,7 +17,7 @@ def check_blas_config(): """checks to see if using OpenBlas/Intel MKL. If so, warn if the number of threads isn't set to 1 (causes severe perf issues when training - can be 10x slower)""" # don't warn repeatedly - global _checked_blas_config + global _checked_blas_config # pylint: disable=global-statement if _checked_blas_config: return _checked_blas_config = True diff --git a/setup.cfg b/setup.cfg index 64b9d377..4bb5ac10 100644 --- a/setup.cfg +++ b/setup.cfg @@ -22,6 +22,11 @@ known_third_party = scipy,annoy,numpy,cython,pandas line_length = 100 skip = build,.eggs,.tox +[codespell] +skip = ./.git,./.github,./build,./dist,./docs/build,.*egg-info.*,*.csv,*.tsv +ignore-words-list = als,coo,nd,unparseable,compiletime + + [bumpversion:file:implicit/__init__.py] [bumpversion:file:setup.py] diff --git a/tests/recommender_base_test.py b/tests/recommender_base_test.py index 8c431d6f..29fb9ced 100644 --- a/tests/recommender_base_test.py +++ b/tests/recommender_base_test.py @@ -38,7 +38,7 @@ def test_recommend(self): model.fit(item_users, show_progress=False) for userid in range(50): - ids, scores = model.recommend(userid, user_items, N=1) + ids, _ = model.recommend(userid, user_items, N=1) self.assertEqual(len(ids), 1) # the top item recommended should be the same as the userid: @@ -49,12 +49,12 @@ def test_recommend(self): # try asking for more items than possible, # should return only the available items # https://github.com/benfred/implicit/issues/22 - ids, scores = model.recommend(0, user_items, N=10000) + ids, _ = model.recommend(0, user_items, N=10000) self.assertTrue(len(ids)) # filter recommended items using an additional filter list # https://github.com/benfred/implicit/issues/26 - ids, scores = model.recommend(0, user_items, N=1, filter_items=[0]) + ids, _ = model.recommend(0, user_items, N=1, filter_items=[0]) self.assertTrue(0 not in set(ids)) def test_recommend_batch(self): @@ -95,7 +95,7 @@ def test_recalculate_user(self): np.arange(50), user_items, N=1, recalculate_user=True ) except NotImplementedError: - # some models don't support recalculating user on the fly, and thats ok + # some models don't support recalculating user on the fly, and that's ok return for userid in range(item_users.shape[1]): @@ -195,7 +195,7 @@ def check_results(ids): ) check_results(ids) except NotImplementedError: - # some models don't support recalculating user on the fly, and thats ok + # some models don't support recalculating user on the fly, and that's ok pass def test_zero_length_row(self): @@ -236,7 +236,7 @@ def test_rank_items(self): for userid in range(50): selected_items = random.sample(range(50), 10) - ids, scores = model.recommend( + ids, _ = model.recommend( userid, user_items, items=selected_items, filter_already_liked_items=False ) @@ -262,7 +262,7 @@ def test_rank_items_batch(self): model.fit(item_users, show_progress=False) selected_items = np.arange(10) * 3 - ids, scores = model.recommend(np.arange(50), user_items, items=selected_items) + ids, _ = model.recommend(np.arange(50), user_items, items=selected_items) for userid in range(50): current_ids = ids[userid] From 34327b8757a15d100572c275e6b37b657a1bf48f Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Wed, 1 Dec 2021 21:41:09 -0800 Subject: [PATCH 5/9] Add filtering options for similar_users and similar_items (#496) --- implicit/cpu/matrix_factorization_base.py | 59 +++++++++++++++++++---- implicit/gpu/knn.cu | 8 ++- implicit/gpu/matrix_factorization_base.py | 57 ++++++++++++++++++++-- implicit/nearest_neighbours.py | 36 +++++++++++--- implicit/recommender_base.py | 22 +++++++-- tests/recommender_base_test.py | 37 +++++++++++++- 6 files changed, 190 insertions(+), 29 deletions(-) diff --git a/implicit/cpu/matrix_factorization_base.py b/implicit/cpu/matrix_factorization_base.py index 85724d93..d8152867 100644 --- a/implicit/cpu/matrix_factorization_base.py +++ b/implicit/cpu/matrix_factorization_base.py @@ -53,7 +53,7 @@ def recommend( # check selected items are in the model if items.max() >= self.item_factors.shape[0] or items.min() < 0: - raise IndexError("Some itemids are not in the model") + raise IndexError("Some itemids in the items parameter in are not in the model") # get a CSR matrix of items to filter per-user filter_query_items = None @@ -136,19 +136,42 @@ def recalculate_user(self, userid, user_items): def recalculate_item(self, itemid, react_users): raise NotImplementedError("recalculate_item is not supported with this model") - def similar_users(self, userid, N=10): - factor = self.user_factors[userid] - factors = self.user_factors + def similar_users(self, userid, N=10, filter_users=None, users=None): + user_factors = self.user_factors norms = self.user_norms norm = norms[userid] - return self._get_similarity_score(factor, norm, factors, norms, N) + + # if we have an user list to restrict down to, we need to filter the user_factors + if users is not None: + if filter_users: + raise ValueError("Can't set both users and filter_users in similar_users call") + + users = np.array(users) + user_factors = user_factors[users] + norms = norms[users] + + # check selected items are in the model + if users.max() >= self.user_factors.shape[0] or users.min() < 0: + raise IndexError("Some userids in the users parameter are not in the model") + + factor = self.user_factors[userid] + ids, scores = self._get_similarity_score( + factor, norm, user_factors, norms, N, filter_items=filter_users + ) + if users is not None: + ids = users[ids] + + return ids, scores similar_users.__doc__ = RecommenderBase.similar_users.__doc__ - def similar_items(self, itemid, N=10, react_users=None, recalculate_item=False): + def similar_items( + self, itemid, N=10, react_users=None, recalculate_item=False, filter_items=None, items=None + ): factor = self._item_factor(itemid, react_users, recalculate_item) factors = self.item_factors norms = self.item_norms + if recalculate_item: if np.isscalar(itemid): norm = np.linalg.norm(factor) @@ -159,12 +182,30 @@ def similar_items(self, itemid, N=10, react_users=None, recalculate_item=False): else: norm = norms[itemid] - return self._get_similarity_score(factor, norm, factors, norms, N) + # if we have an item list to restrict down to, we need to filter the item_factors + if items is not None: + if filter_items: + raise ValueError("Can't set both items and filter_items in similar_items call") + + items = np.array(items) + factors = factors[items] + norms = norms[items] + + # check selected items are in the model + if items.max() >= self.item_factors.shape[0] or items.min() < 0: + raise IndexError("Some itemids in the items parameter are not in the model") + + ids, scores = self._get_similarity_score( + factor, norm, factors, norms, N, filter_items=filter_items + ) + if items is not None: + ids = items[ids] + return ids, scores similar_items.__doc__ = RecommenderBase.similar_items.__doc__ - def _get_similarity_score(self, factor, norm, factors, norms, N): - ids, scores = topk(factors, factor, N, item_norms=norms) + def _get_similarity_score(self, factor, norm, factors, norms, N, filter_items=None): + ids, scores = topk(factors, factor, N, item_norms=norms, filter_items=filter_items) if np.isscalar(norm): ids, scores = ids[0], scores[0] scores /= norm diff --git a/implicit/gpu/knn.cu b/implicit/gpu/knn.cu index e48fa2d3..5c3d7608 100644 --- a/implicit/gpu/knn.cu +++ b/implicit/gpu/knn.cu @@ -162,9 +162,13 @@ void KnnQuery::topk(const Matrix & items, const Matrix & query, int k, auto count = thrust::make_counting_iterator(0); float * data = temp_distances.data; int * items = item_filter->data; - thrust::for_each(count, count + item_filter->size, + int items_size = item_filter->size; + int cols = temp_distances.cols; + thrust::for_each(count, count + items_size * temp_distances.rows, [=] __device__(int i) { - data[items[i]] = -FLT_MAX; + int col = items[i % items_size]; + int row = i / items_size; + data[row * cols + col] = -FLT_MAX; }); } diff --git a/implicit/gpu/matrix_factorization_base.py b/implicit/gpu/matrix_factorization_base.py index 500c2cd4..4dfef10c 100644 --- a/implicit/gpu/matrix_factorization_base.py +++ b/implicit/gpu/matrix_factorization_base.py @@ -56,7 +56,7 @@ def recommend( if items.max() >= self.item_factors.shape[0] or items.min() < 0: raise IndexError("Some itemids are not in the model") - if filter_items: + if filter_items is not None: filter_items = implicit.gpu.IntVector(np.array(filter_items, dtype="int32")) query_filter = None @@ -105,11 +105,33 @@ def item_norms(self): self._item_norms_host = self._item_norms.to_numpy().reshape(self._item_norms.shape[1]) return self._item_norms - def similar_users(self, userid, N=10): + def similar_users(self, userid, N=10, filter_users=None, users=None): + norms = self.user_norms + user_factors = self.user_factors + if users is not None: + if filter_users: + raise ValueError("Can't set both users and filter_users in similar_users call") + + users = np.array(users) + user_factors = user_factors[users] + + # TODO: we should be able to do this all on the GPU + norms = implicit.gpu.Matrix(self._user_norms_host[users].reshape(1, len(users))) + + # check selected items are in the model + if users.max() >= self.user_factors.shape[0] or users.min() < 0: + raise IndexError("Some userids in the users parameter are not in the model") + + if filter_users is not None: + filter_users = implicit.gpu.IntVector(np.array(filter_users, dtype="int32")) + ids, scores = self._knn.topk( - self.user_factors, self.user_factors[userid], N, self.user_norms + user_factors, self.user_factors[userid], N, norms, item_filter=filter_users ) + if users is not None: + ids = users[ids] + user_norms = self._user_norms_host[userid] if np.isscalar(userid): ids, scores = ids[0], scores[0] @@ -120,13 +142,38 @@ def similar_users(self, userid, N=10): similar_users.__doc__ = RecommenderBase.similar_users.__doc__ - def similar_items(self, itemid, N=10, react_users=None, recalculate_item=False): + def similar_items( + self, itemid, N=10, react_users=None, recalculate_item=False, filter_items=None, items=None + ): if recalculate_item: raise NotImplementedError("recalculate_item isn't support on GPU yet") + + item_factors = self.item_factors + norms = self.item_norms + if items is not None: + if filter_items: + raise ValueError("Can't set both items and filter_items in similar_items call") + + items = np.array(items) + + # TODO: we should be able to do this all on the GPU + norms = implicit.gpu.Matrix(self._item_norms_host[items].reshape(1, len(items))) + item_factors = item_factors[items] + + # check selected items are in the model + if items.max() >= self.item_factors.shape[0] or items.min() < 0: + raise IndexError("Some itemids are not in the model") + + if filter_items is not None: + filter_items = implicit.gpu.IntVector(np.array(filter_items, dtype="int32")) + ids, scores = self._knn.topk( - self.item_factors, self.item_factors[itemid], N, self.item_norms + item_factors, self.item_factors[itemid], N, norms, item_filter=filter_items ) + if items is not None: + ids = items[ids] + item_norms = self._item_norms_host[itemid] if np.isscalar(itemid): ids, scores = ids[0], scores[0] diff --git a/implicit/nearest_neighbours.py b/implicit/nearest_neighbours.py index 482c6ccb..5556c602 100644 --- a/implicit/nearest_neighbours.py +++ b/implicit/nearest_neighbours.py @@ -59,7 +59,7 @@ def recommend( if userid >= user_items.shape[0]: raise ValueError("userid is out of bounds of the user_items matrix") - if filter_items and items: + if filter_items is not None and items is not None: raise ValueError("Can't specify both filter_items and items") if filter_items is not None: @@ -96,22 +96,43 @@ def recommend( return ids, scores - def similar_users(self, userid, N=10): - raise NotImplementedError("Not implemented Yet") + def similar_users(self, userid, N=10, filter_users=None, users=None): + raise NotImplementedError("similar_users isn't implemented for item-item recommenders") - def similar_items(self, itemid, N=10, react_users=None, recalculate_item=False): + def similar_items( + self, itemid, N=10, react_users=None, recalculate_item=False, filter_items=None, items=None + ): """Returns a list of the most similar other items""" if recalculate_item: raise NotImplementedError("Recalculate_item isn't implemented") + print("N", N) if not np.isscalar(itemid): - return _batch(self.similar_items, itemid, N=N) + return _batch(self.similar_items, itemid, N=N, filter_items=filter_items, items=items) + + if filter_items is not None and items is not None: + raise ValueError("Can't specify both filter_items and items") if itemid >= self.similarity.shape[0]: return np.array([]), np.array([]) ids = self.similarity[itemid].indices scores = self.similarity[itemid].data + + if filter_items is not None: + mask = np.in1d(ids, filter_items, invert=True) + ids, scores = ids[mask], scores[mask] + + elif items is not None: + mask = np.in1d(ids, items) + ids, scores = ids[mask], scores[mask] + + # returned items should be equal to input selected items + missing = items[np.in1d(items, ids, invert=True)] + if missing.size: + ids = np.append(ids, missing) + scores = np.append(scores, np.full(missing.size, -np.finfo(scores.dtype).max)) + best = np.argsort(scores)[::-1][:N] return ids[best], scores[best] @@ -226,8 +247,9 @@ def _batch(func, ids, *args, N=10, **kwargs): batch_ids, batch_scores = func(idx, *args, N=N, **kwargs) # pad out to N items if we're returned fewer - missing_items = len(batch_ids) - N - if missing_items: + missing_items = N - len(batch_ids) + print("i", i, "idx", idx, " missing ", missing_items) + if missing_items > 0: batch_ids = np.append(batch_ids, np.full(missing_items, -1)) batch_scores = np.append( batch_scores, np.full(missing_items, -np.finfo(np.float32).max) diff --git a/implicit/recommender_base.py b/implicit/recommender_base.py index c9868f88..099941ee 100644 --- a/implicit/recommender_base.py +++ b/implicit/recommender_base.py @@ -74,25 +74,32 @@ def recommend( """ @abstractmethod - def similar_users(self, userid, N=10): + def similar_users(self, userid, N=10, filter_users=None, users=None): """ - Calculates a list of similar users + Calculates the most similar users for a userid or array of userids Parameters ---------- userid : Union[int, array_like] - The userid or an array of userids to retrieve similar users for + The userid or an array of userids to retrieve similar users for. N : int, optional The number of similar users to return + filter_users: array_like, optional + An array of user ids to filter out from the results being returned + users: array_like, optional + An array of user ids to include in the output. If not set all users in the training + set will be included. Cannot be used with the filter_users options Returns ------- tuple - Tuple of (itemids, scores) arrays + Tuple of (userids, scores) arrays """ @abstractmethod - def similar_items(self, itemid, N=10, react_users=None, recalculate_item=False): + def similar_items( + self, itemid, N=10, react_users=None, recalculate_item=False, filter_items=None, items=None + ): """ Calculates a list of similar items @@ -108,6 +115,11 @@ def similar_items(self, itemid, N=10, react_users=None, recalculate_item=False): recalculate_item : bool, optional When true, don't rely on stored item state and instead recalculate from the passed in react_users + filter_items: array_like, optional + An array of item ids to filter out from the results being returned + items: array_like, optional + An array of item ids to include in the output. If not set all items in the training + set will be included. Cannot be used with the filter_items options Returns ------- diff --git a/tests/recommender_base_test.py b/tests/recommender_base_test.py index 29fb9ced..aeb4b9a2 100644 --- a/tests/recommender_base_test.py +++ b/tests/recommender_base_test.py @@ -23,7 +23,7 @@ def get_checker_board(X): return csr_matrix(ret - np.eye(X)) -class RecommenderBaseTestMixin(object): +class RecommenderBaseTestMixin: """Mixin to test a bunch of common functionality in models deriving from RecommenderBase""" @@ -163,6 +163,25 @@ def test_similar_users_batch(self): for r in ids[userid]: self.assertEqual(r % 2, userid % 2) + def test_similar_users_filter(self): + model = self._get_model() + # calculating similar users in nearest-neighbours is not implemented yet + if isinstance(model, ItemItemRecommender): + return + + model.fit(get_checker_board(256), show_progress=False) + userids = np.arange(50) + + ids, _ = model.similar_users(userids, N=10, filter_users=np.arange(52) * 5) + for userid in userids: + for r in ids[userid]: + self.assertTrue(r % 5 != 0) + + selected = np.arange(10) + ids, _ = model.similar_users(userids, N=10, users=selected) + for userid in userids: + self.assertEqual(set(ids[userid]), set(selected)) + def test_similar_items(self): model = self._get_model() model.fit(get_checker_board(256), show_progress=False) @@ -198,6 +217,22 @@ def check_results(ids): # some models don't support recalculating user on the fly, and that's ok pass + def test_similar_items_filter(self): + model = self._get_model() + + model.fit(get_checker_board(256), show_progress=False) + itemids = np.arange(50) + + ids, _ = model.similar_items(itemids, N=10, filter_items=np.arange(52) * 5) + for itemid in itemids: + for r in ids[itemid]: + self.assertTrue(r % 5 != 0) + + selected = np.arange(10) + ids, _ = model.similar_items(itemids, N=10, items=selected) + for itemid in itemids: + self.assertEqual(set(ids[itemid]), set(selected)) + def test_zero_length_row(self): # get a matrix where a row/column is 0 item_users = get_checker_board(50).todense() From 8f5c20b9c095debd2c7514c68e5217f48ac51bae Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Thu, 30 Dec 2021 19:38:43 -0800 Subject: [PATCH 6/9] Fix ANN models to work with latest API changes (#501) --- .pre-commit-config.yaml | 2 +- implicit/approximate_als.py | 311 ++++++++++++++++++++++++++------- implicit/nearest_neighbours.py | 38 +--- implicit/utils.py | 23 +++ tests/approximate_als_test.py | 38 +++- tests/recommender_base_test.py | 31 +++- 6 files changed, 338 insertions(+), 105 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ccb0b34e..a21f8009 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,7 +5,7 @@ repos: - id: isort additional_dependencies: [toml] - repo: https://github.com/python/black - rev: 21.11b1 + rev: 21.12b0 hooks: - id: black - repo: https://github.com/pycqa/flake8 diff --git a/implicit/approximate_als.py b/implicit/approximate_als.py index e18c207e..f185296d 100644 --- a/implicit/approximate_als.py +++ b/implicit/approximate_als.py @@ -3,14 +3,15 @@ See http://www.benfrederickson.com/approximate-nearest-neighbours-for-recommender-systems/ """ -import itertools import logging -import numpy +import numpy as np import implicit.gpu from implicit.cpu.als import AlternatingLeastSquares +from .utils import _batch_call + log = logging.getLogger("implicit") @@ -25,13 +26,13 @@ def augment_inner_product_matrix(factors): Basically this involves transforming each feature vector so that they have the same norm, which means the cosine of this transformed vector is proportional to the dot product (if the other vector in the cosine has a 0 in the extra dimension).""" - norms = numpy.linalg.norm(factors, axis=1) + norms = np.linalg.norm(factors, axis=1) max_norm = norms.max() # add an extra dimension so that the norm of each row is the same # (max_norm) - extra_dimension = numpy.sqrt(max_norm ** 2 - norms ** 2) - return max_norm, numpy.append(factors, extra_dimension.reshape(norms.shape[0], 1), axis=1) + extra_dimension = np.sqrt(max_norm ** 2 - norms ** 2) + return max_norm, np.append(factors, extra_dimension.reshape(norms.shape[0], 1), axis=1) class NMSLibAlternatingLeastSquares(AlternatingLeastSquares): @@ -114,11 +115,11 @@ def fit(self, Cui, show_progress=True): # there are some numerical instability issues here with # building a cosine index with vectors with 0 norms, hack around this # by just not indexing them - norms = numpy.linalg.norm(self.item_factors, axis=1) - ids = numpy.arange(self.item_factors.shape[0]) + norms = np.linalg.norm(self.item_factors, axis=1) + ids = np.arange(self.item_factors.shape[0]) # delete zero valued rows from the matrix - item_factors = numpy.delete(self.item_factors, ids[norms == 0], axis=0) + item_factors = np.delete(self.item_factors, ids[norms == 0], axis=0) ids = ids[norms != 0] self.similar_items_index.addDataPointBatch(item_factors, ids=ids) @@ -135,12 +136,39 @@ def fit(self, Cui, show_progress=True): self.recommend_index.createIndex(self.index_params, print_progress=show_progress) self.recommend_index.setQueryTimeParams(self.query_params) - def similar_items(self, itemid, N=10): + def similar_items( + self, itemid, N=10, react_users=None, recalculate_item=False, filter_items=None, items=None + ): if not self.approximate_similar_items: - return super().similar_items(itemid, N) + return super().similar_items( + itemid, + N, + react_users=react_users, + recalculate_item=recalculate_item, + filter_items=filter_items, + items=items, + ) + + if items is not None: + raise NotImplementedError("using an items filter isn't supported with ANN lookup") - neighbours, distances = self.similar_items_index.knnQuery(self.item_factors[itemid], N) - return zip(neighbours, 1.0 - distances) + factors = self._item_factor(itemid, react_users, recalculate_item) + count = N + if filter_items is not None: + count += len(filter_items) + + if np.isscalar(itemid): + ids, scores = self.similar_items_index.knnQuery(factors, count) + else: + results = self.similar_items_index.knnQueryBatch(factors, count) + ids = np.stack([result[0] for result in results]) + scores = np.stack([result[1] for result in results]) + + scores = 1.0 - scores + if filter_items is not None: + ids, scores = _filter_items_from_results(itemid, ids, scores, filter_items, N) + + return ids, scores def recommend( self, @@ -150,35 +178,60 @@ def recommend( filter_already_liked_items=True, filter_items=None, recalculate_user=False, + items=None, ): + if items and self.approximate_recommend: + raise NotImplementedError("using a 'items' list with ANN search isn't supported") + if not self.approximate_recommend: return super().recommend( userid, user_items, N=N, + filter_already_liked_items=filter_already_liked_items, filter_items=filter_items, recalculate_user=recalculate_user, + items=items, + ) + + # batch computation is hard here, fallback to looping over items + if not np.isscalar(userid): + return _batch_call( + self.recommend, + userid, + user_items=user_items, + N=N, + filter_already_liked_items=filter_already_liked_items, + filter_items=filter_items, + recalculate_user=recalculate_user, + items=items, ) user = self._user_factor(userid, user_items, recalculate_user) # calculate the top N items, removing the users own liked items from # the results - liked = set() - if filter_already_liked_items: - liked.update(user_items[userid].indices) + count = N if filter_items: - liked.update(filter_items) - count = N + len(liked) + count += len(filter_items) + filter_items = np.array(filter_items) + + if filter_already_liked_items: + user_likes = user_items[userid].indices + filter_items = ( + np.append(filter_items, user_likes) if filter_items is not None else user_likes + ) + count += len(user_likes) - query = numpy.append(user, 0) - ids, dist = self.recommend_index.knnQuery(query, count) + query = np.append(user, 0) + ids, scores = self.recommend_index.knnQuery(query, count) + scaling = self.max_norm * np.linalg.norm(query) + scores = scaling * (1.0 - (scores)) - # convert the distances from euclidean to cosine distance, - # and then rescale the cosine distance to go back to inner product - scaling = self.max_norm * numpy.linalg.norm(query) - dist = scaling * (1.0 - dist) - return list(itertools.islice((rec for rec in zip(ids, dist) if rec[0] not in liked), N)) + if filter_items is not None: + ids, scores = _filter_items_from_results(userid, ids, scores, filter_items, N) + + return ids, scores class AnnoyAlternatingLeastSquares(AlternatingLeastSquares): @@ -264,15 +317,48 @@ def fit(self, Cui, show_progress=True): self.recommend_index.add_item(i, row) self.recommend_index.build(self.n_trees) - def similar_items(self, itemid, N=10): + def similar_items( + self, itemid, N=10, react_users=None, recalculate_item=False, filter_items=None, items=None + ): + if items is not None and self.approximate_similar_items: + raise NotImplementedError("using an items filter isn't supported with ANN lookup") + + count = N + if filter_items is not None: + count += len(filter_items) + if not self.approximate_similar_items: - return super().similar_items(itemid, N) + return super().similar_items( + itemid, + N, + react_users=react_users, + recalculate_item=recalculate_item, + filter_items=filter_items, + items=items, + ) - neighbours, dist = self.similar_items_index.get_nns_by_item( - itemid, N, search_k=self.search_k, include_distances=True + # annoy doesn't have a batch mode we can use + if not np.isscalar(itemid): + return _batch_call( + self.similar_items, + itemid, + N=N, + react_users=react_users, + recalculate_item=recalculate_item, + filter_items=filter_items, + ) + + factor = self._item_factor(itemid, react_users, recalculate_item) + + ids, scores = self.similar_items_index.get_nns_by_vector( + factor, N, search_k=self.search_k, include_distances=True ) - # transform distances back to cosine from euclidean distance - return zip(neighbours, 1 - (numpy.array(dist) ** 2) / 2) + ids, scores = np.array(ids), np.array(scores) + + if filter_items is not None: + ids, scores = _filter_items_from_results(itemid, ids, scores, filter_items, N) + + return ids, 1 - (scores ** 2) / 2 def recommend( self, @@ -282,37 +368,64 @@ def recommend( filter_already_liked_items=True, filter_items=None, recalculate_user=False, + items=None, ): + if items and self.approximate_recommend: + raise NotImplementedError("using a 'items' list with ANN search isn't supported") + if not self.approximate_recommend: return super().recommend( userid, user_items, N=N, + filter_already_liked_items=filter_already_liked_items, filter_items=filter_items, recalculate_user=recalculate_user, + items=items, ) + # batch computation isn't supported by annoy, fallback to looping over items + if not np.isscalar(userid): + return _batch_call( + self.recommend, + userid, + user_items=user_items, + N=N, + filter_already_liked_items=filter_already_liked_items, + filter_items=filter_items, + recalculate_user=recalculate_user, + items=items, + ) user = self._user_factor(userid, user_items, recalculate_user) # calculate the top N items, removing the users own liked items from # the results - liked = set() - if filter_already_liked_items: - liked.update(user_items[userid].indices) + count = N if filter_items: - liked.update(filter_items) - count = N + len(liked) + count += len(filter_items) + filter_items = np.array(filter_items) - query = numpy.append(user, 0) - ids, dist = self.recommend_index.get_nns_by_vector( + if filter_already_liked_items: + user_likes = user_items[userid].indices + filter_items = ( + np.append(filter_items, user_likes) if filter_items is not None else user_likes + ) + count += len(user_likes) + + query = np.append(user, 0) + ids, scores = self.recommend_index.get_nns_by_vector( query, count, include_distances=True, search_k=self.search_k ) + ids, scores = np.array(ids), np.array(scores) + + if filter_items is not None: + ids, scores = _filter_items_from_results(userid, ids, scores, filter_items, N) # convert the distances from euclidean to cosine distance, # and then rescale the cosine distance to go back to inner product - scaling = self.max_norm * numpy.linalg.norm(query) - dist = scaling * (1 - (numpy.array(dist) ** 2) / 2) - return list(itertools.islice((rec for rec in zip(ids, dist) if rec[0] not in liked), N)) + scaling = self.max_norm * np.linalg.norm(query) + scores = scaling * (1 - (scores ** 2) / 2) + return ids, scores class FaissAlternatingLeastSquares(AlternatingLeastSquares): @@ -412,7 +525,7 @@ def fit(self, Cui, show_progress=True): # likewise build up cosine index for similar_items, using an inner product # index on normalized vectors` - norms = numpy.linalg.norm(item_factors, axis=1) + norms = np.linalg.norm(item_factors, axis=1) norms[norms == 0] = 1e-10 normalized = (item_factors.T / norms).T.astype("float32") @@ -430,16 +543,43 @@ def fit(self, Cui, show_progress=True): index.nprobe = self.nprobe self.similar_items_index = index - def similar_items(self, itemid, N=10): - if not self.approximate_similar_items or (self.use_gpu and N >= 1024): - return super().similar_items(itemid, N) + def similar_items( + self, itemid, N=10, react_users=None, recalculate_item=False, filter_items=None, items=None + ): + if items is not None and self.approximate_similar_items: + raise NotImplementedError("using an items filter isn't supported with ANN lookup") + + count = N + if filter_items is not None: + count += len(filter_items) + + if not self.approximate_similar_items or (self.use_gpu and count >= 1024): + return super().similar_items( + itemid, + N, + react_users=react_users, + recalculate_item=recalculate_item, + filter_items=filter_items, + items=items, + ) - factors = self.item_factors[itemid] - factors /= numpy.linalg.norm(factors) - (dist,), (ids,) = self.similar_items_index.search( - factors.reshape(1, -1).astype("float32"), N - ) - return zip(ids, dist) + factors = self._item_factor(itemid, react_users, recalculate_item) + + if np.isscalar(itemid): + factors /= np.linalg.norm(factors) + factors = factors.reshape(1, -1) + else: + factors /= np.linalg.norm(factors, axis=1)[:, None] + + scores, ids = self.similar_items_index.search(factors.astype("float32"), count) + + if np.isscalar(itemid): + ids, scores = ids[0], scores[0] + + if filter_items is not None: + ids, scores = _filter_items_from_results(itemid, ids, scores, filter_items, N) + + return ids, scores def recommend( self, @@ -449,26 +589,51 @@ def recommend( filter_already_liked_items=True, filter_items=None, recalculate_user=False, + items=None, ): + if items and self.approximate_recommend: + raise NotImplementedError("using a 'items' list with ANN search isn't supported") + if not self.approximate_recommend: return super().recommend( userid, user_items, N=N, + filter_already_liked_items=filter_already_liked_items, filter_items=filter_items, recalculate_user=recalculate_user, + items=items, + ) + + # batch computation is tricky with filter_already_liked_items (requires querying a + # different number of rows per user). Instead just fallback to a faiss query per user + if filter_already_liked_items and not np.isscalar(userid): + return _batch_call( + self.recommend, + userid, + user_items=user_items, + N=N, + filter_already_liked_items=filter_already_liked_items, + filter_items=filter_items, + recalculate_user=recalculate_user, + items=items, ) user = self._user_factor(userid, user_items, recalculate_user) # calculate the top N items, removing the users own liked items from # the results - liked = set() - if filter_already_liked_items: - liked.update(user_items[userid].indices) + count = N if filter_items: - liked.update(filter_items) - count = N + len(liked) + count += len(filter_items) + filter_items = np.array(filter_items) + + if filter_already_liked_items: + user_likes = user_items[userid].indices + filter_items = ( + np.append(filter_items, user_likes) if filter_items is not None else user_likes + ) + count += len(user_likes) # the GPU variant of faiss doesn't support returning more than 1024 results. # fall back to the exact match when this happens @@ -481,11 +646,33 @@ def recommend( recalculate_user=recalculate_user, ) - # faiss expects multiple queries - convert query to a matrix - # and results back to single vectors - query = user.reshape(1, -1).astype("float32") - (dist,), (ids,) = self.recommend_index.search(query, count) + if np.isscalar(userid): + query = user.reshape(1, -1).astype("float32") + else: + query = user.astype("float32") - # convert the distances from euclidean to cosine distance, - # and then rescale the cosine distance to go back to inner product - return list(itertools.islice((rec for rec in zip(ids, dist) if rec[0] not in liked), N)) + scores, ids = self.recommend_index.search(query, count) + + if np.isscalar(userid): + ids, scores = ids[0], scores[0] + + if filter_items is not None: + ids, scores = _filter_items_from_results(userid, ids, scores, filter_items, N) + + return ids, scores + + +def _filter_items_from_results(queryid, ids, scores, filter_items, N): + if np.isscalar(queryid): + mask = np.in1d(ids, filter_items, invert=True) + ids, scores = ids[mask][:N], scores[mask][:N] + else: + rows = len(queryid) + filtered_scores = np.zeros((rows, N), dtype=scores.dtype) + filtered_ids = np.zeros((rows, N), dtype=ids.dtype) + for row in range(rows): + mask = np.in1d(ids[row], filter_items, invert=True) + filtered_ids[row] = ids[row][mask][:N] + filtered_scores[row] = scores[row][mask][:N] + ids, scores = filtered_ids, filtered_scores + return ids, scores diff --git a/implicit/nearest_neighbours.py b/implicit/nearest_neighbours.py index 5556c602..8542af8f 100644 --- a/implicit/nearest_neighbours.py +++ b/implicit/nearest_neighbours.py @@ -4,6 +4,7 @@ from ._nearest_neighbours import NearestNeighboursScorer, all_pairs_knn from .recommender_base import RecommenderBase +from .utils import _batch_call class ItemItemRecommender(RecommenderBase): @@ -45,7 +46,7 @@ def recommend( ): """returns the best N recommendations for a user given its id""" if not np.isscalar(userid): - return _batch( + return _batch_call( self.recommend, userid, user_items=user_items, @@ -106,9 +107,10 @@ def similar_items( if recalculate_item: raise NotImplementedError("Recalculate_item isn't implemented") - print("N", N) if not np.isscalar(itemid): - return _batch(self.similar_items, itemid, N=N, filter_items=filter_items, items=items) + return _batch_call( + self.similar_items, itemid, N=N, filter_items=filter_items, items=items + ) if filter_items is not None and items is not None: raise ValueError("Can't specify both filter_items and items") @@ -174,14 +176,14 @@ class CosineRecommender(ItemItemRecommender): def fit(self, counts, show_progress=True): # cosine distance is just the dot-product of a normalized matrix - ItemItemRecommender.fit(self, normalize(counts), show_progress) + ItemItemRecommender.fit(self, normalize(counts.T).T, show_progress) class TFIDFRecommender(ItemItemRecommender): """An Item-Item Recommender on TF-IDF distances between items""" def fit(self, counts, show_progress=True): - weighted = normalize(tfidf_weight(counts)) + weighted = normalize(tfidf_weight(counts.T)).T ItemItemRecommender.fit(self, weighted, show_progress) @@ -194,7 +196,7 @@ def __init__(self, K=20, K1=1.2, B=0.75, num_threads=0): self.B = B def fit(self, counts, show_progress=True): - weighted = bm25_weight(counts, self.K1, self.B) + weighted = bm25_weight(counts.T, self.K1, self.B).T ItemItemRecommender.fit(self, weighted, show_progress) @@ -235,27 +237,3 @@ def bm25_weight(X, K1=100, B=0.8): # weight matrix rows by bm25 X.data = X.data * (K1 + 1.0) / (K1 * length_norm[X.row] + X.data) * idf[X.col] return X - - -def _batch(func, ids, *args, N=10, **kwargs): - # we're running in batch mode, just loop over each item and call the scalar version of the - # function - output_ids = np.zeros((len(ids), N), dtype=np.int32) - output_scores = np.zeros((len(ids), N), dtype=np.float32) - - for i, idx in enumerate(ids): - batch_ids, batch_scores = func(idx, *args, N=N, **kwargs) - - # pad out to N items if we're returned fewer - missing_items = N - len(batch_ids) - print("i", i, "idx", idx, " missing ", missing_items) - if missing_items > 0: - batch_ids = np.append(batch_ids, np.full(missing_items, -1)) - batch_scores = np.append( - batch_scores, np.full(missing_items, -np.finfo(np.float32).max) - ) - - output_ids[i] = batch_ids[:N] - output_scores[i] = batch_scores[:N] - - return output_ids, output_scores diff --git a/implicit/utils.py b/implicit/utils.py index 5c584fb4..a28c6d96 100644 --- a/implicit/utils.py +++ b/implicit/utils.py @@ -53,3 +53,26 @@ def check_random_state(random_state): # otherwise try to initialize a new one, and let it fail through # on the numpy side if it doesn't work return np.random.RandomState(random_state) + + +def _batch_call(func, ids, *args, N=10, **kwargs): + # we're running in batch mode, just loop over each item and call the scalar version of the + # function + output_ids = np.zeros((len(ids), N), dtype=np.int32) + output_scores = np.zeros((len(ids), N), dtype=np.float32) + + for i, idx in enumerate(ids): + batch_ids, batch_scores = func(idx, *args, N=N, **kwargs) + + # pad out to N items if we're returned fewer + missing_items = N - len(batch_ids) + if missing_items > 0: + batch_ids = np.append(batch_ids, np.full(missing_items, -1)) + batch_scores = np.append( + batch_scores, np.full(missing_items, -np.finfo(np.float32).max) + ) + + output_ids[i] = batch_ids[:N] + output_scores[i] = batch_scores[:N] + + return output_ids, output_scores diff --git a/tests/approximate_als_test.py b/tests/approximate_als_test.py index ae3b0c3b..6d18be88 100644 --- a/tests/approximate_als_test.py +++ b/tests/approximate_als_test.py @@ -13,48 +13,66 @@ # don't require annoy/faiss/nmslib to be installed try: - import annoy # noqa + import annoy # noqa pylint: disable=unused-import class AnnoyALSTest(unittest.TestCase, RecommenderBaseTestMixin): def _get_model(self): - return AnnoyAlternatingLeastSquares(factors=2, regularization=0, random_state=23) + return AnnoyAlternatingLeastSquares(factors=32, regularization=0, random_state=23) def test_pickle(self): # pickle isn't supported on annoy indices pass + def test_rank_items(self): + pass + + def test_rank_items_batch(self): + pass + except ImportError: pass try: - import nmslib # noqa + import nmslib # noqa pylint: disable=unused-import class NMSLibALSTest(unittest.TestCase, RecommenderBaseTestMixin): def _get_model(self): return NMSLibAlternatingLeastSquares( - factors=2, regularization=0, index_params={"post": 2}, random_state=23 + factors=32, regularization=0, index_params={"post": 2}, random_state=23 ) def test_pickle(self): # pickle isn't supported on nmslib indices pass + def test_rank_items(self): + pass + + def test_rank_items_batch(self): + pass + except ImportError: pass try: - import faiss # noqa + import faiss # noqa pylint: disable=unused-import class FaissALSTest(unittest.TestCase, RecommenderBaseTestMixin): def _get_model(self): return FaissAlternatingLeastSquares( - nlist=1, nprobe=1, factors=2, regularization=0, use_gpu=False, random_state=23 + nlist=1, nprobe=1, factors=32, regularization=0, use_gpu=False, random_state=23 ) def test_pickle(self): # pickle isn't supported on faiss indices pass + def test_rank_items(self): + pass + + def test_rank_items_batch(self): + pass + if HAS_CUDA: class FaissALSGPUTest(unittest.TestCase, RecommenderBaseTestMixin): @@ -77,7 +95,7 @@ def test_similar_items(self): # this causes the test_similar_items call to fail if we set regularization to 0 self.__regularization = 1.0 try: - super(FaissALSGPUTest, self).test_similar_items() + super().test_similar_items() finally: self.__regularization = 0.0 @@ -98,6 +116,12 @@ def test_pickle(self): # pickle isn't supported on faiss indices pass + def test_rank_items(self): + pass + + def test_rank_items_batch(self): + pass + except ImportError: pass diff --git a/tests/recommender_base_test.py b/tests/recommender_base_test.py index aeb4b9a2..0000f2f8 100644 --- a/tests/recommender_base_test.py +++ b/tests/recommender_base_test.py @@ -63,7 +63,7 @@ def test_recommend_batch(self): model = self._get_model() model.fit(user_items, show_progress=False) - ids, _ = model.recommend(np.arange(50), user_items, N=1) + ids, scores = model.recommend(np.arange(50), user_items, N=1) for userid in range(50): assert len(ids[userid]) == 1 @@ -72,6 +72,11 @@ def test_recommend_batch(self): # all the other similar users assert ids[userid][0] == userid + # make sure the batch recommend results match those for a single user + ids_user, scores_user = model.recommend(userid, user_items, N=1) + assert np.allclose(ids_user, ids[userid]) + assert np.allclose(scores_user, scores[userid]) + userids = np.array([2, 3, 4]) ids, _ = model.recommend(userids, user_items, N=1) @@ -83,6 +88,18 @@ def test_recommend_batch(self): for i, _ in enumerate(userids): assert 0 not in ids[i] + # also make sure the results when filter_already_liked_items=False match in batch vs + # scalar mode + ids, scores = model.recommend( + np.arange(50), user_items, N=5, filter_already_liked_items=False + ) + for userid in range(50): + ids_user, scores_user = model.recommend( + userid, user_items, N=5, filter_already_liked_items=False + ) + assert np.allclose(scores_user, scores[userid]) + assert np.allclose(ids_user, ids[userid]) + def test_recalculate_user(self): item_users = get_checker_board(50) user_items = item_users.T.tocsr() @@ -228,10 +245,14 @@ def test_similar_items_filter(self): for r in ids[itemid]: self.assertTrue(r % 5 != 0) - selected = np.arange(10) - ids, _ = model.similar_items(itemids, N=10, items=selected) - for itemid in itemids: - self.assertEqual(set(ids[itemid]), set(selected)) + try: + selected = np.arange(10) + ids, _ = model.similar_items(itemids, N=10, items=selected) + for itemid in itemids: + self.assertEqual(set(ids[itemid]), set(selected)) + except NotImplementedError: + # some models don't support a 'items' filter on the similar_items call + pass def test_zero_length_row(self): # get a matrix where a row/column is 0 From fb11321e58135d859e6cfd66266a4a99bd126358 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Fri, 31 Dec 2021 13:19:19 -0800 Subject: [PATCH 7/9] Approximate nearest neighbour for BPR/LMF and GPU models (#502) Approximate nearest neighbours used to only work for the ALS mode on the CPU. This change makes it so that we can compose ANN methods with any matrix factorization model (including BPR/LMF) and also use the GPU MF models as well. Currently this provides the same api in implicit/approximate_als.py for backwards compatibility - but this may be removed at a future date. Closes #487 --- .github/workflows/build.yml | 3 + .pylintrc | 7 +- implicit/ann/__init__.py | 0 implicit/ann/annoy.py | 232 +++++++++++ implicit/ann/faiss.py | 276 ++++++++++++ implicit/ann/nmslib.py | 235 +++++++++++ implicit/approximate_als.py | 738 +++------------------------------ implicit/utils.py | 36 ++ setup.py | 2 +- tests/approximate_als_test.py | 55 ++- tests/recommender_base_test.py | 45 +- 11 files changed, 918 insertions(+), 711 deletions(-) create mode 100644 implicit/ann/__init__.py create mode 100644 implicit/ann/annoy.py create mode 100644 implicit/ann/faiss.py create mode 100644 implicit/ann/nmslib.py diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index caaf2876..b38592ed 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -25,6 +25,9 @@ jobs: python -m pip install --upgrade pip pip install flake8 isort cpplint black pytest codespell h5py pylint pip install -r requirements.txt + - name: Install ANN Libraries + run: pip install annoy nmslib + if: runner.os == 'Linux' - name: Lint with flake8 run: | # stop the build if there are Python syntax errors or undefined names diff --git a/.pylintrc b/.pylintrc index 9a140043..5b43ad80 100644 --- a/.pylintrc +++ b/.pylintrc @@ -1,6 +1,8 @@ [MASTER] -extension-pkg-whitelist=implicit.cpu._als,implicit._nearest_neighbours,implicit.gpu._cuda,implicit.cpu.bpr,implicit.cpu.topk,numpy.random.mtrand +ignore-patterns=setup.py + +extension-pkg-whitelist=implicit.cpu._als,implicit._nearest_neighbours,implicit.gpu._cuda,implicit.cpu.bpr,implicit.cpu.topk,numpy.random.mtrand,nmslib,faiss [MESSAGES CONTROL] disable=fixme, @@ -39,8 +41,9 @@ disable=fixme, no-name-in-module, arguments-renamed, import-self, + protected-access, [SIMILARITIES] -min-similarity-lines=16 +min-similarity-lines=50 ignore-docstrings=yes ignore-imports=yes diff --git a/implicit/ann/__init__.py b/implicit/ann/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/implicit/ann/annoy.py b/implicit/ann/annoy.py new file mode 100644 index 00000000..5deda717 --- /dev/null +++ b/implicit/ann/annoy.py @@ -0,0 +1,232 @@ +import logging + +import annoy +import numpy as np + +import implicit.gpu +from implicit.recommender_base import RecommenderBase +from implicit.utils import _batch_call, _filter_items_from_results, augment_inner_product_matrix + +log = logging.getLogger("implicit") + + +class AnnoyModel(RecommenderBase): + + """Speeds up inference calls to MatrixFactorization models by using an + `Annoy `_ index to calculate similar items and + recommend items. + + Parameters + ---------- + model : MatrixFactorizationBase + A matrix factorization model to use for the factors + n_trees : int, optional + The number of trees to use when building the Annoy index. More trees gives higher precision + when querying. + search_k : int, optional + Provides a way to search more trees at runtime, giving the ability to have more accurate + results at the cost of taking more time. + approximate_similar_items : bool, optional + whether or not to build an Annoy index for computing similar_items + approximate_recommend : bool, optional + whether or not to build an Annoy index for the recommend call + + Attributes + ---------- + similar_items_index : annoy.AnnoyIndex + Annoy index for looking up similar items in the cosine space formed by the latent + item_factors + + recommend_index : annoy.AnnoyIndex + Annoy index for looking up similar items in the inner product space formed by the latent + item_factors + """ + + def __init__( + self, + model, + approximate_similar_items=True, + approximate_recommend=True, + n_trees=50, + search_k=-1, + ): + self.model = model + + self.similar_items_index = None + self.recommend_index = None + self.max_norm = None + + self.approximate_similar_items = approximate_similar_items + self.approximate_recommend = approximate_recommend + + self.n_trees = n_trees + self.search_k = search_k + + def fit(self, Cui, show_progress=True): + # train the model + self.model.fit(Cui, show_progress) + + item_factors = self.model.item_factors + if implicit.gpu.HAS_CUDA and isinstance(item_factors, implicit.gpu.Matrix): + item_factors = item_factors.to_numpy() + item_factors = item_factors.astype("float32") + + # build up an Annoy Index with all the item_factors (for calculating + # similar items) + if self.approximate_similar_items: + log.debug("Building annoy similar items index") + + self.similar_items_index = annoy.AnnoyIndex(item_factors.shape[1], "angular") + for i, row in enumerate(item_factors): + self.similar_items_index.add_item(i, row) + self.similar_items_index.build(self.n_trees) + + # build up a separate index for the inner product (for recommend + # methods) + if self.approximate_recommend: + log.debug("Building annoy recommendation index") + self.max_norm, extra = augment_inner_product_matrix(item_factors) + self.recommend_index = annoy.AnnoyIndex(extra.shape[1], "angular") + for i, row in enumerate(extra): + self.recommend_index.add_item(i, row) + self.recommend_index.build(self.n_trees) + + def similar_items( + self, itemid, N=10, react_users=None, recalculate_item=False, filter_items=None, items=None + ): + if items is not None and self.approximate_similar_items: + raise NotImplementedError("using an items filter isn't supported with ANN lookup") + + count = N + if filter_items is not None: + count += len(filter_items) + + if not self.approximate_similar_items: + return self.model.similar_items( + itemid, + N, + react_users=react_users, + recalculate_item=recalculate_item, + filter_items=filter_items, + items=items, + ) + + # annoy doesn't have a batch mode we can use + if not np.isscalar(itemid): + return _batch_call( + self.similar_items, + itemid, + N=N, + react_users=react_users, + recalculate_item=recalculate_item, + filter_items=filter_items, + ) + + # support recalculate_item if possible. TODO: refactor this + if hasattr(self.model, "_item_factor"): + factor = self.model._item_factor( + itemid, react_users, recalculate_item + ) # pylint: disable=protected-access + elif recalculate_item: + raise NotImplementedError(f"recalculate_item isn't supported with {self.model}") + else: + factor = self.model.item_factors[itemid] + if implicit.gpu.HAS_CUDA and isinstance(factor, implicit.gpu.Matrix): + factor = factor.to_numpy() + + if len(factor.shape) != 1: + factor = factor.squeeze() + + ids, scores = self.similar_items_index.get_nns_by_vector( + factor, N, search_k=self.search_k, include_distances=True + ) + ids, scores = np.array(ids), np.array(scores) + + if filter_items is not None: + ids, scores = _filter_items_from_results(itemid, ids, scores, filter_items, N) + + return ids, 1 - (scores ** 2) / 2 + + def recommend( + self, + userid, + user_items, + N=10, + filter_already_liked_items=True, + filter_items=None, + recalculate_user=False, + items=None, + ): + if items is not None and self.approximate_recommend: + raise NotImplementedError("using a 'items' list with ANN search isn't supported") + + if not self.approximate_recommend: + return self.model.recommend( + userid, + user_items, + N=N, + filter_already_liked_items=filter_already_liked_items, + filter_items=filter_items, + recalculate_user=recalculate_user, + items=items, + ) + + # batch computation isn't supported by annoy, fallback to looping over items + if not np.isscalar(userid): + return _batch_call( + self.recommend, + userid, + user_items=user_items, + N=N, + filter_already_liked_items=filter_already_liked_items, + filter_items=filter_items, + recalculate_user=recalculate_user, + items=items, + ) + + # support recalculate_user if possible (TODO: come back to this since its a bit of a hack) + if hasattr(self.model, "+_user_factor"): + user = self.model._user_factor( + userid, user_items, recalculate_user + ) # pylint: disable=protected-access + elif recalculate_user: + raise NotImplementedError(f"recalculate_user isn't supported with {self.model}") + else: + user = self.model.user_factors[userid] + if implicit.gpu.HAS_CUDA and isinstance(user, implicit.gpu.Matrix): + user = user.to_numpy() + + # calculate the top N items, removing the users own liked items from + # the results + count = N + if filter_items: + count += len(filter_items) + filter_items = np.array(filter_items) + + if filter_already_liked_items: + user_likes = user_items[userid].indices + filter_items = ( + np.append(filter_items, user_likes) if filter_items is not None else user_likes + ) + count += len(user_likes) + + query = np.append(user, 0) + ids, scores = self.recommend_index.get_nns_by_vector( + query, count, include_distances=True, search_k=self.search_k + ) + ids, scores = np.array(ids), np.array(scores) + + if filter_items is not None: + ids, scores = _filter_items_from_results(userid, ids, scores, filter_items, N) + + # convert the distances from euclidean to cosine distance, + # and then rescale the cosine distance to go back to inner product + scaling = self.max_norm * np.linalg.norm(query) + scores = scaling * (1 - (scores ** 2) / 2) + return ids, scores + + def similar_users(self, userid, N=10, filter_users=None, users=None): + raise NotImplementedError( + "similar_users isn't implemented with Annoy yet. (note: you can call " + " self.model.similar_models to get the same functionality on the inner model class)" + ) diff --git a/implicit/ann/faiss.py b/implicit/ann/faiss.py new file mode 100644 index 00000000..8c9cb3ae --- /dev/null +++ b/implicit/ann/faiss.py @@ -0,0 +1,276 @@ +import logging +import warnings + +import faiss +import numpy as np + +import implicit.gpu +from implicit.recommender_base import RecommenderBase +from implicit.utils import _batch_call, _filter_items_from_results + +log = logging.getLogger("implicit") + + +# pylint: disable=no-value-for-parameter + + +class FaissModel(RecommenderBase): + """ + Speeds up inference calls to MatrixFactorization models by using + `Faiss `_ to create approximate nearest neighbours + indices of the latent factors. + + Parameters + ---------- + model : MatrixFactorizationBase + A matrix factorization model to use for the factors + nlist : int, optional + The number of cells to use when building the Faiss index. + nprobe : int, optional + The number of cells to visit to perform a search. + use_gpu : bool, optional + Whether or not to enable run Faiss on the GPU. Requires faiss to have been + built with GPU support. + approximate_similar_items : bool, optional + whether or not to build an Faiss index for computing similar_items + approximate_recommend : bool, optional + whether or not to build an Faiss index for the recommend call + + Attributes + ---------- + similar_items_index : faiss.IndexIVFFlat + Faiss index for looking up similar items in the cosine space formed by the latent + item_factors + + recommend_index : faiss.IndexIVFFlat + Faiss index for looking up similar items in the inner product space formed by the latent + item_factors + """ + + def __init__( + self, + model, + approximate_similar_items=True, + approximate_recommend=True, + nlist=400, + nprobe=20, + use_gpu=implicit.gpu.HAS_CUDA, + ): + self.model = model + self.similar_items_index = None + self.recommend_index = None + self.quantizer = None + self.gpu_resources = None + self.factors = None + + self.approximate_similar_items = approximate_similar_items + self.approximate_recommend = approximate_recommend + + # hyper-parameters for FAISS + self.nlist = nlist + self.nprobe = nprobe + self.use_gpu = use_gpu + super().__init__() + + def fit(self, Cui, show_progress=True): + self.model.fit(Cui, show_progress) + + item_factors = self.model.item_factors + if implicit.gpu.HAS_CUDA and isinstance(item_factors, implicit.gpu.Matrix): + item_factors = item_factors.to_numpy() + item_factors = item_factors.astype("float32") + + self.factors = item_factors.shape[1] + + self.quantizer = faiss.IndexFlat(self.factors) + + if self.use_gpu: + self.gpu_resources = faiss.StandardGpuResources() + + if self.approximate_recommend: + log.debug("Building faiss recommendation index") + + # build up a inner product index here + if self.use_gpu: + index = faiss.GpuIndexIVFFlat( + self.gpu_resources, self.factors, self.nlist, faiss.METRIC_INNER_PRODUCT + ) + else: + index = faiss.IndexIVFFlat( + self.quantizer, self.factors, self.nlist, faiss.METRIC_INNER_PRODUCT + ) + + index.train(item_factors) + index.add(item_factors) + index.nprobe = self.nprobe + self.recommend_index = index + + if self.approximate_similar_items: + log.debug("Building faiss similar items index") + + # likewise build up cosine index for similar_items, using an inner product + # index on normalized vectors` + norms = np.linalg.norm(item_factors, axis=1) + norms[norms == 0] = 1e-10 + + normalized = (item_factors.T / norms).T.astype("float32") + if self.use_gpu: + index = faiss.GpuIndexIVFFlat( + self.gpu_resources, self.factors, self.nlist, faiss.METRIC_INNER_PRODUCT + ) + else: + index = faiss.IndexIVFFlat( + self.quantizer, self.factors, self.nlist, faiss.METRIC_INNER_PRODUCT + ) + + index.train(normalized) + index.add(normalized) + index.nprobe = self.nprobe + self.similar_items_index = index + + def similar_items( + self, itemid, N=10, react_users=None, recalculate_item=False, filter_items=None, items=None + ): + if items is not None and self.approximate_similar_items: + raise NotImplementedError("using an items filter isn't supported with ANN lookup") + + count = N + if filter_items is not None: + count += len(filter_items) + + if not self.approximate_similar_items or (self.use_gpu and count >= 1024): + return self.model.similar_items( + itemid, + N, + react_users=react_users, + recalculate_item=recalculate_item, + filter_items=filter_items, + items=items, + ) + + # support recalculate_item if possible. TODO: refactor this + if hasattr(self.model, "_item_factor"): + factors = self.model._item_factor( + itemid, react_users, recalculate_item + ) # pylint: disable=protected-access + elif recalculate_item: + raise NotImplementedError(f"recalculate_item isn't supported with {self.model}") + else: + factors = self.model.item_factors[itemid] + if implicit.gpu.HAS_CUDA and isinstance(factors, implicit.gpu.Matrix): + factors = factors.to_numpy() + + if np.isscalar(itemid): + factors /= np.linalg.norm(factors) + factors = factors.reshape(1, -1) + else: + factors /= np.linalg.norm(factors, axis=1)[:, None] + + scores, ids = self.similar_items_index.search(factors.astype("float32"), count) + + if np.isscalar(itemid): + ids, scores = ids[0], scores[0] + + if filter_items is not None: + ids, scores = _filter_items_from_results(itemid, ids, scores, filter_items, N) + + return ids, scores + + def recommend( + self, + userid, + user_items, + N=10, + filter_already_liked_items=True, + filter_items=None, + recalculate_user=False, + items=None, + ): + if items is not None and self.approximate_recommend: + raise NotImplementedError("using a 'items' list with ANN search isn't supported") + + # batch computation is tricky with filter_already_liked_items (requires querying a + # different number of rows per user). Instead just fallback to a faiss query per user + if filter_already_liked_items and not np.isscalar(userid): + return _batch_call( + self.recommend, + userid, + user_items=user_items, + N=N, + filter_already_liked_items=filter_already_liked_items, + filter_items=filter_items, + recalculate_user=recalculate_user, + items=items, + ) + + if not self.approximate_recommend: + warnings.warning("Calling recommend on a FaissModel with approximate_recommend=False") + return self.model.recommend( + userid, + user_items, + N=N, + filter_already_liked_items=filter_already_liked_items, + filter_items=filter_items, + recalculate_user=recalculate_user, + items=items, + ) + + # support recalculate_user if possible (TODO: come back to this since its a bit of a hack) + if hasattr(self.model, "+_user_factor"): + user = self.model._user_factor( + userid, user_items, recalculate_user + ) # pylint: disable=protected-access + elif recalculate_user: + raise NotImplementedError(f"recalculate_user isn't supported with {self.model}") + else: + user = self.model.user_factors[userid] + if implicit.gpu.HAS_CUDA and isinstance(user, implicit.gpu.Matrix): + user = user.to_numpy() + + # calculate the top N items, removing the users own liked items from + # the results + count = N + if filter_items: + count += len(filter_items) + filter_items = np.array(filter_items) + + if filter_already_liked_items: + user_likes = user_items[userid].indices + filter_items = ( + np.append(filter_items, user_likes) if filter_items is not None else user_likes + ) + count += len(user_likes) + + # the GPU variant of faiss doesn't support returning more than 1024 results. + # fall back to the exact match when this happens + if self.use_gpu and count >= 1024: + return self.model.recommend( + userid, + user_items, + N=N, + filter_already_liked_items=filter_already_liked_items, + filter_items=filter_items, + recalculate_user=recalculate_user, + items=items, + ) + + if np.isscalar(userid): + query = user.reshape(1, -1).astype("float32") + else: + query = user.astype("float32") + + scores, ids = self.recommend_index.search(query, count) + + if np.isscalar(userid): + ids, scores = ids[0], scores[0] + + if filter_items is not None: + ids, scores = _filter_items_from_results(userid, ids, scores, filter_items, N) + + return ids, scores + + def similar_users(self, userid, N=10, filter_users=None, users=None): + raise NotImplementedError( + "similar_users isn't implemented with Faiss yet. (note: you can call " + " self.model.similar_models to get the same functionality on the inner model class)" + ) diff --git a/implicit/ann/nmslib.py b/implicit/ann/nmslib.py new file mode 100644 index 00000000..7b0139b0 --- /dev/null +++ b/implicit/ann/nmslib.py @@ -0,0 +1,235 @@ +import logging + +import nmslib +import numpy as np + +import implicit.gpu +from implicit.recommender_base import RecommenderBase +from implicit.utils import _batch_call, _filter_items_from_results, augment_inner_product_matrix + +log = logging.getLogger("implicit") + + +class NMSLibModel(RecommenderBase): + + """Speeds up inference calls to MatrixFactorization models by using + `NMSLib `_ to create approximate nearest neighbours + indices of the latent factors. + + Parameters + ---------- + model : MatrixFactorizationBase + A matrix factorization model to use for the factors + method : str, optional + The NMSLib method to use + index_params: dict, optional + Optional params to send to the createIndex call in NMSLib + query_params: dict, optional + Optional query time params for the NMSLib 'setQueryTimeParams' call + approximate_similar_items : bool, optional + whether or not to build an NMSLIB index for computing similar_items + approximate_recommend : bool, optional + whether or not to build an NMSLIB index for the recommend call + + Attributes + ---------- + similar_items_index : nmslib.FloatIndex + NMSLib index for looking up similar items in the cosine space formed by the latent + item_factors + + recommend_index : nmslib.FloatIndex + NMSLib index for looking up similar items in the inner product space formed by the latent + item_factors + """ + + def __init__( + self, + model, + approximate_similar_items=True, + approximate_recommend=True, + method="hnsw", + index_params=None, + query_params=None, + **kwargs, + ): + self.model = model + if index_params is None: + index_params = {"M": 16, "post": 0, "efConstruction": 400} + if query_params is None: + query_params = {"ef": 90} + + self.similar_items_index = None + self.recommend_index = None + + self.approximate_similar_items = approximate_similar_items + self.approximate_recommend = approximate_recommend + self.method = method + + self.index_params = index_params + self.query_params = query_params + + self.max_norm = None + + def fit(self, Cui, show_progress=True): + # nmslib can be a little chatty when first imported, disable some of + # the logging + logging.getLogger("nmslib").setLevel(logging.WARNING) + + # train the model + self.model.fit(Cui, show_progress) + item_factors = self.model.item_factors + if implicit.gpu.HAS_CUDA and isinstance(item_factors, implicit.gpu.Matrix): + item_factors = item_factors.to_numpy() + + # create index for similar_items + if self.approximate_similar_items: + log.debug("Building nmslib similar items index") + self.similar_items_index = nmslib.init(method=self.method, space="cosinesimil") + + # there are some numerical instability issues here with + # building a cosine index with vectors with 0 norms, hack around this + # by just not indexing them + norms = np.linalg.norm(item_factors, axis=1) + ids = np.arange(item_factors.shape[0]) + + # delete zero valued rows from the matrix + nonzero_item_factors = np.delete(item_factors, ids[norms == 0], axis=0) + ids = ids[norms != 0] + + self.similar_items_index.addDataPointBatch(nonzero_item_factors, ids=ids) + self.similar_items_index.createIndex(self.index_params, print_progress=show_progress) + self.similar_items_index.setQueryTimeParams(self.query_params) + + # build up a separate index for the inner product (for recommend + # methods) + if self.approximate_recommend: + log.debug("Building nmslib recommendation index") + self.max_norm, extra = augment_inner_product_matrix(item_factors) + self.recommend_index = nmslib.init(method="hnsw", space="cosinesimil") + self.recommend_index.addDataPointBatch(extra) + self.recommend_index.createIndex(self.index_params, print_progress=show_progress) + self.recommend_index.setQueryTimeParams(self.query_params) + + def similar_items( + self, itemid, N=10, react_users=None, recalculate_item=False, filter_items=None, items=None + ): + if not self.approximate_similar_items: + return self.model.similar_items( + itemid, + N, + react_users=react_users, + recalculate_item=recalculate_item, + filter_items=filter_items, + items=items, + ) + + if items is not None: + raise NotImplementedError("using an items filter isn't supported with ANN lookup") + + # support recalculate_item if possible. TODO: refactor this + if hasattr(self.model, "_item_factor"): + factors = self.model._item_factor( + itemid, react_users, recalculate_item + ) # pylint: disable=protected-access + elif recalculate_item: + raise NotImplementedError(f"recalculate_item isn't supported with {self.model}") + else: + factors = self.model.item_factors[itemid] + if implicit.gpu.HAS_CUDA and isinstance(factors, implicit.gpu.Matrix): + factors = factors.to_numpy() + + count = N + if filter_items is not None: + count += len(filter_items) + + if np.isscalar(itemid): + ids, scores = self.similar_items_index.knnQuery(factors, count) + else: + results = self.similar_items_index.knnQueryBatch(factors, count) + ids = np.stack([result[0] for result in results]) + scores = np.stack([result[1] for result in results]) + + scores = 1.0 - scores + if filter_items is not None: + ids, scores = _filter_items_from_results(itemid, ids, scores, filter_items, N) + + return ids, scores + + def recommend( + self, + userid, + user_items, + N=10, + filter_already_liked_items=True, + filter_items=None, + recalculate_user=False, + items=None, + ): + if items is not None and self.approximate_recommend: + raise NotImplementedError("using a 'items' list with ANN search isn't supported") + + if not self.approximate_recommend: + return self.model.recommend( + userid, + user_items, + N=N, + filter_already_liked_items=filter_already_liked_items, + filter_items=filter_items, + recalculate_user=recalculate_user, + items=items, + ) + + # batch computation is hard here, fallback to looping over items + if not np.isscalar(userid): + return _batch_call( + self.recommend, + userid, + user_items=user_items, + N=N, + filter_already_liked_items=filter_already_liked_items, + filter_items=filter_items, + recalculate_user=recalculate_user, + items=items, + ) + + # support recalculate_user if possible (TODO: come back to this since its a bit of a hack) + if hasattr(self.model, "+_user_factor"): + user = self.model._user_factor( + userid, user_items, recalculate_user + ) # pylint: disable=protected-access + elif recalculate_user: + raise NotImplementedError(f"recalculate_user isn't supported with {self.model}") + else: + user = self.model.user_factors[userid] + if implicit.gpu.HAS_CUDA and isinstance(user, implicit.gpu.Matrix): + user = user.to_numpy() + + # calculate the top N items, removing the users own liked items from + # the results + count = N + if filter_items: + count += len(filter_items) + filter_items = np.array(filter_items) + + if filter_already_liked_items: + user_likes = user_items[userid].indices + filter_items = ( + np.append(filter_items, user_likes) if filter_items is not None else user_likes + ) + count += len(user_likes) + + query = np.append(user, 0) + ids, scores = self.recommend_index.knnQuery(query, count) + scaling = self.max_norm * np.linalg.norm(query) + scores = scaling * (1.0 - (scores)) + + if filter_items is not None: + ids, scores = _filter_items_from_results(userid, ids, scores, filter_items, N) + + return ids, scores + + def similar_users(self, userid, N=10, filter_users=None, users=None): + raise NotImplementedError( + "similar_users isn't implemented with NMSLib yet. (note: you can call " + " self.model.similar_models to get the same functionality on the inner model class)" + ) diff --git a/implicit/approximate_als.py b/implicit/approximate_als.py index f185296d..f3e3a04f 100644 --- a/implicit/approximate_als.py +++ b/implicit/approximate_als.py @@ -3,676 +3,74 @@ See http://www.benfrederickson.com/approximate-nearest-neighbours-for-recommender-systems/ """ -import logging - -import numpy as np - import implicit.gpu -from implicit.cpu.als import AlternatingLeastSquares - -from .utils import _batch_call - -log = logging.getLogger("implicit") - - -def augment_inner_product_matrix(factors): - """This function transforms a factor matrix such that an angular nearest neighbours search - will return top related items of the inner product. - - This involves transforming each row by adding one extra dimension as suggested in the paper: - "Speeding Up the Xbox Recommender System Using a Euclidean Transformation for Inner-Product - Spaces" https://www.microsoft.com/en-us/research/wp-content/uploads/2016/02/XboxInnerProduct.pdf - - Basically this involves transforming each feature vector so that they have the same norm, which - means the cosine of this transformed vector is proportional to the dot product (if the other - vector in the cosine has a 0 in the extra dimension).""" - norms = np.linalg.norm(factors, axis=1) - max_norm = norms.max() - - # add an extra dimension so that the norm of each row is the same - # (max_norm) - extra_dimension = np.sqrt(max_norm ** 2 - norms ** 2) - return max_norm, np.append(factors, extra_dimension.reshape(norms.shape[0], 1), axis=1) - - -class NMSLibAlternatingLeastSquares(AlternatingLeastSquares): - - """Speeds up the base :class:`~implicit.als.AlternatingLeastSquares` model by using - `NMSLib `_ to create approximate nearest neighbours - indices of the latent factors. - - Parameters - ---------- - method : str, optional - The NMSLib method to use - index_params: dict, optional - Optional params to send to the createIndex call in NMSLib - query_params: dict, optional - Optional query time params for the NMSLib 'setQueryTimeParams' call - approximate_similar_items : bool, optional - whether or not to build an NMSLIB index for computing similar_items - approximate_recommend : bool, optional - whether or not to build an NMSLIB index for the recommend call - random_state : int, RandomState or None, optional - The random state for seeding the initial item and user factors. - Default is None. - - Attributes - ---------- - similar_items_index : nmslib.FloatIndex - NMSLib index for looking up similar items in the cosine space formed by the latent - item_factors - - recommend_index : nmslib.FloatIndex - NMSLib index for looking up similar items in the inner product space formed by the latent - item_factors - """ - - def __init__( - self, - *args, - approximate_similar_items=True, - approximate_recommend=True, - method="hnsw", - index_params=None, - query_params=None, - random_state=None, - **kwargs - ): - if index_params is None: - index_params = {"M": 16, "post": 0, "efConstruction": 400} - if query_params is None: - query_params = {"ef": 90} - - self.similar_items_index = None - self.recommend_index = None - - self.approximate_similar_items = approximate_similar_items - self.approximate_recommend = approximate_recommend - self.method = method - - self.index_params = index_params - self.query_params = query_params - - self.max_norm = None - - super().__init__(*args, random_state=random_state, **kwargs) - - def fit(self, Cui, show_progress=True): - # nmslib can be a little chatty when first imported, disable some of - # the logging - logging.getLogger("nmslib").setLevel(logging.WARNING) - import nmslib - - # train the model - super().fit(Cui, show_progress) - - # create index for similar_items - if self.approximate_similar_items: - log.debug("Building nmslib similar items index") - self.similar_items_index = nmslib.init(method=self.method, space="cosinesimil") - - # there are some numerical instability issues here with - # building a cosine index with vectors with 0 norms, hack around this - # by just not indexing them - norms = np.linalg.norm(self.item_factors, axis=1) - ids = np.arange(self.item_factors.shape[0]) - - # delete zero valued rows from the matrix - item_factors = np.delete(self.item_factors, ids[norms == 0], axis=0) - ids = ids[norms != 0] - - self.similar_items_index.addDataPointBatch(item_factors, ids=ids) - self.similar_items_index.createIndex(self.index_params, print_progress=show_progress) - self.similar_items_index.setQueryTimeParams(self.query_params) - - # build up a separate index for the inner product (for recommend - # methods) - if self.approximate_recommend: - log.debug("Building nmslib recommendation index") - self.max_norm, extra = augment_inner_product_matrix(self.item_factors) - self.recommend_index = nmslib.init(method="hnsw", space="cosinesimil") - self.recommend_index.addDataPointBatch(extra) - self.recommend_index.createIndex(self.index_params, print_progress=show_progress) - self.recommend_index.setQueryTimeParams(self.query_params) - - def similar_items( - self, itemid, N=10, react_users=None, recalculate_item=False, filter_items=None, items=None - ): - if not self.approximate_similar_items: - return super().similar_items( - itemid, - N, - react_users=react_users, - recalculate_item=recalculate_item, - filter_items=filter_items, - items=items, - ) - - if items is not None: - raise NotImplementedError("using an items filter isn't supported with ANN lookup") - - factors = self._item_factor(itemid, react_users, recalculate_item) - count = N - if filter_items is not None: - count += len(filter_items) - - if np.isscalar(itemid): - ids, scores = self.similar_items_index.knnQuery(factors, count) - else: - results = self.similar_items_index.knnQueryBatch(factors, count) - ids = np.stack([result[0] for result in results]) - scores = np.stack([result[1] for result in results]) - - scores = 1.0 - scores - if filter_items is not None: - ids, scores = _filter_items_from_results(itemid, ids, scores, filter_items, N) - - return ids, scores - - def recommend( - self, - userid, - user_items, - N=10, - filter_already_liked_items=True, - filter_items=None, - recalculate_user=False, - items=None, - ): - if items and self.approximate_recommend: - raise NotImplementedError("using a 'items' list with ANN search isn't supported") - - if not self.approximate_recommend: - return super().recommend( - userid, - user_items, - N=N, - filter_already_liked_items=filter_already_liked_items, - filter_items=filter_items, - recalculate_user=recalculate_user, - items=items, - ) - - # batch computation is hard here, fallback to looping over items - if not np.isscalar(userid): - return _batch_call( - self.recommend, - userid, - user_items=user_items, - N=N, - filter_already_liked_items=filter_already_liked_items, - filter_items=filter_items, - recalculate_user=recalculate_user, - items=items, - ) - - user = self._user_factor(userid, user_items, recalculate_user) - - # calculate the top N items, removing the users own liked items from - # the results - count = N - if filter_items: - count += len(filter_items) - filter_items = np.array(filter_items) - - if filter_already_liked_items: - user_likes = user_items[userid].indices - filter_items = ( - np.append(filter_items, user_likes) if filter_items is not None else user_likes - ) - count += len(user_likes) - - query = np.append(user, 0) - ids, scores = self.recommend_index.knnQuery(query, count) - scaling = self.max_norm * np.linalg.norm(query) - scores = scaling * (1.0 - (scores)) - - if filter_items is not None: - ids, scores = _filter_items_from_results(userid, ids, scores, filter_items, N) - - return ids, scores - - -class AnnoyAlternatingLeastSquares(AlternatingLeastSquares): - - """A version of the :class:`~implicit.als.AlternatingLeastSquares` model that uses an - `Annoy `_ index to calculate similar items and - recommend items. - - Parameters - ---------- - n_trees : int, optional - The number of trees to use when building the Annoy index. More trees gives higher precision - when querying. - search_k : int, optional - Provides a way to search more trees at runtime, giving the ability to have more accurate - results at the cost of taking more time. - approximate_similar_items : bool, optional - whether or not to build an Annoy index for computing similar_items - approximate_recommend : bool, optional - whether or not to build an Annoy index for the recommend call - random_state : int, RandomState or None, optional - The random state for seeding the initial item and user factors. - Default is None. - - Attributes - ---------- - similar_items_index : annoy.AnnoyIndex - Annoy index for looking up similar items in the cosine space formed by the latent - item_factors - - recommend_index : annoy.AnnoyIndex - Annoy index for looking up similar items in the inner product space formed by the latent - item_factors - """ - - def __init__( - self, - *args, - approximate_similar_items=True, - approximate_recommend=True, - n_trees=50, - search_k=-1, - random_state=None, - **kwargs - ): - - super().__init__(*args, random_state=random_state, **kwargs) - - self.similar_items_index = None - self.recommend_index = None - self.max_norm = None - - self.approximate_similar_items = approximate_similar_items - self.approximate_recommend = approximate_recommend - - self.n_trees = n_trees - self.search_k = search_k - - def fit(self, Cui, show_progress=True): - # delay loading the annoy library in case its not installed here - import annoy - - # train the model - super().fit(Cui, show_progress) - - # build up an Annoy Index with all the item_factors (for calculating - # similar items) - if self.approximate_similar_items: - log.debug("Building annoy similar items index") - - self.similar_items_index = annoy.AnnoyIndex(self.item_factors.shape[1], "angular") - for i, row in enumerate(self.item_factors): - self.similar_items_index.add_item(i, row) - self.similar_items_index.build(self.n_trees) - - # build up a separate index for the inner product (for recommend - # methods) - if self.approximate_recommend: - log.debug("Building annoy recommendation index") - self.max_norm, extra = augment_inner_product_matrix(self.item_factors) - self.recommend_index = annoy.AnnoyIndex(extra.shape[1], "angular") - for i, row in enumerate(extra): - self.recommend_index.add_item(i, row) - self.recommend_index.build(self.n_trees) - - def similar_items( - self, itemid, N=10, react_users=None, recalculate_item=False, filter_items=None, items=None - ): - if items is not None and self.approximate_similar_items: - raise NotImplementedError("using an items filter isn't supported with ANN lookup") - - count = N - if filter_items is not None: - count += len(filter_items) - - if not self.approximate_similar_items: - return super().similar_items( - itemid, - N, - react_users=react_users, - recalculate_item=recalculate_item, - filter_items=filter_items, - items=items, - ) - - # annoy doesn't have a batch mode we can use - if not np.isscalar(itemid): - return _batch_call( - self.similar_items, - itemid, - N=N, - react_users=react_users, - recalculate_item=recalculate_item, - filter_items=filter_items, - ) - - factor = self._item_factor(itemid, react_users, recalculate_item) - - ids, scores = self.similar_items_index.get_nns_by_vector( - factor, N, search_k=self.search_k, include_distances=True - ) - ids, scores = np.array(ids), np.array(scores) - - if filter_items is not None: - ids, scores = _filter_items_from_results(itemid, ids, scores, filter_items, N) - - return ids, 1 - (scores ** 2) / 2 - - def recommend( - self, - userid, - user_items, - N=10, - filter_already_liked_items=True, - filter_items=None, - recalculate_user=False, - items=None, - ): - if items and self.approximate_recommend: - raise NotImplementedError("using a 'items' list with ANN search isn't supported") - - if not self.approximate_recommend: - return super().recommend( - userid, - user_items, - N=N, - filter_already_liked_items=filter_already_liked_items, - filter_items=filter_items, - recalculate_user=recalculate_user, - items=items, - ) - - # batch computation isn't supported by annoy, fallback to looping over items - if not np.isscalar(userid): - return _batch_call( - self.recommend, - userid, - user_items=user_items, - N=N, - filter_already_liked_items=filter_already_liked_items, - filter_items=filter_items, - recalculate_user=recalculate_user, - items=items, - ) - user = self._user_factor(userid, user_items, recalculate_user) - - # calculate the top N items, removing the users own liked items from - # the results - count = N - if filter_items: - count += len(filter_items) - filter_items = np.array(filter_items) - - if filter_already_liked_items: - user_likes = user_items[userid].indices - filter_items = ( - np.append(filter_items, user_likes) if filter_items is not None else user_likes - ) - count += len(user_likes) - - query = np.append(user, 0) - ids, scores = self.recommend_index.get_nns_by_vector( - query, count, include_distances=True, search_k=self.search_k - ) - ids, scores = np.array(ids), np.array(scores) - - if filter_items is not None: - ids, scores = _filter_items_from_results(userid, ids, scores, filter_items, N) - - # convert the distances from euclidean to cosine distance, - # and then rescale the cosine distance to go back to inner product - scaling = self.max_norm * np.linalg.norm(query) - scores = scaling * (1 - (scores ** 2) / 2) - return ids, scores - - -class FaissAlternatingLeastSquares(AlternatingLeastSquares): - - """Speeds up the base :class:`~implicit.als.AlternatingLeastSquares` model by using - `Faiss `_ to create approximate nearest neighbours - indices of the latent factors. - - - Parameters - ---------- - nlist : int, optional - The number of cells to use when building the Faiss index. - nprobe : int, optional - The number of cells to visit to perform a search. - use_gpu : bool, optional - Whether or not to enable run Faiss on the GPU. Requires faiss to have been - built with GPU support. - approximate_similar_items : bool, optional - whether or not to build an Faiss index for computing similar_items - approximate_recommend : bool, optional - whether or not to build an Faiss index for the recommend call - random_state : int, RandomState or None, optional - The random state for seeding the initial item and user factors. - Default is None. - - Attributes - ---------- - similar_items_index : faiss.IndexIVFFlat - Faiss index for looking up similar items in the cosine space formed by the latent - item_factors - - recommend_index : faiss.IndexIVFFlat - Faiss index for looking up similar items in the inner product space formed by the latent - item_factors - """ - - def __init__( - self, - *args, - approximate_similar_items=True, - approximate_recommend=True, - nlist=400, - nprobe=20, - use_gpu=implicit.gpu.HAS_CUDA, - random_state=None, - **kwargs - ): - - self.similar_items_index = None - self.recommend_index = None - self.quantizer = None - self.gpu_resources = None - - self.approximate_similar_items = approximate_similar_items - self.approximate_recommend = approximate_recommend - - # hyper-parameters for FAISS - self.nlist = nlist - self.nprobe = nprobe - self.use_gpu = use_gpu - super().__init__(*args, random_state=random_state, **kwargs) - - def fit(self, Cui, show_progress=True): - import faiss - - # train the model - super().fit(Cui, show_progress) - - self.quantizer = faiss.IndexFlat(self.factors) - - if self.use_gpu: - self.gpu_resources = faiss.StandardGpuResources() - - item_factors = self.item_factors.astype("float32") - - if self.approximate_recommend: - log.debug("Building faiss recommendation index") - - # build up a inner product index here - if self.use_gpu: - index = faiss.GpuIndexIVFFlat( - self.gpu_resources, self.factors, self.nlist, faiss.METRIC_INNER_PRODUCT - ) - else: - index = faiss.IndexIVFFlat( - self.quantizer, self.factors, self.nlist, faiss.METRIC_INNER_PRODUCT - ) - - index.train(item_factors) - index.add(item_factors) - index.nprobe = self.nprobe - self.recommend_index = index - - if self.approximate_similar_items: - log.debug("Building faiss similar items index") - - # likewise build up cosine index for similar_items, using an inner product - # index on normalized vectors` - norms = np.linalg.norm(item_factors, axis=1) - norms[norms == 0] = 1e-10 - - normalized = (item_factors.T / norms).T.astype("float32") - if self.use_gpu: - index = faiss.GpuIndexIVFFlat( - self.gpu_resources, self.factors, self.nlist, faiss.METRIC_INNER_PRODUCT - ) - else: - index = faiss.IndexIVFFlat( - self.quantizer, self.factors, self.nlist, faiss.METRIC_INNER_PRODUCT - ) - - index.train(normalized) - index.add(normalized) - index.nprobe = self.nprobe - self.similar_items_index = index - - def similar_items( - self, itemid, N=10, react_users=None, recalculate_item=False, filter_items=None, items=None - ): - if items is not None and self.approximate_similar_items: - raise NotImplementedError("using an items filter isn't supported with ANN lookup") - - count = N - if filter_items is not None: - count += len(filter_items) - - if not self.approximate_similar_items or (self.use_gpu and count >= 1024): - return super().similar_items( - itemid, - N, - react_users=react_users, - recalculate_item=recalculate_item, - filter_items=filter_items, - items=items, - ) - - factors = self._item_factor(itemid, react_users, recalculate_item) - - if np.isscalar(itemid): - factors /= np.linalg.norm(factors) - factors = factors.reshape(1, -1) - else: - factors /= np.linalg.norm(factors, axis=1)[:, None] - - scores, ids = self.similar_items_index.search(factors.astype("float32"), count) - - if np.isscalar(itemid): - ids, scores = ids[0], scores[0] - - if filter_items is not None: - ids, scores = _filter_items_from_results(itemid, ids, scores, filter_items, N) - - return ids, scores - - def recommend( - self, - userid, - user_items, - N=10, - filter_already_liked_items=True, - filter_items=None, - recalculate_user=False, - items=None, - ): - if items and self.approximate_recommend: - raise NotImplementedError("using a 'items' list with ANN search isn't supported") - - if not self.approximate_recommend: - return super().recommend( - userid, - user_items, - N=N, - filter_already_liked_items=filter_already_liked_items, - filter_items=filter_items, - recalculate_user=recalculate_user, - items=items, - ) - - # batch computation is tricky with filter_already_liked_items (requires querying a - # different number of rows per user). Instead just fallback to a faiss query per user - if filter_already_liked_items and not np.isscalar(userid): - return _batch_call( - self.recommend, - userid, - user_items=user_items, - N=N, - filter_already_liked_items=filter_already_liked_items, - filter_items=filter_items, - recalculate_user=recalculate_user, - items=items, - ) - - user = self._user_factor(userid, user_items, recalculate_user) - - # calculate the top N items, removing the users own liked items from - # the results - count = N - if filter_items: - count += len(filter_items) - filter_items = np.array(filter_items) - - if filter_already_liked_items: - user_likes = user_items[userid].indices - filter_items = ( - np.append(filter_items, user_likes) if filter_items is not None else user_likes - ) - count += len(user_likes) - - # the GPU variant of faiss doesn't support returning more than 1024 results. - # fall back to the exact match when this happens - if self.use_gpu and count >= 1024: - return super().recommend( - userid, - user_items, - N=N, - filter_items=filter_items, - recalculate_user=recalculate_user, - ) - - if np.isscalar(userid): - query = user.reshape(1, -1).astype("float32") - else: - query = user.astype("float32") - - scores, ids = self.recommend_index.search(query, count) - - if np.isscalar(userid): - ids, scores = ids[0], scores[0] - - if filter_items is not None: - ids, scores = _filter_items_from_results(userid, ids, scores, filter_items, N) - - return ids, scores -def _filter_items_from_results(queryid, ids, scores, filter_items, N): - if np.isscalar(queryid): - mask = np.in1d(ids, filter_items, invert=True) - ids, scores = ids[mask][:N], scores[mask][:N] - else: - rows = len(queryid) - filtered_scores = np.zeros((rows, N), dtype=scores.dtype) - filtered_ids = np.zeros((rows, N), dtype=ids.dtype) - for row in range(rows): - mask = np.in1d(ids[row], filter_items, invert=True) - filtered_ids[row] = ids[row][mask][:N] - filtered_scores[row] = scores[row][mask][:N] - ids, scores = filtered_ids, filtered_scores - return ids, scores +def NMSLibAlternatingLeastSquares( + *args, + approximate_similar_items=True, + approximate_recommend=True, + method="hnsw", + index_params=None, + query_params=None, + use_gpu=implicit.gpu.HAS_CUDA, + **kwargs +): + # delay importing here in case nmslib isn't installed + from implicit.ann.nmslib import NMSLibModel + + # note that we're using the factory function here to instantiate a CPU/GPU model as appropriate + als_model = implicit.als.AlternatingLeastSquares(*args, use_gpu=use_gpu, **kwargs) + return NMSLibModel( + als_model, + approximate_similar_items=approximate_similar_items, + approximate_recommend=approximate_recommend, + method=method, + index_params=index_params, + query_params=query_params, + ) + + +def AnnoyAlternatingLeastSquares( + *args, + approximate_similar_items=True, + approximate_recommend=True, + n_trees=50, + search_k=-1, + use_gpu=implicit.gpu.HAS_CUDA, + **kwargs +): + als_model = implicit.als.AlternatingLeastSquares(*args, use_gpu=use_gpu, **kwargs) + from implicit.ann.annoy import AnnoyModel + + return AnnoyModel( + als_model, + approximate_similar_items=approximate_similar_items, + approximate_recommend=approximate_recommend, + n_trees=n_trees, + search_k=search_k, + ) + + +def FaissAlternatingLeastSquares( + *args, + approximate_similar_items=True, + approximate_recommend=True, + nlist=400, + nprobe=20, + use_gpu=implicit.gpu.HAS_CUDA, + **kwargs +): + # note that we're using the factory function here to instantiate a CPU/GPU model as appropriate + als_model = implicit.als.AlternatingLeastSquares(*args, use_gpu=use_gpu, **kwargs) + + from implicit.ann.faiss import FaissModel + + return FaissModel( + als_model, + approximate_similar_items=approximate_similar_items, + approximate_recommend=approximate_recommend, + nlist=nlist, + nprobe=nprobe, + use_gpu=use_gpu, + ) diff --git a/implicit/utils.py b/implicit/utils.py index a28c6d96..aae68196 100644 --- a/implicit/utils.py +++ b/implicit/utils.py @@ -55,6 +55,26 @@ def check_random_state(random_state): return np.random.RandomState(random_state) +def augment_inner_product_matrix(factors): + """This function transforms a factor matrix such that an angular nearest neighbours search + will return top related items of the inner product. + + This involves transforming each row by adding one extra dimension as suggested in the paper: + "Speeding Up the Xbox Recommender System Using a Euclidean Transformation for Inner-Product + Spaces" https://www.microsoft.com/en-us/research/wp-content/uploads/2016/02/XboxInnerProduct.pdf + + Basically this involves transforming each feature vector so that they have the same norm, which + means the cosine of this transformed vector is proportional to the dot product (if the other + vector in the cosine has a 0 in the extra dimension).""" + norms = np.linalg.norm(factors, axis=1) + max_norm = norms.max() + + # add an extra dimension so that the norm of each row is the same + # (max_norm) + extra_dimension = np.sqrt(max_norm ** 2 - norms ** 2) + return max_norm, np.append(factors, extra_dimension.reshape(norms.shape[0], 1), axis=1) + + def _batch_call(func, ids, *args, N=10, **kwargs): # we're running in batch mode, just loop over each item and call the scalar version of the # function @@ -76,3 +96,19 @@ def _batch_call(func, ids, *args, N=10, **kwargs): output_scores[i] = batch_scores[:N] return output_ids, output_scores + + +def _filter_items_from_results(queryid, ids, scores, filter_items, N): + if np.isscalar(queryid): + mask = np.in1d(ids, filter_items, invert=True) + ids, scores = ids[mask][:N], scores[mask][:N] + else: + rows = len(queryid) + filtered_scores = np.zeros((rows, N), dtype=scores.dtype) + filtered_ids = np.zeros((rows, N), dtype=ids.dtype) + for row in range(rows): + mask = np.in1d(ids[row], filter_items, invert=True) + filtered_ids[row] = ids[row][mask][:N] + filtered_scores[row] = scores[row][mask][:N] + ids, scores = filtered_ids, filtered_scores + return ids, scores diff --git a/setup.py b/setup.py index c1346d3a..ac471beb 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ import sys try: - import numpy.distutils + import numpy.distutils # noqa except ImportError: pass diff --git a/tests/approximate_als_test.py b/tests/approximate_als_test.py index 6d18be88..211b9e9f 100644 --- a/tests/approximate_als_test.py +++ b/tests/approximate_als_test.py @@ -17,17 +17,25 @@ class AnnoyALSTest(unittest.TestCase, RecommenderBaseTestMixin): def _get_model(self): - return AnnoyAlternatingLeastSquares(factors=32, regularization=0, random_state=23) + return AnnoyAlternatingLeastSquares( + factors=32, regularization=0, random_state=23, use_gpu=False + ) def test_pickle(self): # pickle isn't supported on annoy indices pass - def test_rank_items(self): - pass + if HAS_CUDA: - def test_rank_items_batch(self): - pass + class AnnoyALSGPUTest(unittest.TestCase, RecommenderBaseTestMixin): + def _get_model(self): + return AnnoyAlternatingLeastSquares( + factors=32, regularization=0, random_state=23, use_gpu=True + ) + + def test_pickle(self): + # pickle isn't supported on annoy indices + pass except ImportError: pass @@ -38,18 +46,33 @@ def test_rank_items_batch(self): class NMSLibALSTest(unittest.TestCase, RecommenderBaseTestMixin): def _get_model(self): return NMSLibAlternatingLeastSquares( - factors=32, regularization=0, index_params={"post": 2}, random_state=23 + factors=32, + regularization=0, + index_params={"post": 2}, + random_state=23, + use_gpu=False, ) def test_pickle(self): # pickle isn't supported on nmslib indices pass - def test_rank_items(self): - pass + if HAS_CUDA: + # nmslib doesn't support querying on the gpu, but we should be able to still use a GPU als + # model with the nmslib index + class NMSLibALSGPUTest(unittest.TestCase, RecommenderBaseTestMixin): + def _get_model(self): + return NMSLibAlternatingLeastSquares( + factors=32, + regularization=0, + index_params={"post": 2}, + random_state=23, + use_gpu=True, + ) - def test_rank_items_batch(self): - pass + def test_pickle(self): + # pickle isn't supported on nmslib indices + pass except ImportError: pass @@ -67,12 +90,6 @@ def test_pickle(self): # pickle isn't supported on faiss indices pass - def test_rank_items(self): - pass - - def test_rank_items_batch(self): - pass - if HAS_CUDA: class FaissALSGPUTest(unittest.TestCase, RecommenderBaseTestMixin): @@ -116,12 +133,6 @@ def test_pickle(self): # pickle isn't supported on faiss indices pass - def test_rank_items(self): - pass - - def test_rank_items_batch(self): - pass - except ImportError: pass diff --git a/tests/recommender_base_test.py b/tests/recommender_base_test.py index 0000f2f8..2eedf7c7 100644 --- a/tests/recommender_base_test.py +++ b/tests/recommender_base_test.py @@ -151,23 +151,26 @@ def test_evaluation(self): def test_similar_users(self): model = self._get_model() - # calculating similar users in nearest-neighbours is not implemented yet - if isinstance(model, ItemItemRecommender): - return model.fit(get_checker_board(50), show_progress=False) - for userid in range(50): - ids, _ = model.similar_users(userid, N=10) - for r in ids: - self.assertEqual(r % 2, userid % 2) + + try: + for userid in range(50): + ids, _ = model.similar_users(userid, N=10) + for r in ids: + self.assertEqual(r % 2, userid % 2) + except NotImplementedError: + pass def test_similar_users_batch(self): model = self._get_model() - # calculating similar users in nearest-neighbours is not implemented yet - if isinstance(model, ItemItemRecommender): - return model.fit(get_checker_board(256), show_progress=False) userids = np.arange(50) - ids, scores = model.similar_users(userids, N=10) + + try: + ids, scores = model.similar_users(userids, N=10) + except NotImplementedError: + # similar users isn't implemented for many models (ItemItemRecommeder/ ANN models) + return self.assertEqual(ids.shape, (50, 10)) @@ -189,7 +192,11 @@ def test_similar_users_filter(self): model.fit(get_checker_board(256), show_progress=False) userids = np.arange(50) - ids, _ = model.similar_users(userids, N=10, filter_users=np.arange(52) * 5) + try: + ids, _ = model.similar_users(userids, N=10, filter_users=np.arange(52) * 5) + except NotImplementedError: + return + for userid in userids: for r in ids[userid]: self.assertTrue(r % 5 != 0) @@ -292,9 +299,12 @@ def test_rank_items(self): for userid in range(50): selected_items = random.sample(range(50), 10) - ids, _ = model.recommend( - userid, user_items, items=selected_items, filter_already_liked_items=False - ) + try: + ids, _ = model.recommend( + userid, user_items, items=selected_items, filter_already_liked_items=False + ) + except NotImplementedError: + return # ranked list should have same items self.assertEqual(set(ids), set(selected_items)) @@ -318,7 +328,10 @@ def test_rank_items_batch(self): model.fit(item_users, show_progress=False) selected_items = np.arange(10) * 3 - ids, _ = model.recommend(np.arange(50), user_items, items=selected_items) + try: + ids, _ = model.recommend(np.arange(50), user_items, items=selected_items) + except NotImplementedError: + return for userid in range(50): current_ids = ids[userid] From d22cc1e9f091b39f8f9cc85e478a448cc689dace Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Mon, 3 Jan 2022 12:56:57 -0800 Subject: [PATCH 8/9] Update documentation to reflect api changes --- README.md | 12 +++++------- docs/ann.rst | 18 +++++++++--------- 2 files changed, 14 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index 172c7dc1..2b79b864 100644 --- a/README.md +++ b/README.md @@ -53,12 +53,11 @@ import implicit # initialize a model model = implicit.als.AlternatingLeastSquares(factors=50) -# train the model on a sparse matrix of item/user/confidence weights -model.fit(item_user_data) +# train the model on a sparse matrix of user/item/confidence weights +model.fit(user_item_data) # recommend items for a user -user_items = item_user_data.T.tocsr() -recommendations = model.recommend(userid, user_items) +recommendations = model.recommend(userid, user_item_data) # find related items related = model.similar_items(itemid) @@ -88,9 +87,8 @@ There are also several other blog posts about using Implicit to build recommenda #### Requirements -This library requires SciPy version 0.16 or later. Running on OSX requires an OpenMP compiler, -which can be installed with homebrew: ```brew install gcc```. Running on Windows requires Python -3.5+. +This library requires SciPy version 0.16 or later and Python version 3.6 or later. +Running on OSX requires an OpenMP compiler, which can be installed with homebrew: ```brew install gcc```. GPU Support requires at least version 11 of the [NVidia CUDA Toolkit](https://developer.nvidia.com/cuda-downloads). The build will use the ```nvcc``` compiler that is found on the path, but this can be overridden by setting the CUDAHOME environment variable diff --git a/docs/ann.rst b/docs/ann.rst index f751c7bb..bc5c4705 100644 --- a/docs/ann.rst +++ b/docs/ann.rst @@ -13,20 +13,20 @@ See `this post comparing the different ANN libraries `_ for more details. -NMSLibAlternatingLeastSquares ------------------------------ -.. autoclass:: implicit.approximate_als.NMSLibAlternatingLeastSquares +NMSLibModel +----------- +.. autoclass:: implicit.ann.nmslib.NMSLibModel :members: :show-inheritance: -AnnoyAlternatingLeastSquares ----------------------------- -.. autoclass:: implicit.approximate_als.AnnoyAlternatingLeastSquares +AnnoyModel +---------- +.. autoclass:: implicit.ann.annoy.AnnoyModel :members: :show-inheritance: -FaissAlternatingLeastSquares ------------------------------ -.. autoclass:: implicit.approximate_als.FaissAlternatingLeastSquares +FaissModel +---------- +.. autoclass:: implicit.ann.faiss.FaissModel :members: :show-inheritance: From 2ca09275aa4087bc5b300611fde93cf2311e3393 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Mon, 3 Jan 2022 13:06:22 -0800 Subject: [PATCH 9/9] Fix pylint error --- implicit/lmf.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/implicit/lmf.py b/implicit/lmf.py index a0b863fb..25e7c784 100644 --- a/implicit/lmf.py +++ b/implicit/lmf.py @@ -55,14 +55,13 @@ def LogisticMatrixFactorization( """ if use_gpu: raise NotImplementedError - else: - return implicit.cpu.lmf.LogisticMatrixFactorization( - factors, - learning_rate, - regularization, - dtype=dtype, - iterations=iterations, - neg_prop=neg_prop, - num_threads=num_threads, - random_state=random_state, - ) + return implicit.cpu.lmf.LogisticMatrixFactorization( + factors, + learning_rate, + regularization, + dtype=dtype, + iterations=iterations, + neg_prop=neg_prop, + num_threads=num_threads, + random_state=random_state, + )