Skip to content

Commit

Permalink
In multi-aspect context, allow Main model to be chained [Issue #1846]
Browse files Browse the repository at this point in the history
  • Loading branch information
David DiCato committed Jun 5, 2024
1 parent 72cc3e0 commit b0a9602
Showing 1 changed file with 35 additions and 12 deletions.
47 changes: 35 additions & 12 deletions bertopic/_bertopic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4056,35 +4056,58 @@ def _extract_words_per_topic(self,
scores = np.take_along_axis(scores, sorted_indices, axis=1)

# Get top 30 words per topic based on c-TF-IDF score
topics = {label: [(words[word_index], score)
if word_index is not None and score > 0
else ("", 0.00001)
for word_index, score in zip(indices[index][::-1], scores[index][::-1])
]
for index, label in enumerate(labels)}
base_topics = {label: [(words[word_index], score)
if word_index is not None and score > 0
else ("", 0.00001)
for word_index, score in zip(indices[index][::-1], scores[index][::-1])
]
for index, label in enumerate(labels)}

# Fine-tune the topic representations
if isinstance(self.representation_model, list):
topics = base_topics.copy()
if not self.representation_model:
# Default Main behavior: c_tf_idf + top_n_words
topics = {label: values[:self.top_n_words] for label, values in topics.items()}
elif isinstance(self.representation_model, list):
for tuner in self.representation_model:
topics = tuner.extract_topics(self, documents, c_tf_idf, topics)
elif isinstance(self.representation_model, BaseRepresentation):
topics = self.representation_model.extract_topics(self, documents, c_tf_idf, topics)
elif isinstance(self.representation_model, dict):
if self.representation_model.get("Main"):
topics = self.representation_model["Main"].extract_topics(self, documents, c_tf_idf, topics)
topics = {label: values[:self.top_n_words] for label, values in topics.items()}
main_model = self.representation_model["Main"]
if isinstance(main_model, BaseRepresentation):
topics = main_model.extract_topics(self, documents, c_tf_idf, topics)
elif isinstance(main_model, list):
for tuner in main_model:
topics = tuner.extract_topics(self, documents, c_tf_idf, topics)
else:
raise TypeError(
f"unsupported type {type(main_model).__name__} for representation_model['Main']")
else:
# Default Main behavior: c_tf_idf + top_n_words
topics = {label: values[:self.top_n_words] for label, values in topics.items()}
else:
raise TypeError(
f"unsupported type {type(self.representation_model).__name__} for representation_model")

# Extract additional topic aspects
if calculate_aspects and isinstance(self.representation_model, dict):
for aspect, aspect_model in self.representation_model.items():
aspects = topics.copy()
if aspect != "Main":
aspects = base_topics.copy()
if not aspect_model:
# Default non-Main behavior: c_tf_idf
pass
if isinstance(aspect_model, list):
for tuner in aspect_model:
aspects = tuner.extract_topics(self, documents, c_tf_idf, aspects)
self.topic_aspects_[aspect] = aspects
elif isinstance(aspect_model, BaseRepresentation):
self.topic_aspects_[aspect] = aspect_model.extract_topics(self, documents, c_tf_idf, aspects)
aspects = aspect_model.extract_topics(self, documents, c_tf_idf, aspects)
else:
raise TypeError(
f"unsupported type {type(aspect_model).__name__} for representation_model[{repr(aspect)}]")
self.topic_aspects_[aspect] = aspects

return topics

Expand Down

0 comments on commit b0a9602

Please sign in to comment.