-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathRecommendation_using_embeddings.py
185 lines (155 loc) · 6.61 KB
/
Recommendation_using_embeddings.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
# imports
import pandas as pd
import pickle
import openai
from openai.embeddings_utils import (
get_embedding,
distances_from_embeddings,
tsne_components_from_embeddings,
chart_from_components,
indices_of_nearest_neighbors_from_distances,
)
# constants
EMBEDDING_MODEL = "text-embedding-ada-002"
def open_file(filepath):
with open(filepath, 'r', encoding='utf-8') as infile:
return infile.read()
openai.api_key = open_file('openaiapikey.txt')
# load data (full dataset available at http://groups.di.unipi.it/~gulli/AG_corpus_of_news_articles.html)
dataset_path = "data/AG_news_samples.csv"
df = pd.read_csv(dataset_path)
# print dataframe
n_examples = 5
df.head(n_examples)
# print the title, description, and label of each example
for idx, row in df.head(n_examples).iterrows():
print("")
print(f"Title: {row['title']}")
print(f"Description: {row['description']}")
print(f"Label: {row['label']}")
# establish a cache of embeddings to avoid recomputing
# cache is a dict of tuples (text, model) -> embedding, saved as a pickle file
# set path to embedding cache
embedding_cache_path = "data/recommendations_embeddings_cache.pkl"
# load the cache if it exists, and save a copy to disk
try:
embedding_cache = pd.read_pickle(embedding_cache_path)
except FileNotFoundError:
embedding_cache = {}
with open(embedding_cache_path, "wb") as embedding_cache_file:
pickle.dump(embedding_cache, embedding_cache_file)
# define a function to retrieve embeddings from the cache if present, and otherwise request via the API
def embedding_from_string(
string: str,
model: str = EMBEDDING_MODEL,
embedding_cache=embedding_cache
) -> list:
"""Return embedding of given string, using a cache to avoid recomputing."""
if (string, model) not in embedding_cache.keys():
embedding_cache[(string, model)] = get_embedding(string, model)
with open(embedding_cache_path, "wb") as embedding_cache_file:
pickle.dump(embedding_cache, embedding_cache_file)
return embedding_cache[(string, model)]
# as an example, take the first description from the dataset
example_string = df["description"].values[0]
print(f"\nExample string: {example_string}")
# print the first 10 dimensions of the embedding
example_embedding = embedding_from_string(example_string)
print(f"\nExample embedding: {example_embedding[:10]}...")
def print_recommendations_from_strings(
strings: list[str],
index_of_source_string: int,
k_nearest_neighbors: int = 1,
model=EMBEDDING_MODEL,
) -> list[int]:
"""Print out the k nearest neighbors of a given string."""
# get embeddings for all strings
embeddings = [embedding_from_string(string, model=model) for string in strings]
# get the embedding of the source string
query_embedding = embeddings[index_of_source_string]
# get distances between the source embedding and other embeddings (function from embeddings_utils.py)
distances = distances_from_embeddings(query_embedding, embeddings, distance_metric="cosine")
# get indices of nearest neighbors (function from embeddings_utils.py)
indices_of_nearest_neighbors = indices_of_nearest_neighbors_from_distances(distances)
# print out source string
query_string = strings[index_of_source_string]
print(f"Source string: {query_string}")
# print out its k nearest neighbors
k_counter = 0
for i in indices_of_nearest_neighbors:
# skip any strings that are identical matches to the starting string
if query_string == strings[i]:
continue
# stop after printing out k articles
if k_counter >= k_nearest_neighbors:
break
k_counter += 1
# print out the similar strings and their distances
print(
f"""
--- Recommendation #{k_counter} (nearest neighbor {k_counter} of {k_nearest_neighbors}) ---
String: {strings[i]}
Distance: {distances[i]:0.3f}"""
)
return indices_of_nearest_neighbors
article_descriptions = df["description"].tolist()
tony_blair_articles = print_recommendations_from_strings(
strings=article_descriptions, # let's base similarity off of the article description
index_of_source_string=0, # let's look at articles similar to the first one about Tony Blair
k_nearest_neighbors=5, # let's look at the 5 most similar articles
)
chipset_security_articles = print_recommendations_from_strings(
strings=article_descriptions, # let's base similarity off of the article description
index_of_source_string=1, # let's look at articles similar to the second one about a more secure chipset
k_nearest_neighbors=5, # let's look at the 5 most similar articles
)
# get embeddings for all article descriptions
embeddings = [embedding_from_string(string) for string in article_descriptions]
# compress the 2048-dimensional embeddings into 2 dimensions using t-SNE
tsne_components = tsne_components_from_embeddings(embeddings)
# get the article labels for coloring the chart
labels = df["label"].tolist()
chart_from_components(
components=tsne_components,
labels=labels,
strings=article_descriptions,
width=600,
height=500,
title="t-SNE components of article descriptions",
)
# create labels for the recommended articles
def nearest_neighbor_labels(
list_of_indices: list[int],
k_nearest_neighbors: int = 5
) -> list[str]:
"""Return a list of labels to color the k nearest neighbors."""
labels = ["Other" for _ in list_of_indices]
source_index = list_of_indices[0]
labels[source_index] = "Source"
for i in range(k_nearest_neighbors):
nearest_neighbor_index = list_of_indices[i + 1]
labels[nearest_neighbor_index] = f"Nearest neighbor (top {k_nearest_neighbors})"
return labels
tony_blair_labels = nearest_neighbor_labels(tony_blair_articles, k_nearest_neighbors=5)
chipset_security_labels = nearest_neighbor_labels(chipset_security_articles, k_nearest_neighbors=5
)
# a 2D chart of nearest neighbors of the Tony Blair article
chart_from_components(
components=tsne_components,
labels=tony_blair_labels,
strings=article_descriptions,
width=600,
height=500,
title="Nearest neighbors of the Tony Blair article",
category_orders={"label": ["Other", "Nearest neighbor (top 5)", "Source"]},
)
# a 2D chart of nearest neighbors of the chipset security article
chart_from_components(
components=tsne_components,
labels=chipset_security_labels,
strings=article_descriptions,
width=600,
height=500,
title="Nearest neighbors of the chipset security article",
category_orders={"label": ["Other", "Nearest neighbor (top 5)", "Source"]},
)