-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathstreamlit_app.py
481 lines (387 loc) · 22.7 KB
/
streamlit_app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
# Necessary Imports
import streamlit as st
# Standard libraries
import time
import os
import base64
# Data manipulation and analysis
import pandas as pd
import numpy as np
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from sklearn.metrics.pairwise import euclidean_distances
# Plotting and visualization
import plotly.graph_objects as go
import plotly.express as px
import imageio.v2 as imageio_v2
# External libraries
import requests
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline, logging
logging.set_verbosity_error()
# Serialization
import pickle
# OpenAI
import openai
# Set up OpenAI API key
openai.api_key = st.secrets["OPENAI_KEY"]
# Helper Functions
# Function to compute embeddings for a given text
def get_embeddings(text):
response = openai.Embedding.create(
input=f'a {text} person',
model="text-embedding-ada-002"
)
embedding = response['data'][0]['embedding']
return np.array(embedding)
def save_embeddings(embeddings_dict, filename):
with open(filename + '.pickle', 'wb') as handle:
pickle.dump(embeddings_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
# print(f"Embeddings saved to {filename}.pickle")
def load_embeddings(filename):
with open(filename + '.pickle', 'rb') as handle:
embeddings_dict = pickle.load(handle)
# print(f"Embeddings loaded from {filename}.pickle")
return embeddings_dict
def capture_frames_from_plotly(fig, angles, directory="frames", zoom_factor=1):
if not os.path.exists(directory):
os.makedirs(directory)
frame_paths = []
for angle in angles:
fig.update_layout(scene_camera=dict(up=dict(x=0, y=0, z=1), # Keeping the "up" direction fixed to the z-axis
center=dict(x=0, y=0, z=0),
eye=dict(x=zoom_factor * np.cos(np.radians(angle)),
y=zoom_factor * np.sin(np.radians(angle)),
z=0))) # Setting z to 0 to get the straight-on view in z dimension
frame_path = os.path.join(directory, f"frame_{angle}.png")
fig.write_image(frame_path)
frame_paths.append(frame_path)
return frame_paths
def create_gif_from_frames(frame_paths, gif_path="rotating_graph.gif", duration=150):
"""
Create a GIF from the provided frame paths.
Args:
- frame_paths (list): List of paths to individual frames/images.
- gif_path (str): Path where the GIF will be saved.
- duration (int): Duration each frame is displayed in milliseconds. Default is 100ms.
Returns:
- None
"""
with imageio_v2.get_writer(gif_path, mode='I', duration=duration, loop=0) as writer:
for frame_path in frame_paths:
image = imageio_v2.imread(frame_path)
writer.append_data(image)
def is_valid_word(word):
# Check word using the Datamuse API
response = requests.get(f"https://api.datamuse.com/words?sp={word}&max=1")
return len(response.json()) > 0
def load_sentiment_dict(path):
"""Load the sentiment dictionary from a given path."""
with open(path, "rb") as f:
return pickle.load(f)
def get_N_closest_words(descriptor, sentiment_dict, N=5):
"""Get N closest words with a similar sentiment score within the same sentiment label."""
label, score = sentiment_dict[descriptor]["label"], sentiment_dict[descriptor]["score"]
# Filter the words with the same sentiment label and exclude the provided descriptor
same_label_words = {word: info["score"] for word, info in sentiment_dict.items()
if info["label"] == label and word != descriptor}
# Sort the words by the difference in sentiment scores
sorted_words = sorted(same_label_words.keys(), key=lambda word: abs(same_label_words[word] - score))
# Return the top N closest words
return sorted_words[:N]
# Function to load the RoBERTa model for sentiment analysis
def load_sentiment_model():
tokenizer = AutoTokenizer.from_pretrained("cardiffnlp/twitter-roberta-base-sentiment-latest")
model = AutoModelForSequenceClassification.from_pretrained("cardiffnlp/twitter-roberta-base-sentiment-latest")
sentiment_task = pipeline("sentiment-analysis", model=model, tokenizer=tokenizer)
return sentiment_task
def predict_sentiment(descriptor, N=5):
"""
Predict the sentiment of a given descriptor using the sentiment dictionary and the RoBERTa model.
Args:
- descriptor (str): The human descriptor for which sentiment is to be predicted.
- sentiment_dict (dict): Dictionary of descriptors with their sentiment label and score.
- N (int): Number of closest words to return.
Returns:
- str, float: Sentiment label and score.
"""
sentiment_dict = load_sentiment_dict('sentiment_dict.pkl')
# If the descriptor is in the sentiment dictionary, use the pre-computed values
if descriptor in sentiment_dict:
label = sentiment_dict[descriptor]["label"]
score = sentiment_dict[descriptor]["score"]
else:
# Ensure the descriptor is a valid word
if not is_valid_word(descriptor):
raise ValueError(f"'{descriptor}' is not a recognized word.")
# Predict sentiment using the pre-trained RoBERTa model
sentiment_task = load_sentiment_model()
output = sentiment_task(f'a {descriptor} person')
label = output[0]['label'].upper()
score = output[0]['score']
# Add the new descriptor to the sentiment dictionary
sentiment_dict[descriptor] = {"label": label, "score": score}
# Get N closest words with similar sentiment
closest_words = get_N_closest_words(descriptor, sentiment_dict, N)
result1 = f"{descriptor.capitalize()} is considered to have a {label} connotation, {round(score*100,1)}% match."
result2 = f"{N} most connotatively similar descriptors: {', '.join(closest_words)}"
return label, score, result1, result2
# Inner functions
def reduce_dimensions_to_3D(embeddings):
pca = PCA(n_components=3)
transformed_embeddings = pca.fit_transform(embeddings)
return transformed_embeddings, pca
def interpret_clusters(embeddings_dict, centroid, n_closest):
word_centroid_list = [(word, embedding) for word, embedding in embeddings_dict.items()]
word_centroid_list.sort(key=lambda x: euclidean_distances([centroid], [x[1]]))
closest_words = [x[0] for x in word_centroid_list[:n_closest]]
return closest_words
def get_clusters(embeddings, n_clusters):
kmeans = KMeans(n_clusters=n_clusters, n_init='auto', random_state=23)
labels = kmeans.fit_predict(embeddings)
return labels, kmeans
def get_cluster_of_descriptor(descriptor, embeddings_dict, n_clusters, kmeans, labels):
descriptor_embedding = embeddings_dict[descriptor]
embeddings = list(embeddings_dict.values())
# Find the index of the descriptor embedding in the embeddings list
descriptor_index = next(i for i, emb in enumerate(embeddings) if np.array_equal(emb, descriptor_embedding))
# Return the cluster ID of the specified descriptor
return labels[descriptor_index]
def get_centroid_of_cluster(embeddings_dict, cluster_id, kmeans, n_clusters=13):
return kmeans.cluster_centers_[cluster_id]
def get_similar_descriptors(descriptor, embeddings_dict, N=5):
"""Get N descriptors that are most similar to the input descriptor."""
# Check if the descriptor exists in the embeddings_dict
if descriptor not in embeddings_dict:
raise ValueError(f"{descriptor} not found in the embeddings dictionary.")
# Calculate the cosine similarity between the descriptor and all other words
descriptor_embedding = embeddings_dict[descriptor]
cosine_similarities = {word: np.dot(descriptor_embedding, word_embedding) / (np.linalg.norm(descriptor_embedding) * np.linalg.norm(word_embedding))
for word, word_embedding in embeddings_dict.items()}
# Sort the words by similarity
sorted_words = sorted(cosine_similarities.keys(), key=lambda word: cosine_similarities[word], reverse=True)
# Remove the input descriptor from the list
sorted_words = [word for word in sorted_words if word != descriptor]
# Return the top N similar words
return sorted_words[:N]
def visualize_embeddings_complete(embeddings_dict, kmeans, labels, n_clusters, n_words, highlight_word=None, gif=False):
"""
Visualize the embeddings in a 3D cluster space.
Args:
- embeddings_dict (dict): Dictionary containing words and their embeddings.
- n_clusters (int): Number of clusters for KMeans.
- n_words (int): Number of closest words to the centroid to display.
- highlight_word (str, optional): Word to highlight in the visualization.
Returns:
- fig (plotly.graph_objects.Figure): The 3D visualization.
- frame_paths (list): Paths to the frames captured for gif creation.
"""
if highlight_word:
cluster_id = get_cluster_of_descriptor(highlight_word, embeddings_dict, n_clusters, kmeans, labels)
# Begin main function
transformed_embeddings, pca = reduce_dimensions_to_3D(list(embeddings_dict.values()))
word_cluster_map = dict(zip(embeddings_dict.keys(), labels))
df = pd.DataFrame(transformed_embeddings, columns=['x', 'y', 'z'])
df['label'] = labels
df['adjective'] = embeddings_dict.keys()
color_list = list(px.colors.qualitative.Plotly)
colors = [color_list[label % len(color_list)] for label in df['label']]
sizes = [15 if word == highlight_word else 5 for word in df['adjective']]
highlight_colors = ['red' if word == highlight_word else color for word, color in zip(df['adjective'], colors)]
fig = go.Figure()
scatter = go.Scatter3d(x=df['x'], y=df['y'], z=df['z'], mode='markers',
marker=dict(color=highlight_colors, size=sizes),
text=df['adjective'], hoverinfo='text', showlegend=False)
fig.add_trace(scatter)
transformed_centroids = pca.transform(kmeans.cluster_centers_)
# Place user's word first in the legend
if highlight_word:
fig.add_trace(go.Scatter3d(x=[df[df['adjective'] == highlight_word]['x'].values[0]],
y=[df[df['adjective'] == highlight_word]['y'].values[0]],
z=[df[df['adjective'] == highlight_word]['z'].values[0]],
mode='markers',
marker=dict(size=15, color='red', symbol='cross', line=dict(color='Black', width=1)),
showlegend=True, name=f"Selected word: {highlight_word} (part of cluster {cluster_id+1})"))
cluster_texts = []
for i, centroid in enumerate(kmeans.cluster_centers_):
closest_words = interpret_clusters(embeddings_dict, centroid, 8)
legend_text = f"Cluster {i+1}: {', '.join(closest_words)}"
cluster_texts.append(legend_text)
fig.add_trace(go.Scatter3d(x=[transformed_centroids[i][0]], y=[transformed_centroids[i][1]], z=[transformed_centroids[i][2]], mode='markers',
marker=dict(size=8, color=color_list[i % len(color_list)], symbol='diamond',
line=dict(color='Black', width=1)),
showlegend=True, name=legend_text))
fig.update_layout(title_text=f"{n_clusters} Clusters of Human Descriptors (Interactive)",
title_x=0.21, title_y=0.92, title_font_size=24, # Add title_y attribute
scene=dict(xaxis_title='PC1', yaxis_title='PC2', zaxis_title='PC3'),
autosize=False, width=1200, height=1000,
legend=dict(y=-0.1, x=0.5, xanchor='center', orientation='h'))
if gif:
angles = list(range(0, 360, 5))
frame_paths = capture_frames_from_plotly(fig, angles)
# Generate the gif
create_gif_from_frames(frame_paths)
return fig, word_cluster_map, cluster_texts
def get_similar_descriptors(descriptor, embeddings_dict, N=5):
"""Get N descriptors that are most similar to the input descriptor."""
# Check if the descriptor exists in the embeddings_dict
if descriptor not in embeddings_dict:
raise ValueError(f"{descriptor} not found in the embeddings dictionary.")
# Calculate the cosine similarity between the descriptor and all other words
descriptor_embedding = embeddings_dict[descriptor]
cosine_similarities = {word: np.dot(descriptor_embedding, word_embedding) / (np.linalg.norm(descriptor_embedding) * np.linalg.norm(word_embedding))
for word, word_embedding in embeddings_dict.items()}
# Sort the words by similarity
sorted_words = sorted(cosine_similarities.keys(), key=lambda word: cosine_similarities[word], reverse=True)
# Remove the input descriptor from the list
sorted_words = [word for word in sorted_words if word != descriptor]
# Return the top N similar words
return sorted_words[:N]
# Main Function for Analyze Descriptor
def analyze_descriptor_text(descriptor, embeddings_dict, kmeans, labels, n_clusters=13, n_words=15):
"""
Analyze a given descriptor:
- Identify and print descriptors in its cluster.
- Identify and print similar descriptors.
- Identify and print opposite descriptors.
"""
results = []
sentiment_dict = load_sentiment_dict('sentiment_dict.pkl')
# embeddings_dict = load_embeddings('condon_cleaned')
# if descriptor not in embeddings_dict:
# if not is_valid_word(descriptor):
# raise ValueError(f"'{descriptor}' is not a recognized word.")
# embeddings_dict[descriptor] = get_embeddings(descriptor)
# Identifying the cluster of the descriptor
cluster_id = get_cluster_of_descriptor(descriptor, embeddings_dict, n_clusters, kmeans, labels)
centroid = get_centroid_of_cluster(embeddings_dict, cluster_id, kmeans, n_clusters=n_clusters)
closest_words_to_centroid = interpret_clusters(embeddings_dict, centroid, n_words)
# results.append(f"'{descriptor.capitalize()}' belongs to cluster {cluster_id+1} of {n_clusters}: {', '.join(closest_words_to_centroid)}")
# Identifying descriptors most similar and opposite to the input word
similar = get_similar_descriptors(descriptor, embeddings_dict, N=n_words)
results.append(f"Descriptors most mathematically similar to '{descriptor}': {', '.join(similar)}")
# Sentiment results
__, __, r1, r2 = predict_sentiment(descriptor)
results.append(r1)
results.append(r2)
results.append(f"'{descriptor.capitalize()}' belongs to cluster {cluster_id + 1} of {n_clusters}: {', '.join(closest_words_to_centroid)}")
results.append("(Visualize your word amongst all of these clusters by clicking the button below!)")
return results
def analyze_descriptor_visual(descriptor, embeddings_dict, kmeans, labels, n_clusters=13, n_words=15, gif=False):
"""
Visualize the descriptor in a 3D cluster space.
"""
# embeddings_dict = load_embeddings('condon_cleaned')
if descriptor not in embeddings_dict:
if not is_valid_word(descriptor):
raise ValueError(f"'{descriptor}' is not a recognized word.")
embeddings_dict[descriptor] = get_embeddings(descriptor)
fig, word_cluster_map, cluster_text = visualize_embeddings_complete(embeddings_dict, kmeans, labels, n_clusters,
n_words, highlight_word=descriptor, gif=gif)
return fig
# Main Function for Descriptor Blender
def descriptor_blender(descriptors, N=10):
"""
Combines a list of descriptors using additive method and finds
the words in the embeddings_dict that are closest to this combined representation without clustering.
Args:
- descriptors (list of str): List of descriptors.
- embeddings_dict (dict): Dictionary of descriptors with their embeddings.
- N (int): Number of closest descriptors to return. Default is 10.
Returns:
- None: Prints the descriptors that are close to the combined representation of the input descriptors.
"""
# Check and compute embeddings for missing descriptors
embeddings_dict = load_embeddings('condon_cleaned')
descriptors = cleaned_descriptors = [desc.strip().lower() for desc in descriptors]
intersection_words = []
for descriptor in descriptors:
if descriptor not in embeddings_dict:
if is_valid_word(descriptor):
embedding = get_embeddings(descriptor)
embeddings_dict[descriptor] = embedding
else:
raise ValueError(f"{descriptor} not found in the embeddings dictionary, and it's not a valid descriptor.")
# Compute the combined embedding using the additive method
combined_embedding = sum([embeddings_dict[descriptor] for descriptor in descriptors])
# Remove the original descriptors to not have them in the result
words_to_compare = {word: embedding for word, embedding in embeddings_dict.items() if word not in descriptors}
# Calculate enriched scores for words
distances_to_combined = {word: np.linalg.norm(embedding - combined_embedding) for word, embedding in words_to_compare.items()}
enrich_scores = {word: distance * sum([np.linalg.norm(embeddings_dict[descriptor] - embeddings_dict[word]) for descriptor in descriptors]) for word, distance in distances_to_combined.items()}
# Get the top N words based on enriched scores
closest_words = sorted(enrich_scores.keys(), key=lambda word: enrich_scores[word])[:N]
return closest_words
# Streamlit App
st.title("Human Descriptor Analyzer & Blender")
st.markdown("Analyze descriptors from [Condon et al adjective dataset](https://pie-lab.github.io/tda/tda-difficulty.html) (a superset of the trait descriptive adjectives used to construct the Five Factor Model), and blend them to find interesting intersections in the high-dimensional embeddings space.")
st.header("Analyze Descriptor")
st.markdown("""
Input a word that describes human personality, and receive insights on:
- **Similar Descriptors**: Personality descriptors that have the most mathematically similar meaning to the input word in embedding-space.
- **Word Connotation**: Understand the sentiment of the word using the Roberta sentiment analysis model.
- **Similar Connotation**: Words that share a similar sentiment (not likely to have the same *meaning*).
- **Cluster Information**: Discover the cluster the word belongs to, and get to know other words that signify that cluster.
- **Interactive 3D Visualization** (Optional): Explore an interactive 3D space of all human descriptors in the Condon set, including the target word. Dive deep into the vast space of possible ways to describe a personality!
""")
descriptor = st.text_input("Enter any adjective that describes human personality—'they are a ________ person':")
descriptor = descriptor.replace(" ", "").replace(",", "")
n_clusters = st.number_input("How many adjective groupings (more groupings = more fine-grained categories):", min_value=1, value=23, step=1)
n_similar = st.number_input("How many related words to return:", min_value=1, value=15, step=1)
# Placeholders for analyze button, results, visualize button, and visualization
analyze_button_placeholder = st.empty()
analysis_results_placeholder = st.empty()
visualize_button_placeholder = st.empty()
visualization_placeholder = st.empty()
# When the "Analyze" button is pressed, only textual insights will be shown
if analyze_button_placeholder.button("Analyze"):
with st.spinner('Analyzing the descriptor...'):
embeddings_dict = load_embeddings('condon_cleaned')
if descriptor not in embeddings_dict:
if not is_valid_word(descriptor):
raise ValueError(f"'{descriptor}' is not a recognized word.")
embeddings_dict[descriptor] = get_embeddings(descriptor)
labels, kmeans = get_clusters(list(embeddings_dict.values()), n_clusters)
results = analyze_descriptor_text(descriptor, embeddings_dict, kmeans, labels, n_clusters, n_similar)
# Store kmeans and labels in st.session_state for later use
st.session_state['kmeans'] = kmeans
st.session_state['labels'] = labels
st.session_state['embeddings'] = embeddings_dict
st.session_state['analysis_results'] = results
# New button for visualization
if visualize_button_placeholder.button("Visualize your descriptor in full 3D space (interactive)"):
with st.spinner('Generating the 3D clustering visualization...this will take about 10 seconds.'):
# Retrieve kmeans and labels from st.session_state
kmeans = st.session_state.get('kmeans', None)
labels = st.session_state.get('labels', None)
embeddings_dict = st.session_state.get('embeddings', None)
# Ensure kmeans and labels are available before proceeding
if kmeans is not None and labels is not None:
plot = analyze_descriptor_visual(descriptor, embeddings_dict, kmeans, labels, n_clusters, n_similar)
st.session_state['visualization'] = plot
else:
st.warning("Please analyze the descriptor first before visualizing.")
# Display stored results and visualization in their respective placeholders
if 'analysis_results' in st.session_state:
all_results = "\n\n".join(st.session_state['analysis_results'])
analysis_results_placeholder.markdown(all_results)
if 'visualization' in st.session_state:
visualization_placeholder.plotly_chart(st.session_state['visualization'], use_container_width=True)
# Remaining Streamlit code related to the Descriptor Blender
st.header("Descriptor Blender")
st.write("This tool blends the meanings of your input words to find words that capture aspects of all of them. For example, inputting the words 'funny', 'awkward', and 'endearing' might produce a blended word like 'quirky'.")
words_to_blend = st.text_area("Enter words to blend (comma-separated):").split(',')
num_output_words = st.number_input("How many output words:", min_value=1, value=20, step=1)
if st.button("Blend"):
if len(words_to_blend) < 2:
st.warning("Please enter at least two descriptors to blend.")
intersection_words = descriptor_blender(words_to_blend, num_output_words)
st.write("Descriptors that blend your input (higher in list = better blend score):")
# Create two columns to display the words
col1, col2 = st.columns(2)
# Iterate through words in pairs
for i in range(0, len(intersection_words), 2):
word1 = intersection_words[i]
word2 = intersection_words[i + 1] if i + 1 < len(intersection_words) else ""
col1.write(word1)
col2.write(word2)