-
Notifications
You must be signed in to change notification settings - Fork 775
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix issue with zeroshot topic modeling missing outlier #1957
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for picking up this bug so quickly (before I'd even noticed it!)
I've tested the changes locally using the example in the tutorial
from datasets import load_dataset
from pandas import DataFrame
from bertopic import BERTopic
from bertopic.representation import KeyBERTInspired
# %%
# We select a subsample of 5000 abstracts from ArXiv
dataset = load_dataset("CShorten/ML-ArXiv-Papers")["train"]
docs = dataset["abstract"][:5_000]
# We define a number of topics that we know are in the documents
zeroshot_topic_list = ["Clustering", "Topic Modeling", "Large Language Models"]
# We fit our model using the zero-shot topics
# and we define a minimum similarity. For each document,
# if the similarity does not exceed that value, it will be used
# for clustering instead.
topic_model = BERTopic(
embedding_model="thenlper/gte-small",
min_topic_size=15,
zeroshot_topic_list=zeroshot_topic_list,
zeroshot_min_similarity=.85,
representation_model=KeyBERTInspired()
)
topics, probabilities = topic_model.fit_transform(docs)
and with this change I get the error
File [/opt/homebrew/lib/python3.11/site-packages/bertopic/_bertopic.py:448](https://file+.vscode-resource.vscode-cdn.net/opt/homebrew/lib/python3.11/site-packages/bertopic/_bertopic.py:448), in BERTopic.fit_transform(self, documents, embeddings, images, y)
[446](https://file+.vscode-resource.vscode-cdn.net/opt/homebrew/lib/python3.11/site-packages/bertopic/_bertopic.py:446) # Combine Zero-shot with outliers
[447](https://file+.vscode-resource.vscode-cdn.net/opt/homebrew/lib/python3.11/site-packages/bertopic/_bertopic.py:447) if self._is_zeroshot() and len(documents) != len(doc_ids):
--> [448](https://file+.vscode-resource.vscode-cdn.net/opt/homebrew/lib/python3.11/site-packages/bertopic/_bertopic.py:448) predictions = self._combine_zeroshot_topics(documents, assigned_documents, assigned_embeddings)
[450](https://file+.vscode-resource.vscode-cdn.net/opt/homebrew/lib/python3.11/site-packages/bertopic/_bertopic.py:450) return predictions, self.probabilities_
File [/opt/homebrew/lib/python3.11/site-packages/bertopic/_bertopic.py:3717](https://file+.vscode-resource.vscode-cdn.net/opt/homebrew/lib/python3.11/site-packages/bertopic/_bertopic.py:3717), in BERTopic._combine_zeroshot_topics(self, documents, assigned_documents, embeddings)
[3714](https://file+.vscode-resource.vscode-cdn.net/opt/homebrew/lib/python3.11/site-packages/bertopic/_bertopic.py:3714) new_mappings[topic] = topic - 1
[3716](https://file+.vscode-resource.vscode-cdn.net/opt/homebrew/lib/python3.11/site-packages/bertopic/_bertopic.py:3716) # Re-map the topics including all representations (labels, sizes, embeddings, etc.)
-> [3717](https://file+.vscode-resource.vscode-cdn.net/opt/homebrew/lib/python3.11/site-packages/bertopic/_bertopic.py:3717) self.topics_ = [new_mappings[topic] for topic in self.topics_]
[3718](https://file+.vscode-resource.vscode-cdn.net/opt/homebrew/lib/python3.11/site-packages/bertopic/_bertopic.py:3718) self.topic_representations_ = {new_mappings[topic]: repr for topic, repr in self.topic_representations_.items()}
[3719](https://file+.vscode-resource.vscode-cdn.net/opt/homebrew/lib/python3.11/site-packages/bertopic/_bertopic.py:3719) self.topic_labels_ = {new_mappings[topic]: label for topic, label in self.topic_labels_.items()}
File [/opt/homebrew/lib/python3.11/site-packages/bertopic/_bertopic.py:3717](https://file+.vscode-resource.vscode-cdn.net/opt/homebrew/lib/python3.11/site-packages/bertopic/_bertopic.py:3717), in <listcomp>(.0)
[3714](https://file+.vscode-resource.vscode-cdn.net/opt/homebrew/lib/python3.11/site-packages/bertopic/_bertopic.py:3714) new_mappings[topic] = topic - 1
[3716](https://file+.vscode-resource.vscode-cdn.net/opt/homebrew/lib/python3.11/site-packages/bertopic/_bertopic.py:3716) # Re-map the topics including all representations (labels, sizes, embeddings, etc.)
-> [3717](https://file+.vscode-resource.vscode-cdn.net/opt/homebrew/lib/python3.11/site-packages/bertopic/_bertopic.py:3717) self.topics_ = [new_mappings[topic] for topic in self.topics_]
[3718](https://file+.vscode-resource.vscode-cdn.net/opt/homebrew/lib/python3.11/site-packages/bertopic/_bertopic.py:3718) self.topic_representations_ = {new_mappings[topic]: repr for topic, repr in self.topic_representations_.items()}
[3719](https://file+.vscode-resource.vscode-cdn.net/opt/homebrew/lib/python3.11/site-packages/bertopic/_bertopic.py:3719) self.topic_labels_ = {new_mappings[topic]: label for topic, label in self.topic_labels_.items()}
KeyError: nan
I think this is because the reverse_topic_labels
is applied before "Outliers"
is added to it. If I move the assignment up by 2 lines then the tutorial produces the expected result.
bertopic/_bertopic.py
Outdated
df.Label = df.Label.map(reverse_topic_labels) | ||
merged_model.topics_ = df.Label.values | ||
if self._outliers: | ||
reverse_topic_labels["Outliers"] = -1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
df.Label = df.Label.map(reverse_topic_labels) | |
merged_model.topics_ = df.Label.values | |
if self._outliers: | |
reverse_topic_labels["Outliers"] = -1 | |
if self._outliers: | |
reverse_topic_labels["Outliers"] = -1 | |
df.Label = df.Label.map(reverse_topic_labels) | |
merged_model.topics_ = df.Label.values |
Otherwise "Outliers"
isn't added to reverse_topic_labels
until after it's used, meaning that outliers are assigned nan
instead of -1 in merged_model.topics_
and line 3718 throws an error.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Aside: I think that the way this is set up means that the values in merged_model.topics_
have data type Int64
, because initially df.Label
has integers and nan
values in them, before the nan
s are replaced with -1
.
Not a big issue but it means that the data type for self.topics_
is different if you merge models.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome! Thanks for looking into this, it is highly appreciated. I made both changes you suggested which, hopefully, should have resolved this issue. If these changes indeed fix the underlying issue, most likely I will create a new minor release (0.16.2) considering zero-shot topic modeling in BERTopic is widely used.
Adresses #1946