From b0a9602929162f22afff708be78691936fc334cc Mon Sep 17 00:00:00 2001 From: David DiCato Date: Tue, 21 May 2024 19:36:48 -0600 Subject: [PATCH] In multi-aspect context, allow Main model to be chained [Issue #1846] --- bertopic/_bertopic.py | 47 ++++++++++++++++++++++++++++++++----------- 1 file changed, 35 insertions(+), 12 deletions(-) diff --git a/bertopic/_bertopic.py b/bertopic/_bertopic.py index de57c35a..73ffe42d 100644 --- a/bertopic/_bertopic.py +++ b/bertopic/_bertopic.py @@ -4056,35 +4056,58 @@ def _extract_words_per_topic(self, scores = np.take_along_axis(scores, sorted_indices, axis=1) # Get top 30 words per topic based on c-TF-IDF score - topics = {label: [(words[word_index], score) - if word_index is not None and score > 0 - else ("", 0.00001) - for word_index, score in zip(indices[index][::-1], scores[index][::-1]) - ] - for index, label in enumerate(labels)} + base_topics = {label: [(words[word_index], score) + if word_index is not None and score > 0 + else ("", 0.00001) + for word_index, score in zip(indices[index][::-1], scores[index][::-1]) + ] + for index, label in enumerate(labels)} # Fine-tune the topic representations - if isinstance(self.representation_model, list): + topics = base_topics.copy() + if not self.representation_model: + # Default Main behavior: c_tf_idf + top_n_words + topics = {label: values[:self.top_n_words] for label, values in topics.items()} + elif isinstance(self.representation_model, list): for tuner in self.representation_model: topics = tuner.extract_topics(self, documents, c_tf_idf, topics) elif isinstance(self.representation_model, BaseRepresentation): topics = self.representation_model.extract_topics(self, documents, c_tf_idf, topics) elif isinstance(self.representation_model, dict): if self.representation_model.get("Main"): - topics = self.representation_model["Main"].extract_topics(self, documents, c_tf_idf, topics) - topics = {label: values[:self.top_n_words] for label, values in topics.items()} + main_model = self.representation_model["Main"] + if isinstance(main_model, BaseRepresentation): + topics = main_model.extract_topics(self, documents, c_tf_idf, topics) + elif isinstance(main_model, list): + for tuner in main_model: + topics = tuner.extract_topics(self, documents, c_tf_idf, topics) + else: + raise TypeError( + f"unsupported type {type(main_model).__name__} for representation_model['Main']") + else: + # Default Main behavior: c_tf_idf + top_n_words + topics = {label: values[:self.top_n_words] for label, values in topics.items()} + else: + raise TypeError( + f"unsupported type {type(self.representation_model).__name__} for representation_model") # Extract additional topic aspects if calculate_aspects and isinstance(self.representation_model, dict): for aspect, aspect_model in self.representation_model.items(): - aspects = topics.copy() if aspect != "Main": + aspects = base_topics.copy() + if not aspect_model: + # Default non-Main behavior: c_tf_idf + pass if isinstance(aspect_model, list): for tuner in aspect_model: aspects = tuner.extract_topics(self, documents, c_tf_idf, aspects) - self.topic_aspects_[aspect] = aspects elif isinstance(aspect_model, BaseRepresentation): - self.topic_aspects_[aspect] = aspect_model.extract_topics(self, documents, c_tf_idf, aspects) + aspects = aspect_model.extract_topics(self, documents, c_tf_idf, aspects) + else: + raise TypeError( + f"unsupported type {type(aspect_model).__name__} for representation_model[{repr(aspect)}]") + self.topic_aspects_[aspect] = aspects return topics