Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Extend 'sample_listwise' to support more movielens Dataset features #735

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 15 additions & 22 deletions docs/examples/listwise_ranking.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -39,20 +39,20 @@
"source": [
"# Listwise ranking\n",
"\n",
"\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
" \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/recommenders/examples/listwise_ranking\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/recommenders/blob/main/docs/examples/listwise_ranking.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/recommenders/blob/main/docs/examples/listwise_ranking.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/recommenders/docs/examples/listwise_ranking.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n",
" \u003c/td\u003e\n",
"\u003c/table\u003e"
"<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
" <td>\n",
" <a target=\"_blank\" href=\"https://www.tensorflow.org/recommenders/examples/listwise_ranking\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
" </td>\n",
" <td>\n",
" <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/recommenders/blob/main/docs/examples/listwise_ranking.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
" </td>\n",
" <td>\n",
" <a target=\"_blank\" href=\"https://github.com/tensorflow/recommenders/blob/main/docs/examples/listwise_ranking.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
" </td>\n",
" <td>\n",
" <a href=\"https://storage.googleapis.com/tensorflow_docs/recommenders/docs/examples/listwise_ranking.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
" </td>\n",
"</table>"
]
},
{
Expand Down Expand Up @@ -309,18 +309,11 @@
" # Movie embeddings are a [batch_size, num_movies_in_list, embedding_dim]\n",
" # tensor.\n",
" movie_embeddings = self.movie_embeddings(features[\"movie_title\"])\n",
" \n",
" # We want to concatenate user embeddings with movie emebeddings to pass\n",
" # them into the ranking model. To do so, we need to reshape the user\n",
" # embeddings to match the shape of movie embeddings.\n",
" list_length = features[\"movie_title\"].shape[1]\n",
" user_embedding_repeated = tf.repeat(\n",
" tf.expand_dims(user_embeddings, 1), [list_length], axis=1)\n",
"\n",
" # Once reshaped, we concatenate and pass into the dense layers to generate\n",
" # predictions.\n",
" concatenated_embeddings = tf.concat(\n",
" [user_embedding_repeated, movie_embeddings], 2)\n",
" [user_embeddings, movie_embeddings], 2)\n",
" \n",
" return self.score_model(concatenated_embeddings)\n",
"\n",
Expand Down
74 changes: 34 additions & 40 deletions tensorflow_recommenders/examples/movielens.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,9 @@ def evaluate(user_model: tf.keras.Model,
}


def _create_feature_dict() -> Dict[Text, List[tf.Tensor]]:
def _create_feature_dict(features: List[Text]) -> Dict[Text, List[tf.Tensor]]:
"""Helper function for creating an empty feature dict for defaultdict."""
return {"movie_title": [], "user_rating": []}
return {key: [] for key in features}


def _sample_list(
Expand All @@ -108,22 +108,20 @@ def _sample_list(
random_state = np.random.RandomState()

sampled_indices = random_state.choice(
range(len(feature_lists["movie_title"])),
range(len(feature_lists["user_rating"])),
size=num_examples_per_list,
replace=False,
)
sampled_movie_titles = [
feature_lists["movie_title"][idx] for idx in sampled_indices
]
sampled_ratings = [
feature_lists["user_rating"][idx]
for idx in sampled_indices
]

return (
tf.stack(sampled_movie_titles, 0),
tf.stack(sampled_ratings, 0),
)
sampled_features = {}
for name, values in feature_lists.items():
sampled_features[name] = [
values[idx] for idx in sampled_indices
]

return {
name: tf.stack(values, 0)
for name, values in sampled_features.items()
}


def sample_listwise(
Expand All @@ -136,8 +134,8 @@ def sample_listwise(

Args:
rating_dataset:
The MovieLens ratings dataset loaded from TFDS with features
"movie_title", "user_id", and "user_rating".
The MovieLens ratings dataset loaded from TFDS. Feature must be provided
in the dataset. The dataset must contain the "user_rating" feature.
num_list_per_user:
An integer representing the number of lists that should be sampled for
each user in the training dataset.
Expand All @@ -150,28 +148,24 @@ def sample_listwise(
Returns:
A tf.data.Dataset containing list examples.

Each example contains three keys: "user_id", "movie_title", and
"user_rating". "user_id" maps to a string tensor that represents the user
id for the example. "movie_title" maps to a tensor of shape
[sum(num_example_per_list)] with dtype tf.string. It represents the list
of candidate movie ids. "user_rating" maps to a tensor of shape
[sum(num_example_per_list)] with dtype tf.float32. It represents the
rating of each movie in the candidate list.
Each example contains multiple keys. "user_id" maps to a string
tensor that represents the user id for the example. "movie_title" maps
to a tensor of shape [sum(num_example_per_list)] with dtype tf.string.
It represents the list of candidate movie ids. "user_rating" maps to
a tensor of shape [sum(num_example_per_list)] with dtype tf.float32.
It represents the rating of each movie in the candidate list.
"""
random_state = np.random.RandomState(seed)

example_lists_by_user = collections.defaultdict(_create_feature_dict)
features = rating_dataset.take(1).get_single_element().keys()
example_lists_by_user = collections.defaultdict(lambda: _create_feature_dict(features))

movie_title_vocab = set()
for example in rating_dataset:
user_id = example["user_id"].numpy()
example_lists_by_user[user_id]["movie_title"].append(
example["movie_title"])
example_lists_by_user[user_id]["user_rating"].append(
example["user_rating"])
movie_title_vocab.add(example["movie_title"].numpy())
user_id = example.get('user_id').numpy()
for key, value in example.items():
example_lists_by_user[user_id][key].append(value.numpy())

tensor_slices = {"user_id": [], "movie_title": [], "user_rating": []}
tensor_slices = {key: [] for key in features}

for user_id, feature_lists in example_lists_by_user.items():
for _ in range(num_list_per_user):
Expand All @@ -180,13 +174,13 @@ def sample_listwise(
if len(feature_lists["movie_title"]) < num_examples_per_list:
continue

sampled_movie_titles, sampled_ratings = _sample_list(
feature_lists,
num_examples_per_list,
random_state=random_state,
sampled_features = _sample_list(
feature_lists,
num_examples_per_list,
random_state=random_state,
)
tensor_slices["user_id"].append(user_id)
tensor_slices["movie_title"].append(sampled_movie_titles)
tensor_slices["user_rating"].append(sampled_ratings)

for feature, samples in sampled_features.items():
tensor_slices[feature].append(samples)

return tf.data.Dataset.from_tensor_slices(tensor_slices)