diff --git a/bertopic/_bertopic.py b/bertopic/_bertopic.py index a18325ae..de57c35a 100644 --- a/bertopic/_bertopic.py +++ b/bertopic/_bertopic.py @@ -3679,15 +3679,20 @@ def _combine_zeroshot_topics(self, cluster_indices = list(documents.Old_ID.values) cluster_names = list(merged_model.topic_labels_.values())[len(set(y)):] - cluster_topics = [cluster_names[topic + self._outliers] for topic in documents.Topic.values] + if self._outliers: + cluster_topics = [cluster_names[topic] if topic != -1 else "Outliers" for topic in documents.Topic.values] + else: + cluster_topics = [cluster_names[topic] for topic in documents.Topic.values] df = pd.DataFrame({ "Indices": zeroshot_indices + cluster_indices, "Label": zeroshot_topics + cluster_topics} ).sort_values("Indices") reverse_topic_labels = dict((v, k) for k, v in merged_model.topic_labels_.items()) + if self._outliers: + reverse_topic_labels["Outliers"] = -1 df.Label = df.Label.map(reverse_topic_labels) - merged_model.topics_ = df.Label.values + merged_model.topics_ = df.Label.astype(int).tolist() # Update the class internally has_outliers = bool(self._outliers)