diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 00000000..c7c1df87 --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,32 @@ +name: Lint + +on: + push: + branches: + - master + - dev + pull_request: + branches: + - master + - dev + +jobs: + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Set up Python 3.12 + uses: actions/setup-python@v5 + with: + python-version: 3.12 + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -e ".[test]" + - name: Ruff Format Check + run: ruff format --check . + id: format + - name: Ruff Lint Check + run: ruff check --output-format=github . + # Still run if format check fails + if: success() || steps.format.conclusion == 'failure' \ No newline at end of file diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index f4224779..6d754ca9 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -40,6 +40,17 @@ If you are using conda, you can approach it as follows: ❗Note: Unit testing the package can take quite some time since it needs to run several variants of the BERTopic pipeline. +## 🧹 Linting and Formatting + +We use [Ruff](https://docs.astral.sh/ruff/) to ensure code is uniformly formatted and to avoid common mistakes and bad practices. + +* To automatically re-format code, run `make format` +* To check for linting issues, run `make lint` - some issues may be automatically fixed, some will not be + +When a pull request is made, the CI will automatically check for linting and formatting issues. However, it will not automatically apply any fixes, so it is easiest to run locally. + +If you believe an error is incorrectly flagged, use a [`# noqa:` comment to suppress](https://docs.astral.sh/ruff/linter/#error-suppression), but this is discouraged unless strictly necessary. + ## 🤓 Collaborative Efforts When you run into any issue with the above or need help to start with a pull request, feel free to reach out in the issues! As with all repositories, this one has its particularities as a result of the maintainer's view. Each repository is quite different and so will their processes. diff --git a/Makefile b/Makefile index 7647f782..bc7f2ba1 100644 --- a/Makefile +++ b/Makefile @@ -4,6 +4,12 @@ test: coverage: pytest --cov +format: + ruff format + +lint: + ruff check --fix + install: python -m pip install -e . diff --git a/bertopic/_bertopic.py b/bertopic/_bertopic.py index eba54162..4a94d139 100644 --- a/bertopic/_bertopic.py +++ b/bertopic/_bertopic.py @@ -1,11 +1,13 @@ +# ruff: noqa: E402 import yaml import warnings + warnings.filterwarnings("ignore", category=FutureWarning) warnings.filterwarnings("ignore", category=UserWarning) try: yaml._warnings_enabled["YAMLLoadWarning"] = False -except (KeyError, AttributeError, TypeError) as e: +except (KeyError, AttributeError, TypeError): pass import re @@ -27,6 +29,7 @@ # Typing import sys + if sys.version_info >= (3, 8): from typing import Literal else: @@ -53,9 +56,13 @@ from bertopic.dimensionality import BaseDimensionalityReduction from bertopic.cluster._utils import hdbscan_delegator, is_supported_hdbscan from bertopic._utils import ( - MyLogger, check_documents_type, check_embeddings_shape, - check_is_fitted, validate_distance_matrix, select_topic_representation, - get_unique_distances + MyLogger, + check_documents_type, + check_embeddings_shape, + check_is_fitted, + validate_distance_matrix, + select_topic_representation, + get_unique_distances, ) import bertopic._save_utils as save_utils @@ -96,7 +103,6 @@ class BERTopic: representative_docs_ (Mapping[int, str]) : The representative documents for each topic. Examples: - ```python from bertopic import BERTopic from sklearn.datasets import fetch_20newsgroups @@ -123,26 +129,28 @@ class BERTopic: try out BERTopic several times until you find the topics that suit you best. """ - def __init__(self, - language: str = "english", - top_n_words: int = 10, - n_gram_range: Tuple[int, int] = (1, 1), - min_topic_size: int = 10, - nr_topics: Union[int, str] = None, - low_memory: bool = False, - calculate_probabilities: bool = False, - seed_topic_list: List[List[str]] = None, - zeroshot_topic_list: List[str] = None, - zeroshot_min_similarity: float = .7, - embedding_model=None, - umap_model: UMAP = None, - hdbscan_model: hdbscan.HDBSCAN = None, - vectorizer_model: CountVectorizer = None, - ctfidf_model: TfidfTransformer = None, - representation_model: BaseRepresentation = None, - verbose: bool = False, - ): - """BERTopic initialization + + def __init__( + self, + language: str = "english", + top_n_words: int = 10, + n_gram_range: Tuple[int, int] = (1, 1), + min_topic_size: int = 10, + nr_topics: Union[int, str] = None, + low_memory: bool = False, + calculate_probabilities: bool = False, + seed_topic_list: List[List[str]] = None, + zeroshot_topic_list: List[str] = None, + zeroshot_min_similarity: float = 0.7, + embedding_model=None, + umap_model: UMAP = None, + hdbscan_model: hdbscan.HDBSCAN = None, + vectorizer_model: CountVectorizer = None, + ctfidf_model: TfidfTransformer = None, + representation_model: BaseRepresentation = None, + verbose: bool = False, + ): + """BERTopic initialization. Arguments: language: The main language used in your documents. The default sentence-transformers @@ -160,7 +168,7 @@ def __init__(self, NOTE: This param will not be used if you pass in your own CountVectorizer. min_topic_size: The minimum size of the topic. Increasing this value will lead - to a lower number of clusters/topics and vice versa. + to a lower number of clusters/topics and vice versa. It is the same parameter as `min_cluster_size` in HDBSCAN. NOTE: This param will not be used if you are using `hdbscan_model`. nr_topics: Specifying the number of topics will reduce the initial @@ -212,8 +220,10 @@ def __init__(self, """ # Topic-based parameters if top_n_words > 100: - logger.warning("Note that extracting more than 100 words from a sparse " - "can slow down computation quite a bit.") + logger.warning( + "Note that extracting more than 100 words from a sparse " + "can slow down computation quite a bit." + ) self.top_n_words = top_n_words self.min_topic_size = min_topic_size @@ -231,25 +241,31 @@ def __init__(self, # Vectorizer self.n_gram_range = n_gram_range - self.vectorizer_model = vectorizer_model or CountVectorizer(ngram_range=self.n_gram_range) + self.vectorizer_model = vectorizer_model or CountVectorizer( + ngram_range=self.n_gram_range + ) self.ctfidf_model = ctfidf_model or ClassTfidfTransformer() # Representation model self.representation_model = representation_model # UMAP or another algorithm that has .fit and .transform functions - self.umap_model = umap_model or UMAP(n_neighbors=15, - n_components=5, - min_dist=0.0, - metric='cosine', - low_memory=self.low_memory) + self.umap_model = umap_model or UMAP( + n_neighbors=15, + n_components=5, + min_dist=0.0, + metric="cosine", + low_memory=self.low_memory, + ) # HDBSCAN or another clustering algorithm that has .fit and .predict functions and # the .labels_ variable to extract the labels - self.hdbscan_model = hdbscan_model or hdbscan.HDBSCAN(min_cluster_size=self.min_topic_size, - metric='euclidean', - cluster_selection_method='eom', - prediction_data=True) + self.hdbscan_model = hdbscan_model or hdbscan.HDBSCAN( + min_cluster_size=self.min_topic_size, + metric="euclidean", + cluster_selection_method="eom", + prediction_data=True, + ) # Public attributes self.topics_ = None @@ -274,12 +290,14 @@ def __init__(self, else: logger.set_level("WARNING") - def fit(self, - documents: List[str], - embeddings: np.ndarray = None, - images: List[str] = None, - y: Union[List[int], np.ndarray] = None): - """ Fit the models (Bert, UMAP, and, HDBSCAN) on a collection of documents and generate topics + def fit( + self, + documents: List[str], + embeddings: np.ndarray = None, + images: List[str] = None, + y: Union[List[int], np.ndarray] = None, + ): + """Fit the models (Bert, UMAP, and, HDBSCAN) on a collection of documents and generate topics. Arguments: documents: A list of documents to fit on @@ -290,7 +308,6 @@ def fit(self, specific instance is specified. Examples: - ```python from bertopic import BERTopic from sklearn.datasets import fetch_20newsgroups @@ -315,16 +332,19 @@ def fit(self, topic_model = BERTopic().fit(docs, embeddings) ``` """ - self.fit_transform(documents=documents, embeddings=embeddings, y=y, images=images) + self.fit_transform( + documents=documents, embeddings=embeddings, y=y, images=images + ) return self - def fit_transform(self, - documents: List[str], - embeddings: np.ndarray = None, - images: List[str] = None, - y: Union[List[int], np.ndarray] = None) -> Tuple[List[int], - Union[np.ndarray, None]]: - """ Fit the models on a collection of documents, generate topics, + def fit_transform( + self, + documents: List[str], + embeddings: np.ndarray = None, + images: List[str] = None, + y: Union[List[int], np.ndarray] = None, + ) -> Tuple[List[int], Union[np.ndarray, None]]: + """Fit the models on a collection of documents, generate topics, and return the probabilities and topic per document. Arguments: @@ -344,7 +364,6 @@ def fit_transform(self, computation and may increase memory usage. Examples: - ```python from bertopic import BERTopic from sklearn.datasets import fetch_20newsgroups @@ -376,26 +395,28 @@ def fit_transform(self, check_embeddings_shape(embeddings, documents) doc_ids = range(len(documents)) if documents is not None else range(len(images)) - documents = pd.DataFrame({"Document": documents, - "ID": doc_ids, - "Topic": None, - "Image": images}) + documents = pd.DataFrame( + {"Document": documents, "ID": doc_ids, "Topic": None, "Image": images} + ) # Extract embeddings if embeddings is None: logger.info("Embedding - Transforming documents to embeddings.") - self.embedding_model = select_backend(self.embedding_model, - language=self.language, - verbose=self.verbose) - embeddings = self._extract_embeddings(documents.Document.values.tolist(), - images=images, - method="document", - verbose=self.verbose) + self.embedding_model = select_backend( + self.embedding_model, language=self.language, verbose=self.verbose + ) + embeddings = self._extract_embeddings( + documents.Document.values.tolist(), + images=images, + method="document", + verbose=self.verbose, + ) logger.info("Embedding - Completed \u2713") else: if self.embedding_model is not None: - self.embedding_model = select_backend(self.embedding_model, - language=self.language) + self.embedding_model = select_backend( + self.embedding_model, language=self.language + ) # Guided Topic Modeling if self.seed_topic_list is not None and self.embedding_model is not None: @@ -403,15 +424,21 @@ def fit_transform(self, # Zero-shot Topic Modeling if self._is_zeroshot(): - documents, embeddings, assigned_documents, assigned_embeddings = self._zeroshot_topic_modeling(documents, embeddings) + documents, embeddings, assigned_documents, assigned_embeddings = ( + self._zeroshot_topic_modeling(documents, embeddings) + ) if documents is None: - return self._combine_zeroshot_topics(documents, assigned_documents, assigned_embeddings) + return self._combine_zeroshot_topics( + documents, assigned_documents, assigned_embeddings + ) # Reduce dimensionality umap_embeddings = self._reduce_dimensionality(embeddings, y) # Cluster reduced embeddings - documents, probabilities = self._cluster_embeddings(umap_embeddings, documents, y=y) + documents, probabilities = self._cluster_embeddings( + umap_embeddings, documents, y=y + ) # Sort and Map Topic IDs by their frequency if not self.nr_topics: @@ -443,20 +470,26 @@ def fit_transform(self, self._save_representative_docs(documents) # Resulting output - self.probabilities_ = self._map_probabilities(probabilities, original_topics=True) + self.probabilities_ = self._map_probabilities( + probabilities, original_topics=True + ) predictions = documents.Topic.to_list() # Combine Zero-shot with outliers if self._is_zeroshot() and len(documents) != len(doc_ids): - predictions = self._combine_zeroshot_topics(documents, assigned_documents, assigned_embeddings) + predictions = self._combine_zeroshot_topics( + documents, assigned_documents, assigned_embeddings + ) return predictions, self.probabilities_ - def transform(self, - documents: Union[str, List[str]], - embeddings: np.ndarray = None, - images: List[str] = None) -> Tuple[List[int], np.ndarray]: - """ After having fit a model, use transform to predict new instances + def transform( + self, + documents: Union[str, List[str]], + embeddings: np.ndarray = None, + images: List[str] = None, + ) -> Tuple[List[int], np.ndarray]: + """After having fit a model, use transform to predict new instances. Arguments: documents: A single document or a list of documents to predict on @@ -472,7 +505,6 @@ def transform(self, decrease memory usage. Examples: - ```python from bertopic import BERTopic from sklearn.datasets import fetch_20newsgroups @@ -506,20 +538,23 @@ def transform(self, documents = [documents] if embeddings is None: - embeddings = self._extract_embeddings(documents, - images=images, - method="document", - verbose=self.verbose) + embeddings = self._extract_embeddings( + documents, images=images, method="document", verbose=self.verbose + ) # Check if an embedding model was found if embeddings is None: - raise ValueError("No embedding model was found to embed the documents." - "Make sure when loading in the model using BERTopic.load()" - "to also specify the embedding model.") + raise ValueError( + "No embedding model was found to embed the documents." + "Make sure when loading in the model using BERTopic.load()" + "to also specify the embedding model." + ) # Transform without hdbscan_model and umap_model using only cosine similarity elif type(self.hdbscan_model) == BaseCluster: - logger.info("Predicting topic assignments through cosine similarity of topic and document embeddings.") + logger.info( + "Predicting topic assignments through cosine similarity of topic and document embeddings." + ) sim_matrix = cosine_similarity(embeddings, np.array(self.topic_embeddings_)) predictions = np.argmax(sim_matrix, axis=1) - self._outliers @@ -537,12 +572,18 @@ def transform(self, # Extract predictions and probabilities if it is a HDBSCAN-like model logger.info("Clustering - Approximating new points with `hdbscan_model`") if is_supported_hdbscan(self.hdbscan_model): - predictions, probabilities = hdbscan_delegator(self.hdbscan_model, "approximate_predict", umap_embeddings) + predictions, probabilities = hdbscan_delegator( + self.hdbscan_model, "approximate_predict", umap_embeddings + ) # Calculate probabilities if self.calculate_probabilities: - logger.info("Probabilities - Start calculation of probabilities with HDBSCAN") - probabilities = hdbscan_delegator(self.hdbscan_model, "membership_vector", umap_embeddings) + logger.info( + "Probabilities - Start calculation of probabilities with HDBSCAN" + ) + probabilities = hdbscan_delegator( + self.hdbscan_model, "membership_vector", umap_embeddings + ) logger.info("Probabilities - Completed \u2713") else: predictions = self.hdbscan_model.predict(umap_embeddings) @@ -554,11 +595,13 @@ def transform(self, predictions = self._map_predictions(predictions) return predictions, probabilities - def partial_fit(self, - documents: List[str], - embeddings: np.ndarray = None, - y: Union[List[int], np.ndarray] = None): - """ Fit BERTopic on a subset of the data and perform online learning + def partial_fit( + self, + documents: List[str], + embeddings: np.ndarray = None, + y: Union[List[int], np.ndarray] = None, + ): + """Fit BERTopic on a subset of the data and perform online learning with batch-like data. Online topic modeling in BERTopic is performed by using dimensionality @@ -591,7 +634,6 @@ def partial_fit(self, specific instance is specified. Examples: - ```python from sklearn.datasets import fetch_20newsgroups from sklearn.cluster import MiniBatchKMeans @@ -619,30 +661,34 @@ def partial_fit(self, # Checks check_embeddings_shape(embeddings, documents) if not hasattr(self.hdbscan_model, "partial_fit"): - raise ValueError("In order to use `.partial_fit`, the cluster model should have " - "a `.partial_fit` function.") + raise ValueError( + "In order to use `.partial_fit`, the cluster model should have " + "a `.partial_fit` function." + ) # Prepare documents if isinstance(documents, str): documents = [documents] - documents = pd.DataFrame({"Document": documents, - "ID": range(len(documents)), - "Topic": None}) + documents = pd.DataFrame( + {"Document": documents, "ID": range(len(documents)), "Topic": None} + ) # Extract embeddings if embeddings is None: if self.topic_representations_ is None: - self.embedding_model = select_backend(self.embedding_model, - language=self.language, - verbose=self.verbose) - embeddings = self._extract_embeddings(documents.Document.values.tolist(), - method="document", - verbose=self.verbose) + self.embedding_model = select_backend( + self.embedding_model, language=self.language, verbose=self.verbose + ) + embeddings = self._extract_embeddings( + documents.Document.values.tolist(), + method="document", + verbose=self.verbose, + ) else: if self.embedding_model is not None and self.topic_representations_ is None: - self.embedding_model = select_backend(self.embedding_model, - language=self.language, - verbose=self.verbose) + self.embedding_model = select_backend( + self.embedding_model, language=self.language, verbose=self.verbose + ) # Reduce dimensionality if self.seed_topic_list is not None and self.embedding_model is not None: @@ -650,7 +696,9 @@ def partial_fit(self, umap_embeddings = self._reduce_dimensionality(embeddings, y, partial_fit=True) # Cluster reduced embeddings - documents, self.probabilities_ = self._cluster_embeddings(umap_embeddings, documents, partial_fit=True) + documents, self.probabilities_ = self._cluster_embeddings( + umap_embeddings, documents, partial_fit=True + ) topics = documents.Topic.to_list() # Map and find new topics @@ -658,7 +706,10 @@ def partial_fit(self, self.topic_mapper_ = TopicMapper(topics) mappings = self.topic_mapper_.get_mappings() new_topics = set(topics).difference(set(mappings.keys())) - new_topic_ids = {topic: max(mappings.values()) + index + 1 for index, topic in enumerate(new_topics)} + new_topic_ids = { + topic: max(mappings.values()) + index + 1 + for index, topic in enumerate(new_topics) + } self.topic_mapper_.add_new_topics(new_topic_ids) updated_mappings = self.topic_mapper_.get_mappings() updated_topics = [updated_mappings[topic] for topic in topics] @@ -666,35 +717,48 @@ def partial_fit(self, # Add missing topics (topics that were originally created but are now missing) if self.topic_representations_: - missing_topics = set(self.topic_representations_.keys()).difference(set(updated_topics)) + missing_topics = set(self.topic_representations_.keys()).difference( + set(updated_topics) + ) for missing_topic in missing_topics: documents.loc[len(documents), :] = [" ", len(documents), missing_topic] else: missing_topics = {} # Prepare documents - documents_per_topic = documents.sort_values("Topic").groupby(['Topic'], as_index=False) + documents_per_topic = documents.sort_values("Topic").groupby( + ["Topic"], as_index=False + ) updated_topics = documents_per_topic.first().Topic.astype(int) - documents_per_topic = documents_per_topic.agg({'Document': ' '.join}) + documents_per_topic = documents_per_topic.agg({"Document": " ".join}) # Update topic representations - self.c_tf_idf_, updated_words = self._c_tf_idf(documents_per_topic, partial_fit=True) - self.topic_representations_ = self._extract_words_per_topic(updated_words, documents, self.c_tf_idf_, calculate_aspects=False) + self.c_tf_idf_, updated_words = self._c_tf_idf( + documents_per_topic, partial_fit=True + ) + self.topic_representations_ = self._extract_words_per_topic( + updated_words, documents, self.c_tf_idf_, calculate_aspects=False + ) self._create_topic_vectors() - self.topic_labels_ = {key: f"{key}_" + "_".join([word[0] for word in values[:4]]) - for key, values in self.topic_representations_.items()} + self.topic_labels_ = { + key: f"{key}_" + "_".join([word[0] for word in values[:4]]) + for key, values in self.topic_representations_.items() + } # Update topic sizes if len(missing_topics) > 0: - documents = documents.iloc[:-len(missing_topics)] + documents = documents.iloc[: -len(missing_topics)] if self.topic_sizes_ is None: self._update_topic_size(documents) else: - sizes = documents.groupby(['Topic'], as_index=False).count() + sizes = documents.groupby(["Topic"], as_index=False).count() for _, row in sizes.iterrows(): topic = int(row.Topic) - if self.topic_sizes_.get(topic) is not None and topic not in missing_topics: + if ( + self.topic_sizes_.get(topic) is not None + and topic not in missing_topics + ): self.topic_sizes_[topic] += int(row.Document) elif self.topic_sizes_.get(topic) is None: self.topic_sizes_[topic] = int(row.Document) @@ -702,16 +766,17 @@ def partial_fit(self, return self - def topics_over_time(self, - docs: List[str], - timestamps: Union[List[str], - List[int]], - topics: List[int] = None, - nr_bins: int = None, - datetime_format: str = None, - evolution_tuning: bool = True, - global_tuning: bool = True) -> pd.DataFrame: - """ Create topics over time + def topics_over_time( + self, + docs: List[str], + timestamps: Union[List[str], List[int]], + topics: List[int] = None, + nr_bins: int = None, + datetime_format: str = None, + evolution_tuning: bool = True, + global_tuning: bool = True, + ) -> pd.DataFrame: + """Create topics over time. To create the topics over time, BERTopic needs to be already fitted once. From the fitted models, the c-TF-IDF representations are calculate at @@ -719,7 +784,7 @@ def topics_over_time(self, averaged with the global c-TF-IDF representations in order to fine-tune the local representations. - NOTE: + Note: Make sure to use a limited number of unique timestamps (<100) as the c-TF-IDF representation will be calculated at each single unique timestamp. Having a large number of unique timestamps can take some time to be calculated. @@ -755,7 +820,6 @@ def topics_over_time(self, at timestamp *t*. Examples: - The timestamps variable represents the timestamp of each document. If you have over 100 unique timestamps, it is advised to bin the timestamps as shown below: @@ -769,17 +833,21 @@ def topics_over_time(self, check_is_fitted(self) check_documents_type(docs) selected_topics = topics if topics else self.topics_ - documents = pd.DataFrame({"Document": docs, "Topic": selected_topics, "Timestamps": timestamps}) - global_c_tf_idf = normalize(self.c_tf_idf_, axis=1, norm='l1', copy=False) + documents = pd.DataFrame( + {"Document": docs, "Topic": selected_topics, "Timestamps": timestamps} + ) + global_c_tf_idf = normalize(self.c_tf_idf_, axis=1, norm="l1", copy=False) all_topics = sorted(list(documents.Topic.unique())) all_topics_indices = {topic: index for index, topic in enumerate(all_topics)} if isinstance(timestamps[0], str): infer_datetime_format = True if not datetime_format else False - documents["Timestamps"] = pd.to_datetime(documents["Timestamps"], - infer_datetime_format=infer_datetime_format, - format=datetime_format) + documents["Timestamps"] = pd.to_datetime( + documents["Timestamps"], + infer_datetime_format=infer_datetime_format, + format=datetime_format, + ) if nr_bins: documents["Bins"] = pd.cut(documents.Timestamps, bins=nr_bins) @@ -789,64 +857,93 @@ def topics_over_time(self, documents = documents.sort_values("Timestamps") timestamps = documents.Timestamps.unique() if len(timestamps) > 100: - logger.warning(f"There are more than 100 unique timestamps (i.e., {len(timestamps)}) " - "which significantly slows down the application. Consider setting `nr_bins` " - "to a value lower than 100 to speed up calculation. ") + logger.warning( + f"There are more than 100 unique timestamps (i.e., {len(timestamps)}) " + "which significantly slows down the application. Consider setting `nr_bins` " + "to a value lower than 100 to speed up calculation. " + ) # For each unique timestamp, create topic representations topics_over_time = [] for index, timestamp in tqdm(enumerate(timestamps), disable=not self.verbose): - # Calculate c-TF-IDF representation for a specific timestamp selection = documents.loc[documents.Timestamps == timestamp, :] - documents_per_topic = selection.groupby(['Topic'], as_index=False).agg({'Document': ' '.join, - "Timestamps": "count"}) + documents_per_topic = selection.groupby(["Topic"], as_index=False).agg( + {"Document": " ".join, "Timestamps": "count"} + ) c_tf_idf, words = self._c_tf_idf(documents_per_topic, fit=False) if global_tuning or evolution_tuning: - c_tf_idf = normalize(c_tf_idf, axis=1, norm='l1', copy=False) + c_tf_idf = normalize(c_tf_idf, axis=1, norm="l1", copy=False) # Fine-tune the c-TF-IDF matrix at timestamp t by averaging it with the c-TF-IDF # matrix at timestamp t-1 if evolution_tuning and index != 0: current_topics = sorted(list(documents_per_topic.Topic.values)) - overlapping_topics = sorted(list(set(previous_topics).intersection(set(current_topics)))) - - current_overlap_idx = [current_topics.index(topic) for topic in overlapping_topics] - previous_overlap_idx = [previous_topics.index(topic) for topic in overlapping_topics] - - c_tf_idf.tolil()[current_overlap_idx] = ((c_tf_idf[current_overlap_idx] + - previous_c_tf_idf[previous_overlap_idx]) / 2.0).tolil() + overlapping_topics = sorted( + list(set(previous_topics).intersection(set(current_topics))) # noqa: F821 + ) + + current_overlap_idx = [ + current_topics.index(topic) for topic in overlapping_topics + ] + previous_overlap_idx = [ + previous_topics.index(topic) # noqa: F821 + for topic in overlapping_topics + ] + + c_tf_idf.tolil()[current_overlap_idx] = ( + ( + c_tf_idf[current_overlap_idx] + + previous_c_tf_idf[previous_overlap_idx] # noqa: F821 + ) + / 2.0 + ).tolil() # Fine-tune the timestamp c-TF-IDF representation based on the global c-TF-IDF representation # by simply taking the average of the two if global_tuning: - selected_topics = [all_topics_indices[topic] for topic in documents_per_topic.Topic.values] + selected_topics = [ + all_topics_indices[topic] + for topic in documents_per_topic.Topic.values + ] c_tf_idf = (global_c_tf_idf[selected_topics] + c_tf_idf) / 2.0 # Extract the words per topic - words_per_topic = self._extract_words_per_topic(words, selection, c_tf_idf, calculate_aspects=False) - topic_frequency = pd.Series(documents_per_topic.Timestamps.values, - index=documents_per_topic.Topic).to_dict() + words_per_topic = self._extract_words_per_topic( + words, selection, c_tf_idf, calculate_aspects=False + ) + topic_frequency = pd.Series( + documents_per_topic.Timestamps.values, index=documents_per_topic.Topic + ).to_dict() # Fill dataframe with results - topics_at_timestamp = [(topic, - ", ".join([words[0] for words in values][:5]), - topic_frequency[topic], - timestamp) for topic, values in words_per_topic.items()] + topics_at_timestamp = [ + ( + topic, + ", ".join([words[0] for words in values][:5]), + topic_frequency[topic], + timestamp, + ) + for topic, values in words_per_topic.items() + ] topics_over_time.extend(topics_at_timestamp) if evolution_tuning: - previous_topics = sorted(list(documents_per_topic.Topic.values)) - previous_c_tf_idf = c_tf_idf.copy() + previous_topics = sorted(list(documents_per_topic.Topic.values)) # noqa: F841 + previous_c_tf_idf = c_tf_idf.copy() # noqa: F841 - return pd.DataFrame(topics_over_time, columns=["Topic", "Words", "Frequency", "Timestamp"]) + return pd.DataFrame( + topics_over_time, columns=["Topic", "Words", "Frequency", "Timestamp"] + ) - def topics_per_class(self, - docs: List[str], - classes: Union[List[int], List[str]], - global_tuning: bool = True) -> pd.DataFrame: - """ Create topics per class + def topics_per_class( + self, + docs: List[str], + classes: Union[List[int], List[str]], + global_tuning: bool = True, + ) -> pd.DataFrame: + """Create topics per class. To create the topics per class, BERTopic needs to be already fitted once. From the fitted models, the c-TF-IDF representations are calculated at @@ -855,7 +952,7 @@ def topics_per_class(self, local representations. This can be turned off if the pure representation is needed. - NOTE: + Note: Make sure to use a limited number of unique classes (<100) as the c-TF-IDF representation will be calculated at each single unique class. Having a large number of unique classes can take some time to be calculated. @@ -872,7 +969,6 @@ def topics_per_class(self, for each class. Examples: - ```python from bertopic import BERTopic topic_model = BERTopic() @@ -881,47 +977,64 @@ def topics_per_class(self, ``` """ check_documents_type(docs) - documents = pd.DataFrame({"Document": docs, "Topic": self.topics_, "Class": classes}) - global_c_tf_idf = normalize(self.c_tf_idf_, axis=1, norm='l1', copy=False) + documents = pd.DataFrame( + {"Document": docs, "Topic": self.topics_, "Class": classes} + ) + global_c_tf_idf = normalize(self.c_tf_idf_, axis=1, norm="l1", copy=False) # For each unique timestamp, create topic representations topics_per_class = [] for _, class_ in tqdm(enumerate(set(classes)), disable=not self.verbose): - # Calculate c-TF-IDF representation for a specific timestamp selection = documents.loc[documents.Class == class_, :] - documents_per_topic = selection.groupby(['Topic'], as_index=False).agg({'Document': ' '.join, - "Class": "count"}) + documents_per_topic = selection.groupby(["Topic"], as_index=False).agg( + {"Document": " ".join, "Class": "count"} + ) c_tf_idf, words = self._c_tf_idf(documents_per_topic, fit=False) # Fine-tune the timestamp c-TF-IDF representation based on the global c-TF-IDF representation # by simply taking the average of the two if global_tuning: - c_tf_idf = normalize(c_tf_idf, axis=1, norm='l1', copy=False) - c_tf_idf = (global_c_tf_idf[documents_per_topic.Topic.values + self._outliers] + c_tf_idf) / 2.0 + c_tf_idf = normalize(c_tf_idf, axis=1, norm="l1", copy=False) + c_tf_idf = ( + global_c_tf_idf[documents_per_topic.Topic.values + self._outliers] + + c_tf_idf + ) / 2.0 # Extract the words per topic - words_per_topic = self._extract_words_per_topic(words, selection, c_tf_idf, calculate_aspects=False) - topic_frequency = pd.Series(documents_per_topic.Class.values, - index=documents_per_topic.Topic).to_dict() + words_per_topic = self._extract_words_per_topic( + words, selection, c_tf_idf, calculate_aspects=False + ) + topic_frequency = pd.Series( + documents_per_topic.Class.values, index=documents_per_topic.Topic + ).to_dict() # Fill dataframe with results - topics_at_class = [(topic, - ", ".join([words[0] for words in values][:5]), - topic_frequency[topic], - class_) for topic, values in words_per_topic.items()] + topics_at_class = [ + ( + topic, + ", ".join([words[0] for words in values][:5]), + topic_frequency[topic], + class_, + ) + for topic, values in words_per_topic.items() + ] topics_per_class.extend(topics_at_class) - topics_per_class = pd.DataFrame(topics_per_class, columns=["Topic", "Words", "Frequency", "Class"]) + topics_per_class = pd.DataFrame( + topics_per_class, columns=["Topic", "Words", "Frequency", "Class"] + ) return topics_per_class - def hierarchical_topics(self, - docs: List[str], - use_ctfidf: bool = True, - linkage_function: Callable[[csr_matrix], np.ndarray] = None, - distance_function: Callable[[csr_matrix], csr_matrix] = None) -> pd.DataFrame: - """ Create a hierarchy of topics + def hierarchical_topics( + self, + docs: List[str], + use_ctfidf: bool = True, + linkage_function: Callable[[csr_matrix], np.ndarray] = None, + distance_function: Callable[[csr_matrix], csr_matrix] = None, + ) -> pd.DataFrame: + """Create a hierarchy of topics. To create this hierarchy, BERTopic needs to be already fitted once. Then, a hierarchy is calculated on the distance matrix of the c-TF-IDF or topic embeddings @@ -951,7 +1064,6 @@ def hierarchical_topics(self, represented by their parents and their children Examples: - ```python from bertopic import BERTopic topic_model = BERTopic() @@ -977,10 +1089,12 @@ def hierarchical_topics(self, distance_function = lambda x: 1 - cosine_similarity(x) if linkage_function is None: - linkage_function = lambda x: sch.linkage(x, 'ward', optimal_ordering=True) + linkage_function = lambda x: sch.linkage(x, "ward", optimal_ordering=True) # Calculate distance - embeddings = select_topic_representation(self.c_tf_idf_, self.topic_embeddings_, use_ctfidf)[0][self._outliers:] + embeddings = select_topic_representation( + self.c_tf_idf_, self.topic_embeddings_, use_ctfidf + )[0][self._outliers :] X = distance_function(embeddings) X = validate_distance_matrix(X, embeddings.shape[0]) @@ -993,11 +1107,15 @@ def hierarchical_topics(self, Z[:, 2] = get_unique_distances(Z[:, 2]) # Calculate basic bag-of-words to be iteratively merged later - documents = pd.DataFrame({"Document": docs, - "ID": range(len(docs)), - "Topic": self.topics_}) - documents_per_topic = documents.groupby(['Topic'], as_index=False).agg({'Document': ' '.join}) - documents_per_topic = documents_per_topic.loc[documents_per_topic.Topic != -1, :] + documents = pd.DataFrame( + {"Document": docs, "ID": range(len(docs)), "Topic": self.topics_} + ) + documents_per_topic = documents.groupby(["Topic"], as_index=False).agg( + {"Document": " ".join} + ) + documents_per_topic = documents_per_topic.loc[ + documents_per_topic.Topic != -1, : + ] clean_documents = self._preprocess_text(documents_per_topic.Document.values) # Scikit-Learn Deprecation: get_feature_names is deprecated in 1.0 @@ -1010,13 +1128,22 @@ def hierarchical_topics(self, bow = self.vectorizer_model.transform(clean_documents) # Extract clusters - hier_topics = pd.DataFrame(columns=["Parent_ID", "Parent_Name", "Topics", - "Child_Left_ID", "Child_Left_Name", - "Child_Right_ID", "Child_Right_Name"]) + hier_topics = pd.DataFrame( + columns=[ + "Parent_ID", + "Parent_Name", + "Topics", + "Child_Left_ID", + "Child_Left_Name", + "Child_Right_ID", + "Child_Right_Name", + ] + ) for index in tqdm(range(len(Z))): - # Find clustered documents - clusters = sch.fcluster(Z, t=Z[index][2], criterion='distance') - self._outliers + clusters = ( + sch.fcluster(Z, t=Z[index][2], criterion="distance") - self._outliers + ) nr_clusters = len(clusters) # Extract first topic we find to get the set of topics in a merged topic @@ -1027,14 +1154,18 @@ def hierarchical_topics(self, topic = int(val) else: val = Z[int(val - len(clusters))][0] - clustered_topics = [i for i, x in enumerate(clusters) if x == clusters[topic]] + clustered_topics = [ + i for i, x in enumerate(clusters) if x == clusters[topic] + ] # Group bow per cluster, calculate c-TF-IDF and extract words grouped = csr_matrix(bow[clustered_topics].sum(axis=0)) c_tf_idf = self.ctfidf_model.transform(grouped) selection = documents.loc[documents.Topic.isin(clustered_topics), :] selection.Topic = 0 - words_per_topic = self._extract_words_per_topic(words, selection, c_tf_idf, calculate_aspects=False) + words_per_topic = self._extract_words_per_topic( + words, selection, c_tf_idf, calculate_aspects=False + ) # Extract parent's name and ID parent_id = index + len(clusters) @@ -1059,29 +1190,37 @@ def hierarchical_topics(self, child_right_name = hier_topics.iloc[int(child_right_id)].Parent_Name # Save results - hier_topics.loc[len(hier_topics), :] = [parent_id, parent_name, - clustered_topics, - int(Z[index][0]), child_left_name, - int(Z[index][1]), child_right_name] + hier_topics.loc[len(hier_topics), :] = [ + parent_id, + parent_name, + clustered_topics, + int(Z[index][0]), + child_left_name, + int(Z[index][1]), + child_right_name, + ] hier_topics["Distance"] = Z[:, 2] hier_topics = hier_topics.sort_values("Parent_ID", ascending=False) - hier_topics[["Parent_ID", "Child_Left_ID", "Child_Right_ID"]] = hier_topics[["Parent_ID", "Child_Left_ID", "Child_Right_ID"]].astype(str) + hier_topics[["Parent_ID", "Child_Left_ID", "Child_Right_ID"]] = hier_topics[ + ["Parent_ID", "Child_Left_ID", "Child_Right_ID"] + ].astype(str) return hier_topics - def approximate_distribution(self, - documents: Union[str, List[str]], - window: int = 4, - stride: int = 1, - min_similarity: float = 0.1, - batch_size: int = 1000, - padding: bool = False, - use_embedding_model: bool = False, - calculate_tokens: bool = False, - separator: str = " ") -> Tuple[np.ndarray, - Union[List[np.ndarray], None]]: - """ A post-hoc approximation of topic distributions across documents. + def approximate_distribution( + self, + documents: Union[str, List[str]], + window: int = 4, + stride: int = 1, + min_similarity: float = 0.1, + batch_size: int = 1000, + padding: bool = False, + use_embedding_model: bool = False, + calculate_tokens: bool = False, + separator: str = " ", + ) -> Tuple[np.ndarray, Union[List[np.ndarray], None]]: + """A post-hoc approximation of topic distributions across documents. In order to perform this approximation, each document is split into tokens according to the provided tokenizer in the `CountVectorizer`. Then, a @@ -1139,7 +1278,6 @@ def approximate_distribution(self, and `m` the topics. Examples: - After fitting the model, the topic distributions can be calculated regardless of the clustering model and regardless of whether the documents were previously seen or not: @@ -1167,13 +1305,13 @@ def approximate_distribution(self, batch_size = len(documents) batches = 1 else: - batches = math.ceil(len(documents)/batch_size) + batches = math.ceil(len(documents) / batch_size) topic_distributions = [] topic_token_distributions = [] for i in tqdm(range(batches), disable=not self.verbose): - doc_set = documents[i*batch_size: (i+1) * batch_size] + doc_set = documents[i * batch_size : (i + 1) * batch_size] # Extract tokens analyzer = self.vectorizer_model.build_tokenizer() @@ -1189,17 +1327,23 @@ def approximate_distribution(self, token_sets = [tokenset] token_sets_ids = [list(range(len(tokenset)))] else: - # Extract tokensets using window and stride parameters stride_indices = list(range(len(tokenset)))[::stride] token_sets = [] token_sets_ids = [] for stride_index in stride_indices: - selected_tokens = tokenset[stride_index: stride_index+window] + selected_tokens = tokenset[stride_index : stride_index + window] if padding or len(selected_tokens) == window: token_sets.append(selected_tokens) - token_sets_ids.append(list(range(stride_index, stride_index+len(selected_tokens)))) + token_sets_ids.append( + list( + range( + stride_index, + stride_index + len(selected_tokens), + ) + ) + ) # Add empty tokens at the beginning and end of a document if padding: @@ -1207,8 +1351,10 @@ def approximate_distribution(self, padded_ids = [] t = math.ceil(window / stride) - 1 for i in range(math.ceil(window / stride) - 1): - padded.append(tokenset[:window - ((t-i) * stride)]) - padded_ids.append(list(range(0, window - ((t-i) * stride)))) + padded.append(tokenset[: window - ((t - i) * stride)]) + padded_ids.append( + list(range(0, window - ((t - i) * stride))) + ) token_sets = padded + token_sets token_sets_ids = padded_ids + token_sets_ids @@ -1221,14 +1367,20 @@ def approximate_distribution(self, # Calculate similarity between embeddings of token sets and the topics if use_embedding_model: - embeddings = self._extract_embeddings(all_sentences, method="document", verbose=True) - similarity = cosine_similarity(embeddings, self.topic_embeddings_[self._outliers:]) + embeddings = self._extract_embeddings( + all_sentences, method="document", verbose=True + ) + similarity = cosine_similarity( + embeddings, self.topic_embeddings_[self._outliers :] + ) # Calculate similarity between c-TF-IDF of token sets and the topics else: bow_doc = self.vectorizer_model.transform(all_sentences) c_tf_idf_doc = self.ctfidf_model.transform(bow_doc) - similarity = cosine_similarity(c_tf_idf_doc, self.c_tf_idf_[self._outliers:]) + similarity = cosine_similarity( + c_tf_idf_doc, self.c_tf_idf_[self._outliers :] + ) # Only keep similarities that exceed the minimum similarity[similarity < min_similarity] = 0 @@ -1239,7 +1391,7 @@ def approximate_distribution(self, topic_token_distribution = [] for index, token in enumerate(tokens): start = all_indices[index] - end = all_indices[index+1] + end = all_indices[index + 1] if start == end: end = end + 1 @@ -1247,7 +1399,9 @@ def approximate_distribution(self, # Assign topics to individual tokens token_id = [i for i in range(len(token))] token_val = {index: [] for index in token_id} - for sim, token_set in zip(similarity[start:end], all_token_sets_ids[start:end]): + for sim, token_set in zip( + similarity[start:end], all_token_sets_ids[start:end] + ): for token in token_set: if token in token_val: token_val[token].append(sim) @@ -1264,20 +1418,22 @@ def approximate_distribution(self, topic_token_distribution.append(np.array(matrix)) topic_distribution.append(np.add.reduce(matrix)) - topic_distribution = normalize(topic_distribution, norm='l1', axis=1) + topic_distribution = normalize(topic_distribution, norm="l1", axis=1) # Aggregate on a tokenset level indicated by the window and stride else: topic_distribution = [] - for index in range(len(all_indices)-1): + for index in range(len(all_indices) - 1): start = all_indices[index] - end = all_indices[index+1] + end = all_indices[index + 1] if start == end: end = end + 1 group = similarity[start:end].sum(axis=0) topic_distribution.append(group) - topic_distribution = normalize(np.array(topic_distribution), norm='l1', axis=1) + topic_distribution = normalize( + np.array(topic_distribution), norm="l1", axis=1 + ) topic_token_distribution = None # Combine results @@ -1291,22 +1447,24 @@ def approximate_distribution(self, return topic_distributions, topic_token_distributions - def find_topics(self, - search_term: str = None, - image: str = None, - top_n: int = 5) -> Tuple[List[int], List[float]]: - """ Find topics most similar to a search_term + def find_topics( + self, search_term: str = None, image: str = None, top_n: int = 5 + ) -> Tuple[List[int], List[float]]: + """Find topics most similar to a search_term. - Creates an embedding for search_term and compares that with + Creates an embedding for a search query and compares that with the topic embeddings. The most similar topics are returned along with their similarity values. + The query is specified using search_term for text queries or image for image queries. + The search_term can be of any size but since it is compared with the topic representation it is advised to keep it below 5 words. Arguments: search_term: the term you want to use to search for topics. + image: path to the image you want to use to search for topics. top_n: the number of topics to return Returns: @@ -1314,7 +1472,6 @@ def find_topics(self, similarity: the similarity scores from high to low Examples: - You can use the underlying embedding model to find topics that best represent the search term: @@ -1326,22 +1483,25 @@ def find_topics(self, search_term consists of a phrase or multiple words. """ if self.embedding_model is None: - raise Exception("This method can only be used if you did not use custom embeddings.") + raise Exception( + "This method can only be used if you did not use custom embeddings." + ) topic_list = list(self.topic_representations_.keys()) topic_list.sort() # Extract search_term embeddings and compare with topic embeddings if search_term is not None: - search_embedding = self._extract_embeddings([search_term], - method="word", - verbose=False).flatten() + search_embedding = self._extract_embeddings( + [search_term], method="word", verbose=False + ).flatten() elif image is not None: - search_embedding = self._extract_embeddings([None], - images=[image], - method="document", - verbose=False).flatten() - sims = cosine_similarity(search_embedding.reshape(1, -1), self.topic_embeddings_).flatten() + search_embedding = self._extract_embeddings( + [None], images=[image], method="document", verbose=False + ).flatten() + sims = cosine_similarity( + search_embedding.reshape(1, -1), self.topic_embeddings_ + ).flatten() # Extract topics most similar to search_term ids = np.argsort(sims)[-top_n:] @@ -1350,16 +1510,18 @@ def find_topics(self, return similar_topics, similarity - def update_topics(self, - docs: List[str], - images: List[str] = None, - topics: List[int] = None, - top_n_words: int = 10, - n_gram_range: Tuple[int, int] = None, - vectorizer_model: CountVectorizer = None, - ctfidf_model: ClassTfidfTransformer = None, - representation_model: BaseRepresentation = None): - """ Updates the topic representation by recalculating c-TF-IDF with the new + def update_topics( + self, + docs: List[str], + images: List[str] = None, + topics: List[int] = None, + top_n_words: int = 10, + n_gram_range: Tuple[int, int] = None, + vectorizer_model: CountVectorizer = None, + ctfidf_model: ClassTfidfTransformer = None, + representation_model: BaseRepresentation = None, + ): + """Updates the topic representation by recalculating c-TF-IDF with the new parameters as defined in this function. When you have trained a model and viewed the topics and the words that represent them, @@ -1386,7 +1548,6 @@ def update_topics(self, are supported. Examples: - In order to update the topic representation, you will need to first fit the topic model and extract topics from them. Based on these, you can update the representation: @@ -1415,48 +1576,64 @@ def update_topics(self, n_gram_range = self.n_gram_range if top_n_words > 100: - logger.warning("Note that extracting more than 100 words from a sparse " - "can slow down computation quite a bit.") + logger.warning( + "Note that extracting more than 100 words from a sparse " + "can slow down computation quite a bit." + ) self.top_n_words = top_n_words - self.vectorizer_model = vectorizer_model or CountVectorizer(ngram_range=n_gram_range) + self.vectorizer_model = vectorizer_model or CountVectorizer( + ngram_range=n_gram_range + ) self.ctfidf_model = ctfidf_model or ClassTfidfTransformer() self.representation_model = representation_model if topics is None: topics = self.topics_ else: - logger.warning("Using a custom list of topic assignments may lead to errors if " - "topic reduction techniques are used afterwards. Make sure that " - "manually assigning topics is the last step in the pipeline." - "Note that topic embeddings will also be created through weighted" - "c-TF-IDF embeddings instead of centroid embeddings.") + logger.warning( + "Using a custom list of topic assignments may lead to errors if " + "topic reduction techniques are used afterwards. Make sure that " + "manually assigning topics is the last step in the pipeline." + "Note that topic embeddings will also be created through weighted" + "c-TF-IDF embeddings instead of centroid embeddings." + ) self._outliers = 1 if -1 in set(topics) else 0 # Extract words - documents = pd.DataFrame({"Document": docs, "Topic": topics, "ID": range(len(docs)), "Image": images}) - documents_per_topic = documents.groupby(['Topic'], as_index=False).agg({'Document': ' '.join}) + documents = pd.DataFrame( + {"Document": docs, "Topic": topics, "ID": range(len(docs)), "Image": images} + ) + documents_per_topic = documents.groupby(["Topic"], as_index=False).agg( + {"Document": " ".join} + ) self.c_tf_idf_, words = self._c_tf_idf(documents_per_topic) self.topic_representations_ = self._extract_words_per_topic(words, documents) # Update topic vectors if set(topics) != self.topics_: - # Remove outlier topic embedding if all that has changed is the outlier class - same_position = all([True if old_topic == new_topic else False for old_topic, new_topic in zip(self.topics_, topics) if old_topic != -1]) + same_position = all( + [ + True if old_topic == new_topic else False + for old_topic, new_topic in zip(self.topics_, topics) + if old_topic != -1 + ] + ) if same_position and -1 not in topics and -1 in self.topics_: self.topic_embeddings_ = self.topic_embeddings_[1:] else: self._create_topic_vectors() # Update topic labels - self.topic_labels_ = {key: f"{key}_" + "_".join([word[0] for word in values[:4]]) - for key, values in - self.topic_representations_.items()} + self.topic_labels_ = { + key: f"{key}_" + "_".join([word[0] for word in values[:4]]) + for key, values in self.topic_representations_.items() + } self._update_topic_size(documents) def get_topics(self, full: bool = False) -> Mapping[str, Tuple[str, float]]: - """ Return topics with top n words and their c-TF-IDF score + """Return topics with top n words and their c-TF-IDF score. Arguments: full: If True, returns all different forms of topic representations @@ -1466,7 +1643,6 @@ def get_topics(self, full: bool = False) -> Mapping[str, Tuple[str, float]]: self.topic_representations_: The top n words per topic and the corresponding c-TF-IDF score Examples: - ```python all_topics = topic_model.get_topics() ``` @@ -1480,8 +1656,10 @@ def get_topics(self, full: bool = False) -> Mapping[str, Tuple[str, float]]: else: return self.topic_representations_ - def get_topic(self, topic: int, full: bool = False) -> Union[Mapping[str, Tuple[str, float]], bool]: - """ Return top n words for a specific topic and their c-TF-IDF scores + def get_topic( + self, topic: int, full: bool = False + ) -> Union[Mapping[str, Tuple[str, float]], bool]: + """Return top n words for a specific topic and their c-TF-IDF scores. Arguments: topic: A specific topic for which you want its representation @@ -1492,7 +1670,6 @@ def get_topic(self, topic: int, full: bool = False) -> Union[Mapping[str, Tuple[ The top n words for a specific word and its respective c-TF-IDF scores Examples: - ```python topic = topic_model.get_topic(12) ``` @@ -1501,7 +1678,10 @@ def get_topic(self, topic: int, full: bool = False) -> Union[Mapping[str, Tuple[ if topic in self.topic_representations_: if full: representations = {"Main": self.topic_representations_[topic]} - aspects = {aspect: representations[topic] for aspect, representations in self.topic_aspects_.items()} + aspects = { + aspect: representations[topic] + for aspect, representations in self.topic_aspects_.items() + } representations.update(aspects) return representations else: @@ -1510,7 +1690,7 @@ def get_topic(self, topic: int, full: bool = False) -> Union[Mapping[str, Tuple[ return False def get_topic_info(self, topic: int = None) -> pd.DataFrame: - """ Get information about each topic including its ID, frequency, and name. + """Get information about each topic including its ID, frequency, and name. Arguments: topic: A specific topic for which you want the frequency @@ -1519,41 +1699,58 @@ def get_topic_info(self, topic: int = None) -> pd.DataFrame: info: The information relating to either a single topic or all topics Examples: - ```python info_df = topic_model.get_topic_info() ``` """ check_is_fitted(self) - info = pd.DataFrame(self.topic_sizes_.items(), columns=["Topic", "Count"]).sort_values("Topic") + info = pd.DataFrame( + self.topic_sizes_.items(), columns=["Topic", "Count"] + ).sort_values("Topic") info["Name"] = info.Topic.map(self.topic_labels_) # Custom label if self.custom_labels_ is not None: if len(self.custom_labels_) == len(info): - labels = {topic - self._outliers: label for topic, label in enumerate(self.custom_labels_)} + labels = { + topic - self._outliers: label + for topic, label in enumerate(self.custom_labels_) + } info["CustomName"] = info["Topic"].map(labels) # Main Keywords - values = {topic: list(list(zip(*values))[0]) for topic, values in self.topic_representations_.items()} + values = { + topic: list(list(zip(*values))[0]) + for topic, values in self.topic_representations_.items() + } info["Representation"] = info["Topic"].map(values) # Extract all topic aspects if self.topic_aspects_: for aspect, values in self.topic_aspects_.items(): if isinstance(list(values.values())[-1], list): - if isinstance(list(values.values())[-1][0], tuple) or isinstance(list(values.values())[-1][0], list): - values = {topic: list(list(zip(*value))[0]) for topic, value in values.items()} + if isinstance(list(values.values())[-1][0], tuple) or isinstance( + list(values.values())[-1][0], list + ): + values = { + topic: list(list(zip(*value))[0]) + for topic, value in values.items() + } elif isinstance(list(values.values())[-1][0], str): - values = {topic: " ".join(value).strip() for topic, value in values.items()} + values = { + topic: " ".join(value).strip() + for topic, value in values.items() + } info[aspect] = info["Topic"].map(values) # Representative Docs / Images if self.representative_docs_ is not None: info["Representative_Docs"] = info["Topic"].map(self.representative_docs_) if self.representative_images_ is not None: - info["Representative_Images"] = info["Topic"].map(self.representative_images_) + info["Representative_Images"] = info["Topic"].map( + self.representative_images_ + ) # Select specific topic to return if topic is not None: @@ -1562,7 +1759,7 @@ def get_topic_info(self, topic: int = None) -> pd.DataFrame: return info.reset_index(drop=True) def get_topic_freq(self, topic: int = None) -> Union[pd.DataFrame, int]: - """ Return the size of topics (descending order) + """Return the size of topics (descending order). Arguments: topic: A specific topic for which you want the frequency @@ -1572,7 +1769,6 @@ def get_topic_freq(self, topic: int = None) -> Union[pd.DataFrame, int]: the frequencies of all topics Examples: - To extract the frequency of all topics: ```python @@ -1589,14 +1785,17 @@ def get_topic_freq(self, topic: int = None) -> Union[pd.DataFrame, int]: if isinstance(topic, int): return self.topic_sizes_[topic] else: - return pd.DataFrame(self.topic_sizes_.items(), columns=['Topic', 'Count']).sort_values("Count", - ascending=False) - - def get_document_info(self, - docs: List[str], - df: pd.DataFrame = None, - metadata: Mapping[str, Any] = None) -> pd.DataFrame: - """ Get information about the documents on which the topic was trained + return pd.DataFrame( + self.topic_sizes_.items(), columns=["Topic", "Count"] + ).sort_values("Count", ascending=False) + + def get_document_info( + self, + docs: List[str], + df: pd.DataFrame = None, + metadata: Mapping[str, Any] = None, + ) -> pd.DataFrame: + """Get information about the documents on which the topic was trained including the documents themselves, their respective topics, the name of each topic, the top n words of each topic, whether it is a representative document, and probability of the clustering if the cluster @@ -1659,7 +1858,10 @@ def get_document_info(self, document_info = pd.merge(document_info, topic_info, on="Topic", how="left") # Add top n words - top_n_words = {topic: " - ".join(list(zip(*self.get_topic(topic)))[0]) for topic in set(self.topics_)} + top_n_words = { + topic: " - ".join(list(zip(*self.get_topic(topic)))[0]) + for topic in set(self.topics_) + } document_info["Top_n_words"] = document_info.Topic.map(top_n_words) # Add flat probabilities @@ -1667,13 +1869,21 @@ def get_document_info(self, if len(self.probabilities_.shape) == 1: document_info["Probability"] = self.probabilities_ else: - document_info["Probability"] = [max(probs) if topic != -1 else 1-sum(probs) - for topic, probs in zip(self.topics_, self.probabilities_)] + document_info["Probability"] = [ + max(probs) if topic != -1 else 1 - sum(probs) + for topic, probs in zip(self.topics_, self.probabilities_) + ] # Add representative document labels - repr_docs = [repr_doc for repr_docs in self.representative_docs_.values() for repr_doc in repr_docs] + repr_docs = [ + repr_doc + for repr_docs in self.representative_docs_.values() + for repr_doc in repr_docs + ] document_info["Representative_document"] = False - document_info.loc[document_info.Document.isin(repr_docs), "Representative_document"] = True + document_info.loc[ + document_info.Document.isin(repr_docs), "Representative_document" + ] = True # Add custom meta data provided by the user if metadata is not None: @@ -1682,9 +1892,9 @@ def get_document_info(self, return document_info def get_representative_docs(self, topic: int = None) -> List[str]: - """ Extract the best representing documents per topic. + """Extract the best representing documents per topic. - NOTE: + Note: This does not extract all documents per topic as all documents are not saved within BERTopic. To get all documents, please run the following: @@ -1705,7 +1915,6 @@ def get_representative_docs(self, topic: int = None) -> List[str]: Representative documents of the chosen topic Examples: - To extract the representative docs of all topics: ```python @@ -1728,10 +1937,12 @@ def get_representative_docs(self, topic: int = None) -> List[str]: return self.representative_docs_ @staticmethod - def get_topic_tree(hier_topics: pd.DataFrame, - max_distance: float = None, - tight_layout: bool = False) -> str: - """ Extract the topic tree such that it can be printed + def get_topic_tree( + hier_topics: pd.DataFrame, + max_distance: float = None, + tight_layout: bool = False, + ) -> str: + """Extract the topic tree such that it can be printed. Arguments: hier_topics: A dataframe containing the structure of the topic tree. @@ -1757,7 +1968,6 @@ def get_topic_tree(hier_topics: pd.DataFrame, from `topic_model.get_topic`. In other words, they are the original un-grouped topics. Examples: - ```python # Train model from bertopic import BERTopic @@ -1777,22 +1987,33 @@ def get_topic_tree(hier_topics: pd.DataFrame, max_original_topic = hier_topics.Parent_ID.astype(int).min() - 1 # Extract mapping from ID to name - topic_to_name = dict(zip(hier_topics.Child_Left_ID, hier_topics.Child_Left_Name)) - topic_to_name.update(dict(zip(hier_topics.Child_Right_ID, hier_topics.Child_Right_Name))) + topic_to_name = dict( + zip(hier_topics.Child_Left_ID, hier_topics.Child_Left_Name) + ) + topic_to_name.update( + dict(zip(hier_topics.Child_Right_ID, hier_topics.Child_Right_Name)) + ) topic_to_name = {topic: name[:100] for topic, name in topic_to_name.items()} # Create tree - tree = {str(row[1].Parent_ID): [str(row[1].Child_Left_ID), str(row[1].Child_Right_ID)] - for row in hier_topics.iterrows()} + tree = { + str(row[1].Parent_ID): [ + str(row[1].Child_Left_ID), + str(row[1].Child_Right_ID), + ] + for row in hier_topics.iterrows() + } def get_tree(start, tree): - """ Based on: https://stackoverflow.com/a/51920869/10532563 """ + """Based on: https://stackoverflow.com/a/51920869/10532563.""" def _tree(to_print, start, parent, tree, grandpa=None, indent=""): - # Get distance between merged topics - distance = hier_topics.loc[(hier_topics.Child_Left_ID == parent) | - (hier_topics.Child_Right_ID == parent), "Distance"] + distance = hier_topics.loc[ + (hier_topics.Child_Left_ID == parent) + | (hier_topics.Child_Right_ID == parent), + "Distance", + ] distance = distance.values[0] if len(distance) > 0 else 10 if parent != start: @@ -1800,10 +2021,14 @@ def _tree(to_print, start, parent, tree, grandpa=None, indent=""): to_print += topic_to_name[parent] else: if int(parent) <= max_original_topic: - # Do not append topic ID if they are not merged if distance < max_distance: - to_print += "■──" + topic_to_name[parent] + f" ── Topic: {parent}" + "\n" + to_print += ( + "■──" + + topic_to_name[parent] + + f" ── Topic: {parent}" + + "\n" + ) else: to_print += "O \n" else: @@ -1814,11 +2039,15 @@ def _tree(to_print, start, parent, tree, grandpa=None, indent=""): for child in tree[parent][:-1]: to_print += indent + "├" + "─" - to_print = _tree(to_print, start, child, tree, parent, indent + "│" + " " * width) + to_print = _tree( + to_print, start, child, tree, parent, indent + "│" + " " * width + ) child = tree[parent][-1] to_print += indent + "└" + "─" - to_print = _tree(to_print, start, child, tree, parent, indent + " " * (width+1)) + to_print = _tree( + to_print, start, child, tree, parent, indent + " " * (width + 1) + ) return to_print @@ -1829,8 +2058,10 @@ def _tree(to_print, start, parent, tree, grandpa=None, indent=""): start = str(hier_topics.Parent_ID.astype(int).max()) return get_tree(start, tree) - def set_topic_labels(self, topic_labels: Union[List[str], Mapping[int, str]]) -> None: - """ Set custom topic labels in your fitted BERTopic model + def set_topic_labels( + self, topic_labels: Union[List[str], Mapping[int, str]] + ) -> None: + """Set custom topic labels in your fitted BERTopic model. Arguments: topic_labels: If a list of topic labels, it should contain the same number @@ -1842,7 +2073,6 @@ def set_topic_labels(self, topic_labels: Union[List[str], Mapping[int, str]]) -> in the dictionary. Examples: - First, we define our topic labels with `.generate_topic_labels` in which we can customize our topic labels: @@ -1874,28 +2104,40 @@ def set_topic_labels(self, topic_labels: Union[List[str], Mapping[int, str]]) -> if isinstance(topic_labels, dict): if self.custom_labels_ is not None: - original_labels = {topic: label for topic, label in zip(unique_topics, self.custom_labels_)} + original_labels = { + topic: label + for topic, label in zip(unique_topics, self.custom_labels_) + } else: info = self.get_topic_info() original_labels = dict(zip(info.Topic, info.Name)) - custom_labels = [topic_labels.get(topic) if topic_labels.get(topic) else original_labels[topic] for topic in unique_topics] + custom_labels = [ + topic_labels.get(topic) + if topic_labels.get(topic) + else original_labels[topic] + for topic in unique_topics + ] elif isinstance(topic_labels, list): if len(topic_labels) == len(unique_topics): custom_labels = topic_labels else: - raise ValueError("Make sure that `topic_labels` contains the same number " - "of labels as there are topics.") + raise ValueError( + "Make sure that `topic_labels` contains the same number " + "of labels as there are topics." + ) self.custom_labels_ = custom_labels - def generate_topic_labels(self, - nr_words: int = 3, - topic_prefix: bool = True, - word_length: int = None, - separator: str = "_", - aspect: str = None) -> List[str]: - """ Get labels for each topic in a user-defined format + def generate_topic_labels( + self, + nr_words: int = 3, + topic_prefix: bool = True, + word_length: int = None, + separator: str = "_", + aspect: str = None, + ) -> List[str]: + """Get labels for each topic in a user-defined format. Arguments: nr_words: Top `n` words per topic to use @@ -1917,7 +2159,6 @@ def generate_topic_labels(self, otherwise it is 0. Examples: - To create our custom topic labels, usage is rather straightforward: ```python @@ -1947,12 +2188,13 @@ def generate_topic_labels(self, return topic_labels - def merge_topics(self, - docs: List[str], - topics_to_merge: List[Union[Iterable[int], int]], - images: List[str] = None) -> None: - """ - Arguments: + def merge_topics( + self, + docs: List[str], + topics_to_merge: List[Union[Iterable[int], int]], + images: List[str] = None, + ) -> None: + """Arguments: docs: The documents you used when calling either `fit` or `fit_transform` topics_to_merge: Either a list of topics or a list of list of topics to merge. For example: @@ -1960,10 +2202,9 @@ def merge_topics(self, [[1, 2], [3, 4]] will merge topics 1 and 2, and separately merge topics 3 and 4. images: A list of paths to the images used when calling either - `fit` or `fit_transform` + `fit` or `fit_transform`. Examples: - If you want to merge topics 1, 2, and 3: ```python @@ -1982,7 +2223,14 @@ def merge_topics(self, """ check_is_fitted(self) check_documents_type(docs) - documents = pd.DataFrame({"Document": docs, "Topic": self.topics_, "Image": images, "ID": range(len(docs))}) + documents = pd.DataFrame( + { + "Document": docs, + "Topic": self.topics_, + "Image": images, + "ID": range(len(docs)), + } + ) mapping = {topic: topic for topic in set(self.topics_)} if isinstance(topics_to_merge[0], int): @@ -1993,17 +2241,22 @@ def merge_topics(self, for topic in topic_group: mapping[topic] = topic_group[0] else: - raise ValueError("Make sure that `topics_to_merge` is either" - "a list of topics or a list of list of topics.") + raise ValueError( + "Make sure that `topics_to_merge` is either" + "a list of topics or a list of list of topics." + ) # Track mappings and sizes of topics for merging topic embeddings mappings = defaultdict(list) for key, val in sorted(mapping.items()): mappings[val].append(key) - mappings = {topic_from: - {"topics_to": topics_to, - "topic_sizes": [self.topic_sizes_[topic] for topic in topics_to]} - for topic_from, topics_to in mappings.items()} + mappings = { + topic_from: { + "topics_to": topics_to, + "topic_sizes": [self.topic_sizes_[topic] for topic in topics_to], + } + for topic_from, topics_to in mappings.items() + } # Update topics documents.Topic = documents.Topic.map(mapping) @@ -2014,13 +2267,14 @@ def merge_topics(self, self._save_representative_docs(documents) self.probabilities_ = self._map_probabilities(self.probabilities_) - def reduce_topics(self, - docs: List[str], - nr_topics: Union[int, str] = 20, - images: List[str] = None, - use_ctfidf: bool = False, - ) -> None: - """ Reduce the number of topics to a fixed number of topics + def reduce_topics( + self, + docs: List[str], + nr_topics: Union[int, str] = 20, + images: List[str] = None, + use_ctfidf: bool = False, + ) -> None: + """Reduce the number of topics to a fixed number of topics or automatically. If nr_topics is an integer, then the number of topics is reduced @@ -2045,7 +2299,6 @@ def reduce_topics(self, probabilities_ : Assigns probabilities to their merged representations. Examples: - You can further reduce the topics by passing the documents with their topics and probabilities (if they were calculated): @@ -2064,7 +2317,14 @@ def reduce_topics(self, check_documents_type(docs) self.nr_topics = nr_topics - documents = pd.DataFrame({"Document": docs, "Topic": self.topics_, "Image": images, "ID": range(len(docs))}) + documents = pd.DataFrame( + { + "Document": docs, + "Topic": self.topics_, + "Image": images, + "ID": range(len(docs)), + } + ) # Reduce number of topics documents = self._reduce_topics(documents, use_ctfidf) @@ -2074,16 +2334,18 @@ def reduce_topics(self, return self - def reduce_outliers(self, - documents: List[str], - topics: List[int], - images: List[str] = None, - strategy: str = "distributions", - probabilities: np.ndarray = None, - threshold: float = 0, - embeddings: np.ndarray = None, - distributions_params: Mapping[str, Any] = {}) -> List[int]: - """ Reduce outliers by merging them with their nearest topic according + def reduce_outliers( + self, + documents: List[str], + topics: List[int], + images: List[str] = None, + strategy: str = "distributions", + probabilities: np.ndarray = None, + threshold: float = 0, + embeddings: np.ndarray = None, + distributions_params: Mapping[str, Any] = {}, + ) -> List[int]: + """Reduce outliers by merging them with their nearest topic according to one of several strategies. When using HDBSCAN, DBSCAN, or OPTICS, a number of outlier documents might be created @@ -2130,6 +2392,7 @@ def reduce_outliers(self, * "embeddings" Calculate the embeddings for outlier documents and find the best matching topic embedding. + probabilities: Probabilities generated by HDBSCAN for each document when using the strategy `"probabilities"`. threshold: The threshold for assigning topics to outlier documents. This value represents the minimum probability when `strategy="probabilities"`. For all other strategies, it represents the minimum similarity. @@ -2165,20 +2428,30 @@ def reduce_outliers(self, # Check correct use of parameters if strategy.lower() == "probabilities" and probabilities is None: - raise ValueError("Make sure to pass in `probabilities` in order to use the probabilities strategy") + raise ValueError( + "Make sure to pass in `probabilities` in order to use the probabilities strategy" + ) # Reduce outliers by extracting most likely topics through the topic-term probability matrix if strategy.lower() == "probabilities": - new_topics = [np.argmax(prob) if np.max(prob) >= threshold and topic == -1 else topic - for topic, prob in zip(topics, probabilities)] + new_topics = [ + np.argmax(prob) if np.max(prob) >= threshold and topic == -1 else topic + for topic, prob in zip(topics, probabilities) + ] # Reduce outliers by extracting most frequent topics through calculating of Topic Distributions elif strategy.lower() == "distributions": outlier_ids = [index for index, topic in enumerate(topics) if topic == -1] outlier_docs = [documents[index] for index in outlier_ids] - topic_distr, _ = self.approximate_distribution(outlier_docs, min_similarity=threshold, **distributions_params) - outlier_topics = iter([np.argmax(prob) if sum(prob) > 0 else -1 for prob in topic_distr]) - new_topics = [topic if topic != -1 else next(outlier_topics) for topic in topics] + topic_distr, _ = self.approximate_distribution( + outlier_docs, min_similarity=threshold, **distributions_params + ) + outlier_topics = iter( + [np.argmax(prob) if sum(prob) > 0 else -1 for prob in topic_distr] + ) + new_topics = [ + topic if topic != -1 else next(outlier_topics) for topic in topics + ] # Reduce outliers by finding the most similar c-TF-IDF representations elif strategy.lower() == "c-tf-idf": @@ -2188,18 +2461,26 @@ def reduce_outliers(self, # Calculate c-TF-IDF of outlier documents with all topics bow_doc = self.vectorizer_model.transform(outlier_docs) c_tf_idf_doc = self.ctfidf_model.transform(bow_doc) - similarity = cosine_similarity(c_tf_idf_doc, self.c_tf_idf_[self._outliers:]) + similarity = cosine_similarity( + c_tf_idf_doc, self.c_tf_idf_[self._outliers :] + ) # Update topics similarity[similarity < threshold] = 0 - outlier_topics = iter([np.argmax(sim) if sum(sim) > 0 else -1 for sim in similarity]) - new_topics = [topic if topic != -1 else next(outlier_topics) for topic in topics] + outlier_topics = iter( + [np.argmax(sim) if sum(sim) > 0 else -1 for sim in similarity] + ) + new_topics = [ + topic if topic != -1 else next(outlier_topics) for topic in topics + ] # Reduce outliers by finding the most similar topic embeddings elif strategy.lower() == "embeddings": if self.embedding_model is None and embeddings is None: - raise ValueError("To use this strategy, you will need to pass a model to `embedding_model`" - "when instantiating BERTopic.") + raise ValueError( + "To use this strategy, you will need to pass a model to `embedding_model`" + "when instantiating BERTopic." + ) outlier_ids = [index for index, topic in enumerate(topics) if topic == -1] if images is not None: outlier_docs = [images[index] for index in outlier_ids] @@ -2208,29 +2489,41 @@ def reduce_outliers(self, # Extract or calculate embeddings for outlier documents if embeddings is not None: - outlier_embeddings = np.array([embeddings[index] for index in outlier_ids]) + outlier_embeddings = np.array( + [embeddings[index] for index in outlier_ids] + ) elif images is not None: outlier_images = [images[index] for index in outlier_ids] - outlier_embeddings = self.embedding_model.embed_images(outlier_images, verbose=self.verbose) + outlier_embeddings = self.embedding_model.embed_images( + outlier_images, verbose=self.verbose + ) else: outlier_embeddings = self.embedding_model.embed_documents(outlier_docs) - similarity = cosine_similarity(outlier_embeddings, self.topic_embeddings_[self._outliers:]) + similarity = cosine_similarity( + outlier_embeddings, self.topic_embeddings_[self._outliers :] + ) # Update topics similarity[similarity < threshold] = 0 - outlier_topics = iter([np.argmax(sim) if sum(sim) > 0 else -1 for sim in similarity]) - new_topics = [topic if topic != -1 else next(outlier_topics) for topic in topics] + outlier_topics = iter( + [np.argmax(sim) if sum(sim) > 0 else -1 for sim in similarity] + ) + new_topics = [ + topic if topic != -1 else next(outlier_topics) for topic in topics + ] return new_topics - def visualize_topics(self, - topics: List[int] = None, - top_n_topics: int = None, - custom_labels: bool = False, - title: str = "Intertopic Distance Map", - width: int = 650, - height: int = 650) -> go.Figure: - """ Visualize topics, their sizes, and their corresponding words + def visualize_topics( + self, + topics: List[int] = None, + top_n_topics: int = None, + custom_labels: bool = False, + title: str = "Intertopic Distance Map", + width: int = 650, + height: int = 650, + ) -> go.Figure: + """Visualize topics, their sizes, and their corresponding words. This visualization is highly inspired by LDAvis, a great visualization technique typically reserved for LDA. @@ -2248,7 +2541,6 @@ def visualize_topics(self, height: The height of the figure. Examples: - To visualize the topics simply run: ```python @@ -2263,27 +2555,31 @@ def visualize_topics(self, ``` """ check_is_fitted(self) - return plotting.visualize_topics(self, - topics=topics, - top_n_topics=top_n_topics, - custom_labels=custom_labels, - title=title, - width=width, - height=height) - - def visualize_documents(self, - docs: List[str], - topics: List[int] = None, - embeddings: np.ndarray = None, - reduced_embeddings: np.ndarray = None, - sample: float = None, - hide_annotations: bool = False, - hide_document_hover: bool = False, - custom_labels: bool = False, - title: str = "Documents and Topics", - width: int = 1200, - height: int = 750) -> go.Figure: - """ Visualize documents and their topics in 2D + return plotting.visualize_topics( + self, + topics=topics, + top_n_topics=top_n_topics, + custom_labels=custom_labels, + title=title, + width=width, + height=height, + ) + + def visualize_documents( + self, + docs: List[str], + topics: List[int] = None, + embeddings: np.ndarray = None, + reduced_embeddings: np.ndarray = None, + sample: float = None, + hide_annotations: bool = False, + hide_document_hover: bool = False, + custom_labels: bool = False, + title: str = "Documents and Topics", + width: int = 1200, + height: int = 750, + ) -> go.Figure: + """Visualize documents and their topics in 2D. Arguments: topic_model: A fitted BERTopic instance. @@ -2308,7 +2604,6 @@ def visualize_documents(self, height: The height of the figure. Examples: - To visualize the topics simply run: ```python @@ -2354,37 +2649,43 @@ def visualize_documents(self, """ check_is_fitted(self) check_documents_type(docs) - return plotting.visualize_documents(self, - docs=docs, - topics=topics, - embeddings=embeddings, - reduced_embeddings=reduced_embeddings, - sample=sample, - hide_annotations=hide_annotations, - hide_document_hover=hide_document_hover, - custom_labels=custom_labels, - title=title, - width=width, - height=height) - - def visualize_document_datamap(self, - docs: List[str], - topics: List[int] = None, - embeddings: np.ndarray = None, - reduced_embeddings: np.ndarray = None, - custom_labels: Union[bool, str] = False, - title: str = "Documents and Topics", - sub_title: Union[str, None] = None, - width: int = 1200, - height: int = 1200, - **datamap_kwds): - """ Visualize documents and their topics in 2D as a static plot for publication using + return plotting.visualize_documents( + self, + docs=docs, + topics=topics, + embeddings=embeddings, + reduced_embeddings=reduced_embeddings, + sample=sample, + hide_annotations=hide_annotations, + hide_document_hover=hide_document_hover, + custom_labels=custom_labels, + title=title, + width=width, + height=height, + ) + + def visualize_document_datamap( + self, + docs: List[str], + topics: List[int] = None, + embeddings: np.ndarray = None, + reduced_embeddings: np.ndarray = None, + custom_labels: Union[bool, str] = False, + title: str = "Documents and Topics", + sub_title: Union[str, None] = None, + width: int = 1200, + height: int = 1200, + **datamap_kwds, + ): + """Visualize documents and their topics in 2D as a static plot for publication using DataMapPlot. This works best if there are between 5 and 60 topics. It is therefore best to use a sufficiently large `min_topic_size` or set `nr_topics` when building the model. Arguments: topic_model: A fitted BERTopic instance. docs: The documents you used when calling either `fit` or `fit_transform` + topics: A selection of topics to visualize. + Not to be confused with the topics that you get from .fit_transform. For example, if you want to visualize only topics 1 through 5: topics = [1, 2, 3, 4, 5]. Documents not in these topics will be shown as noise points. embeddings: The embeddings of all documents in `docs`. reduced_embeddings: The 2D reduced embeddings of all documents in `docs`. custom_labels: If bool, whether to use custom topic labels that were defined using @@ -2402,7 +2703,6 @@ def visualize_document_datamap(self, figure: A Matplotlib Figure object. Examples: - To visualize the topics simply run: ```python @@ -2445,33 +2745,38 @@ def visualize_document_datamap(self, """ check_is_fitted(self) check_documents_type(docs) - return plotting.visualize_document_datamap(self, - docs, - topics, - embeddings, - reduced_embeddings, - custom_labels, - title, - sub_title, - width, - height, - **datamap_kwds) - def visualize_hierarchical_documents(self, - docs: List[str], - hierarchical_topics: pd.DataFrame, - topics: List[int] = None, - embeddings: np.ndarray = None, - reduced_embeddings: np.ndarray = None, - sample: Union[float, int] = None, - hide_annotations: bool = False, - hide_document_hover: bool = True, - nr_levels: int = 10, - level_scale: str = 'linear', - custom_labels: bool = False, - title: str = "Hierarchical Documents and Topics", - width: int = 1200, - height: int = 750) -> go.Figure: - """ Visualize documents and their topics in 2D at different levels of hierarchy + return plotting.visualize_document_datamap( + self, + docs, + topics, + embeddings, + reduced_embeddings, + custom_labels, + title, + sub_title, + width, + height, + **datamap_kwds, + ) + + def visualize_hierarchical_documents( + self, + docs: List[str], + hierarchical_topics: pd.DataFrame, + topics: List[int] = None, + embeddings: np.ndarray = None, + reduced_embeddings: np.ndarray = None, + sample: Union[float, int] = None, + hide_annotations: bool = False, + hide_document_hover: bool = True, + nr_levels: int = 10, + level_scale: str = "linear", + custom_labels: bool = False, + title: str = "Hierarchical Documents and Topics", + width: int = 1200, + height: int = 750, + ) -> go.Figure: + """Visualize documents and their topics in 2D at different levels of hierarchy. Arguments: docs: The documents you used when calling either `fit` or `fit_transform` @@ -2492,7 +2797,7 @@ def visualize_hierarchical_documents(self, specific points. Helps to speed up generation of visualizations. nr_levels: The number of levels to be visualized in the hierarchy. First, the distances in `hierarchical_topics.Distance` are split in `nr_levels` lists of distances with - equal length. Then, for each list of distances, the merged topics, that have + equal length. Then, for each list of distances, the merged topics, that have a distance less or equal to the maximum distance of the selected list of distances, are selected. NOTE: To get all possible merged steps, make sure that `nr_levels` is equal to the length of `hierarchical_topics`. @@ -2510,7 +2815,6 @@ def visualize_hierarchical_documents(self, height: The height of the figure. Examples: - To visualize the topics simply run: ```python @@ -2557,30 +2861,34 @@ def visualize_hierarchical_documents(self, """ check_is_fitted(self) check_documents_type(docs) - return plotting.visualize_hierarchical_documents(self, - docs=docs, - hierarchical_topics=hierarchical_topics, - topics=topics, - embeddings=embeddings, - reduced_embeddings=reduced_embeddings, - sample=sample, - hide_annotations=hide_annotations, - hide_document_hover=hide_document_hover, - nr_levels=nr_levels, - level_scale=level_scale, - custom_labels=custom_labels, - title=title, - width=width, - height=height) - - def visualize_term_rank(self, - topics: List[int] = None, - log_scale: bool = False, - custom_labels: bool = False, - title: str = "Term score decline per Topic", - width: int = 800, - height: int = 500) -> go.Figure: - """ Visualize the ranks of all terms across all topics + return plotting.visualize_hierarchical_documents( + self, + docs=docs, + hierarchical_topics=hierarchical_topics, + topics=topics, + embeddings=embeddings, + reduced_embeddings=reduced_embeddings, + sample=sample, + hide_annotations=hide_annotations, + hide_document_hover=hide_document_hover, + nr_levels=nr_levels, + level_scale=level_scale, + custom_labels=custom_labels, + title=title, + width=width, + height=height, + ) + + def visualize_term_rank( + self, + topics: List[int] = None, + log_scale: bool = False, + custom_labels: bool = False, + title: str = "Term score decline per Topic", + width: int = 800, + height: int = 500, + ) -> go.Figure: + """Visualize the ranks of all terms across all topics. Each topic is represented by a set of words. These words, however, do not all equally represent the topic. This visualization shows @@ -2601,7 +2909,6 @@ def visualize_term_rank(self, fig: A plotly figure Examples: - To visualize the ranks of all words across all topics simply run: @@ -2625,24 +2932,28 @@ def visualize_term_rank(self, [here](https://wzbsocialsciencecenter.github.io/tm_corona/tm_analysis.html). """ check_is_fitted(self) - return plotting.visualize_term_rank(self, - topics=topics, - log_scale=log_scale, - custom_labels=custom_labels, - title=title, - width=width, - height=height) - - def visualize_topics_over_time(self, - topics_over_time: pd.DataFrame, - top_n_topics: int = None, - topics: List[int] = None, - normalize_frequency: bool = False, - custom_labels: bool = False, - title: str = "Topics over Time", - width: int = 1250, - height: int = 450) -> go.Figure: - """ Visualize topics over time + return plotting.visualize_term_rank( + self, + topics=topics, + log_scale=log_scale, + custom_labels=custom_labels, + title=title, + width=width, + height=height, + ) + + def visualize_topics_over_time( + self, + topics_over_time: pd.DataFrame, + top_n_topics: int = None, + topics: List[int] = None, + normalize_frequency: bool = False, + custom_labels: bool = False, + title: str = "Topics over Time", + width: int = 1250, + height: int = 450, + ) -> go.Figure: + """Visualize topics over time. Arguments: topics_over_time: The topics you would like to be visualized with the @@ -2660,7 +2971,6 @@ def visualize_topics_over_time(self, A plotly.graph_objects.Figure including all traces Examples: - To visualize the topics over time, simply run: ```python @@ -2676,26 +2986,30 @@ def visualize_topics_over_time(self, ``` """ check_is_fitted(self) - return plotting.visualize_topics_over_time(self, - topics_over_time=topics_over_time, - top_n_topics=top_n_topics, - topics=topics, - normalize_frequency=normalize_frequency, - custom_labels=custom_labels, - title=title, - width=width, - height=height) - - def visualize_topics_per_class(self, - topics_per_class: pd.DataFrame, - top_n_topics: int = 10, - topics: List[int] = None, - normalize_frequency: bool = False, - custom_labels: bool = False, - title: str = "Topics per Class", - width: int = 1250, - height: int = 900) -> go.Figure: - """ Visualize topics per class + return plotting.visualize_topics_over_time( + self, + topics_over_time=topics_over_time, + top_n_topics=top_n_topics, + topics=topics, + normalize_frequency=normalize_frequency, + custom_labels=custom_labels, + title=title, + width=width, + height=height, + ) + + def visualize_topics_per_class( + self, + topics_per_class: pd.DataFrame, + top_n_topics: int = 10, + topics: List[int] = None, + normalize_frequency: bool = False, + custom_labels: bool = False, + title: str = "Topics per Class", + width: int = 1250, + height: int = 900, + ) -> go.Figure: + """Visualize topics per class. Arguments: topics_per_class: The topics you would like to be visualized with the @@ -2713,7 +3027,6 @@ def visualize_topics_per_class(self, A plotly.graph_objects.Figure including all traces Examples: - To visualize the topics per class, simply run: ```python @@ -2729,24 +3042,28 @@ def visualize_topics_per_class(self, ``` """ check_is_fitted(self) - return plotting.visualize_topics_per_class(self, - topics_per_class=topics_per_class, - top_n_topics=top_n_topics, - topics=topics, - normalize_frequency=normalize_frequency, - custom_labels=custom_labels, - title=title, - width=width, - height=height) - - def visualize_distribution(self, - probabilities: np.ndarray, - min_probability: float = 0.015, - custom_labels: bool = False, - title: str = "Topic Probability Distribution", - width: int = 800, - height: int = 600) -> go.Figure: - """ Visualize the distribution of topic probabilities + return plotting.visualize_topics_per_class( + self, + topics_per_class=topics_per_class, + top_n_topics=top_n_topics, + topics=topics, + normalize_frequency=normalize_frequency, + custom_labels=custom_labels, + title=title, + width=width, + height=height, + ) + + def visualize_distribution( + self, + probabilities: np.ndarray, + min_probability: float = 0.015, + custom_labels: bool = False, + title: str = "Topic Probability Distribution", + width: int = 800, + height: int = 600, + ) -> go.Figure: + """Visualize the distribution of topic probabilities. Arguments: probabilities: An array of probability scores @@ -2759,7 +3076,6 @@ def visualize_distribution(self, height: The height of the figure. Examples: - Make sure to fit the model before and only input the probabilities of a single document: @@ -2775,19 +3091,23 @@ def visualize_distribution(self, ``` """ check_is_fitted(self) - return plotting.visualize_distribution(self, - probabilities=probabilities, - min_probability=min_probability, - custom_labels=custom_labels, - title=title, - width=width, - height=height) - - def visualize_approximate_distribution(self, - document: str, - topic_token_distribution: np.ndarray, - normalize: bool = False): - """ Visualize the topic distribution calculated by `.approximate_topic_distribution` + return plotting.visualize_distribution( + self, + probabilities=probabilities, + min_probability=min_probability, + custom_labels=custom_labels, + title=title, + width=width, + height=height, + ) + + def visualize_approximate_distribution( + self, + document: str, + topic_token_distribution: np.ndarray, + normalize: bool = False, + ): + """Visualize the topic distribution calculated by `.approximate_topic_distribution` on a token level. Thereby indicating the extent to which a certain word or phrase belongs to a specific topic. The assumption here is that a single word can belong to multiple similar topics and as such can give information about the broader set of topics within @@ -2807,7 +3127,6 @@ def visualize_approximate_distribution(self, for each token. Examples: - ```python # Calculate the topic distributions on a token level # Note that we need to have `calculate_token_level=True` @@ -2829,25 +3148,29 @@ def visualize_approximate_distribution(self, ``` """ check_is_fitted(self) - return plotting.visualize_approximate_distribution(self, - document=document, - topic_token_distribution=topic_token_distribution, - normalize=normalize) - - def visualize_hierarchy(self, - orientation: str = "left", - topics: List[int] = None, - top_n_topics: int = None, - use_ctfidf: bool = True, - custom_labels: bool = False, - title: str = "Hierarchical Clustering", - width: int = 1000, - height: int = 600, - hierarchical_topics: pd.DataFrame = None, - linkage_function: Callable[[csr_matrix], np.ndarray] = None, - distance_function: Callable[[csr_matrix], csr_matrix] = None, - color_threshold: int = 1) -> go.Figure: - """ Visualize a hierarchical structure of the topics + return plotting.visualize_approximate_distribution( + self, + document=document, + topic_token_distribution=topic_token_distribution, + normalize=normalize, + ) + + def visualize_hierarchy( + self, + orientation: str = "left", + topics: List[int] = None, + top_n_topics: int = None, + use_ctfidf: bool = True, + custom_labels: bool = False, + title: str = "Hierarchical Clustering", + width: int = 1000, + height: int = 600, + hierarchical_topics: pd.DataFrame = None, + linkage_function: Callable[[csr_matrix], np.ndarray] = None, + distance_function: Callable[[csr_matrix], csr_matrix] = None, + color_threshold: int = 1, + ) -> go.Figure: + """Visualize a hierarchical structure of the topics. A ward linkage function is used to perform the hierarchical clustering based on the cosine distance @@ -2888,7 +3211,6 @@ def visualize_hierarchy(self, fig: A plotly figure Examples: - To visualize the hierarchical structure of topics simply run: @@ -2917,31 +3239,34 @@ def visualize_hierarchy(self, style="width:1000px; height: 680px; border: 0px;""> """ check_is_fitted(self) - return plotting.visualize_hierarchy(self, - orientation=orientation, - topics=topics, - top_n_topics=top_n_topics, - use_ctfidf=use_ctfidf, - custom_labels=custom_labels, - title=title, - width=width, - height=height, - hierarchical_topics=hierarchical_topics, - linkage_function=linkage_function, - distance_function=distance_function, - color_threshold=color_threshold - ) - - def visualize_heatmap(self, - topics: List[int] = None, - top_n_topics: int = None, - n_clusters: int = None, - use_ctfidf: bool = False, - custom_labels: bool = False, - title: str = "Similarity Matrix", - width: int = 800, - height: int = 800) -> go.Figure: - """ Visualize a heatmap of the topic's similarity matrix + return plotting.visualize_hierarchy( + self, + orientation=orientation, + topics=topics, + top_n_topics=top_n_topics, + use_ctfidf=use_ctfidf, + custom_labels=custom_labels, + title=title, + width=width, + height=height, + hierarchical_topics=hierarchical_topics, + linkage_function=linkage_function, + distance_function=distance_function, + color_threshold=color_threshold, + ) + + def visualize_heatmap( + self, + topics: List[int] = None, + top_n_topics: int = None, + n_clusters: int = None, + use_ctfidf: bool = False, + custom_labels: bool = False, + title: str = "Similarity Matrix", + width: int = 800, + height: int = 800, + ) -> go.Figure: + """Visualize a heatmap of the topic's similarity matrix. Based on the cosine similarity matrix between c-TF-IDFs or semantic embeddings of the topics, a heatmap is created showing the similarity between topics. @@ -2963,7 +3288,6 @@ def visualize_heatmap(self, fig: A plotly figure Examples: - To visualize the similarity matrix of topics simply run: @@ -2979,26 +3303,30 @@ def visualize_heatmap(self, ``` """ check_is_fitted(self) - return plotting.visualize_heatmap(self, - topics=topics, - top_n_topics=top_n_topics, - n_clusters=n_clusters, - use_ctfidf=use_ctfidf, - custom_labels=custom_labels, - title=title, - width=width, - height=height) - - def visualize_barchart(self, - topics: List[int] = None, - top_n_topics: int = 8, - n_words: int = 5, - custom_labels: bool = False, - title: str = "Topic Word Scores", - width: int = 250, - height: int = 250, - autoscale: bool=False) -> go.Figure: - """ Visualize a barchart of selected topics + return plotting.visualize_heatmap( + self, + topics=topics, + top_n_topics=top_n_topics, + n_clusters=n_clusters, + use_ctfidf=use_ctfidf, + custom_labels=custom_labels, + title=title, + width=width, + height=height, + ) + + def visualize_barchart( + self, + topics: List[int] = None, + top_n_topics: int = 8, + n_words: int = 5, + custom_labels: bool = False, + title: str = "Topic Word Scores", + width: int = 250, + height: int = 250, + autoscale: bool = False, + ) -> go.Figure: + """Visualize a barchart of selected topics. Arguments: topics: A selection of topics to visualize. @@ -3015,7 +3343,6 @@ def visualize_barchart(self, fig: A plotly figure Examples: - To visualize the barchart of selected topics simply run: @@ -3031,22 +3358,26 @@ def visualize_barchart(self, ``` """ check_is_fitted(self) - return plotting.visualize_barchart(self, - topics=topics, - top_n_topics=top_n_topics, - n_words=n_words, - custom_labels=custom_labels, - title=title, - width=width, - height=height, - autoscale=autoscale) - - def save(self, - path, - serialization: Literal["safetensors", "pickle", "pytorch"] = "pickle", - save_embedding_model: Union[bool, str] = True, - save_ctfidf: bool = False): - """ Saves the model to the specified path or folder + return plotting.visualize_barchart( + self, + topics=topics, + top_n_topics=top_n_topics, + n_words=n_words, + custom_labels=custom_labels, + title=title, + width=width, + height=height, + autoscale=autoscale, + ) + + def save( + self, + path, + serialization: Literal["safetensors", "pickle", "pytorch"] = "pickle", + save_embedding_model: Union[bool, str] = True, + save_ctfidf: bool = False, + ): + """Saves the model to the specified path or folder. When saving the model, make sure to also keep track of the versions of dependencies and Python used. Loading and saving the model should @@ -3068,7 +3399,6 @@ def save(self, or `pytorch` Examples: - To save the model in an efficient and safe format (safetensors) with c-TF-IDF information: ```python @@ -3093,13 +3423,14 @@ def save(self, safetensors. """ if serialization == "pickle": - logger.warning("When you use `pickle` to save/load a BERTopic model," - "please make sure that the environments in which you save" - "and load the model are **exactly** the same. The version of BERTopic," - "its dependencies, and python need to remain the same.") - - with open(path, 'wb') as file: + logger.warning( + "When you use `pickle` to save/load a BERTopic model," + "please make sure that the environments in which you save" + "and load the model are **exactly** the same. The version of BERTopic," + "its dependencies, and python need to remain the same." + ) + with open(path, "wb") as file: # This prevents the vectorizer from being too large in size if `min_df` was # set to a value higher than 1 self.vectorizer_model.stop_words_ = None @@ -3112,36 +3443,51 @@ def save(self, else: joblib.dump(self, file) elif serialization == "safetensors" or serialization == "pytorch": - # Directory save_directory = Path(path) save_directory.mkdir(exist_ok=True, parents=True) # Check embedding model - if save_embedding_model and hasattr(self.embedding_model, '_hf_model') and not isinstance(save_embedding_model, str): + if ( + save_embedding_model + and hasattr(self.embedding_model, "_hf_model") + and not isinstance(save_embedding_model, str) + ): save_embedding_model = self.embedding_model._hf_model elif not save_embedding_model: - logger.warning("You are saving a BERTopic model without explicitly defining an embedding model." - "If you are using a sentence-transformers model or a HuggingFace model supported" - "by sentence-transformers, please save the model by using a pointer towards that model." - "For example, `save_embedding_model='sentence-transformers/all-mpnet-base-v2'`") + logger.warning( + "You are saving a BERTopic model without explicitly defining an embedding model." + "If you are using a sentence-transformers model or a HuggingFace model supported" + "by sentence-transformers, please save the model by using a pointer towards that model." + "For example, `save_embedding_model='sentence-transformers/all-mpnet-base-v2'`" + ) # Minimal - save_utils.save_hf(model=self, save_directory=save_directory, serialization=serialization) + save_utils.save_hf( + model=self, save_directory=save_directory, serialization=serialization + ) save_utils.save_topics(model=self, path=save_directory / "topics.json") save_utils.save_images(model=self, path=save_directory / "images") - save_utils.save_config(model=self, path=save_directory / 'config.json', embedding_model=save_embedding_model) + save_utils.save_config( + model=self, + path=save_directory / "config.json", + embedding_model=save_embedding_model, + ) # Additional if save_ctfidf: - save_utils.save_ctfidf(model=self, save_directory=save_directory, serialization=serialization) - save_utils.save_ctfidf_config(model=self, path=save_directory / 'ctfidf_config.json') + save_utils.save_ctfidf( + model=self, + save_directory=save_directory, + serialization=serialization, + ) + save_utils.save_ctfidf_config( + model=self, path=save_directory / "ctfidf_config.json" + ) @classmethod - def load(cls, - path: str, - embedding_model=None): - """ Loads the model from the specified path or directory + def load(cls, path: str, embedding_model=None): + """Loads the model from the specified path or directory. Arguments: path: Either load a BERTopic model from a file (`.pickle`) or a folder containing @@ -3150,7 +3496,6 @@ def load(cls, in the BERTopic model file or directory. Examples: - ```python BERTopic.load("model_dir") ``` @@ -3165,33 +3510,48 @@ def load(cls, # Load from Pickle if file_or_dir.is_file(): - with open(file_or_dir, 'rb') as file: + with open(file_or_dir, "rb") as file: if embedding_model: topic_model = joblib.load(file) - topic_model.embedding_model = select_backend(embedding_model, verbose=topic_model.verbose) + topic_model.embedding_model = select_backend( + embedding_model, verbose=topic_model.verbose + ) else: topic_model = joblib.load(file) return topic_model # Load from directory or HF if file_or_dir.is_dir(): - topics, params, tensors, ctfidf_tensors, ctfidf_config, images = save_utils.load_local_files(file_or_dir) + topics, params, tensors, ctfidf_tensors, ctfidf_config, images = ( + save_utils.load_local_files(file_or_dir) + ) elif "/" in str(path): - topics, params, tensors, ctfidf_tensors, ctfidf_config, images = save_utils.load_files_from_hf(path) + topics, params, tensors, ctfidf_tensors, ctfidf_config, images = ( + save_utils.load_files_from_hf(path) + ) else: raise ValueError("Make sure to either pass a valid directory or HF model.") - topic_model = _create_model_from_files(topics, params, tensors, ctfidf_tensors, ctfidf_config, images, - warn_no_backend=(embedding_model is None)) + topic_model = _create_model_from_files( + topics, + params, + tensors, + ctfidf_tensors, + ctfidf_config, + images, + warn_no_backend=(embedding_model is None), + ) # Replace embedding model if one is specifically chosen if embedding_model is not None: - topic_model.embedding_model = select_backend(embedding_model, verbose=topic_model.verbose) + topic_model.embedding_model = select_backend( + embedding_model, verbose=topic_model.verbose + ) return topic_model @classmethod - def merge_models(cls, models, min_similarity: float = .7, embedding_model=None): - """ Merge multiple pre-trained BERTopic models into a single model. + def merge_models(cls, models, min_similarity: float = 0.7, embedding_model=None): + """Merge multiple pre-trained BERTopic models into a single model. The models are merged as if they were all saved using pytorch or safetensors, so a minimal version without c-TF-IDF. @@ -3218,7 +3578,6 @@ def merge_models(cls, models, min_similarity: float = .7, embedding_model=None): loading a model from the HuggingFace Hub without c-TF-IDF Examples: - ```python from bertopic import BERTopic from sklearn.datasets import fetch_20newsgroups @@ -3238,12 +3597,13 @@ def merge_models(cls, models, min_similarity: float = .7, embedding_model=None): # Temporarily save model and push to HF with TemporaryDirectory() as tmpdir: - # Save model weights and config. all_topics, all_params, all_tensors = [], [], [] for index, model in enumerate(models): model.save(tmpdir, serialization="pytorch") - topics, params, tensors, _, _, _ = save_utils.load_local_files(Path(tmpdir)) + topics, params, tensors, _, _, _ = save_utils.load_local_files( + Path(tmpdir) + ) all_topics.append(topics) all_params.append(params) all_tensors.append(np.array(tensors["topic_embeddings"])) @@ -3261,7 +3621,13 @@ def merge_models(cls, models, min_similarity: float = .7, embedding_model=None): sims = np.max(sim_matrix, axis=1) # Extract new topics - new_topics = sorted([index - selected_topics["_outliers"] for index, sim in enumerate(sims) if sim < min_similarity]) + new_topics = sorted( + [ + index - selected_topics["_outliers"] + for index, sim in enumerate(sims) + if sim < min_similarity + ] + ) max_topic = max(set(merged_topics["topics"])) # Merge Topic Representations @@ -3270,8 +3636,12 @@ def merge_models(cls, models, min_similarity: float = .7, embedding_model=None): if new_topic != -1: max_topic += 1 new_topics_dict[new_topic] = max_topic - merged_topics["topic_representations"][str(max_topic)] = selected_topics["topic_representations"][str(new_topic)] - merged_topics["topic_labels"][str(max_topic)] = selected_topics["topic_labels"][str(new_topic)] + merged_topics["topic_representations"][str(max_topic)] = ( + selected_topics["topic_representations"][str(new_topic)] + ) + merged_topics["topic_labels"][str(max_topic)] = selected_topics[ + "topic_labels" + ][str(new_topic)] # Add new aspects if selected_topics["topic_aspects"]: @@ -3284,26 +3654,34 @@ def merge_models(cls, models, min_similarity: float = .7, embedding_model=None): # If the original model does not have topic aspects but the to be added model does if not merged_topics.get("topic_aspects"): - merged_topics["topic_aspects"] = selected_topics["topic_aspects"] + merged_topics["topic_aspects"] = selected_topics[ + "topic_aspects" + ] # If they both contain topic aspects, add to the existing set of aspects else: - for aspect, values in selected_topics["topic_aspects"].items(): - merged_topics["topic_aspects"][aspect][str(max_topic)] = values[str(new_topic)] + for aspect, values in selected_topics[ + "topic_aspects" + ].items(): + merged_topics["topic_aspects"][aspect][ + str(max_topic) + ] = values[str(new_topic)] # Add new embeddings new_tensors = tensors[new_topic + selected_topics["_outliers"]] merged_tensors = np.vstack([merged_tensors, new_tensors]) # Topic Mapper - merged_topics["topic_mapper"] = TopicMapper(list(range(-1, max_topic+1, 1))).mappings_ + merged_topics["topic_mapper"] = TopicMapper( + list(range(-1, max_topic + 1, 1)) + ).mappings_ # Find similar topics and re-assign those from the new models sims_idx = np.argmax(sim_matrix, axis=1) sims = np.max(sim_matrix, axis=1) to_merge = { - a - selected_topics["_outliers"]: - b - merged_topics["_outliers"] for a, (b, val) in enumerate(zip(sims_idx, sims)) + a - selected_topics["_outliers"]: b - merged_topics["_outliers"] + for a, (b, val) in enumerate(zip(sims_idx, sims)) if val >= min_similarity } to_merge.update(new_topics_dict) @@ -3314,29 +3692,42 @@ def merge_models(cls, models, min_similarity: float = .7, embedding_model=None): # Create a new model from the merged parameters merged_tensors = {"topic_embeddings": torch.from_numpy(merged_tensors)} - merged_model = _create_model_from_files(merged_topics, merged_params, merged_tensors, None, None, None, warn_no_backend=False) + merged_model = _create_model_from_files( + merged_topics, + merged_params, + merged_tensors, + None, + None, + None, + warn_no_backend=False, + ) merged_model.embedding_model = models[0].embedding_model # Replace embedding model if one is specifically chosen verbose = any([model.verbose for model in models]) - if embedding_model is not None and type(merged_model.embedding_model) == BaseEmbedder: - merged_model.embedding_model = select_backend(embedding_model, verbose=verbose) + if ( + embedding_model is not None + and type(merged_model.embedding_model) == BaseEmbedder + ): + merged_model.embedding_model = select_backend( + embedding_model, verbose=verbose + ) return merged_model def push_to_hf_hub( - self, - repo_id: str, - commit_message: str = 'Add BERTopic model', - token: str = None, - revision: str = None, - private: bool = False, - create_pr: bool = False, - model_card: bool = True, - serialization: str = "safetensors", - save_embedding_model: Union[str, bool] = True, - save_ctfidf: bool = False, - ): - """ Push your BERTopic model to a HuggingFace Hub + self, + repo_id: str, + commit_message: str = "Add BERTopic model", + token: str = None, + revision: str = None, + private: bool = False, + create_pr: bool = False, + model_card: bool = True, + serialization: str = "safetensors", + save_embedding_model: Union[str, bool] = True, + save_ctfidf: bool = False, + ): + """Push your BERTopic model to a HuggingFace Hub. Whenever you want to upload files to the Hub, you need to log in to your HuggingFace account: @@ -3371,7 +3762,6 @@ def push_to_hf_hub( Examples: - ```python topic_model.push_to_hf_hub( repo_id="ArXiv", @@ -3380,13 +3770,22 @@ def push_to_hf_hub( ) ``` """ - return save_utils.push_to_hf_hub(model=self, repo_id=repo_id, commit_message=commit_message, - token=token, revision=revision, private=private, create_pr=create_pr, - model_card=model_card, serialization=serialization, - save_embedding_model=save_embedding_model, save_ctfidf=save_ctfidf) + return save_utils.push_to_hf_hub( + model=self, + repo_id=repo_id, + commit_message=commit_message, + token=token, + revision=revision, + private=private, + create_pr=create_pr, + model_card=model_card, + serialization=serialization, + save_embedding_model=save_embedding_model, + save_ctfidf=save_ctfidf, + ) def get_params(self, deep: bool = False) -> Mapping[str, Any]: - """ Get parameters for this estimator. + """Get parameters for this estimator. Adapted from: https://github.com/scikit-learn/scikit-learn/blob/b3ea3ed6a/sklearn/base.py#L178 @@ -3402,19 +3801,21 @@ def get_params(self, deep: bool = False) -> Mapping[str, Any]: out = dict() for key in self._get_param_names(): value = getattr(self, key) - if deep and hasattr(value, 'get_params'): + if deep and hasattr(value, "get_params"): deep_items = value.get_params().items() - out.update((key + '__' + k, val) for k, val in deep_items) + out.update((key + "__" + k, val) for k, val in deep_items) out[key] = value return out - def _extract_embeddings(self, - documents: Union[List[str], str], - images: List[str] = None, - method: str = "document", - verbose: bool = None) -> np.ndarray: - """ Extract sentence/document embeddings through pre-trained embeddings - For an overview of pre-trained models: https://www.sbert.net/docs/pretrained_models.html + def _extract_embeddings( + self, + documents: Union[List[str], str], + images: List[str] = None, + method: str = "document", + verbose: bool = None, + ) -> np.ndarray: + """Extract sentence/document embeddings through pre-trained embeddings + For an overview of pre-trained models: https://www.sbert.net/docs/pretrained_models.html. Arguments: documents: Dataframe with documents and their corresponding IDs @@ -3429,50 +3830,66 @@ def _extract_embeddings(self, documents = [documents] if images is not None and hasattr(self.embedding_model, "embed_images"): - embeddings = self.embedding_model.embed(documents=documents, images=images, verbose=verbose) + embeddings = self.embedding_model.embed( + documents=documents, images=images, verbose=verbose + ) elif method == "word": - embeddings = self.embedding_model.embed_words(words=documents, verbose=verbose) + embeddings = self.embedding_model.embed_words( + words=documents, verbose=verbose + ) elif method == "document": - embeddings = self.embedding_model.embed_documents(documents, verbose=verbose) + embeddings = self.embedding_model.embed_documents( + documents, verbose=verbose + ) elif documents[0] is None and images is None: - raise ValueError("Make sure to use an embedding model that can either embed documents" - "or images depending on which you want to embed.") + raise ValueError( + "Make sure to use an embedding model that can either embed documents" + "or images depending on which you want to embed." + ) else: - raise ValueError("Wrong method for extracting document/word embeddings. " - "Either choose 'word' or 'document' as the method. ") + raise ValueError( + "Wrong method for extracting document/word embeddings. " + "Either choose 'word' or 'document' as the method. " + ) return embeddings - def _images_to_text(self, documents: pd.DataFrame, embeddings: np.ndarray) -> pd.DataFrame: - """ Convert images to text """ + def _images_to_text( + self, documents: pd.DataFrame, embeddings: np.ndarray + ) -> pd.DataFrame: + """Convert images to text.""" logger.info("Images - Converting images to text. This might take a while.") if isinstance(self.representation_model, dict): for tuner in self.representation_model.values(): - if getattr(tuner, 'image_to_text_model', False): + if getattr(tuner, "image_to_text_model", False): documents = tuner.image_to_text(documents, embeddings) elif isinstance(self.representation_model, list): for tuner in self.representation_model: - if getattr(tuner, 'image_to_text_model', False): + if getattr(tuner, "image_to_text_model", False): documents = tuner.image_to_text(documents, embeddings) elif isinstance(self.representation_model, BaseRepresentation): - if getattr(self.representation_model, 'image_to_text_model', False): - documents = self.representation_model.image_to_text(documents, embeddings) + if getattr(self.representation_model, "image_to_text_model", False): + documents = self.representation_model.image_to_text( + documents, embeddings + ) logger.info("Images - Completed \u2713") return documents def _map_predictions(self, predictions: List[int]) -> List[int]: - """ Map predictions to the correct topics if topics were reduced """ + """Map predictions to the correct topics if topics were reduced.""" mappings = self.topic_mapper_.get_mappings(original_topics=True) - mapped_predictions = [mappings[prediction] - if prediction in mappings - else -1 - for prediction in predictions] + mapped_predictions = [ + mappings[prediction] if prediction in mappings else -1 + for prediction in predictions + ] return mapped_predictions - def _reduce_dimensionality(self, - embeddings: Union[np.ndarray, csr_matrix], - y: Union[List[int], np.ndarray] = None, - partial_fit: bool = False) -> np.ndarray: - """ Reduce dimensionality of embeddings using UMAP and train a UMAP model + def _reduce_dimensionality( + self, + embeddings: Union[np.ndarray, csr_matrix], + y: Union[List[int], np.ndarray] = None, + partial_fit: bool = False, + ) -> np.ndarray: + """Reduce dimensionality of embeddings using UMAP and train a UMAP model. Arguments: embeddings: The extracted embeddings using the sentence transformer module. @@ -3497,25 +3914,26 @@ def _reduce_dimensionality(self, y = np.array(y) if y is not None else None self.umap_model.fit(embeddings, y=y) except TypeError: - self.umap_model.fit(embeddings) umap_embeddings = self.umap_model.transform(embeddings) logger.info("Dimensionality - Completed \u2713") return np.nan_to_num(umap_embeddings) - def _cluster_embeddings(self, - umap_embeddings: np.ndarray, - documents: pd.DataFrame, - partial_fit: bool = False, - y: np.ndarray = None) -> Tuple[pd.DataFrame, - np.ndarray]: - """ Cluster UMAP embeddings with HDBSCAN + def _cluster_embeddings( + self, + umap_embeddings: np.ndarray, + documents: pd.DataFrame, + partial_fit: bool = False, + y: np.ndarray = None, + ) -> Tuple[pd.DataFrame, np.ndarray]: + """Cluster UMAP embeddings with HDBSCAN. Arguments: umap_embeddings: The reduced sentence embeddings with UMAP documents: Dataframe with documents and their corresponding IDs partial_fit: Whether to run `partial_fit` for online learning + y: Array of topics to use Returns: documents: Updated dataframe with documents and their corresponding IDs @@ -3526,7 +3944,7 @@ def _cluster_embeddings(self, if partial_fit: self.hdbscan_model = self.hdbscan_model.partial_fit(umap_embeddings) labels = self.hdbscan_model.labels_ - documents['Topic'] = labels + documents["Topic"] = labels self.topics_ = labels else: try: @@ -3538,7 +3956,7 @@ def _cluster_embeddings(self, labels = self.hdbscan_model.labels_ except AttributeError: labels = y - documents['Topic'] = labels + documents["Topic"] = labels self._update_topic_size(documents) # Some algorithms have outlier labels (-1) that can be tricky to work @@ -3551,17 +3969,22 @@ def _cluster_embeddings(self, if hasattr(self.hdbscan_model, "probabilities_"): probabilities = self.hdbscan_model.probabilities_ - if self.calculate_probabilities and is_supported_hdbscan(self.hdbscan_model): - probabilities = hdbscan_delegator(self.hdbscan_model, "all_points_membership_vectors") + if self.calculate_probabilities and is_supported_hdbscan( + self.hdbscan_model + ): + probabilities = hdbscan_delegator( + self.hdbscan_model, "all_points_membership_vectors" + ) if not partial_fit: self.topic_mapper_ = TopicMapper(self.topics_) logger.info("Cluster - Completed \u2713") return documents, probabilities - def _zeroshot_topic_modeling(self, documents: pd.DataFrame, embeddings: np.ndarray) -> Tuple[pd.DataFrame, np.array, - pd.DataFrame, np.array]: - """ Find documents that could be assigned to either one of the topics in self.zeroshot_topic_list + def _zeroshot_topic_modeling( + self, documents: pd.DataFrame, embeddings: np.ndarray + ) -> Tuple[pd.DataFrame, np.array, pd.DataFrame, np.array]: + """Find documents that could be assigned to either one of the topics in self.zeroshot_topic_list. We transform the topics in `self.zeroshot_topic_list` to embeddings and compare them through cosine similarity with the document embeddings. @@ -3575,14 +3998,24 @@ def _zeroshot_topic_modeling(self, documents: pd.DataFrame, embeddings: np.ndarr documents: The leftover documents that were not assigned to any topic embeddings: The leftover embeddings that were not assigned to any topic """ - logger.info("Zeroshot Step 1 - Finding documents that could be assigned to either one of the zero-shot topics") + logger.info( + "Zeroshot Step 1 - Finding documents that could be assigned to either one of the zero-shot topics" + ) # Similarity between document and zero-shot topic embeddings zeroshot_embeddings = self._extract_embeddings(self.zeroshot_topic_list) cosine_similarities = cosine_similarity(embeddings, zeroshot_embeddings) assignment = np.argmax(cosine_similarities, 1) assignment_vals = np.max(cosine_similarities, 1) - assigned_ids = [index for index, value in enumerate(assignment_vals) if value >= self.zeroshot_min_similarity] - non_assigned_ids = [index for index, value in enumerate(assignment_vals) if value < self.zeroshot_min_similarity] + assigned_ids = [ + index + for index, value in enumerate(assignment_vals) + if value >= self.zeroshot_min_similarity + ] + non_assigned_ids = [ + index + for index, value in enumerate(assignment_vals) + if value < self.zeroshot_min_similarity + ] # Assign topics assigned_documents = documents.iloc[assigned_ids] @@ -3604,21 +4037,27 @@ def _zeroshot_topic_modeling(self, documents: pd.DataFrame, embeddings: np.ndarr return documents, embeddings, assigned_documents, assigned_embeddings def _is_zeroshot(self): - """ Check whether zero-shot topic modeling is possible + """Check whether zero-shot topic modeling is possible. * There should be a cluster model used * Embedding model is necessary to convert zero-shot topics to embeddings * Zero-shot topics should be defined """ - if self.zeroshot_topic_list is not None and self.embedding_model is not None and type(self.hdbscan_model) != BaseCluster: + if ( + self.zeroshot_topic_list is not None + and self.embedding_model is not None + and type(self.hdbscan_model) != BaseCluster + ): return True return False - def _combine_zeroshot_topics(self, - documents: pd.DataFrame, - assigned_documents: pd.DataFrame, - embeddings: np.ndarray) -> Union[Tuple[np.ndarray, np.ndarray], np.ndarray]: - """ Combine the zero-shot topics with the clustered topics + def _combine_zeroshot_topics( + self, + documents: pd.DataFrame, + assigned_documents: pd.DataFrame, + embeddings: np.ndarray, + ) -> Union[Tuple[np.ndarray, np.ndarray], np.ndarray]: + """Combine the zero-shot topics with the clustered topics. There are three cases considered: * Only zero-shot topics were found which will only return the zero-shot topic model @@ -3637,7 +4076,9 @@ def _combine_zeroshot_topics(self, topics: The topics for each document probabilities: The probabilities for each document """ - logger.info("Zeroshot Step 2 - Clustering documents that were not found in the zero-shot model...") + logger.info( + "Zeroshot Step 2 - Clustering documents that were not found in the zero-shot model..." + ) # Fit BERTopic without actually performing any clustering docs = assigned_documents.Document.tolist() @@ -3645,19 +4086,21 @@ def _combine_zeroshot_topics(self, empty_dimensionality_model = BaseDimensionalityReduction() empty_cluster_model = BaseCluster() zeroshot_model = BERTopic( - n_gram_range=self.n_gram_range, - low_memory=self.low_memory, - calculate_probabilities=self.calculate_probabilities, - embedding_model=self.embedding_model, - umap_model=empty_dimensionality_model, - hdbscan_model=empty_cluster_model, - vectorizer_model=self.vectorizer_model, - ctfidf_model=self.ctfidf_model, - representation_model=self.representation_model, - verbose=self.verbose + n_gram_range=self.n_gram_range, + low_memory=self.low_memory, + calculate_probabilities=self.calculate_probabilities, + embedding_model=self.embedding_model, + umap_model=empty_dimensionality_model, + hdbscan_model=empty_cluster_model, + vectorizer_model=self.vectorizer_model, + ctfidf_model=self.ctfidf_model, + representation_model=self.representation_model, + verbose=self.verbose, ).fit(docs, embeddings=embeddings, y=y) logger.info("Zeroshot Step 2 - Completed \u2713") - logger.info("Zeroshot Step 3 - Combining clustered topics with the zeroshot model") + logger.info( + "Zeroshot Step 3 - Combining clustered topics with the zeroshot model" + ) # Update model self.umap_model = BaseDimensionalityReduction() @@ -3666,15 +4109,23 @@ def _combine_zeroshot_topics(self, # Update topic label assigned_topics = assigned_documents.groupby("Topic").first().reset_index() indices, topics = assigned_topics.ID.values, assigned_topics.Topic.values - labels = [zeroshot_model.topic_labels_[zeroshot_model.topics_[index]] for index in indices] - labels = {label: self.zeroshot_topic_list[topic] for label, topic in zip(labels, topics)} + labels = [ + zeroshot_model.topic_labels_[zeroshot_model.topics_[index]] + for index in indices + ] + labels = { + label: self.zeroshot_topic_list[topic] + for label, topic in zip(labels, topics) + } # If only zero-shot matches were found and clustering was not performed if documents is None: for topic in range(len(set(y))): if zeroshot_model.topic_labels_.get(topic): if labels.get(zeroshot_model.topic_labels_[topic]): - zeroshot_model.topic_labels_[topic] = labels[zeroshot_model.topic_labels_[topic]] + zeroshot_model.topic_labels_[topic] = labels[ + zeroshot_model.topic_labels_[topic] + ] self.__dict__.clear() self.__dict__.update(zeroshot_model.__dict__) return self.topics_, self.probabilities_ @@ -3688,11 +4139,15 @@ def _combine_zeroshot_topics(self, if labels.get(merged_model.topic_labels_[topic]): label = labels[merged_model.topic_labels_[topic]] merged_model.topic_labels_[topic] = label - merged_model.representative_docs_[topic] = zeroshot_model.representative_docs_[topic] + merged_model.representative_docs_[topic] = ( + zeroshot_model.representative_docs_[topic] + ) # Add representative docs of the clustered model for topic in set(self.topics_): - merged_model.representative_docs_[topic + self._outliers + len(set(y))] = self.representative_docs_[topic] + merged_model.representative_docs_[topic + self._outliers + len(set(y))] = ( + self.representative_docs_[topic] + ) if self._outliers and merged_model.topic_sizes_.get(-1): merged_model.topic_sizes_[len(set(y))] = merged_model.topic_sizes_[-1] @@ -3701,20 +4156,29 @@ def _combine_zeroshot_topics(self, # Update topic assignment by finding the documents with the # correct updated topics zeroshot_indices = list(assigned_documents.Old_ID.values) - zeroshot_topics = [self.zeroshot_topic_list[topic] for topic in assigned_documents.Topic.values] + zeroshot_topics = [ + self.zeroshot_topic_list[topic] for topic in assigned_documents.Topic.values + ] cluster_indices = list(documents.Old_ID.values) - cluster_names = list(merged_model.topic_labels_.values())[len(set(y)):] + cluster_names = list(merged_model.topic_labels_.values())[len(set(y)) :] if self._outliers: - cluster_topics = [cluster_names[topic] if topic != -1 else "Outliers" for topic in documents.Topic.values] + cluster_topics = [ + cluster_names[topic] if topic != -1 else "Outliers" + for topic in documents.Topic.values + ] else: cluster_topics = [cluster_names[topic] for topic in documents.Topic.values] - df = pd.DataFrame({ - "Indices": zeroshot_indices + cluster_indices, - "Label": zeroshot_topics + cluster_topics} + df = pd.DataFrame( + { + "Indices": zeroshot_indices + cluster_indices, + "Label": zeroshot_topics + cluster_topics, + } ).sort_values("Indices") - reverse_topic_labels = dict((v, k) for k, v in merged_model.topic_labels_.items()) + reverse_topic_labels = dict( + (v, k) for k, v in merged_model.topic_labels_.items() + ) if self._outliers: reverse_topic_labels["Outliers"] = -1 df.Label = df.Label.map(reverse_topic_labels) @@ -3742,20 +4206,30 @@ def _combine_zeroshot_topics(self, # Re-map the topics including all representations (labels, sizes, embeddings, etc.) self.topics_ = [new_mappings[topic] for topic in self.topics_] - self.topic_representations_ = {new_mappings[topic]: repr for topic, repr in self.topic_representations_.items()} - self.topic_labels_ = {new_mappings[topic]: label for topic, label in self.topic_labels_.items()} + self.topic_representations_ = { + new_mappings[topic]: repr + for topic, repr in self.topic_representations_.items() + } + self.topic_labels_ = { + new_mappings[topic]: label + for topic, label in self.topic_labels_.items() + } self.topic_sizes_ = collections.Counter(self.topics_) - self.topic_embeddings_ = np.vstack([ - self.topic_embeddings_[nr_zeroshot_topics], - self.topic_embeddings_[:nr_zeroshot_topics], - self.topic_embeddings_[nr_zeroshot_topics+1:] - ]) + self.topic_embeddings_ = np.vstack( + [ + self.topic_embeddings_[nr_zeroshot_topics], + self.topic_embeddings_[:nr_zeroshot_topics], + self.topic_embeddings_[nr_zeroshot_topics + 1 :], + ] + ) self._outliers = 1 return self.topics_ - def _guided_topic_modeling(self, embeddings: np.ndarray) -> Tuple[List[int], np.array]: - """ Apply Guided Topic Modeling + def _guided_topic_modeling( + self, embeddings: np.ndarray + ) -> Tuple[List[int], np.array]: + """Apply Guided Topic Modeling. We transform the seeded topics to embeddings using the same embedder as used for generating document embeddings. @@ -3771,15 +4245,19 @@ def _guided_topic_modeling(self, embeddings: np.ndarray) -> Tuple[List[int], np. Arguments: embeddings: The document embeddings - Returns + Returns: y: The labels for each seeded topic embeddings: Updated embeddings """ logger.info("Guided - Find embeddings highly related to seeded topics.") # Create embeddings from the seeded topics seed_topic_list = [" ".join(seed_topic) for seed_topic in self.seed_topic_list] - seed_topic_embeddings = self._extract_embeddings(seed_topic_list, verbose=self.verbose) - seed_topic_embeddings = np.vstack([seed_topic_embeddings, embeddings.mean(axis=0)]) + seed_topic_embeddings = self._extract_embeddings( + seed_topic_list, verbose=self.verbose + ) + seed_topic_embeddings = np.vstack( + [seed_topic_embeddings, embeddings.mean(axis=0)] + ) # Label documents that are most similar to one of the seeded topics sim_matrix = cosine_similarity(embeddings, seed_topic_embeddings) @@ -3790,12 +4268,20 @@ def _guided_topic_modeling(self, embeddings: np.ndarray) -> Tuple[List[int], np. # embedding of the seeded topic to force the documents in a cluster for seed_topic in range(len(seed_topic_list)): indices = [index for index, topic in enumerate(y) if topic == seed_topic] - embeddings[indices] = np.average([embeddings[indices], seed_topic_embeddings[seed_topic]], weights=[3, 1]) + embeddings[indices] = np.average( + [embeddings[indices], seed_topic_embeddings[seed_topic]], weights=[3, 1] + ) logger.info("Guided - Completed \u2713") return y, embeddings - def _extract_topics(self, documents: pd.DataFrame, embeddings: np.ndarray = None, mappings=None, verbose: bool = False): - """ Extract topics from the clusters using a class-based TF-IDF + def _extract_topics( + self, + documents: pd.DataFrame, + embeddings: np.ndarray = None, + mappings=None, + verbose: bool = False, + ): + """Extract topics from the clusters using a class-based TF-IDF. Arguments: documents: Dataframe with documents and their corresponding IDs @@ -3807,19 +4293,26 @@ def _extract_topics(self, documents: pd.DataFrame, embeddings: np.ndarray = None c_tf_idf: The resulting matrix giving a value (importance score) for each word per topic """ if verbose: - logger.info("Representation - Extracting topics from clusters using representation models.") - documents_per_topic = documents.groupby(['Topic'], as_index=False).agg({'Document': ' '.join}) + logger.info( + "Representation - Extracting topics from clusters using representation models." + ) + documents_per_topic = documents.groupby(["Topic"], as_index=False).agg( + {"Document": " ".join} + ) self.c_tf_idf_, words = self._c_tf_idf(documents_per_topic) self.topic_representations_ = self._extract_words_per_topic(words, documents) - self._create_topic_vectors(documents=documents, embeddings=embeddings, mappings=mappings) - self.topic_labels_ = {key: f"{key}_" + "_".join([word[0] for word in values[:4]]) - for key, values in - self.topic_representations_.items()} + self._create_topic_vectors( + documents=documents, embeddings=embeddings, mappings=mappings + ) + self.topic_labels_ = { + key: f"{key}_" + "_".join([word[0] for word in values[:4]]) + for key, values in self.topic_representations_.items() + } if verbose: logger.info("Representation - Completed \u2713") def _save_representative_docs(self, documents: pd.DataFrame): - """ Save the 3 most representative docs per topic + """Save the 3 most representative docs per topic. Arguments: documents: Dataframe with documents and their corresponding IDs @@ -3832,19 +4325,20 @@ def _save_representative_docs(self, documents: pd.DataFrame): documents, self.topic_representations_, nr_samples=500, - nr_repr_docs=3 + nr_repr_docs=3, ) self.representative_docs_ = repr_docs - def _extract_representative_docs(self, - c_tf_idf: csr_matrix, - documents: pd.DataFrame, - topics: Mapping[str, List[Tuple[str, float]]], - nr_samples: int = 500, - nr_repr_docs: int = 5, - diversity: float = None - ) -> Union[List[str], List[List[int]]]: - """ Approximate most representative documents per topic by sampling + def _extract_representative_docs( + self, + c_tf_idf: csr_matrix, + documents: pd.DataFrame, + topics: Mapping[str, List[Tuple[str, float]]], + nr_samples: int = 500, + nr_repr_docs: int = 5, + diversity: float = None, + ) -> Union[List[str], List[List[int]]]: + """Approximate most representative documents per topic by sampling a subset of the documents in each topic and calculating which are most representative to their topic based on the cosine similarity between c-TF-IDF representations. @@ -3869,9 +4363,9 @@ def _extract_representative_docs(self, # Sample documents per topic documents_per_topic = ( documents.drop("Image", axis=1, errors="ignore") - .groupby('Topic') - .sample(n=nr_samples, replace=True, random_state=42) - .drop_duplicates() + .groupby("Topic") + .sample(n=nr_samples, replace=True, random_state=42) + .drop_duplicates() ) # Find and extract documents that are most similar to the topic @@ -3881,37 +4375,65 @@ def _extract_representative_docs(self, repr_docs_ids = [] labels = sorted(list(topics.keys())) for index, topic in enumerate(labels): - # Slice data selection = documents_per_topic.loc[documents_per_topic.Topic == topic, :] selected_docs = selection["Document"].values selected_docs_ids = selection.index.tolist() # Calculate similarity - nr_docs = nr_repr_docs if len(selected_docs) > nr_repr_docs else len(selected_docs) + nr_docs = ( + nr_repr_docs + if len(selected_docs) > nr_repr_docs + else len(selected_docs) + ) bow = self.vectorizer_model.transform(selected_docs) ctfidf = self.ctfidf_model.transform(bow) sim_matrix = cosine_similarity(ctfidf, c_tf_idf[index]) # Use MMR to find representative but diverse documents if diversity: - docs = mmr(c_tf_idf[index], ctfidf, selected_docs, top_n=nr_docs, diversity=diversity) + docs = mmr( + c_tf_idf[index], + ctfidf, + selected_docs, + top_n=nr_docs, + diversity=diversity, + ) # Extract top n most representative documents else: - indices = np.argpartition(sim_matrix.reshape(1, -1)[0], -nr_docs)[-nr_docs:] + indices = np.argpartition(sim_matrix.reshape(1, -1)[0], -nr_docs)[ + -nr_docs: + ] docs = [selected_docs[index] for index in indices] - doc_ids = [selected_docs_ids[index] for index, doc in enumerate(selected_docs) if doc in docs] + doc_ids = [ + selected_docs_ids[index] + for index, doc in enumerate(selected_docs) + if doc in docs + ] repr_docs_ids.append(doc_ids) repr_docs.extend(docs) - repr_docs_indices.append([repr_docs_indices[-1][-1] + i + 1 if index != 0 else i for i in range(nr_docs)]) - repr_docs_mappings = {topic: repr_docs[i[0]:i[-1]+1] for topic, i in zip(topics.keys(), repr_docs_indices)} + repr_docs_indices.append( + [ + repr_docs_indices[-1][-1] + i + 1 if index != 0 else i + for i in range(nr_docs) + ] + ) + repr_docs_mappings = { + topic: repr_docs[i[0] : i[-1] + 1] + for topic, i in zip(topics.keys(), repr_docs_indices) + } return repr_docs_mappings, repr_docs, repr_docs_indices, repr_docs_ids - def _create_topic_vectors(self, documents: pd.DataFrame = None, embeddings: np.ndarray = None, mappings=None): - """ Creates embeddings per topics based on their topic representation + def _create_topic_vectors( + self, + documents: pd.DataFrame = None, + embeddings: np.ndarray = None, + mappings=None, + ): + """Creates embeddings per topics based on their topic representation. As a default, topic vectors (topic embeddings) are created by taking the average of all document embeddings within a topic. If topics are @@ -3942,20 +4464,30 @@ def _create_topic_vectors(self, documents: pd.DataFrame = None, embeddings: np.n topic_ids = topics_to["topics_to"] topic_sizes = topics_to["topic_sizes"] if topic_ids: - embds = np.array(self.topic_embeddings_)[np.array(topic_ids) + self._outliers] + embds = np.array(self.topic_embeddings_)[ + np.array(topic_ids) + self._outliers + ] topic_embedding = np.average(embds, axis=0, weights=topic_sizes) topic_embeddings_dict[topic_from] = topic_embedding # Re-order topic embeddings - topics_to_map = {topic_mapping[0]: topic_mapping[1] for topic_mapping in np.array(self.topic_mapper_.mappings_)[:, -2:]} + topics_to_map = { + topic_mapping[0]: topic_mapping[1] + for topic_mapping in np.array(self.topic_mapper_.mappings_)[:, -2:] + } topic_embeddings = {} for topic, embds in topic_embeddings_dict.items(): topic_embeddings[topics_to_map[topic]] = embds unique_topics = sorted(list(topic_embeddings.keys())) - self.topic_embeddings_ = np.array([topic_embeddings[topic] for topic in unique_topics]) + self.topic_embeddings_ = np.array( + [topic_embeddings[topic] for topic in unique_topics] + ) # Topic embeddings based on keyword representations - elif self.embedding_model is not None and type(self.embedding_model) is not BaseEmbedder: + elif ( + self.embedding_model is not None + and type(self.embedding_model) is not BaseEmbedder + ): topic_list = list(self.topic_representations_.keys()) topic_list.sort() @@ -3968,9 +4500,7 @@ def _create_topic_vectors(self, documents: pd.DataFrame = None, embeddings: np.n topic_words = [self.get_topic(topic) for topic in topic_list] topic_words = [word[0] for topic in topic_words for word in topic] word_embeddings = self._extract_embeddings( - topic_words, - method="word", - verbose=False + topic_words, method="word", verbose=False ) # Take the weighted average of word embeddings in a topic based on their c-TF-IDF value @@ -3981,16 +4511,22 @@ def _create_topic_vectors(self, documents: pd.DataFrame = None, embeddings: np.n word_importance = [val[1] for val in self.get_topic(topic)] if sum(word_importance) == 0: word_importance = [1 for _ in range(len(self.get_topic(topic)))] - topic_embedding = np.average(word_embeddings[i * n: n + (i * n)], weights=word_importance, axis=0) + topic_embedding = np.average( + word_embeddings[i * n : n + (i * n)], + weights=word_importance, + axis=0, + ) topic_embeddings.append(topic_embedding) self.topic_embeddings_ = np.array(topic_embeddings) - def _c_tf_idf(self, - documents_per_topic: pd.DataFrame, - fit: bool = True, - partial_fit: bool = False) -> Tuple[csr_matrix, List[str]]: - """ Calculate a class-based TF-IDF where m is the number of total documents. + def _c_tf_idf( + self, + documents_per_topic: pd.DataFrame, + fit: bool = True, + partial_fit: bool = False, + ) -> Tuple[csr_matrix, List[str]]: + """Calculate a class-based TF-IDF where m is the number of total documents. Arguments: documents_per_topic: The joined documents per topic such that each topic has a single @@ -4022,13 +4558,34 @@ def _c_tf_idf(self, multiplier = None if self.ctfidf_model.seed_words and self.seed_topic_list: seed_topic_list = [seed for seeds in self.seed_topic_list for seed in seeds] - multiplier = np.array([self.ctfidf_model.seed_multiplier if word in self.ctfidf_model.seed_words else 1 for word in words]) - multiplier = np.array([1.2 if word in seed_topic_list else value for value, word in zip(multiplier, words)]) + multiplier = np.array( + [ + self.ctfidf_model.seed_multiplier + if word in self.ctfidf_model.seed_words + else 1 + for word in words + ] + ) + multiplier = np.array( + [ + 1.2 if word in seed_topic_list else value + for value, word in zip(multiplier, words) + ] + ) elif self.ctfidf_model.seed_words: - multiplier = np.array([self.ctfidf_model.seed_multiplier if word in self.ctfidf_model.seed_words else 1 for word in words]) + multiplier = np.array( + [ + self.ctfidf_model.seed_multiplier + if word in self.ctfidf_model.seed_words + else 1 + for word in words + ] + ) elif self.seed_topic_list: seed_topic_list = [seed for seeds in self.seed_topic_list for seed in seeds] - multiplier = np.array([1.2 if word in seed_topic_list else 1 for word in words]) + multiplier = np.array( + [1.2 if word in seed_topic_list else 1 for word in words] + ) if fit: self.ctfidf_model = self.ctfidf_model.fit(X, multiplier=multiplier) @@ -4038,7 +4595,7 @@ def _c_tf_idf(self, return c_tf_idf, words def _update_topic_size(self, documents: pd.DataFrame): - """ Calculate the topic sizes + """Calculate the topic sizes. Arguments: documents: Updated dataframe with documents and their corresponding IDs and newly added Topics @@ -4046,13 +4603,14 @@ def _update_topic_size(self, documents: pd.DataFrame): self.topic_sizes_ = collections.Counter(documents.Topic.values.tolist()) self.topics_ = documents.Topic.astype(int).tolist() - def _extract_words_per_topic(self, - words: List[str], - documents: pd.DataFrame, - c_tf_idf: csr_matrix = None, - calculate_aspects: bool = True) -> Mapping[str, - List[Tuple[str, float]]]: - """ Based on tf_idf scores per topic, extract the top n words per topic + def _extract_words_per_topic( + self, + words: List[str], + documents: pd.DataFrame, + c_tf_idf: csr_matrix = None, + calculate_aspects: bool = True, + ) -> Mapping[str, List[Tuple[str, float]]]: + """Based on tf_idf scores per topic, extract the top n words per topic. If the top words per topic need to be extracted, then only the `words` parameter needs to be passed. If the top words per topic in a specific timestamp, then it @@ -4063,6 +4621,7 @@ def _extract_words_per_topic(self, words: List of all words (sorted according to tf_idf matrix position) documents: DataFrame with documents and their topic IDs c_tf_idf: A c-TF-IDF matrix from which to calculate the top words + calculate_aspects: Whether to calculate additional topic aspects Returns: topics: The top words per topic @@ -4082,40 +4641,54 @@ def _extract_words_per_topic(self, scores = np.take_along_axis(scores, sorted_indices, axis=1) # Get top 30 words per topic based on c-TF-IDF score - base_topics = {label: [(words[word_index], score) - if word_index is not None and score > 0 - else ("", 0.00001) - for word_index, score in zip(indices[index][::-1], scores[index][::-1]) - ] - for index, label in enumerate(labels)} + base_topics = { + label: [ + (words[word_index], score) + if word_index is not None and score > 0 + else ("", 0.00001) + for word_index, score in zip(indices[index][::-1], scores[index][::-1]) + ] + for index, label in enumerate(labels) + } # Fine-tune the topic representations topics = base_topics.copy() if not self.representation_model: # Default representation: c_tf_idf + top_n_words - topics = {label: values[:self.top_n_words] for label, values in topics.items()} + topics = { + label: values[: self.top_n_words] for label, values in topics.items() + } elif isinstance(self.representation_model, list): for tuner in self.representation_model: topics = tuner.extract_topics(self, documents, c_tf_idf, topics) elif isinstance(self.representation_model, BaseRepresentation): - topics = self.representation_model.extract_topics(self, documents, c_tf_idf, topics) + topics = self.representation_model.extract_topics( + self, documents, c_tf_idf, topics + ) elif isinstance(self.representation_model, dict): if self.representation_model.get("Main"): main_model = self.representation_model["Main"] if isinstance(main_model, BaseRepresentation): - topics = main_model.extract_topics(self, documents, c_tf_idf, topics) + topics = main_model.extract_topics( + self, documents, c_tf_idf, topics + ) elif isinstance(main_model, list): for tuner in main_model: topics = tuner.extract_topics(self, documents, c_tf_idf, topics) else: raise TypeError( - f"unsupported type {type(main_model).__name__} for representation_model['Main']") + f"unsupported type {type(main_model).__name__} for representation_model['Main']" + ) else: # Default representation: c_tf_idf + top_n_words - topics = {label: values[:self.top_n_words] for label, values in topics.items()} + topics = { + label: values[: self.top_n_words] + for label, values in topics.items() + } else: raise TypeError( - f"unsupported type {type(self.representation_model).__name__} for representation_model") + f"unsupported type {type(self.representation_model).__name__} for representation_model" + ) # Extract additional topic aspects if calculate_aspects and isinstance(self.representation_model, dict): @@ -4124,21 +4697,31 @@ def _extract_words_per_topic(self, aspects = base_topics.copy() if not aspect_model: # Default representation: c_tf_idf + top_n_words - aspects = {label: values[:self.top_n_words] for label, values in aspects.items()} + aspects = { + label: values[: self.top_n_words] + for label, values in aspects.items() + } if isinstance(aspect_model, list): for tuner in aspect_model: - aspects = tuner.extract_topics(self, documents, c_tf_idf, aspects) + aspects = tuner.extract_topics( + self, documents, c_tf_idf, aspects + ) elif isinstance(aspect_model, BaseRepresentation): - aspects = aspect_model.extract_topics(self, documents, c_tf_idf, aspects) + aspects = aspect_model.extract_topics( + self, documents, c_tf_idf, aspects + ) else: raise TypeError( - f"unsupported type {type(aspect_model).__name__} for representation_model[{repr(aspect)}]") + f"unsupported type {type(aspect_model).__name__} for representation_model[{repr(aspect)}]" + ) self.topic_aspects_[aspect] = aspects return topics - def _reduce_topics(self, documents: pd.DataFrame, use_ctfidf: bool = False) -> pd.DataFrame: - """ Reduce topics to self.nr_topics + def _reduce_topics( + self, documents: pd.DataFrame, use_ctfidf: bool = False + ) -> pd.DataFrame: + """Reduce topics to self.nr_topics. Arguments: documents: Dataframe with documents and their corresponding IDs and Topics @@ -4148,7 +4731,6 @@ def _reduce_topics(self, documents: pd.DataFrame, use_ctfidf: bool = False) -> p Returns: documents: Updated dataframe with documents and the reduced number of Topics """ - logger.info("Topic reduction - Reducing number of topics") initial_nr_topics = len(self.get_topics()) @@ -4160,11 +4742,15 @@ def _reduce_topics(self, documents: pd.DataFrame, use_ctfidf: bool = False) -> p else: raise ValueError("nr_topics needs to be an int or 'auto'! ") - logger.info(f"Topic reduction - Reduced number of topics from {initial_nr_topics} to {len(self.get_topic_freq())}") + logger.info( + f"Topic reduction - Reduced number of topics from {initial_nr_topics} to {len(self.get_topic_freq())}" + ) return documents - def _reduce_to_n_topics(self, documents: pd.DataFrame, use_ctfidf: bool = False) -> pd.DataFrame: - """ Reduce topics to self.nr_topics + def _reduce_to_n_topics( + self, documents: pd.DataFrame, use_ctfidf: bool = False + ) -> pd.DataFrame: + """Reduce topics to self.nr_topics. Arguments: documents: Dataframe with documents and their corresponding IDs and Topics @@ -4179,27 +4765,38 @@ def _reduce_to_n_topics(self, documents: pd.DataFrame, use_ctfidf: bool = False) # Create topic distance matrix topic_embeddings = select_topic_representation( self.c_tf_idf_, self.topic_embeddings_, use_ctfidf, output_ndarray=True - )[0][self._outliers:] - distance_matrix = 1-cosine_similarity(topic_embeddings) + )[0][self._outliers :] + distance_matrix = 1 - cosine_similarity(topic_embeddings) np.fill_diagonal(distance_matrix, 0) # Cluster the topic embeddings using AgglomerativeClustering if version.parse(sklearn_version) >= version.parse("1.4.0"): - cluster = AgglomerativeClustering(self.nr_topics - self._outliers, metric="precomputed", linkage="average") + cluster = AgglomerativeClustering( + self.nr_topics - self._outliers, metric="precomputed", linkage="average" + ) else: - cluster = AgglomerativeClustering(self.nr_topics - self._outliers, affinity="precomputed", linkage="average") + cluster = AgglomerativeClustering( + self.nr_topics - self._outliers, + affinity="precomputed", + linkage="average", + ) cluster.fit(distance_matrix) new_topics = [cluster.labels_[topic] if topic != -1 else -1 for topic in topics] # Track mappings and sizes of topics for merging topic embeddings - mapped_topics = {from_topic: to_topic for from_topic, to_topic in zip(topics, new_topics)} + mapped_topics = { + from_topic: to_topic for from_topic, to_topic in zip(topics, new_topics) + } mappings = defaultdict(list) for key, val in sorted(mapped_topics.items()): mappings[val].append(key) - mappings = {topic_from: - {"topics_to": topics_to, - "topic_sizes": [self.topic_sizes_[topic] for topic in topics_to]} - for topic_from, topics_to in mappings.items()} + mappings = { + topic_from: { + "topics_to": topics_to, + "topic_sizes": [self.topic_sizes_[topic] for topic in topics_to], + } + for topic_from, topics_to in mappings.items() + } # Map topics documents.Topic = new_topics @@ -4212,8 +4809,10 @@ def _reduce_to_n_topics(self, documents: pd.DataFrame, use_ctfidf: bool = False) self._update_topic_size(documents) return documents - def _auto_reduce_topics(self, documents: pd.DataFrame, use_ctfidf: bool = False) -> pd.DataFrame: - """ Reduce the number of topics automatically using HDBSCAN + def _auto_reduce_topics( + self, documents: pd.DataFrame, use_ctfidf: bool = False + ) -> pd.DataFrame: + """Reduce the number of topics automatically using HDBSCAN. Arguments: documents: Dataframe with documents and their corresponding IDs and Topics @@ -4224,34 +4823,46 @@ def _auto_reduce_topics(self, documents: pd.DataFrame, use_ctfidf: bool = False) documents: Updated dataframe with documents and the reduced number of Topics """ topics = documents.Topic.tolist().copy() - unique_topics = sorted(list(documents.Topic.unique()))[self._outliers:] + unique_topics = sorted(list(documents.Topic.unique()))[self._outliers :] max_topic = unique_topics[-1] # Find similar topics embeddings = select_topic_representation( self.c_tf_idf_, self.topic_embeddings_, use_ctfidf, output_ndarray=True )[0] - norm_data = normalize(embeddings, norm='l2') - predictions = hdbscan.HDBSCAN(min_cluster_size=2, - metric='euclidean', - cluster_selection_method='eom', - prediction_data=True).fit_predict(norm_data[self._outliers:]) + norm_data = normalize(embeddings, norm="l2") + predictions = hdbscan.HDBSCAN( + min_cluster_size=2, + metric="euclidean", + cluster_selection_method="eom", + prediction_data=True, + ).fit_predict(norm_data[self._outliers :]) # Map similar topics - mapped_topics = {unique_topics[index]: prediction + max_topic - for index, prediction in enumerate(predictions) - if prediction != -1} - documents.Topic = documents.Topic.map(mapped_topics).fillna(documents.Topic).astype(int) - mapped_topics = {from_topic: to_topic for from_topic, to_topic in zip(topics, documents.Topic.tolist())} + mapped_topics = { + unique_topics[index]: prediction + max_topic + for index, prediction in enumerate(predictions) + if prediction != -1 + } + documents.Topic = ( + documents.Topic.map(mapped_topics).fillna(documents.Topic).astype(int) + ) + mapped_topics = { + from_topic: to_topic + for from_topic, to_topic in zip(topics, documents.Topic.tolist()) + } # Track mappings and sizes of topics for merging topic embeddings mappings = defaultdict(list) for key, val in sorted(mapped_topics.items()): mappings[val].append(key) - mappings = {topic_from: - {"topics_to": topics_to, - "topic_sizes": [self.topic_sizes_[topic] for topic in topics_to]} - for topic_from, topics_to in mappings.items()} + mappings = { + topic_from: { + "topics_to": topics_to, + "topic_sizes": [self.topic_sizes_[topic] for topic in topics_to], + } + for topic_from, topics_to in mappings.items() + } # Update documents and topics self.topic_mapper_.add_mappings(mapped_topics) @@ -4261,7 +4872,7 @@ def _auto_reduce_topics(self, documents: pd.DataFrame, use_ctfidf: bool = False) return documents def _sort_mappings_by_frequency(self, documents: pd.DataFrame) -> pd.DataFrame: - """ Reorder mappings by their frequency. + """Reorder mappings by their frequency. For example, if topic 88 was mapped to topic 5 and topic 5 turns out to be the largest topic, @@ -4287,20 +4898,24 @@ def _sort_mappings_by_frequency(self, documents: pd.DataFrame) -> pd.DataFrame: self._update_topic_size(documents) # Map topics based on frequency - df = pd.DataFrame(self.topic_sizes_.items(), columns=["Old_Topic", "Size"]).sort_values("Size", ascending=False) + df = pd.DataFrame( + self.topic_sizes_.items(), columns=["Old_Topic", "Size"] + ).sort_values("Size", ascending=False) df = df[df.Old_Topic != -1] sorted_topics = {**{-1: -1}, **dict(zip(df.Old_Topic, range(len(df))))} self.topic_mapper_.add_mappings(sorted_topics) # Map documents - documents.Topic = documents.Topic.map(sorted_topics).fillna(documents.Topic).astype(int) + documents.Topic = ( + documents.Topic.map(sorted_topics).fillna(documents.Topic).astype(int) + ) self._update_topic_size(documents) return documents - def _map_probabilities(self, - probabilities: Union[np.ndarray, None], - original_topics: bool = False) -> Union[np.ndarray, None]: - """ Map the probabilities to the reduced topics. + def _map_probabilities( + self, probabilities: Union[np.ndarray, None], original_topics: bool = False + ) -> Union[np.ndarray, None]: + """Map the probabilities to the reduced topics. This is achieved by adding together the probabilities of all topics that are mapped to the same topic. Then, the topics that were mapped from are set to 0 as they @@ -4320,18 +4935,24 @@ def _map_probabilities(self, # Map array of probabilities (probability for assigned topic per document) if probabilities is not None: if len(probabilities.shape) == 2: - mapped_probabilities = np.zeros((probabilities.shape[0], - len(set(mappings.values())) - self._outliers)) + mapped_probabilities = np.zeros( + ( + probabilities.shape[0], + len(set(mappings.values())) - self._outliers, + ) + ) for from_topic, to_topic in mappings.items(): if to_topic != -1 and from_topic != -1: - mapped_probabilities[:, to_topic] += probabilities[:, from_topic] + mapped_probabilities[:, to_topic] += probabilities[ + :, from_topic + ] return mapped_probabilities return probabilities def _preprocess_text(self, documents: np.ndarray) -> List[str]: - """ Basic preprocessing of text + r"""Basic preprocessing of text. Steps: * Replace \n and \t with whitespace @@ -4340,13 +4961,17 @@ def _preprocess_text(self, documents: np.ndarray) -> List[str]: cleaned_documents = [doc.replace("\n", " ") for doc in documents] cleaned_documents = [doc.replace("\t", " ") for doc in cleaned_documents] if self.language == "english": - cleaned_documents = [re.sub(r'[^A-Za-z0-9 ]+', '', doc) for doc in cleaned_documents] - cleaned_documents = [doc if doc != "" else "emptydoc" for doc in cleaned_documents] + cleaned_documents = [ + re.sub(r"[^A-Za-z0-9 ]+", "", doc) for doc in cleaned_documents + ] + cleaned_documents = [ + doc if doc != "" else "emptydoc" for doc in cleaned_documents + ] return cleaned_documents @staticmethod def _top_n_idx_sparse(matrix: csr_matrix, n: int) -> np.ndarray: - """ Return indices of top n values in each row of a sparse matrix + """Return indices of top n values in each row of a sparse matrix. Retrieved from: https://stackoverflow.com/questions/49207275/finding-the-top-n-values-in-a-row-of-a-scipy-sparse-matrix @@ -4361,14 +4986,19 @@ def _top_n_idx_sparse(matrix: csr_matrix, n: int) -> np.ndarray: indices = [] for le, ri in zip(matrix.indptr[:-1], matrix.indptr[1:]): n_row_pick = min(n, ri - le) - values = matrix.indices[le + np.argpartition(matrix.data[le:ri], -n_row_pick)[-n_row_pick:]] - values = [values[index] if len(values) >= index + 1 else None for index in range(n)] + values = matrix.indices[ + le + np.argpartition(matrix.data[le:ri], -n_row_pick)[-n_row_pick:] + ] + values = [ + values[index] if len(values) >= index + 1 else None + for index in range(n) + ] indices.append(values) return np.array(indices) @staticmethod def _top_n_values_sparse(matrix: csr_matrix, indices: np.ndarray) -> np.ndarray: - """ Return the top n values for each row in a sparse matrix + """Return the top n values for each row in a sparse matrix. Arguments: matrix: The sparse matrix from which to get the top n indices per row @@ -4379,20 +5009,27 @@ def _top_n_values_sparse(matrix: csr_matrix, indices: np.ndarray) -> np.ndarray: """ top_values = [] for row, values in enumerate(indices): - scores = np.array([matrix[row, value] if value is not None else 0 for value in values]) + scores = np.array( + [matrix[row, value] if value is not None else 0 for value in values] + ) top_values.append(scores) return np.array(top_values) @classmethod def _get_param_names(cls): - """Get parameter names for the estimator + """Get parameter names for the estimator. Adapted from: https://github.com/scikit-learn/scikit-learn/blob/b3ea3ed6a/sklearn/base.py#L178 """ init_signature = inspect.signature(cls.__init__) - parameters = sorted([p.name for p in init_signature.parameters.values() - if p.name != 'self' and p.kind != p.VAR_KEYWORD]) + parameters = sorted( + [ + p.name + for p in init_signature.parameters.values() + if p.name != "self" and p.kind != p.VAR_KEYWORD + ] + ) return parameters def __str__(self): @@ -4413,7 +5050,7 @@ def __str__(self): class TopicMapper: - """ Keep track of Topic Mappings + """Keep track of Topic Mappings. The number of topics can be reduced by merging them together. This mapping @@ -4437,8 +5074,9 @@ class TopicMapper: of topics and the first column represents the initial state of topics. """ + def __init__(self, topics: List[int]): - """ Initialization of Topic Mapper + """Initialization of Topic Mapper. Arguments: topics: A list of topics per document @@ -4448,8 +5086,8 @@ def __init__(self, topics: List[int]): self.mappings_ = np.hstack([topics.copy(), topics.copy()]).tolist() def get_mappings(self, original_topics: bool = True) -> Mapping[int, int]: - """ Get mappings from either the original topics or - the second-most recent topics to the current topics + """Get mappings from either the original topics or + the second-most recent topics to the current topics. Arguments: original_topics: Whether we want to map from the @@ -4460,7 +5098,6 @@ def get_mappings(self, original_topics: bool = True) -> Mapping[int, int]: mappings: The mappings from old topics to new topics Examples: - To get mappings, simply call: ```python mapper = TopicMapper(topics) @@ -4476,7 +5113,7 @@ def get_mappings(self, original_topics: bool = True) -> Mapping[int, int]: return mappings def add_mappings(self, mappings: Mapping[int, int]): - """ Add new column(s) of topic mappings + """Add new column(s) of topic mappings. Arguments: mappings: The mappings to add @@ -4489,26 +5126,27 @@ def add_mappings(self, mappings: Mapping[int, int]): topics.append(-1) def add_new_topics(self, mappings: Mapping[int, int]): - """ Add new row(s) of topic mappings + """Add new row(s) of topic mappings. Arguments: mappings: The mappings to add """ length = len(self.mappings_[0]) for key, value in mappings.items(): - to_append = [key] + ([None] * (length-2)) + [value] + to_append = [key] + ([None] * (length - 2)) + [value] self.mappings_.append(to_append) def _create_model_from_files( - topics: Mapping[str, Any], - params: Mapping[str, Any], - tensors: Mapping[str, np.array], - ctfidf_tensors: Mapping[str, Any] = None, - ctfidf_config: Mapping[str, Any] = None, - images: Mapping[int, Any] = None, - warn_no_backend: bool = True): - """ Create a BERTopic model from a variety of inputs + topics: Mapping[str, Any], + params: Mapping[str, Any], + tensors: Mapping[str, np.array], + ctfidf_tensors: Mapping[str, Any] = None, + ctfidf_config: Mapping[str, Any] = None, + images: Mapping[int, Any] = None, + warn_no_backend: bool = True, +): + """Create a BERTopic model from a variety of inputs. Arguments: topics: A dictionary containing topic metadata, including: @@ -4522,6 +5160,7 @@ def _create_model_from_files( warn_no_backend: Whether to warn the user if no backend is given """ from sentence_transformers import SentenceTransformer + params["n_gram_range"] = tuple(params["n_gram_range"]) if ctfidf_config is not None: @@ -4533,17 +5172,19 @@ def _create_model_from_files( # Select HF model through SentenceTransformers try: - embedding_model = select_backend(SentenceTransformer(params['embedding_model'])) - except: + embedding_model = select_backend(SentenceTransformer(params["embedding_model"])) + except: # noqa: E722 embedding_model = BaseEmbedder() if warn_no_backend: - logger.warning("You are loading a BERTopic model without explicitly defining an embedding model." - " If you want to also load in an embedding model, make sure to use" - " `BERTopic.load(my_model, embedding_model=my_embedding_model)`.") + logger.warning( + "You are loading a BERTopic model without explicitly defining an embedding model." + " If you want to also load in an embedding model, make sure to use" + " `BERTopic.load(my_model, embedding_model=my_embedding_model)`." + ) if params.get("embedding_model") is not None: - del params['embedding_model'] + del params["embedding_model"] # Prepare our empty sub-models empty_dimensionality_model = BaseDimensionalityReduction() @@ -4551,16 +5192,22 @@ def _create_model_from_files( # Fit BERTopic without actually performing any clustering topic_model = BERTopic( - embedding_model=embedding_model, - umap_model=empty_dimensionality_model, - hdbscan_model=empty_cluster_model, - **params + embedding_model=embedding_model, + umap_model=empty_dimensionality_model, + hdbscan_model=empty_cluster_model, + **params, ) topic_model.topic_embeddings_ = tensors["topic_embeddings"].numpy() - topic_model.topic_representations_ = {int(key): val for key, val in topics["topic_representations"].items()} + topic_model.topic_representations_ = { + int(key): val for key, val in topics["topic_representations"].items() + } topic_model.topics_ = topics["topics"] - topic_model.topic_sizes_ = {int(key): val for key, val in topics["topic_sizes"].items()} - topic_model.topic_labels_ = {int(key): val for key, val in topics["topic_labels"].items()} + topic_model.topic_sizes_ = { + int(key): val for key, val in topics["topic_sizes"].items() + } + topic_model.topic_labels_ = { + int(key): val for key, val in topics["topic_labels"].items() + } topic_model.custom_labels_ = topics["custom_labels"] topic_model._outliers = topics["_outliers"] @@ -4568,7 +5215,9 @@ def _create_model_from_files( topic_aspects = {} for aspect, values in topics["topic_aspects"].items(): if aspect != "Visual_Aspect": - topic_aspects[aspect] = {int(topic): value for topic, value in values.items()} + topic_aspects[aspect] = { + int(topic): value for topic, value in values.items() + } topic_model.topic_aspects_ = topic_aspects if images is not None: @@ -4579,15 +5228,32 @@ def _create_model_from_files( topic_model.topic_mapper_.mappings_ = topics["topic_mapper"] if ctfidf_tensors is not None: - topic_model.c_tf_idf_ = csr_matrix((ctfidf_tensors["data"], ctfidf_tensors["indices"], ctfidf_tensors["indptr"]), shape=ctfidf_tensors["shape"]) + topic_model.c_tf_idf_ = csr_matrix( + ( + ctfidf_tensors["data"], + ctfidf_tensors["indices"], + ctfidf_tensors["indptr"], + ), + shape=ctfidf_tensors["shape"], + ) # CountVectorizer - topic_model.vectorizer_model = CountVectorizer(**ctfidf_config["vectorizer_model"]["params"]) - topic_model.vectorizer_model.vocabulary_ = ctfidf_config["vectorizer_model"]["vocab"] + topic_model.vectorizer_model = CountVectorizer( + **ctfidf_config["vectorizer_model"]["params"] + ) + topic_model.vectorizer_model.vocabulary_ = ctfidf_config["vectorizer_model"][ + "vocab" + ] # ClassTfidfTransformer - topic_model.ctfidf_model.reduce_frequent_words = ctfidf_config["ctfidf_model"]["reduce_frequent_words"] - topic_model.ctfidf_model.bm25_weighting = ctfidf_config["ctfidf_model"]["bm25_weighting"] + topic_model.ctfidf_model.reduce_frequent_words = ctfidf_config["ctfidf_model"][ + "reduce_frequent_words" + ] + topic_model.ctfidf_model.bm25_weighting = ctfidf_config["ctfidf_model"][ + "bm25_weighting" + ] idf = ctfidf_tensors["diag"].numpy() - topic_model.ctfidf_model._idf_diag = sp.diags(idf, offsets=0, shape=(len(idf), len(idf)), format='csr', dtype=np.float64) + topic_model.ctfidf_model._idf_diag = sp.diags( + idf, offsets=0, shape=(len(idf), len(idf)), format="csr", dtype=np.float64 + ) return topic_model diff --git a/bertopic/_save_utils.py b/bertopic/_save_utils.py index 39e20d41..a01ba691 100644 --- a/bertopic/_save_utils.py +++ b/bertopic/_save_utils.py @@ -1,5 +1,4 @@ import os -import sys import json import numpy as np @@ -10,23 +9,25 @@ # HuggingFace Hub try: from huggingface_hub import ( - create_repo, get_hf_file_metadata, - hf_hub_download, hf_hub_url, - repo_type_and_id_from_hf_id, upload_folder) + create_repo, + get_hf_file_metadata, + hf_hub_download, + hf_hub_url, + repo_type_and_id_from_hf_id, + upload_folder, + ) + _has_hf_hub = True except ImportError: _has_hf_hub = False # Typing -if sys.version_info >= (3, 8): - from typing import Literal -else: - from typing_extensions import Literal -from typing import Union, Mapping, Any +from typing import Union # Pytorch check try: import torch + _has_torch = True except ImportError: _has_torch = False @@ -34,8 +35,9 @@ # Image check try: from PIL import Image + _has_vision = True -except: +except ImportError: _has_vision = False @@ -101,23 +103,23 @@ """ - def push_to_hf_hub( - model, - repo_id: str, - commit_message: str = 'Add BERTopic model', - token: str = None, - revision: str = None, - private: bool = False, - create_pr: bool = False, - model_card: bool = True, - serialization: str = "safetensors", - save_embedding_model: Union[str, bool] = True, - save_ctfidf: bool = False, - ): - """ Push your BERTopic model to a HuggingFace Hub + model, + repo_id: str, + commit_message: str = "Add BERTopic model", + token: str = None, + revision: str = None, + private: bool = False, + create_pr: bool = False, + model_card: bool = True, + serialization: str = "safetensors", + save_embedding_model: Union[str, bool] = True, + save_ctfidf: bool = False, +): + """Push your BERTopic model to a HuggingFace Hub. Arguments: + model: The BERTopic model to push repo_id: The name of your HuggingFace repository commit_message: A commit message token: Token to add if not already logged in @@ -133,7 +135,9 @@ def push_to_hf_hub( save_ctfidf: Whether to save c-TF-IDF information """ if not _has_hf_hub: - raise ValueError("Make sure you have the huggingface hub installed via `pip install --upgrade huggingface_hub`") + raise ValueError( + "Make sure you have the huggingface hub installed via `pip install --upgrade huggingface_hub`" + ) # Create repo if it doesn't exist yet and infer complete repo_id repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True) @@ -142,26 +146,37 @@ def push_to_hf_hub( # Temporarily save model and push to HF with TemporaryDirectory() as tmpdir: - # Save model weights and config. - model.save(tmpdir, serialization=serialization, save_embedding_model=save_embedding_model, save_ctfidf=save_ctfidf) + model.save( + tmpdir, + serialization=serialization, + save_embedding_model=save_embedding_model, + save_ctfidf=save_ctfidf, + ) # Add README if it does not exist try: - get_hf_file_metadata(hf_hub_url(repo_id=repo_id, filename="README.md", revision=revision)) - except: + get_hf_file_metadata( + hf_hub_url(repo_id=repo_id, filename="README.md", revision=revision) + ) + except: # noqa: E722 if model_card: readme_text = generate_readme(model, repo_id) readme_path = Path(tmpdir) / "README.md" - readme_path.write_text(readme_text, encoding='utf8') + readme_path.write_text(readme_text, encoding="utf8") # Upload model - return upload_folder(repo_id=repo_id, folder_path=tmpdir, revision=revision, - create_pr=create_pr, commit_message=commit_message) + return upload_folder( + repo_id=repo_id, + folder_path=tmpdir, + revision=revision, + create_pr=create_pr, + commit_message=commit_message, + ) def load_local_files(path): - """ Load local BERTopic files """ + """Load local BERTopic files.""" # Load json configs topics = load_cfg_from_json(path / TOPICS_NAME) params = load_cfg_from_json(path / CONFIG_NAME) @@ -186,7 +201,7 @@ def load_local_files(path): if torch_path.is_file(): ctfidf_tensors = torch.load(torch_path, map_location="cpu") ctfidf_config = load_cfg_from_json(path / CTFIDF_CFG_NAME) - except: + except: # noqa: E722 ctfidf_config, ctfidf_tensors = None, None # Load images @@ -195,7 +210,7 @@ def load_local_files(path): try: Image.open(path / "images/0.jpg") _has_images = True - except: + except: # noqa: E722 _has_images = False if _has_images: @@ -209,7 +224,7 @@ def load_local_files(path): def load_files_from_hf(path): - """ Load files from HuggingFace. """ + """Load files from HuggingFace.""" path = str(path) # Configs @@ -220,20 +235,24 @@ def load_files_from_hf(path): try: tensors = hf_hub_download(path, HF_SAFE_WEIGHTS_NAME, revision=None) tensors = load_safetensors(tensors) - except: + except: # noqa: E722 tensors = hf_hub_download(path, HF_WEIGHTS_NAME, revision=None) tensors = torch.load(tensors, map_location="cpu") # c-TF-IDF try: - ctfidf_config = load_cfg_from_json(hf_hub_download(path, CTFIDF_CFG_NAME, revision=None)) + ctfidf_config = load_cfg_from_json( + hf_hub_download(path, CTFIDF_CFG_NAME, revision=None) + ) try: - ctfidf_tensors = hf_hub_download(path, CTFIDF_SAFE_WEIGHTS_NAME, revision=None) + ctfidf_tensors = hf_hub_download( + path, CTFIDF_SAFE_WEIGHTS_NAME, revision=None + ) ctfidf_tensors = load_safetensors(ctfidf_tensors) - except: + except: # noqa: E722 ctfidf_tensors = hf_hub_download(path, CTFIDF_WEIGHTS_NAME, revision=None) ctfidf_tensors = torch.load(ctfidf_tensors, map_location="cpu") - except: + except: # noqa: E722 ctfidf_config, ctfidf_tensors = None, None # Load images if they exist @@ -242,27 +261,33 @@ def load_files_from_hf(path): try: hf_hub_download(path, "images/0.jpg", revision=None) _has_images = True - except: + except: # noqa: E722 _has_images = False if _has_images: topic_list = list(topics["topic_representations"].keys()) images = {} for topic in topic_list: - image = Image.open(hf_hub_download(path, f"images/{topic}.jpg", revision=None)) + image = Image.open( + hf_hub_download(path, f"images/{topic}.jpg", revision=None) + ) images[int(topic)] = image return topics, params, tensors, ctfidf_tensors, ctfidf_config, images def generate_readme(model, repo_id: str): - """ Generate README for HuggingFace model card """ + """Generate README for HuggingFace model card.""" model_card = MODEL_CARD_TEMPLATE topic_table_head = "| Topic ID | Topic Keywords | Topic Frequency | Label | \n|----------|----------------|-----------------|-------| \n" # Get Statistics model_name = repo_id.split("/")[-1] - params = {param: value for param, value in model.get_params().items() if "model" not in param} + params = { + param: value + for param, value in model.get_params().items() + if "model" not in param + } params = "\n".join([f"* {param}: {value}" for param, value in params.items()]) topics = sorted(list(set(model.topics_))) nr_topics = str(len(set(model.topics_))) @@ -273,34 +298,47 @@ def generate_readme(model, repo_id: str): nr_documents = "" # Topic information - topic_keywords = [" - ".join(list(zip(*model.get_topic(topic)))[0][:5]) for topic in topics] + topic_keywords = [ + " - ".join(list(zip(*model.get_topic(topic)))[0][:5]) for topic in topics + ] topic_freq = [model.get_topic_freq(topic) for topic in topics] - topic_labels = model.custom_labels_ if model.custom_labels_ else [model.topic_labels_[topic] for topic in topics] - topics = [f"| {topic} | {topic_keywords[index]} | {topic_freq[topic]} | {topic_labels[index]} | \n" for index, topic in enumerate(topics)] + topic_labels = ( + model.custom_labels_ + if model.custom_labels_ + else [model.topic_labels_[topic] for topic in topics] + ) + topics = [ + f"| {topic} | {topic_keywords[index]} | {topic_freq[topic]} | {topic_labels[index]} | \n" + for index, topic in enumerate(topics) + ] topics = topic_table_head + "".join(topics) - frameworks = "\n".join([f"* {param}: {value}" for param, value in get_package_versions().items()]) + frameworks = "\n".join( + [f"* {param}: {value}" for param, value in get_package_versions().items()] + ) # Fill Statistics into model card model_card = model_card.replace("{MODEL_NAME}", model_name) model_card = model_card.replace("{PATH}", repo_id) - model_card = model_card.replace("{NR_TOPICS}", nr_topics) - model_card = model_card.replace("{TOPICS}", topics.strip()) + model_card = model_card.replace("{NR_TOPICS}", nr_topics) + model_card = model_card.replace("{TOPICS}", topics.strip()) model_card = model_card.replace("{NR_DOCUMENTS}", nr_documents) model_card = model_card.replace("{HYPERPARAMS}", params) model_card = model_card.replace("{FRAMEWORKS}", frameworks) - + # Fill Pipeline tag has_visual_aspect = check_has_visual_aspect(model) if not has_visual_aspect: model_card = model_card.replace("{PIPELINE_TAG}", "text-classification") else: - model_card = model_card.replace("pipeline_tag: {PIPELINE_TAG}\n","") # TODO add proper tag for this instance - + model_card = model_card.replace( + "pipeline_tag: {PIPELINE_TAG}\n", "" + ) # TODO add proper tag for this instance + return model_card def save_hf(model, save_directory, serialization: str): - """ Save topic embeddings, either safely (using safetensors) or using legacy pytorch """ + """Save topic embeddings, either safely (using safetensors) or using legacy pytorch.""" tensors = torch.from_numpy(np.array(model.topic_embeddings_, dtype=np.float32)) tensors = {"topic_embeddings": tensors} @@ -311,10 +349,8 @@ def save_hf(model, save_directory, serialization: str): torch.save(tensors, save_directory / HF_WEIGHTS_NAME) -def save_ctfidf(model, - save_directory: str, - serialization: str): - """ Save c-TF-IDF sparse matrix """ +def save_ctfidf(model, save_directory: str, serialization: str): + """Save c-TF-IDF sparse matrix.""" indptr = torch.from_numpy(model.c_tf_idf_.indptr) indices = torch.from_numpy(model.c_tf_idf_.indices) data = torch.from_numpy(model.c_tf_idf_.data) @@ -325,7 +361,7 @@ def save_ctfidf(model, "indices": indices, "data": data, "shape": shape, - "diag": diag + "diag": diag, } if serialization == "safetensors": @@ -336,13 +372,13 @@ def save_ctfidf(model, def save_ctfidf_config(model, path): - """ Save parameters to recreate CountVectorizer and c-TF-IDF """ + """Save parameters to recreate CountVectorizer and c-TF-IDF.""" config = {} # Recreate ClassTfidfTransformer config["ctfidf_model"] = { "bm25_weighting": model.ctfidf_model.bm25_weighting, - "reduce_frequent_words": model.ctfidf_model.reduce_frequent_words + "reduce_frequent_words": model.ctfidf_model.reduce_frequent_words, } # Recreate CountVectorizer @@ -353,15 +389,15 @@ def save_ctfidf_config(model, path): config["vectorizer_model"] = { "params": cv_params, - "vocab": model.vectorizer_model.vocabulary_ + "vocab": model.vectorizer_model.vocabulary_, } - with path.open('w') as f: + with path.open("w") as f: json.dump(config, f, indent=2) def save_config(model, path: str, embedding_model): - """ Save BERTopic configuration """ + """Save BERTopic configuration.""" path = Path(path) params = model.get_params() config = {param: value for param, value in params.items() if "model" not in param} @@ -370,28 +406,29 @@ def save_config(model, path: str, embedding_model): if isinstance(embedding_model, str): config["embedding_model"] = embedding_model - with path.open('w') as f: + with path.open("w") as f: json.dump(config, f, indent=2) return config + def check_has_visual_aspect(model): - """Check if model has visual aspect""" + """Check if model has visual aspect.""" if _has_vision: for aspect, value in model.topic_aspects_.items(): if isinstance(value[0], Image.Image): - visual_aspects = model.topic_aspects_[aspect] return True + def save_images(model, path: str): - """ Save topic images """ + """Save topic images.""" if _has_vision: visual_aspects = None for aspect, value in model.topic_aspects_.items(): if isinstance(value[0], Image.Image): visual_aspects = model.topic_aspects_[aspect] break - + if visual_aspects is not None: path.mkdir(exist_ok=True, parents=True) for topic, image in visual_aspects.items(): @@ -399,7 +436,7 @@ def save_images(model, path: str): def save_topics(model, path: str): - """ Save Topic-specific information """ + """Save Topic-specific information.""" path = Path(path) if _has_vision: @@ -420,15 +457,15 @@ def save_topics(model, path: str): "topic_labels": model.topic_labels_, "custom_labels": model.custom_labels_, "_outliers": int(model._outliers), - "topic_aspects": selected_topic_aspects + "topic_aspects": selected_topic_aspects, } - with path.open('w') as f: + with path.open("w") as f: json.dump(topics, f, indent=2, cls=NumpyEncoder) def load_cfg_from_json(json_file: Union[str, os.PathLike]): - """ Load configuration from json """ + """Load configuration from json.""" with open(json_file, "r", encoding="utf-8") as reader: text = reader.read() return json.loads(text) @@ -443,17 +480,17 @@ def default(self, obj): return super(NumpyEncoder, self).default(obj) - def get_package_versions(): - """ Get versions of main dependencies of BERTopic """ + """Get versions of main dependencies of BERTopic.""" try: import platform from numpy import __version__ as np_version - + try: from importlib.metadata import version - hdbscan_version = version('hdbscan') - except: + + hdbscan_version = version("hdbscan") + except: # noqa: E722 hdbscan_version = None from umap import __version__ as umap_version @@ -462,31 +499,42 @@ def get_package_versions(): from sentence_transformers import __version__ as sbert_version from numba import __version__ as numba_version from transformers import __version__ as transformers_version - + from plotly import __version__ as plotly_version - return {"Numpy": np_version, "HDBSCAN": hdbscan_version, "UMAP": umap_version, - "Pandas": pandas_version, "Scikit-Learn": sklearn_version, - "Sentence-transformers": sbert_version, "Transformers": transformers_version, - "Numba": numba_version, "Plotly": plotly_version, "Python": platform.python_version()} + + return { + "Numpy": np_version, + "HDBSCAN": hdbscan_version, + "UMAP": umap_version, + "Pandas": pandas_version, + "Scikit-Learn": sklearn_version, + "Sentence-transformers": sbert_version, + "Transformers": transformers_version, + "Numba": numba_version, + "Plotly": plotly_version, + "Python": platform.python_version(), + } except Exception as e: return e - + def load_safetensors(path): - """ Load safetensors and check whether it is installed """ + """Load safetensors and check whether it is installed.""" try: import safetensors.torch import safetensors + return safetensors.torch.load_file(path, device="cpu") except ImportError: raise ValueError("`pip install safetensors` to load .safetensors") def save_safetensors(path, tensors): - """ Save safetensors and check whether it is installed """ + """Save safetensors and check whether it is installed.""" try: import safetensors.torch import safetensors + safetensors.torch.save_file(tensors, path) except ImportError: raise ValueError("`pip install safetensors` to save as .safetensors") diff --git a/bertopic/_utils.py b/bertopic/_utils.py index f8a88f11..0695b7cf 100644 --- a/bertopic/_utils.py +++ b/bertopic/_utils.py @@ -9,7 +9,7 @@ class MyLogger: def __init__(self): - self.logger = logging.getLogger('BERTopic') + self.logger = logging.getLogger("BERTopic") def configure(self, level): self.set_level(level) @@ -29,7 +29,7 @@ def set_level(self, level): def _add_handler(self): sh = logging.StreamHandler() - sh.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(message)s')) + sh.setFormatter(logging.Formatter("%(asctime)s - %(name)s - %(message)s")) self.logger.addHandler(sh) # Remove duplicate handlers @@ -38,33 +38,41 @@ def _add_handler(self): def check_documents_type(documents): - """ Check whether the input documents are indeed a list of strings """ + """Check whether the input documents are indeed a list of strings.""" if isinstance(documents, pd.DataFrame): raise TypeError("Make sure to supply a list of strings, not a dataframe.") elif isinstance(documents, Iterable) and not isinstance(documents, str): if not any([isinstance(doc, str) for doc in documents]): raise TypeError("Make sure that the iterable only contains strings.") else: - raise TypeError("Make sure that the documents variable is an iterable containing strings only.") + raise TypeError( + "Make sure that the documents variable is an iterable containing strings only." + ) def check_embeddings_shape(embeddings, docs): - """ Check if the embeddings have the correct shape """ + """Check if the embeddings have the correct shape.""" if embeddings is not None: - if not any([isinstance(embeddings, np.ndarray), isinstance(embeddings, csr_matrix)]): - raise ValueError("Make sure to input embeddings as a numpy array or scipy.sparse.csr.csr_matrix. ") + if not any( + [isinstance(embeddings, np.ndarray), isinstance(embeddings, csr_matrix)] + ): + raise ValueError( + "Make sure to input embeddings as a numpy array or scipy.sparse.csr.csr_matrix. " + ) else: if embeddings.shape[0] != len(docs): - raise ValueError("Make sure that the embeddings are a numpy array with shape: " - "(len(docs), vector_dim) where vector_dim is the dimensionality " - "of the vector embeddings. ") + raise ValueError( + "Make sure that the embeddings are a numpy array with shape: " + "(len(docs), vector_dim) where vector_dim is the dimensionality " + "of the vector embeddings. " + ) def check_is_fitted(topic_model): - """ Checks if the model was fitted by verifying the presence of self.matches + """Checks if the model was fitted by verifying the presence of self.matches. Arguments: - model: BERTopic instance for which the check is performed. + topic_model: BERTopic instance for which the check is performed. Returns: None @@ -72,16 +80,17 @@ def check_is_fitted(topic_model): Raises: ValueError: If the matches were not found. """ - msg = ("This %(name)s instance is not fitted yet. Call 'fit' with " - "appropriate arguments before using this estimator.") + msg = ( + "This %(name)s instance is not fitted yet. Call 'fit' with " + "appropriate arguments before using this estimator." + ) if topic_model.topics_ is None: - raise ValueError(msg % {'name': type(topic_model).__name__}) + raise ValueError(msg % {"name": type(topic_model).__name__}) class NotInstalled: - """ - This object is used to notify the user that additional dependencies need to be + """This object is used to notify the user that additional dependencies need to be installed in order to use the string matching model. """ @@ -104,7 +113,7 @@ def __call__(self, *args, **kwargs): def validate_distance_matrix(X, n_samples): - """ Validate the distance matrix and convert it to a condensed distance matrix + """Validate the distance matrix and convert it to a condensed distance matrix if necessary. A valid distance matrix is either a square matrix of shape (n_samples, n_samples) @@ -128,22 +137,27 @@ def validate_distance_matrix(X, n_samples): # check it has correct size n = s[0] if n != (n_samples * (n_samples - 1) / 2): - raise ValueError("The condensed distance matrix must have " - "shape (n*(n-1)/2,).") + raise ValueError( + "The condensed distance matrix must have " "shape (n*(n-1)/2,)." + ) elif len(s) == 2: # check it has correct size if (s[0] != n_samples) or (s[1] != n_samples): - raise ValueError("The distance matrix must be of shape " - "(n, n) where n is the number of samples.") + raise ValueError( + "The distance matrix must be of shape " + "(n, n) where n is the number of samples." + ) # force zero diagonal and convert to condensed np.fill_diagonal(X, 0) X = squareform(X) else: - raise ValueError("The distance matrix must be either a 1-D condensed " - "distance matrix of shape (n*(n-1)/2,) or a " - "2-D square distance matrix of shape (n, n)." - "where n is the number of documents." - "Got a distance matrix of shape %s" % str(s)) + raise ValueError( + "The distance matrix must be either a 1-D condensed " + "distance matrix of shape (n*(n-1)/2,) or a " + "2-D square distance matrix of shape (n, n)." + "where n is the number of documents." + "Got a distance matrix of shape %s" % str(s) + ) # Make sure its entries are non-negative if np.any(X < 0): @@ -152,7 +166,6 @@ def validate_distance_matrix(X, n_samples): return X - def get_unique_distances(dists: np.array, noise_max=1e-7) -> np.array: """Check if the consecutive elements in the distance array are the same. If so, a small noise is added to one of the elements to make sure that the array does not contain duplicates. @@ -169,14 +182,18 @@ def get_unique_distances(dists: np.array, noise_max=1e-7) -> np.array: for i in range(dists.shape[0] - 1): if dists[i] == dists[i + 1]: # returns the next unique distance or the current distance with the added noise - next_unique_dist = next((d for d in dists[i + 1:] if d != dists[i]), dists[i] + noise_max) + next_unique_dist = next( + (d for d in dists[i + 1 :] if d != dists[i]), dists[i] + noise_max + ) # the noise can never be large then the difference between the next unique distance and the current one curr_max_noise = min(noise_max, next_unique_dist - dists_cp[i]) - dists_cp[i + 1] = np.random.uniform(low=dists_cp[i] + curr_max_noise / 2, high=dists_cp[i] + curr_max_noise) + dists_cp[i + 1] = np.random.uniform( + low=dists_cp[i] + curr_max_noise / 2, high=dists_cp[i] + curr_max_noise + ) return dists_cp - + def select_topic_representation( ctfidf_embeddings: Optional[Union[np.ndarray, csr_matrix]] = None, embeddings: Optional[Union[np.ndarray, csr_matrix]] = None, diff --git a/bertopic/backend/__init__.py b/bertopic/backend/__init__.py index 3d331ee8..df123b8b 100644 --- a/bertopic/backend/__init__.py +++ b/bertopic/backend/__init__.py @@ -31,5 +31,5 @@ "OpenAIBackend", "CohereBackend", "MultiModalBackend", - "languages" + "languages", ] diff --git a/bertopic/backend/_base.py b/bertopic/backend/_base.py index 81f8a061..97809b15 100644 --- a/bertopic/backend/_base.py +++ b/bertopic/backend/_base.py @@ -3,7 +3,7 @@ class BaseEmbedder: - """ The Base Embedder used for creating embedding models + """The Base Embedder used for creating embedding models. Arguments: embedding_model: The main embedding model to be used for extracting @@ -13,17 +13,14 @@ class BaseEmbedder: then the `embedding_model` is purely used for creating document embeddings. """ - def __init__(self, - embedding_model=None, - word_embedding_model=None): + + def __init__(self, embedding_model=None, word_embedding_model=None): self.embedding_model = embedding_model self.word_embedding_model = word_embedding_model - def embed(self, - documents: List[str], - verbose: bool = False) -> np.ndarray: - """ Embed a list of n documents/words into an n-dimensional - matrix of embeddings + def embed(self, documents: List[str], verbose: bool = False) -> np.ndarray: + """Embed a list of n documents/words into an n-dimensional + matrix of embeddings. Arguments: documents: A list of documents or words to be embedded @@ -35,11 +32,9 @@ def embed(self, """ pass - def embed_words(self, - words: List[str], - verbose: bool = False) -> np.ndarray: - """ Embed a list of n words into an n-dimensional - matrix of embeddings + def embed_words(self, words: List[str], verbose: bool = False) -> np.ndarray: + """Embed a list of n words into an n-dimensional + matrix of embeddings. Arguments: words: A list of words to be embedded @@ -52,11 +47,9 @@ def embed_words(self, """ return self.embed(words, verbose) - def embed_documents(self, - document: List[str], - verbose: bool = False) -> np.ndarray: - """ Embed a list of n words into an n-dimensional - matrix of embeddings + def embed_documents(self, document: List[str], verbose: bool = False) -> np.ndarray: + """Embed a list of n words into an n-dimensional + matrix of embeddings. Arguments: document: A list of documents to be embedded diff --git a/bertopic/backend/_cohere.py b/bertopic/backend/_cohere.py index 5a7b9a81..a7b618f2 100644 --- a/bertopic/backend/_cohere.py +++ b/bertopic/backend/_cohere.py @@ -6,7 +6,7 @@ class CohereBackend(BaseEmbedder): - """ Cohere Embedding Model + """Cohere Embedding Model. Arguments: client: A `cohere` client. @@ -21,7 +21,6 @@ class CohereBackend(BaseEmbedder): such as `input_type` Examples: - ```python import cohere from bertopic.backend import CohereBackend @@ -40,12 +39,15 @@ class CohereBackend(BaseEmbedder): ) ``` """ - def __init__(self, - client, - embedding_model: str = "large", - delay_in_seconds: float = None, - batch_size: int = None, - embed_kwargs: Mapping[str, Any] = {}): + + def __init__( + self, + client, + embedding_model: str = "large", + delay_in_seconds: float = None, + batch_size: int = None, + embed_kwargs: Mapping[str, Any] = {}, + ): super().__init__() self.client = client self.embedding_model = embedding_model @@ -58,11 +60,9 @@ def __init__(self, else: self.embed_kwargs["model"] = self.embedding_model - def embed(self, - documents: List[str], - verbose: bool = False) -> np.ndarray: - """ Embed a list of n documents/words into an n-dimensional - matrix of embeddings + def embed(self, documents: List[str], verbose: bool = False) -> np.ndarray: + """Embed a list of n documents/words into an n-dimensional + matrix of embeddings. Arguments: documents: A list of documents or words to be embedded @@ -91,4 +91,4 @@ def embed(self, def _chunks(self, documents): for i in range(0, len(documents), self.batch_size): - yield documents[i:i + self.batch_size] + yield documents[i : i + self.batch_size] diff --git a/bertopic/backend/_flair.py b/bertopic/backend/_flair.py index 7cd6cfc7..2abeec49 100644 --- a/bertopic/backend/_flair.py +++ b/bertopic/backend/_flair.py @@ -8,7 +8,7 @@ class FlairBackend(BaseEmbedder): - """ Flair Embedding Model + """Flair Embedding Model. The Flair embedding model used for generating document and word embeddings. @@ -17,7 +17,6 @@ class FlairBackend(BaseEmbedder): embedding_model: A Flair embedding model Examples: - ```python from bertopic.backend import FlairBackend from flair.embeddings import WordEmbeddings, DocumentPoolEmbeddings @@ -30,6 +29,7 @@ class FlairBackend(BaseEmbedder): flair_embedder = FlairBackend(document_glove_embeddings) ``` """ + def __init__(self, embedding_model: Union[TokenEmbeddings, DocumentEmbeddings]): super().__init__() @@ -45,16 +45,16 @@ def __init__(self, embedding_model: Union[TokenEmbeddings, DocumentEmbeddings]): self.embedding_model = embedding_model else: - raise ValueError("Please select a correct Flair model by either using preparing a token or document " - "embedding model: \n" - "`from flair.embeddings import TransformerDocumentEmbeddings` \n" - "`roberta = TransformerDocumentEmbeddings('roberta-base')`") + raise ValueError( + "Please select a correct Flair model by either using preparing a token or document " + "embedding model: \n" + "`from flair.embeddings import TransformerDocumentEmbeddings` \n" + "`roberta = TransformerDocumentEmbeddings('roberta-base')`" + ) - def embed(self, - documents: List[str], - verbose: bool = False) -> np.ndarray: - """ Embed a list of n documents/words into an n-dimensional - matrix of embeddings + def embed(self, documents: List[str], verbose: bool = False) -> np.ndarray: + """Embed a list of n documents/words into an n-dimensional + matrix of embeddings. Arguments: documents: A list of documents or words to be embedded @@ -67,7 +67,9 @@ def embed(self, embeddings = [] for document in tqdm(documents, disable=not verbose): try: - sentence = Sentence(document) if document else Sentence("an empty document") + sentence = ( + Sentence(document) if document else Sentence("an empty document") + ) self.embedding_model.embed(sentence) except RuntimeError: sentence = Sentence("an empty document") diff --git a/bertopic/backend/_gensim.py b/bertopic/backend/_gensim.py index 7ceb603d..3727e04d 100644 --- a/bertopic/backend/_gensim.py +++ b/bertopic/backend/_gensim.py @@ -6,7 +6,7 @@ class GensimBackend(BaseEmbedder): - """ Gensim Embedding Model + """Gensim Embedding Model. The Gensim embedding model is typically used for word embeddings with GloVe, Word2Vec or FastText. @@ -15,7 +15,6 @@ class GensimBackend(BaseEmbedder): embedding_model: A Gensim embedding model Examples: - ```python from bertopic.backend import GensimBackend import gensim.downloader as api @@ -24,21 +23,22 @@ class GensimBackend(BaseEmbedder): ft_embedder = GensimBackend(ft) ``` """ + def __init__(self, embedding_model: Word2VecKeyedVectors): super().__init__() if isinstance(embedding_model, Word2VecKeyedVectors): self.embedding_model = embedding_model else: - raise ValueError("Please select a correct Gensim model: \n" - "`import gensim.downloader as api` \n" - "`ft = api.load('fasttext-wiki-news-subwords-300')`") + raise ValueError( + "Please select a correct Gensim model: \n" + "`import gensim.downloader as api` \n" + "`ft = api.load('fasttext-wiki-news-subwords-300')`" + ) - def embed(self, - documents: List[str], - verbose: bool = False) -> np.ndarray: - """ Embed a list of n documents/words into an n-dimensional - matrix of embeddings + def embed(self, documents: List[str], verbose: bool = False) -> np.ndarray: + """Embed a list of n documents/words into an n-dimensional + matrix of embeddings. Arguments: documents: A list of documents or words to be embedded @@ -48,14 +48,19 @@ def embed(self, Document/words embeddings with shape (n, m) with `n` documents/words that each have an embeddings size of `m` """ - vector_shape = self.embedding_model.get_vector(list(self.embedding_model.index_to_key)[0]).shape[0] + vector_shape = self.embedding_model.get_vector( + list(self.embedding_model.index_to_key)[0] + ).shape[0] empty_vector = np.zeros(vector_shape) # Extract word embeddings and pool to document-level embeddings = [] for doc in tqdm(documents, disable=not verbose, position=0, leave=True): - embedding = [self.embedding_model.get_vector(word) for word in doc.split() - if word in self.embedding_model.key_to_index] + embedding = [ + self.embedding_model.get_vector(word) + for word in doc.split() + if word in self.embedding_model.key_to_index + ] if len(embedding) > 0: embeddings.append(np.mean(embedding, axis=0)) diff --git a/bertopic/backend/_hftransformers.py b/bertopic/backend/_hftransformers.py index 9d77c7dd..8de9cc2a 100644 --- a/bertopic/backend/_hftransformers.py +++ b/bertopic/backend/_hftransformers.py @@ -10,17 +10,16 @@ class HFTransformerBackend(BaseEmbedder): - """ Hugging Face transformers model + """Hugging Face transformers model. - This uses the `transformers.pipelines.pipeline` to define and create - a feature generation pipeline from which embeddings can be extracted. + This uses the `transformers.pipelines.pipeline` to define and create + a feature generation pipeline from which embeddings can be extracted. Arguments: embedding_model: A Hugging Face feature extraction pipeline Examples: - - To use a Hugging Face transformers model, load in a pipeline and point + To use a Hugging Face transformers model, load in a pipeline and point to any model found on their model hub (https://huggingface.co/models): ```python @@ -31,20 +30,21 @@ class HFTransformerBackend(BaseEmbedder): embedding_model = HFTransformerBackend(hf_model) ``` """ + def __init__(self, embedding_model: Pipeline): super().__init__() if isinstance(embedding_model, Pipeline): self.embedding_model = embedding_model else: - raise ValueError("Please select a correct transformers pipeline. For example: " - "pipeline('feature-extraction', model='distilbert-base-cased', device=0)") + raise ValueError( + "Please select a correct transformers pipeline. For example: " + "pipeline('feature-extraction', model='distilbert-base-cased', device=0)" + ) - def embed(self, - documents: List[str], - verbose: bool = False) -> np.ndarray: - """ Embed a list of n documents/words into an n-dimensional - matrix of embeddings + def embed(self, documents: List[str], verbose: bool = False) -> np.ndarray: + """Embed a list of n documents/words into an n-dimensional + matrix of embeddings. Arguments: documents: A list of documents or words to be embedded @@ -57,16 +57,19 @@ def embed(self, dataset = MyDataset(documents) embeddings = [] - for document, features in tqdm(zip(documents, self.embedding_model(dataset, truncation=True, padding=True)), - total=len(dataset), disable=not verbose): + for document, features in tqdm( + zip( + documents, self.embedding_model(dataset, truncation=True, padding=True) + ), + total=len(dataset), + disable=not verbose, + ): embeddings.append(self._embed(document, features)) return np.array(embeddings) - def _embed(self, - document: str, - features: np.ndarray) -> np.ndarray: - """ Mean pooling + def _embed(self, document: str, features: np.ndarray) -> np.ndarray: + """Mean pooling. Arguments: document: The document for which to extract the attention mask @@ -76,16 +79,25 @@ def _embed(self, https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2#usage-huggingface-transformers """ token_embeddings = np.array(features) - attention_mask = self.embedding_model.tokenizer(document, truncation=True, padding=True, return_tensors="np")["attention_mask"] - input_mask_expanded = np.broadcast_to(np.expand_dims(attention_mask, -1), token_embeddings.shape) + attention_mask = self.embedding_model.tokenizer( + document, truncation=True, padding=True, return_tensors="np" + )["attention_mask"] + input_mask_expanded = np.broadcast_to( + np.expand_dims(attention_mask, -1), token_embeddings.shape + ) sum_embeddings = np.sum(token_embeddings * input_mask_expanded, 1) - sum_mask = np.clip(input_mask_expanded.sum(1), a_min=1e-9, a_max=input_mask_expanded.sum(1).max()) + sum_mask = np.clip( + input_mask_expanded.sum(1), + a_min=1e-9, + a_max=input_mask_expanded.sum(1).max(), + ) embedding = normalize(sum_embeddings / sum_mask)[0] return embedding class MyDataset(Dataset): - """ Dataset to pass to `transformers.pipelines.pipeline` """ + """Dataset to pass to `transformers.pipelines.pipeline`.""" + def __init__(self, docs): self.docs = docs diff --git a/bertopic/backend/_multimodal.py b/bertopic/backend/_multimodal.py index 053919f0..846efc41 100644 --- a/bertopic/backend/_multimodal.py +++ b/bertopic/backend/_multimodal.py @@ -1,4 +1,3 @@ - import numpy as np from PIL import Image from tqdm import tqdm @@ -9,22 +8,21 @@ class MultiModalBackend(BaseEmbedder): - """ Multimodal backend using Sentence-transformers + """Multimodal backend using Sentence-transformers. - The sentence-transformers embedding model used for - generating word, document, and image embeddings. + The sentence-transformers embedding model used for + generating word, document, and image embeddings. Arguments: - embedding_model: A sentence-transformers embedding model that + embedding_model: A sentence-transformers embedding model that can either embed both images and text or only text. - If it only embeds text, then `image_model` needs + If it only embeds text, then `image_model` needs to be used to embed the images. image_model: A sentence-transformers embedding model that is used to embed only images. batch_size: The sizes of image batches to pass Examples: - To create a model, you can load in a string pointing to a sentence-transformers model: @@ -43,22 +41,27 @@ class MultiModalBackend(BaseEmbedder): sentence_model = MultiModalBackend(embedding_model) ``` """ - def __init__(self, - embedding_model: Union[str, SentenceTransformer], - image_model: Union[str, SentenceTransformer] = None, - batch_size: int = 32): + + def __init__( + self, + embedding_model: Union[str, SentenceTransformer], + image_model: Union[str, SentenceTransformer] = None, + batch_size: int = 32, + ): super().__init__() self.batch_size = batch_size - + # Text or Text+Image model if isinstance(embedding_model, SentenceTransformer): self.embedding_model = embedding_model elif isinstance(embedding_model, str): self.embedding_model = SentenceTransformer(embedding_model) else: - raise ValueError("Please select a correct SentenceTransformers model: \n" - "`from sentence_transformers import SentenceTransformer` \n" - "`model = SentenceTransformer('clip-ViT-B-32')`") + raise ValueError( + "Please select a correct SentenceTransformers model: \n" + "`from sentence_transformers import SentenceTransformer` \n" + "`model = SentenceTransformer('clip-ViT-B-32')`" + ) # Image Model self.image_model = None @@ -68,26 +71,31 @@ def __init__(self, elif isinstance(image_model, str): self.image_model = SentenceTransformer(image_model) else: - raise ValueError("Please select a correct SentenceTransformers model: \n" - "`from sentence_transformers import SentenceTransformer` \n" - "`model = SentenceTransformer('clip-ViT-B-32')`") - + raise ValueError( + "Please select a correct SentenceTransformers model: \n" + "`from sentence_transformers import SentenceTransformer` \n" + "`model = SentenceTransformer('clip-ViT-B-32')`" + ) + try: self.tokenizer = self.embedding_model._first_module().processor.tokenizer except AttributeError: self.tokenizer = self.embedding_model.tokenizer - except: + except: # noqa: E722 self.tokenizer = None - def embed(self, - documents: List[str], - images: List[str] = None, - verbose: bool = False) -> np.ndarray: - """ Embed a list of n documents/words into an n-dimensional - matrix of embeddings + def embed( + self, documents: List[str], images: List[str] = None, verbose: bool = False + ) -> np.ndarray: + """Embed a list of n documents/words or images into an n-dimensional + matrix of embeddings. + + Either documents, images, or both can be provided. If both are provided, + then the embeddings are averaged. Arguments: documents: A list of documents or words to be embedded + images: A list of image paths to be embedded verbose: Controls the verbosity of the process Returns: @@ -115,12 +123,12 @@ def embed(self, return doc_embeddings elif image_embeddings is not None: return image_embeddings - - def embed_documents(self, - documents: List[str], - verbose: bool = False) -> np.ndarray: - """ Embed a list of n documents/words into an n-dimensional - matrix of embeddings + + def embed_documents( + self, documents: List[str], verbose: bool = False + ) -> np.ndarray: + """Embed a list of n documents/words into an n-dimensional + matrix of embeddings. Arguments: documents: A list of documents or words to be embedded @@ -131,12 +139,14 @@ def embed_documents(self, that each have an embeddings size of `m` """ truncated_docs = [self._truncate_document(doc) for doc in documents] - embeddings = self.embedding_model.encode(truncated_docs, show_progress_bar=verbose) + embeddings = self.embedding_model.encode( + truncated_docs, show_progress_bar=verbose + ) return embeddings - + def embed_words(self, words: List[str], verbose: bool = False) -> np.ndarray: - """ Embed a list of n words into an n-dimensional - matrix of embeddings + """Embed a list of n words into an n-dimensional + matrix of embeddings. Arguments: words: A list of words to be embedded @@ -146,9 +156,9 @@ def embed_words(self, words: List[str], verbose: bool = False) -> np.ndarray: Document/words embeddings with shape (n, m) with `n` documents/words that each have an embeddings size of `m` """ - embeddings = self.embedding_model.encode(words, show_progress_bar=verbose) + embeddings = self.embedding_model.encode(words, show_progress_bar=verbose) return embeddings - + def embed_images(self, images, verbose): if self.batch_size: nr_iterations = int(np.ceil(len(images) / self.batch_size)) @@ -159,11 +169,16 @@ def embed_images(self, images, verbose): start_index = i * self.batch_size end_index = (i * self.batch_size) + self.batch_size - images_to_embed = [Image.open(image) if isinstance(image, str) else image for image in images[start_index:end_index]] + images_to_embed = [ + Image.open(image) if isinstance(image, str) else image + for image in images[start_index:end_index] + ] if self.image_model is not None: img_emb = self.image_model.encode(images_to_embed) else: - img_emb = self.embedding_model.encode(images_to_embed, show_progress_bar=False) + img_emb = self.embedding_model.encode( + images_to_embed, show_progress_bar=False + ) embeddings.extend(img_emb.tolist()) # Close images @@ -176,9 +191,11 @@ def embed_images(self, images, verbose): if self.image_model is not None: embeddings = self.image_model.encode(images_to_embed) else: - embeddings = self.embedding_model.encode(images_to_embed, show_progress_bar=False) + embeddings = self.embedding_model.encode( + images_to_embed, show_progress_bar=False + ) return embeddings - + def _truncate_document(self, document): if self.tokenizer: tokens = self.tokenizer.encode(document) diff --git a/bertopic/backend/_openai.py b/bertopic/backend/_openai.py index 2a8e03a9..19d18268 100644 --- a/bertopic/backend/_openai.py +++ b/bertopic/backend/_openai.py @@ -7,7 +7,7 @@ class OpenAIBackend(BaseEmbedder): - """ OpenAI Embedding Model + """OpenAI Embedding Model. Arguments: client: A `openai.OpenAI` client. @@ -22,7 +22,6 @@ class OpenAIBackend(BaseEmbedder): deployment_ids. Examples: - ```python import openai from bertopic.backend import OpenAIBackend @@ -31,12 +30,15 @@ class OpenAIBackend(BaseEmbedder): openai_embedder = OpenAIBackend(client, "text-embedding-ada-002") ``` """ - def __init__(self, - client: openai.OpenAI, - embedding_model: str = "text-embedding-ada-002", - delay_in_seconds: float = None, - batch_size: int = None, - generator_kwargs: Mapping[str, Any] = {}): + + def __init__( + self, + client: openai.OpenAI, + embedding_model: str = "text-embedding-ada-002", + delay_in_seconds: float = None, + batch_size: int = None, + generator_kwargs: Mapping[str, Any] = {}, + ): super().__init__() self.client = client self.embedding_model = embedding_model @@ -49,11 +51,9 @@ def __init__(self, elif not self.generator_kwargs.get("engine"): self.generator_kwargs["model"] = self.embedding_model - def embed(self, - documents: List[str], - verbose: bool = False) -> np.ndarray: - """ Embed a list of n documents/words into an n-dimensional - matrix of embeddings + def embed(self, documents: List[str], verbose: bool = False) -> np.ndarray: + """Embed a list of n documents/words into an n-dimensional + matrix of embeddings. Arguments: documents: A list of documents or words to be embedded @@ -70,7 +70,9 @@ def embed(self, if self.batch_size is not None: embeddings = [] for batch in tqdm(self._chunks(prepared_documents), disable=not verbose): - response = self.client.embeddings.create(input=batch, **self.generator_kwargs) + response = self.client.embeddings.create( + input=batch, **self.generator_kwargs + ) embeddings.extend([r.embedding for r in response.data]) # Delay subsequent calls @@ -79,10 +81,12 @@ def embed(self, # Extract embeddings all at once else: - response = self.client.embeddings.create(input=prepared_documents, **self.generator_kwargs) + response = self.client.embeddings.create( + input=prepared_documents, **self.generator_kwargs + ) embeddings = [r.embedding for r in response.data] return np.array(embeddings) def _chunks(self, documents): for i in range(0, len(documents), self.batch_size): - yield documents[i:i + self.batch_size] + yield documents[i : i + self.batch_size] diff --git a/bertopic/backend/_sentencetransformers.py b/bertopic/backend/_sentencetransformers.py index dbc0dc7b..a54ad0ec 100644 --- a/bertopic/backend/_sentencetransformers.py +++ b/bertopic/backend/_sentencetransformers.py @@ -6,7 +6,7 @@ class SentenceTransformerBackend(BaseEmbedder): - """ Sentence-transformers embedding model + """Sentence-transformers embedding model. The sentence-transformers embedding model used for generating document and word embeddings. @@ -15,7 +15,6 @@ class SentenceTransformerBackend(BaseEmbedder): embedding_model: A sentence-transformers embedding model Examples: - To create a model, you can load in a string pointing to a sentence-transformers model: @@ -34,6 +33,7 @@ class SentenceTransformerBackend(BaseEmbedder): sentence_model = SentenceTransformerBackend(embedding_model) ``` """ + def __init__(self, embedding_model: Union[str, SentenceTransformer]): super().__init__() @@ -44,15 +44,15 @@ def __init__(self, embedding_model: Union[str, SentenceTransformer]): self.embedding_model = SentenceTransformer(embedding_model) self._hf_model = embedding_model else: - raise ValueError("Please select a correct SentenceTransformers model: \n" - "`from sentence_transformers import SentenceTransformer` \n" - "`model = SentenceTransformer('all-MiniLM-L6-v2')`") + raise ValueError( + "Please select a correct SentenceTransformers model: \n" + "`from sentence_transformers import SentenceTransformer` \n" + "`model = SentenceTransformer('all-MiniLM-L6-v2')`" + ) - def embed(self, - documents: List[str], - verbose: bool = False) -> np.ndarray: - """ Embed a list of n documents/words into an n-dimensional - matrix of embeddings + def embed(self, documents: List[str], verbose: bool = False) -> np.ndarray: + """Embed a list of n documents/words into an n-dimensional + matrix of embeddings. Arguments: documents: A list of documents or words to be embedded @@ -63,4 +63,4 @@ def embed(self, that each have an embeddings size of `m` """ embeddings = self.embedding_model.encode(documents, show_progress_bar=verbose) - return embeddings \ No newline at end of file + return embeddings diff --git a/bertopic/backend/_sklearn.py b/bertopic/backend/_sklearn.py index b341a038..d8150fe6 100644 --- a/bertopic/backend/_sklearn.py +++ b/bertopic/backend/_sklearn.py @@ -3,7 +3,7 @@ class SklearnEmbedder(BaseEmbedder): - """ Scikit-Learn based embedding model + """Scikit-Learn based embedding model. This component allows the usage of scikit-learn pipelines for generating document and word embeddings. @@ -12,15 +12,14 @@ class SklearnEmbedder(BaseEmbedder): pipe: A scikit-learn pipeline that can `.transform()` text. Examples: - - Scikit-Learn is very flexible and it allows for many representations. - A relatively simple pipeline is shown below. + Scikit-Learn is very flexible and it allows for many representations. + A relatively simple pipeline is shown below. ```python from sklearn.pipeline import make_pipeline from sklearn.decomposition import TruncatedSVD from sklearn.feature_extraction.text import TfidfVectorizer - + from bertopic.backend import SklearnEmbedder pipe = make_pipeline( @@ -33,23 +32,24 @@ class SklearnEmbedder(BaseEmbedder): ``` This pipeline first constructs a sparse representation based on TF/idf and then - makes it dense by applying SVD. Alternatively, you might also construct something + makes it dense by applying SVD. Alternatively, you might also construct something more elaborate. As long as you construct a scikit-learn compatible pipeline, you - should be able to pass it to Bertopic. + should be able to pass it to Bertopic. - !!! Warning + !!! Warning One caveat to be aware of is that scikit-learns base `Pipeline` class does not support the `.partial_fit()`-API. If you have a pipeline that theoretically should be able to support online learning then you might want to explore - the [scikit-partial](https://github.com/koaning/scikit-partial) project. + the [scikit-partial](https://github.com/koaning/scikit-partial) project. """ + def __init__(self, pipe): super().__init__() self.pipe = pipe def embed(self, documents, verbose=False): - """ Embed a list of n documents/words into an n-dimensional - matrix of embeddings + """Embed a list of n documents/words into an n-dimensional + matrix of embeddings. Arguments: documents: A list of documents or words to be embedded @@ -59,10 +59,10 @@ def embed(self, documents, verbose=False): Document/words embeddings with shape (n, m) with `n` documents/words that each have an embeddings size of `m` """ - try: + try: check_is_fitted(self.pipe) embeddings = self.pipe.transform(documents) except NotFittedError: embeddings = self.pipe.fit_transform(documents) - return embeddings + return embeddings diff --git a/bertopic/backend/_spacy.py b/bertopic/backend/_spacy.py index 96afb710..f55fd080 100644 --- a/bertopic/backend/_spacy.py +++ b/bertopic/backend/_spacy.py @@ -5,7 +5,7 @@ class SpacyBackend(BaseEmbedder): - """ Spacy embedding model + """Spacy embedding model. The Spacy embedding model used for generating document and word embeddings. @@ -14,7 +14,6 @@ class SpacyBackend(BaseEmbedder): embedding_model: A spacy embedding model Examples: - To create a Spacy backend, you need to create an nlp object and pass it through this backend: @@ -50,20 +49,21 @@ class SpacyBackend(BaseEmbedder): spacy_model = SpacyBackend(nlp) ``` """ + def __init__(self, embedding_model): super().__init__() if "spacy" in str(type(embedding_model)): self.embedding_model = embedding_model else: - raise ValueError("Please select a correct Spacy model by either using a string such as 'en_core_web_md' " - "or create a nlp model using: `nlp = spacy.load('en_core_web_md')") + raise ValueError( + "Please select a correct Spacy model by either using a string such as 'en_core_web_md' " + "or create a nlp model using: `nlp = spacy.load('en_core_web_md')" + ) - def embed(self, - documents: List[str], - verbose: bool = False) -> np.ndarray: - """ Embed a list of n documents/words into an n-dimensional - matrix of embeddings + def embed(self, documents: List[str], verbose: bool = False) -> np.ndarray: + """Embed a list of n documents/words into an n-dimensional + matrix of embeddings. Arguments: documents: A list of documents or words to be embedded @@ -86,7 +86,7 @@ def embed(self, else: embedding = embedding._.trf_data.tensors[-1][0] - if not isinstance(embedding, np.ndarray) and hasattr(embedding, 'get'): + if not isinstance(embedding, np.ndarray) and hasattr(embedding, "get"): # Convert cupy array to numpy array embedding = embedding.get() embeddings.append(embedding) diff --git a/bertopic/backend/_use.py b/bertopic/backend/_use.py index 142e06bf..c33c76fc 100644 --- a/bertopic/backend/_use.py +++ b/bertopic/backend/_use.py @@ -6,7 +6,7 @@ class USEBackend(BaseEmbedder): - """ Universal Sentence Encoder + """Universal Sentence Encoder. USE encodes text into high-dimensional vectors that are used for semantic similarity in BERTopic. @@ -15,7 +15,6 @@ class USEBackend(BaseEmbedder): embedding_model: An USE embedding model Examples: - ```python import tensorflow_hub from bertopic.backend import USEBackend @@ -24,6 +23,7 @@ class USEBackend(BaseEmbedder): use_embedder = USEBackend(embedding_model) ``` """ + def __init__(self, embedding_model): super().__init__() @@ -31,15 +31,15 @@ def __init__(self, embedding_model): embedding_model(["test sentence"]) self.embedding_model = embedding_model except TypeError: - raise ValueError("Please select a correct USE model: \n" - "`import tensorflow_hub` \n" - "`embedding_model = tensorflow_hub.load(path_to_model)`") - - def embed(self, - documents: List[str], - verbose: bool = False) -> np.ndarray: - """ Embed a list of n documents/words into an n-dimensional - matrix of embeddings + raise ValueError( + "Please select a correct USE model: \n" + "`import tensorflow_hub` \n" + "`embedding_model = tensorflow_hub.load(path_to_model)`" + ) + + def embed(self, documents: List[str], verbose: bool = False) -> np.ndarray: + """Embed a list of n documents/words into an n-dimensional + matrix of embeddings. Arguments: documents: A list of documents or words to be embedded diff --git a/bertopic/backend/_utils.py b/bertopic/backend/_utils.py index 8386e8b0..7c78d32e 100644 --- a/bertopic/backend/_utils.py +++ b/bertopic/backend/_utils.py @@ -68,10 +68,10 @@ ] -def select_backend(embedding_model, - language: str = None, - verbose: bool = False) -> BaseEmbedder: - """ Select an embedding model based on language or a specific provided model. +def select_backend( + embedding_model, language: str = None, verbose: bool = False +) -> BaseEmbedder: + """Select an embedding model based on language or a specific provided model. When selecting a language, we choose all-MiniLM-L6-v2 for English and paraphrase-multilingual-MiniLM-L12-v2 for all other languages as it support 100+ languages. If sentence-transformers is not installed, in the case of a lightweight installation, @@ -80,7 +80,6 @@ def select_backend(embedding_model, Returns: model: The selected model backend. """ - logger.set_level("INFO" if verbose else "WARNING") # BERTopic language backend @@ -94,46 +93,61 @@ def select_backend(embedding_model, # Flair word embeddings if "flair" in str(type(embedding_model)): from bertopic.backend._flair import FlairBackend + return FlairBackend(embedding_model) # Spacy embeddings if "spacy" in str(type(embedding_model)): from bertopic.backend._spacy import SpacyBackend + return SpacyBackend(embedding_model) # Gensim embeddings if "gensim" in str(type(embedding_model)): from bertopic.backend._gensim import GensimBackend + return GensimBackend(embedding_model) # USE embeddings if "tensorflow" and "saved_model" in str(type(embedding_model)): from bertopic.backend._use import USEBackend + return USEBackend(embedding_model) # Sentence Transformer embeddings - if "sentence_transformers" in str(type(embedding_model)) or isinstance(embedding_model, str): + if "sentence_transformers" in str(type(embedding_model)) or isinstance( + embedding_model, str + ): from ._sentencetransformers import SentenceTransformerBackend + return SentenceTransformerBackend(embedding_model) # Hugging Face embeddings if "transformers" and "pipeline" in str(type(embedding_model)): from ._hftransformers import HFTransformerBackend + return HFTransformerBackend(embedding_model) # Select embedding model based on language if language: try: from ._sentencetransformers import SentenceTransformerBackend + if language.lower() in ["English", "english", "en"]: - return SentenceTransformerBackend("sentence-transformers/all-MiniLM-L6-v2") + return SentenceTransformerBackend( + "sentence-transformers/all-MiniLM-L6-v2" + ) elif language.lower() in languages or language == "multilingual": - return SentenceTransformerBackend("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2") + return SentenceTransformerBackend( + "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" + ) else: - raise ValueError(f"{language} is currently not supported. However, you can " - f"create any embeddings yourself and pass it through fit_transform(docs, embeddings)\n" - "Else, please select a language from the following list:\n" - f"{languages}") + raise ValueError( + f"{language} is currently not supported. However, you can " + f"create any embeddings yourself and pass it through fit_transform(docs, embeddings)\n" + "Else, please select a language from the following list:\n" + f"{languages}" + ) # A ModuleNotFoundError might be a lightweight installation except ModuleNotFoundError as e: @@ -142,9 +156,12 @@ def select_backend(embedding_model, raise e # Whole sentence_transformers module is missing, probably a lightweight install if verbose: - logger.info("Automatically selecting lightweight scikit-learn embedding backend as sentence-transformers appears to not be installed.") + logger.info( + "Automatically selecting lightweight scikit-learn embedding backend as sentence-transformers appears to not be installed." + ) pipe = make_pipeline(TfidfVectorizer(), TruncatedSVD(100)) return SklearnEmbedder(pipe) from ._sentencetransformers import SentenceTransformerBackend + return SentenceTransformerBackend("sentence-transformers/all-MiniLM-L6-v2") diff --git a/bertopic/backend/_word_doc.py b/bertopic/backend/_word_doc.py index c71cad13..4cb7a201 100644 --- a/bertopic/backend/_word_doc.py +++ b/bertopic/backend/_word_doc.py @@ -5,21 +5,17 @@ class WordDocEmbedder(BaseEmbedder): - """ Combine a document- and word-level embedder - """ - def __init__(self, - embedding_model, - word_embedding_model): + """Combine a document- and word-level embedder.""" + + def __init__(self, embedding_model, word_embedding_model): super().__init__() self.embedding_model = select_backend(embedding_model) self.word_embedding_model = select_backend(word_embedding_model) - def embed_words(self, - words: List[str], - verbose: bool = False) -> np.ndarray: - """ Embed a list of n words into an n-dimensional - matrix of embeddings + def embed_words(self, words: List[str], verbose: bool = False) -> np.ndarray: + """Embed a list of n words into an n-dimensional + matrix of embeddings. Arguments: words: A list of words to be embedded @@ -32,11 +28,9 @@ def embed_words(self, """ return self.word_embedding_model.embed(words, verbose) - def embed_documents(self, - document: List[str], - verbose: bool = False) -> np.ndarray: - """ Embed a list of n words into an n-dimensional - matrix of embeddings + def embed_documents(self, document: List[str], verbose: bool = False) -> np.ndarray: + """Embed a list of n words into an n-dimensional + matrix of embeddings. Arguments: document: A list of documents to be embedded diff --git a/bertopic/cluster/_base.py b/bertopic/cluster/_base.py index dc8412f0..a096d99e 100644 --- a/bertopic/cluster/_base.py +++ b/bertopic/cluster/_base.py @@ -2,15 +2,14 @@ class BaseCluster: - """ The Base Cluster class + """The Base Cluster class. Using this class directly in BERTopic will make it skip - over the cluster step. As a result, topics need to be passed - to BERTopic in the form of its `y` parameter in order to create - topic representations. - - Examples: + over the cluster step. As a result, topics need to be passed + to BERTopic in the form of its `y` parameter in order to create + topic representations. + Examples: This will skip over the cluster step in BERTopic: ```python @@ -22,14 +21,15 @@ class BaseCluster: topic_model = BERTopic(hdbscan_model=empty_cluster_model) ``` - Then, this class can be used to perform manual topic modeling. - That is, topic modeling on a topics that were already generated before + Then, this class can be used to perform manual topic modeling. + That is, topic modeling on a topics that were already generated before without the need to learn them: ```python topic_model.fit(docs, y=y) ``` """ + def fit(self, X, y=None): if y is not None: self.labels_ = y diff --git a/bertopic/cluster/_utils.py b/bertopic/cluster/_utils.py index 4e1805cc..82f243c6 100644 --- a/bertopic/cluster/_utils.py +++ b/bertopic/cluster/_utils.py @@ -3,7 +3,7 @@ def hdbscan_delegator(model, func: str, embeddings: np.ndarray = None): - """ Function used to select the HDBSCAN-like model for generating + """Function used to select the HDBSCAN-like model for generating predictions and probabilities. Arguments: @@ -15,7 +15,6 @@ def hdbscan_delegator(model, func: str, embeddings: np.ndarray = None): embeddings: Input embeddings for "approximate_predict" and "membership_vector" """ - # Approximate predict if func == "approximate_predict": if isinstance(model, hdbscan.HDBSCAN): @@ -25,7 +24,10 @@ def hdbscan_delegator(model, func: str, embeddings: np.ndarray = None): str_type_model = str(type(model)).lower() if "cuml" in str_type_model and "hdbscan" in str_type_model: from cuml.cluster import hdbscan as cuml_hdbscan - predictions, probabilities = cuml_hdbscan.approximate_predict(model, embeddings) + + predictions, probabilities = cuml_hdbscan.approximate_predict( + model, embeddings + ) return predictions, probabilities predictions = model.predict(embeddings) @@ -39,10 +41,11 @@ def hdbscan_delegator(model, func: str, embeddings: np.ndarray = None): str_type_model = str(type(model)).lower() if "cuml" in str_type_model and "hdbscan" in str_type_model: from cuml.cluster import hdbscan as cuml_hdbscan + return cuml_hdbscan.all_points_membership_vectors(model) return None - + # membership_vector if func == "membership_vector": if isinstance(model, hdbscan.HDBSCAN): @@ -52,6 +55,7 @@ def hdbscan_delegator(model, func: str, embeddings: np.ndarray = None): str_type_model = str(type(model)).lower() if "cuml" in str_type_model and "hdbscan" in str_type_model: from cuml.cluster import hdbscan as cuml_hdbscan + probabilities = cuml_hdbscan.membership_vector(model, embeddings) return probabilities @@ -59,7 +63,7 @@ def hdbscan_delegator(model, func: str, embeddings: np.ndarray = None): def is_supported_hdbscan(model): - """ Check whether the input model is a supported HDBSCAN-like model """ + """Check whether the input model is a supported HDBSCAN-like model.""" if isinstance(model, hdbscan.HDBSCAN): return True diff --git a/bertopic/dimensionality/_base.py b/bertopic/dimensionality/_base.py index 7b39c3b4..922f8df7 100644 --- a/bertopic/dimensionality/_base.py +++ b/bertopic/dimensionality/_base.py @@ -2,12 +2,11 @@ class BaseDimensionalityReduction: - """ The Base Dimensionality Reduction class + """The Base Dimensionality Reduction class. You can use this to skip over the dimensionality reduction step in BERTopic. Examples: - This will skip over the reduction step in BERTopic: ```python @@ -19,6 +18,7 @@ class BaseDimensionalityReduction: topic_model = BERTopic(umap_model=empty_reduction_model) ``` """ + def fit(self, X: np.ndarray = None): return self diff --git a/bertopic/plotting/__init__.py b/bertopic/plotting/__init__.py index e6ace977..d4579b56 100644 --- a/bertopic/plotting/__init__.py +++ b/bertopic/plotting/__init__.py @@ -24,5 +24,5 @@ "visualize_topics_over_time", "visualize_topics_per_class", "visualize_hierarchical_documents", - "visualize_approximate_distribution" + "visualize_approximate_distribution", ] diff --git a/bertopic/plotting/_approximate_distribution.py b/bertopic/plotting/_approximate_distribution.py index 1f93424f..a6380273 100644 --- a/bertopic/plotting/_approximate_distribution.py +++ b/bertopic/plotting/_approximate_distribution.py @@ -2,62 +2,64 @@ import pandas as pd try: - from pandas.io.formats.style import Styler + from pandas.io.formats.style import Styler # noqa: F401 + HAS_JINJA = True except (ModuleNotFoundError, ImportError): HAS_JINJA = False -def visualize_approximate_distribution(topic_model, - document: str, - topic_token_distribution: np.ndarray, - normalize: bool = False): - """ Visualize the topic distribution calculated by `.approximate_topic_distribution` - on a token level. Thereby indicating the extend to which a certain word or phrases belong - to a specific topic. The assumption here is that a single word can belong to multiple - similar topics and as such give information about the broader set of topics within - a single document. +def visualize_approximate_distribution( + topic_model, + document: str, + topic_token_distribution: np.ndarray, + normalize: bool = False, +): + """Visualize the topic distribution calculated by `.approximate_topic_distribution` + on a token level. Thereby indicating the extend to which a certain word or phrases belong + to a specific topic. The assumption here is that a single word can belong to multiple + similar topics and as such give information about the broader set of topics within + a single document. - NOTE: - This function will return a stylized pandas dataframe if Jinja2 is installed. If not, + Note: + This function will return a stylized pandas dataframe if Jinja2 is installed. If not, it will only return a pandas dataframe without color highlighting. To install jinja: `pip install jinja2` - + Arguments: topic_model: A fitted BERTopic instance. - document: The document for which you want to visualize + document: The document for which you want to visualize the approximated topic distribution. - topic_token_distribution: The topic-token distribution of the document as + topic_token_distribution: The topic-token distribution of the document as extracted by `.approximate_topic_distribution` - normalize: Whether to normalize, between 0 and 1 (summing to 1), the - topic distribution values. - + normalize: Whether to normalize, between 0 and 1 (summing to 1), the + topic distribution values. + Returns: df: A stylized dataframe indicating the best fitting topics for each token. - + Examples: - ```python # Calculate the topic distributions on a token level # Note that we need to have `calculate_token_level=True` topic_distr, topic_token_distr = topic_model.approximate_distribution( docs, calculate_token_level=True ) - + # Visualize the approximated topic distributions df = topic_model.visualize_approximate_distribution(docs[0], topic_token_distr[0]) df ``` - - To revert this stylized dataframe back to a regular dataframe, + + To revert this stylized dataframe back to a regular dataframe, you can run the following: - + ```python df.data.columns = [column.strip() for column in df.data.columns] df = df.data - ``` + ``` """ # Tokenize document analyzer = topic_model.vectorizer_model.build_tokenizer() @@ -65,35 +67,36 @@ def visualize_approximate_distribution(topic_model, if len(tokens) == 0: raise ValueError("Make sure that your document contains at least 1 token.") - + # Prepare dataframe with results if normalize: df = pd.DataFrame(topic_token_distribution / topic_token_distribution.sum()).T else: df = pd.DataFrame(topic_token_distribution).T - + df.columns = [f"{token}_{i}" for i, token in enumerate(tokens)] df.columns = [f"{token}{' '*i}" for i, token in enumerate(tokens)] - df.index = list(topic_model.topic_labels_.values())[topic_model._outliers:] + df.index = list(topic_model.topic_labels_.values())[topic_model._outliers :] df = df.loc[(df.sum(axis=1) != 0), :] - + # Style the resulting dataframe def text_color(val): - color = 'white' if val == 0 else 'black' - return 'color: %s' % color + color = "white" if val == 0 else "black" + return "color: %s" % color + + def highligh_color(data, color="white"): + attr = "background-color: {}".format(color) + return pd.DataFrame( + np.where(data == 0, attr, ""), index=data.index, columns=data.columns + ) - def highligh_color(data, color='white'): - attr = 'background-color: {}'.format(color) - return pd.DataFrame(np.where(data == 0, attr, ''), index=data.index, columns=data.columns) - if len(df) == 0: return df elif HAS_JINJA: df = ( - df.style - .format("{:.3f}") - .background_gradient(cmap='Blues', axis=None) - .applymap(lambda x: text_color(x)) - .apply(highligh_color, axis=None) + df.style.format("{:.3f}") + .background_gradient(cmap="Blues", axis=None) + .applymap(lambda x: text_color(x)) + .apply(highligh_color, axis=None) ) return df diff --git a/bertopic/plotting/_barchart.py b/bertopic/plotting/_barchart.py index 7ece6855..417e2c0f 100644 --- a/bertopic/plotting/_barchart.py +++ b/bertopic/plotting/_barchart.py @@ -6,23 +6,25 @@ from plotly.subplots import make_subplots -def visualize_barchart(topic_model, - topics: List[int] = None, - top_n_topics: int = 8, - n_words: int = 5, - custom_labels: Union[bool, str] = False, - title: str = "Topic Word Scores", - width: int = 250, - height: int = 250, - autoscale: bool=False) -> go.Figure: - """ Visualize a barchart of selected topics +def visualize_barchart( + topic_model, + topics: List[int] = None, + top_n_topics: int = 8, + n_words: int = 5, + custom_labels: Union[bool, str] = False, + title: str = "Topic Word Scores", + width: int = 250, + height: int = 250, + autoscale: bool = False, +) -> go.Figure: + """Visualize a barchart of selected topics. Arguments: topic_model: A fitted BERTopic instance. topics: A selection of topics to visualize. top_n_topics: Only select the top n most frequent topics. n_words: Number of words to show in a topic - custom_labels: If bool, whether to use custom topic labels that were defined using + custom_labels: If bool, whether to use custom topic labels that were defined using `topic_model.set_topic_labels`. If `str`, it uses labels from other aspects, e.g., "Aspect1". title: Title of the plot. @@ -34,7 +36,6 @@ def visualize_barchart(topic_model, fig: A plotly figure Examples: - To visualize the barchart of selected topics simply run: @@ -51,7 +52,9 @@ def visualize_barchart(topic_model, """ - colors = itertools.cycle(["#D55E00", "#0072B2", "#CC79A7", "#E69F00", "#56B4E9", "#009E73", "#F0E442"]) + colors = itertools.cycle( + ["#D55E00", "#0072B2", "#CC79A7", "#E69F00", "#56B4E9", "#009E73", "#F0E442"] + ) # Select topics based on top_n and topics args freq_df = topic_model.get_topic_freq() @@ -65,44 +68,55 @@ def visualize_barchart(topic_model, # Initialize figure if isinstance(custom_labels, str): - subplot_titles = [[[str(topic), None]] + topic_model.topic_aspects_[custom_labels][topic] for topic in topics] - subplot_titles = ["_".join([label[0] for label in labels[:4]]) for labels in subplot_titles] - subplot_titles = [label if len(label) < 30 else label[:27] + "..." for label in subplot_titles] + subplot_titles = [ + [[str(topic), None]] + topic_model.topic_aspects_[custom_labels][topic] + for topic in topics + ] + subplot_titles = [ + "_".join([label[0] for label in labels[:4]]) for labels in subplot_titles + ] + subplot_titles = [ + label if len(label) < 30 else label[:27] + "..." for label in subplot_titles + ] elif topic_model.custom_labels_ is not None and custom_labels: - subplot_titles = [topic_model.custom_labels_[topic + topic_model._outliers] for topic in topics] + subplot_titles = [ + topic_model.custom_labels_[topic + topic_model._outliers] + for topic in topics + ] else: subplot_titles = [f"Topic {topic}" for topic in topics] columns = 4 rows = int(np.ceil(len(topics) / columns)) - fig = make_subplots(rows=rows, - cols=columns, - shared_xaxes=False, - horizontal_spacing=.1, - vertical_spacing=.4 / rows if rows > 1 else 0, - subplot_titles=subplot_titles) + fig = make_subplots( + rows=rows, + cols=columns, + shared_xaxes=False, + horizontal_spacing=0.1, + vertical_spacing=0.4 / rows if rows > 1 else 0, + subplot_titles=subplot_titles, + ) # Add barchart for each topic row = 1 column = 1 for topic in topics: - words = [word + " " for word, _ in topic_model.get_topic(topic)][:n_words][::-1] + words = [word + " " for word, _ in topic_model.get_topic(topic)][:n_words][ + ::-1 + ] scores = [score for _, score in topic_model.get_topic(topic)][:n_words][::-1] fig.add_trace( - go.Bar(x=scores, - y=words, - orientation='h', - marker_color=next(colors)), - row=row, col=column) + go.Bar(x=scores, y=words, orientation="h", marker_color=next(colors)), + row=row, + col=column, + ) if autoscale: if len(words) > 12: height = 250 + (len(words) - 12) * 11 if len(words) > 9: - fig.update_yaxes( - tickfont=dict(size=(height - 140) // len(words)) - ) + fig.update_yaxes(tickfont=dict(size=(height - 140) // len(words))) if column == columns: column = 1 @@ -115,21 +129,15 @@ def visualize_barchart(topic_model, template="plotly_white", showlegend=False, title={ - 'text': f"{title}", - 'x': .5, - 'xanchor': 'center', - 'yanchor': 'top', - 'font': dict( - size=22, - color="Black") + "text": f"{title}", + "x": 0.5, + "xanchor": "center", + "yanchor": "top", + "font": dict(size=22, color="Black"), }, - width=width*4, - height=height*rows if rows > 1 else height * 1.3, - hoverlabel=dict( - bgcolor="white", - font_size=16, - font_family="Rockwell" - ), + width=width * 4, + height=height * rows if rows > 1 else height * 1.3, + hoverlabel=dict(bgcolor="white", font_size=16, font_family="Rockwell"), ) fig.update_xaxes(showgrid=True) diff --git a/bertopic/plotting/_datamap.py b/bertopic/plotting/_datamap.py index 08577121..a793e4fc 100644 --- a/bertopic/plotting/_datamap.py +++ b/bertopic/plotting/_datamap.py @@ -3,27 +3,32 @@ from typing import List, Union from umap import UMAP from warnings import warn + try: import datamapplot from matplotlib.figure import Figure except ImportError: warn("Data map plotting is unavailable unless datamapplot is installed.") + # Create a dummy figure type for typing - class Figure (object): + class Figure(object): pass -def visualize_document_datamap(topic_model, - docs: List[str], - topics: List[int] = None, - embeddings: np.ndarray = None, - reduced_embeddings: np.ndarray = None, - custom_labels: Union[bool, str] = False, - title: str = "Documents and Topics", - sub_title: Union[str, None] = None, - width: int = 1200, - height: int = 1200, - **datamap_kwds) -> Figure: - """ Visualize documents and their topics in 2D as a static plot for publication using + +def visualize_document_datamap( + topic_model, + docs: List[str], + topics: List[int] = None, + embeddings: np.ndarray = None, + reduced_embeddings: np.ndarray = None, + custom_labels: Union[bool, str] = False, + title: str = "Documents and Topics", + sub_title: Union[str, None] = None, + width: int = 1200, + height: int = 1200, + **datamap_kwds, +) -> Figure: + """Visualize documents and their topics in 2D as a static plot for publication using DataMapPlot. Arguments: @@ -51,7 +56,6 @@ def visualize_document_datamap(topic_model, figure: A Matplotlib Figure object. Examples: - To visualize the topics simply run: ```python @@ -94,7 +98,6 @@ def visualize_document_datamap(topic_model, DataMapPlot of 20-Newsgroups """ - topic_per_doc = topic_model.topics_ df = pd.DataFrame({"topic": np.array(topic_per_doc)}) @@ -103,13 +106,17 @@ def visualize_document_datamap(topic_model, # Extract embeddings if not already done if embeddings is None and reduced_embeddings is None: - embeddings_to_reduce = topic_model._extract_embeddings(df.doc.to_list(), method="document") + embeddings_to_reduce = topic_model._extract_embeddings( + df.doc.to_list(), method="document" + ) else: embeddings_to_reduce = embeddings # Reduce input embeddings if reduced_embeddings is None: - umap_model = UMAP(n_neighbors=15, n_components=2, min_dist=0.15, metric='cosine').fit(embeddings_to_reduce) + umap_model = UMAP( + n_neighbors=15, n_components=2, min_dist=0.15, metric="cosine" + ).fit(embeddings_to_reduce) embeddings_2d = umap_model.embedding_ else: embeddings_2d = reduced_embeddings @@ -118,15 +125,27 @@ def visualize_document_datamap(topic_model, # Prepare text and names if isinstance(custom_labels, str): - names = [[[str(topic), None]] + topic_model.topic_aspects_[custom_labels][topic] for topic in unique_topics] + names = [ + [[str(topic), None]] + topic_model.topic_aspects_[custom_labels][topic] + for topic in unique_topics + ] names = [" ".join([label[0] for label in labels[:4]]) for labels in names] names = [label if len(label) < 30 else label[:27] + "..." for label in names] elif topic_model.custom_labels_ is not None and custom_labels: - names = [topic_model.custom_labels_[topic + topic_model._outliers] for topic in unique_topics] + names = [ + topic_model.custom_labels_[topic + topic_model._outliers] + for topic in unique_topics + ] else: - names = [f"Topic-{topic}: " + " ".join([word for word, value in topic_model.get_topic(topic)][:3]) for topic in unique_topics] - - topic_name_mapping = {topic_num: topic_name for topic_num, topic_name in zip(unique_topics, names)} + names = [ + f"Topic-{topic}: " + + " ".join([word for word, value in topic_model.get_topic(topic)][:3]) + for topic in unique_topics + ] + + topic_name_mapping = { + topic_num: topic_name for topic_num, topic_name in zip(unique_topics, names) + } topic_name_mapping[-1] = "Unlabelled" # If a set of topics is chosen, set everything else to "Unlabelled" @@ -142,7 +161,7 @@ def visualize_document_datamap(topic_model, figure, axes = datamapplot.create_plot( embeddings_2d, named_topic_per_doc, - figsize=(width/100, height/100), + figsize=(width / 100, height / 100), dpi=100, title=title, sub_title=sub_title, diff --git a/bertopic/plotting/_distribution.py b/bertopic/plotting/_distribution.py index d1d2b4b0..d04d140b 100644 --- a/bertopic/plotting/_distribution.py +++ b/bertopic/plotting/_distribution.py @@ -3,21 +3,23 @@ import plotly.graph_objects as go -def visualize_distribution(topic_model, - probabilities: np.ndarray, - min_probability: float = 0.015, - custom_labels: Union[bool, str] = False, - title: str = "Topic Probability Distribution", - width: int = 800, - height: int = 600) -> go.Figure: - """ Visualize the distribution of topic probabilities +def visualize_distribution( + topic_model, + probabilities: np.ndarray, + min_probability: float = 0.015, + custom_labels: Union[bool, str] = False, + title: str = "Topic Probability Distribution", + width: int = 800, + height: int = 600, +) -> go.Figure: + """Visualize the distribution of topic probabilities. Arguments: topic_model: A fitted BERTopic instance. probabilities: An array of probability scores min_probability: The minimum probability score to visualize. All others are ignored. - custom_labels: If bool, whether to use custom topic labels that were defined using + custom_labels: If bool, whether to use custom topic labels that were defined using `topic_model.set_topic_labels`. If `str`, it uses labels from other aspects, e.g., "Aspect1". title: Title of the plot. @@ -25,7 +27,6 @@ def visualize_distribution(topic_model, height: The height of the figure. Examples: - Make sure to fit the model before and only input the probabilities of a single document: @@ -43,11 +44,15 @@ def visualize_distribution(topic_model, style="width:1000px; height: 500px; border: 0px;""> """ if len(probabilities.shape) != 1: - raise ValueError("This visualization cannot be used if you have set `calculate_probabilities` to False " - "as it uses the topic probabilities of all topics. ") + raise ValueError( + "This visualization cannot be used if you have set `calculate_probabilities` to False " + "as it uses the topic probabilities of all topics. " + ) if len(probabilities[probabilities > min_probability]) == 0: - raise ValueError("There are no values where `min_probability` is higher than the " - "probabilities that were supplied. Lower `min_probability` to prevent this error.") + raise ValueError( + "There are no values where `min_probability` is higher than the " + "probabilities that were supplied. Lower `min_probability` to prevent this error." + ) # Get values and indices equal or exceed the minimum probability labels_idx = np.argwhere(probabilities >= min_probability).flatten() @@ -55,11 +60,17 @@ def visualize_distribution(topic_model, # Create labels if isinstance(custom_labels, str): - labels = [[[str(topic), None]] + topic_model.topic_aspects_[custom_labels][topic] for topic in labels_idx] - labels = ["_".join([label[0] for label in l[:4]]) for l in labels] + labels = [ + [[str(topic), None]] + topic_model.topic_aspects_[custom_labels][topic] + for topic in labels_idx + ] + labels = ["_".join([label[0] for label in l[:4]]) for l in labels] # noqa: E741 labels = [label if len(label) < 30 else label[:27] + "..." for label in labels] elif topic_model.custom_labels_ is not None and custom_labels: - labels = [topic_model.custom_labels_[idx + topic_model._outliers] for idx in labels_idx] + labels = [ + topic_model.custom_labels_[idx + topic_model._outliers] + for idx in labels_idx + ] else: labels = [] for idx in labels_idx: @@ -73,38 +84,32 @@ def visualize_distribution(topic_model, vals.remove(probabilities[idx]) # Create Figure - fig = go.Figure(go.Bar( - x=vals, - y=labels, - marker=dict( - color='#C8D2D7', - line=dict( - color='#6E8484', - width=1), - ), - orientation='h') + fig = go.Figure( + go.Bar( + x=vals, + y=labels, + marker=dict( + color="#C8D2D7", + line=dict(color="#6E8484", width=1), + ), + orientation="h", + ) ) fig.update_layout( xaxis_title="Probability", title={ - 'text': f"{title}", - 'y': .95, - 'x': 0.5, - 'xanchor': 'center', - 'yanchor': 'top', - 'font': dict( - size=22, - color="Black") + "text": f"{title}", + "y": 0.95, + "x": 0.5, + "xanchor": "center", + "yanchor": "top", + "font": dict(size=22, color="Black"), }, template="simple_white", width=width, height=height, - hoverlabel=dict( - bgcolor="white", - font_size=16, - font_family="Rockwell" - ), + hoverlabel=dict(bgcolor="white", font_size=16, font_family="Rockwell"), ) return fig diff --git a/bertopic/plotting/_documents.py b/bertopic/plotting/_documents.py index c38d95ac..0c5287b4 100644 --- a/bertopic/plotting/_documents.py +++ b/bertopic/plotting/_documents.py @@ -6,19 +6,21 @@ from typing import List, Union -def visualize_documents(topic_model, - docs: List[str], - topics: List[int] = None, - embeddings: np.ndarray = None, - reduced_embeddings: np.ndarray = None, - sample: float = None, - hide_annotations: bool = False, - hide_document_hover: bool = False, - custom_labels: Union[bool, str] = False, - title: str = "Documents and Topics", - width: int = 1200, - height: int = 750): - """ Visualize documents and their topics in 2D +def visualize_documents( + topic_model, + docs: List[str], + topics: List[int] = None, + embeddings: np.ndarray = None, + reduced_embeddings: np.ndarray = None, + sample: float = None, + hide_annotations: bool = False, + hide_document_hover: bool = False, + custom_labels: Union[bool, str] = False, + title: str = "Documents and Topics", + width: int = 1200, + height: int = 750, +): + """Visualize documents and their topics in 2D. Arguments: topic_model: A fitted BERTopic instance. @@ -36,7 +38,7 @@ def visualize_documents(topic_model, hide_annotations: Hide the names of the traces on top of each cluster. hide_document_hover: Hide the content of the documents when hovering over specific points. Helps to speed up generation of visualization. - custom_labels: If bool, whether to use custom topic labels that were defined using + custom_labels: If bool, whether to use custom topic labels that were defined using `topic_model.set_topic_labels`. If `str`, it uses labels from other aspects, e.g., "Aspect1". title: Title of the plot. @@ -44,7 +46,6 @@ def visualize_documents(topic_model, height: The height of the figure. Examples: - To visualize the topics simply run: ```python @@ -108,18 +109,24 @@ def visualize_documents(topic_model, # Extract embeddings if not already done if sample is None: if embeddings is None and reduced_embeddings is None: - embeddings_to_reduce = topic_model._extract_embeddings(df.doc.to_list(), method="document") + embeddings_to_reduce = topic_model._extract_embeddings( + df.doc.to_list(), method="document" + ) else: embeddings_to_reduce = embeddings else: if embeddings is not None: embeddings_to_reduce = embeddings[indices] elif embeddings is None and reduced_embeddings is None: - embeddings_to_reduce = topic_model._extract_embeddings(df.doc.to_list(), method="document") + embeddings_to_reduce = topic_model._extract_embeddings( + df.doc.to_list(), method="document" + ) # Reduce input embeddings if reduced_embeddings is None: - umap_model = UMAP(n_neighbors=10, n_components=2, min_dist=0.0, metric='cosine').fit(embeddings_to_reduce) + umap_model = UMAP( + n_neighbors=10, n_components=2, min_dist=0.0, metric="cosine" + ).fit(embeddings_to_reduce) embeddings_2d = umap_model.embedding_ elif sample is not None and reduced_embeddings is not None: embeddings_2d = reduced_embeddings[indices] @@ -136,13 +143,23 @@ def visualize_documents(topic_model, # Prepare text and names if isinstance(custom_labels, str): - names = [[[str(topic), None]] + topic_model.topic_aspects_[custom_labels][topic] for topic in unique_topics] + names = [ + [[str(topic), None]] + topic_model.topic_aspects_[custom_labels][topic] + for topic in unique_topics + ] names = ["_".join([label[0] for label in labels[:4]]) for labels in names] names = [label if len(label) < 30 else label[:27] + "..." for label in names] elif topic_model.custom_labels_ is not None and custom_labels: - names = [topic_model.custom_labels_[topic + topic_model._outliers] for topic in unique_topics] + names = [ + topic_model.custom_labels_[topic + topic_model._outliers] + for topic in unique_topics + ] else: - names = [f"{topic}_" + "_".join([word for word, value in topic_model.get_topic(topic)][:3]) for topic in unique_topics] + names = [ + f"{topic}_" + + "_".join([word for word, value in topic_model.get_topic(topic)][:3]) + for topic in unique_topics + ] # Visualize fig = go.Figure() @@ -154,7 +171,13 @@ def visualize_documents(topic_model, selection = df.loc[df.topic.isin(non_selected_topics), :] selection["text"] = "" - selection.loc[len(selection), :] = [None, None, selection.x.mean(), selection.y.mean(), "Other documents"] + selection.loc[len(selection), :] = [ + None, + None, + selection.x.mean(), + selection.y.mean(), + "Other documents", + ] fig.add_trace( go.Scattergl( @@ -162,10 +185,10 @@ def visualize_documents(topic_model, y=selection.y, hovertext=selection.doc if not hide_document_hover else None, hoverinfo="text", - mode='markers+text', + mode="markers+text", name="other", showlegend=False, - marker=dict(color='#CFD8DC', size=5, opacity=0.5) + marker=dict(color="#CFD8DC", size=5, opacity=0.5), ) ) @@ -176,7 +199,13 @@ def visualize_documents(topic_model, selection["text"] = "" if not hide_annotations: - selection.loc[len(selection), :] = [None, None, selection.x.mean(), selection.y.mean(), name] + selection.loc[len(selection), :] = [ + None, + None, + selection.x.mean(), + selection.y.mean(), + name, + ] fig.add_trace( go.Scattergl( @@ -185,41 +214,59 @@ def visualize_documents(topic_model, hovertext=selection.doc if not hide_document_hover else None, hoverinfo="text", text=selection.text, - mode='markers+text', + mode="markers+text", name=name, textfont=dict( size=12, ), - marker=dict(size=5, opacity=0.5) + marker=dict(size=5, opacity=0.5), ) ) # Add grid in a 'plus' shape - x_range = (df.x.min() - abs((df.x.min()) * .15), df.x.max() + abs((df.x.max()) * .15)) - y_range = (df.y.min() - abs((df.y.min()) * .15), df.y.max() + abs((df.y.max()) * .15)) - fig.add_shape(type="line", - x0=sum(x_range) / 2, y0=y_range[0], x1=sum(x_range) / 2, y1=y_range[1], - line=dict(color="#CFD8DC", width=2)) - fig.add_shape(type="line", - x0=x_range[0], y0=sum(y_range) / 2, x1=x_range[1], y1=sum(y_range) / 2, - line=dict(color="#9E9E9E", width=2)) - fig.add_annotation(x=x_range[0], y=sum(y_range) / 2, text="D1", showarrow=False, yshift=10) - fig.add_annotation(y=y_range[1], x=sum(x_range) / 2, text="D2", showarrow=False, xshift=10) + x_range = ( + df.x.min() - abs((df.x.min()) * 0.15), + df.x.max() + abs((df.x.max()) * 0.15), + ) + y_range = ( + df.y.min() - abs((df.y.min()) * 0.15), + df.y.max() + abs((df.y.max()) * 0.15), + ) + fig.add_shape( + type="line", + x0=sum(x_range) / 2, + y0=y_range[0], + x1=sum(x_range) / 2, + y1=y_range[1], + line=dict(color="#CFD8DC", width=2), + ) + fig.add_shape( + type="line", + x0=x_range[0], + y0=sum(y_range) / 2, + x1=x_range[1], + y1=sum(y_range) / 2, + line=dict(color="#9E9E9E", width=2), + ) + fig.add_annotation( + x=x_range[0], y=sum(y_range) / 2, text="D1", showarrow=False, yshift=10 + ) + fig.add_annotation( + y=y_range[1], x=sum(x_range) / 2, text="D2", showarrow=False, xshift=10 + ) # Stylize layout fig.update_layout( template="simple_white", title={ - 'text': f"{title}", - 'x': 0.5, - 'xanchor': 'center', - 'yanchor': 'top', - 'font': dict( - size=22, - color="Black") + "text": f"{title}", + "x": 0.5, + "xanchor": "center", + "yanchor": "top", + "font": dict(size=22, color="Black"), }, width=width, - height=height + height=height, ) fig.update_xaxes(visible=False) diff --git a/bertopic/plotting/_heatmap.py b/bertopic/plotting/_heatmap.py index a1f251e2..ad9f0664 100644 --- a/bertopic/plotting/_heatmap.py +++ b/bertopic/plotting/_heatmap.py @@ -8,16 +8,18 @@ import plotly.graph_objects as go -def visualize_heatmap(topic_model, - topics: List[int] = None, - top_n_topics: int = None, - n_clusters: int = None, - use_ctfidf: bool = False, - custom_labels: Union[bool, str] = False, - title: str = "Similarity Matrix", - width: int = 800, - height: int = 800) -> go.Figure: - """ Visualize a heatmap of the topic's similarity matrix +def visualize_heatmap( + topic_model, + topics: List[int] = None, + top_n_topics: int = None, + n_clusters: int = None, + use_ctfidf: bool = False, + custom_labels: Union[bool, str] = False, + title: str = "Similarity Matrix", + width: int = 800, + height: int = 800, +) -> go.Figure: + """Visualize a heatmap of the topic's similarity matrix. Based on the cosine similarity matrix between topic embeddings (either c-TF-IDF or the embeddings from the embedding model), a heatmap is created showing the similarity between topics. @@ -30,7 +32,7 @@ def visualize_heatmap(topic_model, matrix by those clusters. use_ctfidf: Whether to calculate distances between topics based on c-TF-IDF embeddings. If False, the embeddings from the embedding model are used. - custom_labels: If bool, whether to use custom topic labels that were defined using + custom_labels: If bool, whether to use custom topic labels that were defined using `topic_model.set_topic_labels`. If `str`, it uses labels from other aspects, e.g., "Aspect1". title: Title of the plot. @@ -41,7 +43,6 @@ def visualize_heatmap(topic_model, fig: A plotly figure Examples: - To visualize the similarity matrix of topics simply run: @@ -58,10 +59,9 @@ def visualize_heatmap(topic_model, """ - embeddings = select_topic_representation( topic_model.c_tf_idf_, topic_model.topic_embeddings_, use_ctfidf - )[0][topic_model._outliers:] + )[0][topic_model._outliers :] # Select topics based on top_n and topics args freq_df = topic_model.get_topic_freq() @@ -77,12 +77,14 @@ def visualize_heatmap(topic_model, sorted_topics = topics if n_clusters: if n_clusters >= len(set(topics)): - raise ValueError("Make sure to set `n_clusters` lower than " - "the total number of unique topics.") + raise ValueError( + "Make sure to set `n_clusters` lower than " + "the total number of unique topics." + ) distance_matrix = cosine_similarity(embeddings[topics]) - Z = linkage(distance_matrix, 'ward') - clusters = fcluster(Z, t=n_clusters, criterion='maxclust') + Z = linkage(distance_matrix, "ward") + clusters = fcluster(Z, t=n_clusters, criterion="maxclust") # Extract new order of topics mapping = {cluster: [] for cluster in clusters} @@ -98,43 +100,55 @@ def visualize_heatmap(topic_model, # Create labels if isinstance(custom_labels, str): - new_labels = [[[str(topic), None]] + topic_model.topic_aspects_[custom_labels][topic] for topic in sorted_topics] - new_labels = ["_".join([label[0] for label in labels[:4]]) for labels in new_labels] - new_labels = [label if len(label) < 30 else label[:27] + "..." for label in new_labels] + new_labels = [ + [[str(topic), None]] + topic_model.topic_aspects_[custom_labels][topic] + for topic in sorted_topics + ] + new_labels = [ + "_".join([label[0] for label in labels[:4]]) for labels in new_labels + ] + new_labels = [ + label if len(label) < 30 else label[:27] + "..." for label in new_labels + ] elif topic_model.custom_labels_ is not None and custom_labels: - new_labels = [topic_model.custom_labels_[topic + topic_model._outliers] for topic in sorted_topics] + new_labels = [ + topic_model.custom_labels_[topic + topic_model._outliers] + for topic in sorted_topics + ] else: - new_labels = [[[str(topic), None]] + topic_model.get_topic(topic) for topic in sorted_topics] - new_labels = ["_".join([label[0] for label in labels[:4]]) for labels in new_labels] - new_labels = [label if len(label) < 30 else label[:27] + "..." for label in new_labels] - - fig = px.imshow(distance_matrix, - labels=dict(color="Similarity Score"), - x=new_labels, - y=new_labels, - color_continuous_scale='GnBu' - ) + new_labels = [ + [[str(topic), None]] + topic_model.get_topic(topic) + for topic in sorted_topics + ] + new_labels = [ + "_".join([label[0] for label in labels[:4]]) for labels in new_labels + ] + new_labels = [ + label if len(label) < 30 else label[:27] + "..." for label in new_labels + ] + + fig = px.imshow( + distance_matrix, + labels=dict(color="Similarity Score"), + x=new_labels, + y=new_labels, + color_continuous_scale="GnBu", + ) fig.update_layout( title={ - 'text': f"{title}", - 'y': .95, - 'x': 0.55, - 'xanchor': 'center', - 'yanchor': 'top', - 'font': dict( - size=22, - color="Black") + "text": f"{title}", + "y": 0.95, + "x": 0.55, + "xanchor": "center", + "yanchor": "top", + "font": dict(size=22, color="Black"), }, width=width, height=height, - hoverlabel=dict( - bgcolor="white", - font_size=16, - font_family="Rockwell" - ), + hoverlabel=dict(bgcolor="white", font_size=16, font_family="Rockwell"), ) fig.update_layout(showlegend=True) - fig.update_layout(legend_title_text='Trend') + fig.update_layout(legend_title_text="Trend") return fig diff --git a/bertopic/plotting/_hierarchical_documents.py b/bertopic/plotting/_hierarchical_documents.py index e9e4ca64..5501c8b7 100644 --- a/bertopic/plotting/_hierarchical_documents.py +++ b/bertopic/plotting/_hierarchical_documents.py @@ -1,30 +1,33 @@ import numpy as np import pandas as pd import plotly.graph_objects as go -import math +import math from umap import UMAP from typing import List, Union -def visualize_hierarchical_documents(topic_model, - docs: List[str], - hierarchical_topics: pd.DataFrame, - topics: List[int] = None, - embeddings: np.ndarray = None, - reduced_embeddings: np.ndarray = None, - sample: Union[float, int] = None, - hide_annotations: bool = False, - hide_document_hover: bool = True, - nr_levels: int = 10, - level_scale: str = 'linear', - custom_labels: Union[bool, str] = False, - title: str = "Hierarchical Documents and Topics", - width: int = 1200, - height: int = 750) -> go.Figure: - """ Visualize documents and their topics in 2D at different levels of hierarchy +def visualize_hierarchical_documents( + topic_model, + docs: List[str], + hierarchical_topics: pd.DataFrame, + topics: List[int] = None, + embeddings: np.ndarray = None, + reduced_embeddings: np.ndarray = None, + sample: Union[float, int] = None, + hide_annotations: bool = False, + hide_document_hover: bool = True, + nr_levels: int = 10, + level_scale: str = "linear", + custom_labels: Union[bool, str] = False, + title: str = "Hierarchical Documents and Topics", + width: int = 1200, + height: int = 750, +) -> go.Figure: + """Visualize documents and their topics in 2D at different levels of hierarchy. Arguments: + topic_model: A fitted BERTopic instance. docs: The documents you used when calling either `fit` or `fit_transform` hierarchical_topics: A dataframe that contains a hierarchy of topics represented by their parents and their children @@ -42,27 +45,26 @@ def visualize_hierarchical_documents(topic_model, hide_document_hover: Hide the content of the documents when hovering over specific points. Helps to speed up generation of visualizations. nr_levels: The number of levels to be visualized in the hierarchy. First, the distances - in `hierarchical_topics.Distance` are split in `nr_levels` lists of distances. - Then, for each list of distances, the merged topics are selected that have a + in `hierarchical_topics.Distance` are split in `nr_levels` lists of distances. + Then, for each list of distances, the merged topics are selected that have a distance less or equal to the maximum distance of the selected list of distances. NOTE: To get all possible merged steps, make sure that `nr_levels` is equal to the length of `hierarchical_topics`. - level_scale: Whether to apply a linear or logarithmic (log) scale levels of the distance - vector. Linear scaling will perform an equal number of merges at each level - while logarithmic scaling will perform more mergers in earlier levels to - provide more resolution at higher levels (this can be used for when the number - of topics is large). - custom_labels: If bool, whether to use custom topic labels that were defined using + level_scale: Whether to apply a linear or logarithmic (log) scale levels of the distance + vector. Linear scaling will perform an equal number of merges at each level + while logarithmic scaling will perform more mergers in earlier levels to + provide more resolution at higher levels (this can be used for when the number + of topics is large). + custom_labels: If bool, whether to use custom topic labels that were defined using `topic_model.set_topic_labels`. If `str`, it uses labels from other aspects, e.g., "Aspect1". - NOTE: Custom labels are only generated for the original + NOTE: Custom labels are only generated for the original un-merged topics. title: Title of the plot. width: The width of the figure. height: The height of the figure. Examples: - To visualize the topics simply run: ```python @@ -104,7 +106,7 @@ def visualize_hierarchical_documents(topic_model, fig.write_html("path/to/file.html") ``` - NOTE: + Note: This visualization was inspired by the scatter plot representation of Doc2Map: https://github.com/louisgeisler/Doc2Map @@ -120,7 +122,7 @@ def visualize_hierarchical_documents(topic_model, indices = [] for topic in set(topic_per_doc): s = np.where(np.array(topic_per_doc) == topic)[0] - size = len(s) if len(s) < 100 else int(len(s)*sample) + size = len(s) if len(s) < 100 else int(len(s) * sample) indices.extend(np.random.choice(s, size=size, replace=False)) indices = np.array(indices) @@ -131,18 +133,24 @@ def visualize_hierarchical_documents(topic_model, # Extract embeddings if not already done if sample is None: if embeddings is None and reduced_embeddings is None: - embeddings_to_reduce = topic_model._extract_embeddings(df.doc.to_list(), method="document") + embeddings_to_reduce = topic_model._extract_embeddings( + df.doc.to_list(), method="document" + ) else: embeddings_to_reduce = embeddings else: if embeddings is not None: embeddings_to_reduce = embeddings[indices] elif embeddings is None and reduced_embeddings is None: - embeddings_to_reduce = topic_model._extract_embeddings(df.doc.to_list(), method="document") + embeddings_to_reduce = topic_model._extract_embeddings( + df.doc.to_list(), method="document" + ) # Reduce input embeddings if reduced_embeddings is None: - umap_model = UMAP(n_neighbors=10, n_components=2, min_dist=0.0, metric='cosine').fit(embeddings_to_reduce) + umap_model = UMAP( + n_neighbors=10, n_components=2, min_dist=0.0, metric="cosine" + ).fit(embeddings_to_reduce) embeddings_2d = umap_model.embedding_ elif sample is not None and reduced_embeddings is not None: embeddings_2d = reduced_embeddings[indices] @@ -155,20 +163,34 @@ def visualize_hierarchical_documents(topic_model, # Create topic list for each level, levels are created by calculating the distance distances = hierarchical_topics.Distance.to_list() - if level_scale == 'log' or level_scale == 'logarithmic': - log_indices = np.round(np.logspace(start=math.log(1,10), stop=math.log(len(distances)-1,10), num=nr_levels)).astype(int).tolist() + if level_scale == "log" or level_scale == "logarithmic": + log_indices = ( + np.round( + np.logspace( + start=math.log(1, 10), + stop=math.log(len(distances) - 1, 10), + num=nr_levels, + ) + ) + .astype(int) + .tolist() + ) log_indices.reverse() max_distances = [distances[i] for i in log_indices] - elif level_scale == 'lin' or level_scale == 'linear': - max_distances = [distances[indices[-1]] for indices in np.array_split(range(len(hierarchical_topics)), nr_levels)][::-1] + elif level_scale == "lin" or level_scale == "linear": + max_distances = [ + distances[indices[-1]] + for indices in np.array_split(range(len(hierarchical_topics)), nr_levels) + ][::-1] else: raise ValueError("level_scale needs to be one of 'log' or 'linear'") - - for index, max_distance in enumerate(max_distances): + for index, max_distance in enumerate(max_distances): # Get topics below `max_distance` mapping = {topic: topic for topic in df.topic.unique()} - selection = hierarchical_topics.loc[hierarchical_topics.Distance <= max_distance, :] + selection = hierarchical_topics.loc[ + hierarchical_topics.Distance <= max_distance, : + ] selection.Parent_ID = selection.Parent_ID.astype(int) selection = selection.sort_values("Parent_ID") @@ -196,17 +218,36 @@ def visualize_hierarchical_documents(topic_model, if topic < hierarchical_topics.Parent_ID.astype(int).min(): if topic_model.get_topic(topic): if isinstance(custom_labels, str): - trace_name = f"{topic}_" + "_".join(list(zip(*topic_model.topic_aspects_[custom_labels][topic]))[0][:3]) + trace_name = f"{topic}_" + "_".join( + list(zip(*topic_model.topic_aspects_[custom_labels][topic]))[0][ + :3 + ] + ) elif topic_model.custom_labels_ is not None and custom_labels: - trace_name = topic_model.custom_labels_[topic + topic_model._outliers] + trace_name = topic_model.custom_labels_[ + topic + topic_model._outliers + ] else: - trace_name = f"{topic}_" + "_".join([word[:20] for word, _ in topic_model.get_topic(topic)][:3]) - topic_names[topic] = {"trace_name": trace_name[:40], "plot_text": trace_name[:40]} + trace_name = f"{topic}_" + "_".join( + [word[:20] for word, _ in topic_model.get_topic(topic)][:3] + ) + topic_names[topic] = { + "trace_name": trace_name[:40], + "plot_text": trace_name[:40], + } trace_names.append(trace_name) else: - trace_name = f"{topic}_" + hierarchical_topics.loc[hierarchical_topics.Parent_ID == str(topic), "Parent_Name"].values[0] + trace_name = ( + f"{topic}_" + + hierarchical_topics.loc[ + hierarchical_topics.Parent_ID == str(topic), "Parent_Name" + ].values[0] + ) plot_text = "_".join([name[:20] for name in trace_name.split("_")[:3]]) - topic_names[topic] = {"trace_name": trace_name[:40], "plot_text": plot_text[:40]} + topic_names[topic] = { + "trace_name": trace_name[:40], + "plot_text": plot_text[:40], + } trace_names.append(trace_name) # Prepare traces @@ -217,30 +258,37 @@ def visualize_hierarchical_documents(topic_model, # Outliers if topic_model._outliers: traces.append( - go.Scattergl( - x=df.loc[(df[f"level_{level+1}"] == -1), "x"], - y=df.loc[df[f"level_{level+1}"] == -1, "y"], - mode='markers+text', - name="other", - hoverinfo="text", - hovertext=df.loc[(df[f"level_{level+1}"] == -1), "doc"] if not hide_document_hover else None, - showlegend=False, - marker=dict(color='#CFD8DC', size=5, opacity=0.5) - ) + go.Scattergl( + x=df.loc[(df[f"level_{level+1}"] == -1), "x"], + y=df.loc[df[f"level_{level+1}"] == -1, "y"], + mode="markers+text", + name="other", + hoverinfo="text", + hovertext=df.loc[(df[f"level_{level+1}"] == -1), "doc"] + if not hide_document_hover + else None, + showlegend=False, + marker=dict(color="#CFD8DC", size=5, opacity=0.5), ) + ) # Selected topics if topics: selection = df.loc[(df.topic.isin(topics)), :] - unique_topics = sorted([int(topic) for topic in selection[f"level_{level+1}"].unique()]) + unique_topics = sorted( + [int(topic) for topic in selection[f"level_{level+1}"].unique()] + ) else: - unique_topics = sorted([int(topic) for topic in df[f"level_{level+1}"].unique()]) + unique_topics = sorted( + [int(topic) for topic in df[f"level_{level+1}"].unique()] + ) for topic in unique_topics: if topic != -1: if topics: - selection = df.loc[(df[f"level_{level+1}"] == topic) & - (df.topic.isin(topics)), :] + selection = df.loc[ + (df[f"level_{level+1}"] == topic) & (df.topic.isin(topics)), : + ] else: selection = df.loc[df[f"level_{level+1}"] == topic, :] @@ -249,7 +297,9 @@ def visualize_hierarchical_documents(topic_model, selection["text"] = "" selection.loc[len(selection) - 1, "x"] = selection.x.mean() selection.loc[len(selection) - 1, "y"] = selection.y.mean() - selection.loc[len(selection) - 1, "text"] = topic_names[int(topic)]["plot_text"] + selection.loc[len(selection) - 1, "text"] = topic_names[int(topic)][ + "plot_text" + ] traces.append( go.Scattergl( @@ -259,8 +309,8 @@ def visualize_hierarchical_documents(topic_model, hovertext=selection.doc if not hide_document_hover else None, hoverinfo="text", name=topic_names[int(topic)]["trace_name"], - mode='markers+text', - marker=dict(size=5, opacity=0.5) + mode="markers+text", + marker=dict(size=5, opacity=0.5), ) ) @@ -290,42 +340,56 @@ def visualize_hierarchical_documents(topic_model, step = dict( method="update", label=str(index), - args=[{"visible": [False] * len(fig.data)}] + args=[{"visible": [False] * len(fig.data)}], ) - for index in range(indices[1]-indices[0]): - step["args"][0]["visible"][index+indices[0]] = True + for index in range(indices[1] - indices[0]): + step["args"][0]["visible"][index + indices[0]] = True steps.append(step) - sliders = [dict( - currentvalue={"prefix": "Level: "}, - pad={"t": 20}, - steps=steps - )] + sliders = [dict(currentvalue={"prefix": "Level: "}, pad={"t": 20}, steps=steps)] # Add grid in a 'plus' shape - x_range = (df.x.min() - abs((df.x.min()) * .15), df.x.max() + abs((df.x.max()) * .15)) - y_range = (df.y.min() - abs((df.y.min()) * .15), df.y.max() + abs((df.y.max()) * .15)) - fig.add_shape(type="line", - x0=sum(x_range) / 2, y0=y_range[0], x1=sum(x_range) / 2, y1=y_range[1], - line=dict(color="#CFD8DC", width=2)) - fig.add_shape(type="line", - x0=x_range[0], y0=sum(y_range) / 2, x1=x_range[1], y1=sum(y_range) / 2, - line=dict(color="#9E9E9E", width=2)) - fig.add_annotation(x=x_range[0], y=sum(y_range) / 2, text="D1", showarrow=False, yshift=10) - fig.add_annotation(y=y_range[1], x=sum(x_range) / 2, text="D2", showarrow=False, xshift=10) + x_range = ( + df.x.min() - abs((df.x.min()) * 0.15), + df.x.max() + abs((df.x.max()) * 0.15), + ) + y_range = ( + df.y.min() - abs((df.y.min()) * 0.15), + df.y.max() + abs((df.y.max()) * 0.15), + ) + fig.add_shape( + type="line", + x0=sum(x_range) / 2, + y0=y_range[0], + x1=sum(x_range) / 2, + y1=y_range[1], + line=dict(color="#CFD8DC", width=2), + ) + fig.add_shape( + type="line", + x0=x_range[0], + y0=sum(y_range) / 2, + x1=x_range[1], + y1=sum(y_range) / 2, + line=dict(color="#9E9E9E", width=2), + ) + fig.add_annotation( + x=x_range[0], y=sum(y_range) / 2, text="D1", showarrow=False, yshift=10 + ) + fig.add_annotation( + y=y_range[1], x=sum(x_range) / 2, text="D2", showarrow=False, xshift=10 + ) # Stylize layout fig.update_layout( sliders=sliders, template="simple_white", title={ - 'text': f"{title}", - 'x': 0.5, - 'xanchor': 'center', - 'yanchor': 'top', - 'font': dict( - size=22, - color="Black") + "text": f"{title}", + "x": 0.5, + "xanchor": "center", + "yanchor": "top", + "font": dict(size=22, color="Black"), }, width=width, height=height, diff --git a/bertopic/plotting/_hierarchy.py b/bertopic/plotting/_hierarchy.py index 5ee5c841..6faa1bc4 100644 --- a/bertopic/plotting/_hierarchy.py +++ b/bertopic/plotting/_hierarchy.py @@ -3,7 +3,6 @@ from typing import Callable, List, Union from scipy.sparse import csr_matrix from scipy.cluster import hierarchy as sch -from scipy.spatial.distance import squareform from sklearn.metrics.pairwise import cosine_similarity from bertopic._utils import select_topic_representation @@ -13,20 +12,23 @@ from bertopic._utils import validate_distance_matrix -def visualize_hierarchy(topic_model, - orientation: str = "left", - topics: List[int] = None, - top_n_topics: int = None, - use_ctfidf: bool = True, - custom_labels: Union[bool, str] = False, - title: str = "Hierarchical Clustering", - width: int = 1000, - height: int = 600, - hierarchical_topics: pd.DataFrame = None, - linkage_function: Callable[[csr_matrix], np.ndarray] = None, - distance_function: Callable[[csr_matrix], csr_matrix] = None, - color_threshold: int = 1) -> go.Figure: - """ Visualize a hierarchical structure of the topics + +def visualize_hierarchy( + topic_model, + orientation: str = "left", + topics: List[int] = None, + top_n_topics: int = None, + use_ctfidf: bool = True, + custom_labels: Union[bool, str] = False, + title: str = "Hierarchical Clustering", + width: int = 1000, + height: int = 600, + hierarchical_topics: pd.DataFrame = None, + linkage_function: Callable[[csr_matrix], np.ndarray] = None, + distance_function: Callable[[csr_matrix], csr_matrix] = None, + color_threshold: int = 1, +) -> go.Figure: + """Visualize a hierarchical structure of the topics. A ward linkage function is used to perform the hierarchical clustering based on the cosine distance @@ -40,10 +42,10 @@ def visualize_hierarchy(topic_model, top_n_topics: Only select the top n most frequent topics use_ctfidf: Whether to calculate distances between topics based on c-TF-IDF embeddings. If False, the embeddings from the embedding model are used. - custom_labels: If bool, whether to use custom topic labels that were defined using + custom_labels: If bool, whether to use custom topic labels that were defined using `topic_model.set_topic_labels`. If `str`, it uses labels from other aspects, e.g., "Aspect1". - NOTE: Custom labels are only generated for the original + NOTE: Custom labels are only generated for the original un-merged topics. title: Title of the plot. width: The width of the figure. Only works if orientation is set to 'left' @@ -58,10 +60,10 @@ def visualize_hierarchy(topic_model, in `topic_model.hierarchical_topics`. distance_function: The distance function to use on the c-TF-IDF matrix. Default is: `lambda x: 1 - cosine_similarity(x)`. - You can pass any function that returns either a square matrix of - shape (n_samples, n_samples) with zeros on the diagonal and - non-negative values or condensed distance matrix of shape - (n_samples * (n_samples - 1) / 2,) containing the upper + You can pass any function that returns either a square matrix of + shape (n_samples, n_samples) with zeros on the diagonal and + non-negative values or condensed distance matrix of shape + (n_samples * (n_samples - 1) / 2,) containing the upper triangular of the distance matrix. NOTE: Make sure to use the same `distance_function` as used in `topic_model.hierarchical_topics`. @@ -73,7 +75,6 @@ def visualize_hierarchy(topic_model, fig: A plotly figure Examples: - To visualize the hierarchical structure of topics simply run: @@ -105,7 +106,7 @@ def visualize_hierarchy(topic_model, distance_function = lambda x: 1 - cosine_similarity(x) if linkage_function is None: - linkage_function = lambda x: sch.linkage(x, 'ward', optimal_ordering=True) + linkage_function = lambda x: sch.linkage(x, "ward", optimal_ordering=True) # Select topics based on top_n and topics args freq_df = topic_model.get_topic_freq() @@ -128,103 +129,141 @@ def visualize_hierarchy(topic_model, # Annotations if hierarchical_topics is not None and len(topics) == len(freq_df.Topic.to_list()): - annotations = _get_annotations(topic_model=topic_model, - hierarchical_topics=hierarchical_topics, - embeddings=embeddings, - distance_function=distance_function, - linkage_function=linkage_function, - orientation=orientation, - custom_labels=custom_labels) + annotations = _get_annotations( + topic_model=topic_model, + hierarchical_topics=hierarchical_topics, + embeddings=embeddings, + distance_function=distance_function, + linkage_function=linkage_function, + orientation=orientation, + custom_labels=custom_labels, + ) else: annotations = None # wrap distance function to validate input and return a condensed distance matrix distance_function_viz = lambda x: validate_distance_matrix( - distance_function(x), embeddings.shape[0]) + distance_function(x), embeddings.shape[0] + ) # Create dendogram - fig = ff.create_dendrogram(embeddings, - orientation=orientation, - distfun=distance_function_viz, - linkagefun=linkage_function, - hovertext=annotations, - color_threshold=color_threshold) + fig = ff.create_dendrogram( + embeddings, + orientation=orientation, + distfun=distance_function_viz, + linkagefun=linkage_function, + hovertext=annotations, + color_threshold=color_threshold, + ) # Create nicer labels axis = "yaxis" if orientation == "left" else "xaxis" if isinstance(custom_labels, str): - new_labels = [[[str(x), None]] + topic_model.topic_aspects_[custom_labels][x] for x in fig.layout[axis]["ticktext"]] - new_labels = ["_".join([label[0] for label in labels[:4]]) for labels in new_labels] - new_labels = [label if len(label) < 30 else label[:27] + "..." for label in new_labels] + new_labels = [ + [[str(x), None]] + topic_model.topic_aspects_[custom_labels][x] + for x in fig.layout[axis]["ticktext"] + ] + new_labels = [ + "_".join([label[0] for label in labels[:4]]) for labels in new_labels + ] + new_labels = [ + label if len(label) < 30 else label[:27] + "..." for label in new_labels + ] elif topic_model.custom_labels_ is not None and custom_labels: - new_labels = [topic_model.custom_labels_[topics[int(x)] + topic_model._outliers] for x in fig.layout[axis]["ticktext"]] + new_labels = [ + topic_model.custom_labels_[topics[int(x)] + topic_model._outliers] + for x in fig.layout[axis]["ticktext"] + ] else: - new_labels = [[[str(topics[int(x)]), None]] + topic_model.get_topic(topics[int(x)]) - for x in fig.layout[axis]["ticktext"]] - new_labels = ["_".join([label[0] for label in labels[:4]]) for labels in new_labels] - new_labels = [label if len(label) < 30 else label[:27] + "..." for label in new_labels] + new_labels = [ + [[str(topics[int(x)]), None]] + topic_model.get_topic(topics[int(x)]) + for x in fig.layout[axis]["ticktext"] + ] + new_labels = [ + "_".join([label[0] for label in labels[:4]]) for labels in new_labels + ] + new_labels = [ + label if len(label) < 30 else label[:27] + "..." for label in new_labels + ] # Stylize layout fig.update_layout( - plot_bgcolor='#ECEFF1', + plot_bgcolor="#ECEFF1", template="plotly_white", title={ - 'text': f"{title}", - 'x': 0.5, - 'xanchor': 'center', - 'yanchor': 'top', - 'font': dict( - size=22, - color="Black") + "text": f"{title}", + "x": 0.5, + "xanchor": "center", + "yanchor": "top", + "font": dict(size=22, color="Black"), }, - hoverlabel=dict( - bgcolor="white", - font_size=16, - font_family="Rockwell" - ), + hoverlabel=dict(bgcolor="white", font_size=16, font_family="Rockwell"), ) # Stylize orientation if orientation == "left": - fig.update_layout(height=200 + (15 * len(topics)), - width=width, - yaxis=dict(tickmode="array", - ticktext=new_labels)) + fig.update_layout( + height=200 + (15 * len(topics)), + width=width, + yaxis=dict(tickmode="array", ticktext=new_labels), + ) # Fix empty space on the bottom of the graph - y_max = max([trace['y'].max() + 5 for trace in fig['data']]) - y_min = min([trace['y'].min() - 5 for trace in fig['data']]) + y_max = max([trace["y"].max() + 5 for trace in fig["data"]]) + y_min = min([trace["y"].min() - 5 for trace in fig["data"]]) fig.update_layout(yaxis=dict(range=[y_min, y_max])) else: - fig.update_layout(width=200 + (15 * len(topics)), - height=height, - xaxis=dict(tickmode="array", - ticktext=new_labels)) + fig.update_layout( + width=200 + (15 * len(topics)), + height=height, + xaxis=dict(tickmode="array", ticktext=new_labels), + ) if hierarchical_topics is not None: for index in [0, 3]: axis = "x" if orientation == "left" else "y" - xs = [data["x"][index] for data in fig.data if (data["text"] and data[axis][index] > 0)] - ys = [data["y"][index] for data in fig.data if (data["text"] and data[axis][index] > 0)] - hovertext = [data["text"][index] for data in fig.data if (data["text"] and data[axis][index] > 0)] - - fig.add_trace(go.Scatter(x=xs, y=ys, marker_color='black', - hovertext=hovertext, hoverinfo="text", - mode='markers', showlegend=False)) + xs = [ + data["x"][index] + for data in fig.data + if (data["text"] and data[axis][index] > 0) + ] + ys = [ + data["y"][index] + for data in fig.data + if (data["text"] and data[axis][index] > 0) + ] + hovertext = [ + data["text"][index] + for data in fig.data + if (data["text"] and data[axis][index] > 0) + ] + + fig.add_trace( + go.Scatter( + x=xs, + y=ys, + marker_color="black", + hovertext=hovertext, + hoverinfo="text", + mode="markers", + showlegend=False, + ) + ) return fig -def _get_annotations(topic_model, - hierarchical_topics: pd.DataFrame, - embeddings: csr_matrix, - linkage_function: Callable[[csr_matrix], np.ndarray], - distance_function: Callable[[csr_matrix], csr_matrix], - orientation: str, - custom_labels: bool = False) -> List[List[str]]: +def _get_annotations( + topic_model, + hierarchical_topics: pd.DataFrame, + embeddings: csr_matrix, + linkage_function: Callable[[csr_matrix], np.ndarray], + distance_function: Callable[[csr_matrix], csr_matrix], + orientation: str, + custom_labels: bool = False, +) -> List[List[str]]: + """Get annotations by replicating linkage function calculation in scipy. - """ Get annotations by replicating linkage function calculation in scipy - - Arguments + Arguments: topic_model: A fitted BERTopic instance. hierarchical_topics: A dataframe that contains a hierarchy of topics represented by their parents and their children. @@ -237,10 +276,10 @@ def _get_annotations(topic_model, in `topic_model.hierarchical_topics`. distance_function: The distance function to use on the c-TF-IDF matrix. Default is: `lambda x: 1 - cosine_similarity(x)`. - You can pass any function that returns either a square matrix of - shape (n_samples, n_samples) with zeros on the diagonal and - non-negative values or condensed distance matrix of shape - (n_samples * (n_samples - 1) / 2,) containing the upper + You can pass any function that returns either a square matrix of + shape (n_samples, n_samples) with zeros on the diagonal and + non-negative values or condensed distance matrix of shape + (n_samples * (n_samples - 1) / 2,) containing the upper triangular of the distance matrix. NOTE: Make sure to use the same `distance_function` as used in `topic_model.hierarchical_topics`. @@ -265,8 +304,8 @@ def _get_annotations(topic_model, P = sch.dendrogram(Z, orientation=orientation, no_plot=True) # store topic no.(leaves) corresponding to the x-ticks in dendrogram - x_ticks = np.arange(5, len(P['leaves']) * 10 + 5, 10) - x_topic = dict(zip(P['leaves'], x_ticks)) + x_ticks = np.arange(5, len(P["leaves"]) * 10 + 5, 10) + x_topic = dict(zip(P["leaves"], x_ticks)) topic_vals = dict() for key, val in x_topic.items(): @@ -276,17 +315,25 @@ def _get_annotations(topic_model, # loop through every trace (scatter plot) in dendrogram text_annotations = [] - for index, trace in enumerate(P['icoord']): + for index, trace in enumerate(P["icoord"]): fst_topic = topic_vals[trace[0]] scnd_topic = topic_vals[trace[2]] if len(fst_topic) == 1: if isinstance(custom_labels, str): - fst_name = f"{fst_topic[0]}_" + "_".join(list(zip(*topic_model.topic_aspects_[custom_labels][fst_topic[0]]))[0][:3]) + fst_name = f"{fst_topic[0]}_" + "_".join( + list(zip(*topic_model.topic_aspects_[custom_labels][fst_topic[0]]))[ + 0 + ][:3] + ) elif topic_model.custom_labels_ is not None and custom_labels: - fst_name = topic_model.custom_labels_[fst_topic[0] + topic_model._outliers] + fst_name = topic_model.custom_labels_[ + fst_topic[0] + topic_model._outliers + ] else: - fst_name = "_".join([word for word, _ in topic_model.get_topic(fst_topic[0])][:5]) + fst_name = "_".join( + [word for word, _ in topic_model.get_topic(fst_topic[0])][:5] + ) else: for key, value in parent_topic.items(): if set(value) == set(fst_topic): @@ -294,11 +341,19 @@ def _get_annotations(topic_model, if len(scnd_topic) == 1: if isinstance(custom_labels, str): - scnd_name = f"{scnd_topic[0]}_" + "_".join(list(zip(*topic_model.topic_aspects_[custom_labels][scnd_topic[0]]))[0][:3]) + scnd_name = f"{scnd_topic[0]}_" + "_".join( + list( + zip(*topic_model.topic_aspects_[custom_labels][scnd_topic[0]]) + )[0][:3] + ) elif topic_model.custom_labels_ is not None and custom_labels: - scnd_name = topic_model.custom_labels_[scnd_topic[0] + topic_model._outliers] + scnd_name = topic_model.custom_labels_[ + scnd_topic[0] + topic_model._outliers + ] else: - scnd_name = "_".join([word for word, _ in topic_model.get_topic(scnd_topic[0])][:5]) + scnd_name = "_".join( + [word for word, _ in topic_model.get_topic(scnd_topic[0])][:5] + ) else: for key, value in parent_topic.items(): if set(value) == set(scnd_topic): diff --git a/bertopic/plotting/_term_rank.py b/bertopic/plotting/_term_rank.py index a02ab220..5dc98a23 100644 --- a/bertopic/plotting/_term_rank.py +++ b/bertopic/plotting/_term_rank.py @@ -3,14 +3,16 @@ import plotly.graph_objects as go -def visualize_term_rank(topic_model, - topics: List[int] = None, - log_scale: bool = False, - custom_labels: Union[bool, str] = False, - title: str = "Term score decline per Topic", - width: int = 800, - height: int = 500) -> go.Figure: - """ Visualize the ranks of all terms across all topics +def visualize_term_rank( + topic_model, + topics: List[int] = None, + log_scale: bool = False, + custom_labels: Union[bool, str] = False, + title: str = "Term score decline per Topic", + width: int = 800, + height: int = 500, +) -> go.Figure: + """Visualize the ranks of all terms across all topics. Each topic is represented by a set of words. These words, however, do not all equally represent the topic. This visualization shows @@ -22,7 +24,7 @@ def visualize_term_rank(topic_model, topics: A selection of topics to visualize. These will be colored red where all others will be colored black. log_scale: Whether to represent the ranking on a log scale - custom_labels: If bool, whether to use custom topic labels that were defined using + custom_labels: If bool, whether to use custom topic labels that were defined using `topic_model.set_topic_labels`. If `str`, it uses labels from other aspects, e.g., "Aspect1". title: Title of the plot. @@ -33,7 +35,6 @@ def visualize_term_rank(topic_model, fig: A plotly figure Examples: - To visualize the ranks of all words across all topics simply run: @@ -62,42 +63,49 @@ def visualize_term_rank(topic_model, Reference to that specific analysis can be found [here](https://wzbsocialsciencecenter.github.io/tm_corona/tm_analysis.html). """ - topics = [] if topics is None else topics topic_ids = topic_model.get_topic_info().Topic.unique().tolist() topic_words = [topic_model.get_topic(topic) for topic in topic_ids] values = np.array([[value[1] for value in values] for values in topic_words]) - indices = np.array([[value + 1 for value in range(len(values))] for values in topic_words]) + indices = np.array( + [[value + 1 for value in range(len(values))] for values in topic_words] + ) # Create figure lines = [] for topic, x, y in zip(topic_ids, indices, values): if not any(y > 1.5): - # labels if isinstance(custom_labels, str): - label = f"{topic}_" + "_".join(list(zip(*topic_model.topic_aspects_[custom_labels][topic]))[0][:3]) + label = f"{topic}_" + "_".join( + list(zip(*topic_model.topic_aspects_[custom_labels][topic]))[0][:3] + ) elif topic_model.custom_labels_ is not None and custom_labels: label = topic_model.custom_labels_[topic + topic_model._outliers] else: - label = f"Topic {topic}:" + "_".join([word[0] for word in topic_model.get_topic(topic)]) + label = f"Topic {topic}:" + "_".join( + [word[0] for word in topic_model.get_topic(topic)] + ) label = label[:50] # line parameters color = "red" if topic in topics else "black" - opacity = 1 if topic in topics else .1 + opacity = 1 if topic in topics else 0.1 if any(y == 0): y[y == 0] = min(values[values > 0]) y = np.log10(y, out=y, where=y > 0) if log_scale else y - line = go.Scatter(x=x, y=y, - name="", - hovertext=label, - mode="lines+lines", - opacity=opacity, - line=dict(color=color, width=1.5)) + line = go.Scatter( + x=x, + y=y, + name="", + hovertext=label, + mode="lines+lines", + opacity=opacity, + line=dict(color=color, width=1.5), + ) lines.append(line) fig = go.Figure(data=lines) @@ -108,28 +116,22 @@ def visualize_term_rank(topic_model, showlegend=False, template="plotly_white", title={ - 'text': f"{title}", - 'y': .9, - 'x': 0.5, - 'xanchor': 'center', - 'yanchor': 'top', - 'font': dict( - size=22, - color="Black") + "text": f"{title}", + "y": 0.9, + "x": 0.5, + "xanchor": "center", + "yanchor": "top", + "font": dict(size=22, color="Black"), }, width=width, height=height, - hoverlabel=dict( - bgcolor="white", - font_size=16, - font_family="Rockwell" - ), + hoverlabel=dict(bgcolor="white", font_size=16, font_family="Rockwell"), ) - fig.update_xaxes(title_text='Term Rank') + fig.update_xaxes(title_text="Term Rank") if log_scale: - fig.update_yaxes(title_text='c-TF-IDF score (log scale)') + fig.update_yaxes(title_text="c-TF-IDF score (log scale)") else: - fig.update_yaxes(title_text='c-TF-IDF score') + fig.update_yaxes(title_text="c-TF-IDF score") return fig diff --git a/bertopic/plotting/_topics.py b/bertopic/plotting/_topics.py index ff81a603..8a14a34d 100644 --- a/bertopic/plotting/_topics.py +++ b/bertopic/plotting/_topics.py @@ -8,15 +8,17 @@ import plotly.graph_objects as go -def visualize_topics(topic_model, - topics: List[int] = None, - top_n_topics: int = None, - use_ctfidf: bool = False, - custom_labels: Union[bool, str] = False, - title: str = "Intertopic Distance Map", - width: int = 650, - height: int = 650) -> go.Figure: - """ Visualize topics, their sizes, and their corresponding words +def visualize_topics( + topic_model, + topics: List[int] = None, + top_n_topics: int = None, + use_ctfidf: bool = False, + custom_labels: Union[bool, str] = False, + title: str = "Intertopic Distance Map", + width: int = 650, + height: int = 650, +) -> go.Figure: + """Visualize topics, their sizes, and their corresponding words. This visualization is highly inspired by LDAvis, a great visualization technique typically reserved for LDA. @@ -26,7 +28,7 @@ def visualize_topics(topic_model, topics: A selection of topics to visualize top_n_topics: Only select the top n most frequent topics use_ctfidf: Whether to use c-TF-IDF representations instead of the embeddings from the embedding model. - custom_labels: If bool, whether to use custom topic labels that were defined using + custom_labels: If bool, whether to use custom topic labels that were defined using `topic_model.set_topic_labels`. If `str`, it uses labels from other aspects, e.g., "Aspect1". title: Title of the plot. @@ -34,7 +36,6 @@ def visualize_topics(topic_model, height: The height of the figure. Examples: - To visualize the topics simply run: ```python @@ -64,90 +65,131 @@ def visualize_topics(topic_model, topic_list = sorted(topics) frequencies = [topic_model.topic_sizes_[topic] for topic in topic_list] if isinstance(custom_labels, str): - words = [[[str(topic), None]] + topic_model.topic_aspects_[custom_labels][topic] for topic in topic_list] + words = [ + [[str(topic), None]] + topic_model.topic_aspects_[custom_labels][topic] + for topic in topic_list + ] words = ["_".join([label[0] for label in labels[:4]]) for labels in words] words = [label if len(label) < 30 else label[:27] + "..." for label in words] elif custom_labels and topic_model.custom_labels_ is not None: - words = [topic_model.custom_labels_[topic + topic_model._outliers] for topic in topic_list] + words = [ + topic_model.custom_labels_[topic + topic_model._outliers] + for topic in topic_list + ] else: - words = [" | ".join([word[0] for word in topic_model.get_topic(topic)[:5]]) for topic in topic_list] + words = [ + " | ".join([word[0] for word in topic_model.get_topic(topic)[:5]]) + for topic in topic_list + ] # Embed c-TF-IDF into 2D all_topics = sorted(list(topic_model.get_topics().keys())) indices = np.array([all_topics.index(topic) for topic in topics]) embeddings, c_tfidf_used = select_topic_representation( - topic_model.c_tf_idf_, topic_model.topic_embeddings_, use_ctfidf=use_ctfidf, output_ndarray=True, + topic_model.c_tf_idf_, + topic_model.topic_embeddings_, + use_ctfidf=use_ctfidf, + output_ndarray=True, ) embeddings = embeddings[indices] if c_tfidf_used: embeddings = MinMaxScaler().fit_transform(embeddings) - embeddings = UMAP(n_neighbors=2, n_components=2, metric='hellinger', random_state=42).fit_transform(embeddings) + embeddings = UMAP( + n_neighbors=2, n_components=2, metric="hellinger", random_state=42 + ).fit_transform(embeddings) else: - embeddings = UMAP(n_neighbors=2, n_components=2, metric='cosine', random_state=42).fit_transform(embeddings) - + embeddings = UMAP( + n_neighbors=2, n_components=2, metric="cosine", random_state=42 + ).fit_transform(embeddings) # Visualize with plotly - df = pd.DataFrame({"x": embeddings[:, 0], "y": embeddings[:, 1], - "Topic": topic_list, "Words": words, "Size": frequencies}) + df = pd.DataFrame( + { + "x": embeddings[:, 0], + "y": embeddings[:, 1], + "Topic": topic_list, + "Words": words, + "Size": frequencies, + } + ) return _plotly_topic_visualization(df, topic_list, title, width, height) -def _plotly_topic_visualization(df: pd.DataFrame, - topic_list: List[str], - title: str, - width: int, - height: int): - """ Create plotly-based visualization of topics with a slider for topic selection """ +def _plotly_topic_visualization( + df: pd.DataFrame, topic_list: List[str], title: str, width: int, height: int +): + """Create plotly-based visualization of topics with a slider for topic selection.""" def get_color(topic_selected): if topic_selected == -1: marker_color = ["#B0BEC5" for _ in topic_list] else: - marker_color = ["red" if topic == topic_selected else "#B0BEC5" for topic in topic_list] - return [{'marker.color': [marker_color]}] + marker_color = [ + "red" if topic == topic_selected else "#B0BEC5" for topic in topic_list + ] + return [{"marker.color": [marker_color]}] # Prepare figure range - x_range = (df.x.min() - abs((df.x.min()) * .15), df.x.max() + abs((df.x.max()) * .15)) - y_range = (df.y.min() - abs((df.y.min()) * .15), df.y.max() + abs((df.y.max()) * .15)) + x_range = ( + df.x.min() - abs((df.x.min()) * 0.15), + df.x.max() + abs((df.x.max()) * 0.15), + ) + y_range = ( + df.y.min() - abs((df.y.min()) * 0.15), + df.y.max() + abs((df.y.max()) * 0.15), + ) # Plot topics - fig = px.scatter(df, x="x", y="y", size="Size", size_max=40, template="simple_white", labels={"x": "", "y": ""}, - hover_data={"Topic": True, "Words": True, "Size": True, "x": False, "y": False}) - fig.update_traces(marker=dict(color="#B0BEC5", line=dict(width=2, color='DarkSlateGrey'))) + fig = px.scatter( + df, + x="x", + y="y", + size="Size", + size_max=40, + template="simple_white", + labels={"x": "", "y": ""}, + hover_data={"Topic": True, "Words": True, "Size": True, "x": False, "y": False}, + ) + fig.update_traces( + marker=dict(color="#B0BEC5", line=dict(width=2, color="DarkSlateGrey")) + ) # Update hover order - fig.update_traces(hovertemplate="
".join(["Topic %{customdata[0]}", - "%{customdata[1]}", - "Size: %{customdata[2]}"])) + fig.update_traces( + hovertemplate="
".join( + [ + "Topic %{customdata[0]}", + "%{customdata[1]}", + "Size: %{customdata[2]}", + ] + ) + ) # Create a slider for topic selection - steps = [dict(label=f"Topic {topic}", method="update", args=get_color(topic)) for topic in topic_list] + steps = [ + dict(label=f"Topic {topic}", method="update", args=get_color(topic)) + for topic in topic_list + ] sliders = [dict(active=0, pad={"t": 50}, steps=steps)] # Stylize layout fig.update_layout( title={ - 'text': f"{title}", - 'y': .95, - 'x': 0.5, - 'xanchor': 'center', - 'yanchor': 'top', - 'font': dict( - size=22, - color="Black") + "text": f"{title}", + "y": 0.95, + "x": 0.5, + "xanchor": "center", + "yanchor": "top", + "font": dict(size=22, color="Black"), }, width=width, height=height, - hoverlabel=dict( - bgcolor="white", - font_size=16, - font_family="Rockwell" - ), + hoverlabel=dict(bgcolor="white", font_size=16, font_family="Rockwell"), xaxis={"visible": False}, yaxis={"visible": False}, - sliders=sliders + sliders=sliders, ) # Update axes ranges @@ -155,14 +197,28 @@ def get_color(topic_selected): fig.update_yaxes(range=y_range) # Add grid in a 'plus' shape - fig.add_shape(type="line", - x0=sum(x_range) / 2, y0=y_range[0], x1=sum(x_range) / 2, y1=y_range[1], - line=dict(color="#CFD8DC", width=2)) - fig.add_shape(type="line", - x0=x_range[0], y0=sum(y_range) / 2, x1=x_range[1], y1=sum(y_range) / 2, - line=dict(color="#9E9E9E", width=2)) - fig.add_annotation(x=x_range[0], y=sum(y_range) / 2, text="D1", showarrow=False, yshift=10) - fig.add_annotation(y=y_range[1], x=sum(x_range) / 2, text="D2", showarrow=False, xshift=10) + fig.add_shape( + type="line", + x0=sum(x_range) / 2, + y0=y_range[0], + x1=sum(x_range) / 2, + y1=y_range[1], + line=dict(color="#CFD8DC", width=2), + ) + fig.add_shape( + type="line", + x0=x_range[0], + y0=sum(y_range) / 2, + x1=x_range[1], + y1=sum(y_range) / 2, + line=dict(color="#9E9E9E", width=2), + ) + fig.add_annotation( + x=x_range[0], y=sum(y_range) / 2, text="D1", showarrow=False, yshift=10 + ) + fig.add_annotation( + y=y_range[1], x=sum(x_range) / 2, text="D2", showarrow=False, xshift=10 + ) fig.data = fig.data[::-1] return fig diff --git a/bertopic/plotting/_topics_over_time.py b/bertopic/plotting/_topics_over_time.py index 534486d7..625a8cce 100644 --- a/bertopic/plotting/_topics_over_time.py +++ b/bertopic/plotting/_topics_over_time.py @@ -4,16 +4,18 @@ from sklearn.preprocessing import normalize -def visualize_topics_over_time(topic_model, - topics_over_time: pd.DataFrame, - top_n_topics: int = None, - topics: List[int] = None, - normalize_frequency: bool = False, - custom_labels: Union[bool, str] = False, - title: str = "Topics over Time", - width: int = 1250, - height: int = 450) -> go.Figure: - """ Visualize topics over time +def visualize_topics_over_time( + topic_model, + topics_over_time: pd.DataFrame, + top_n_topics: int = None, + topics: List[int] = None, + normalize_frequency: bool = False, + custom_labels: Union[bool, str] = False, + title: str = "Topics over Time", + width: int = 1250, + height: int = 450, +) -> go.Figure: + """Visualize topics over time. Arguments: topic_model: A fitted BERTopic instance. @@ -22,7 +24,7 @@ def visualize_topics_over_time(topic_model, top_n_topics: To visualize the most frequent topics instead of all topics: Select which topics you would like to be visualized normalize_frequency: Whether to normalize each topic's frequency individually - custom_labels: If bool, whether to use custom topic labels that were defined using + custom_labels: If bool, whether to use custom topic labels that were defined using `topic_model.set_topic_labels`. If `str`, it uses labels from other aspects, e.g., "Aspect1". title: Title of the plot. @@ -33,7 +35,6 @@ def visualize_topics_over_time(topic_model, A plotly.graph_objects.Figure including all traces Examples: - To visualize the topics over time, simply run: ```python @@ -50,7 +51,15 @@ def visualize_topics_over_time(topic_model, """ - colors = ["#E69F00", "#56B4E9", "#009E73", "#F0E442", "#D55E00", "#0072B2", "#CC79A7"] + colors = [ + "#E69F00", + "#56B4E9", + "#009E73", + "#F0E442", + "#D55E00", + "#0072B2", + "#CC79A7", + ] # Select topics based on top_n and topics args freq_df = topic_model.get_topic_freq() @@ -64,17 +73,34 @@ def visualize_topics_over_time(topic_model, # Prepare data if isinstance(custom_labels, str): - topic_names = [[[str(topic), None]] + topic_model.topic_aspects_[custom_labels][topic] for topic in topics] - topic_names = ["_".join([label[0] for label in labels[:4]]) for labels in topic_names] - topic_names = [label if len(label) < 30 else label[:27] + "..." for label in topic_names] - topic_names = {key: topic_names[index] for index, key in enumerate(topic_model.topic_labels_.keys())} + topic_names = [ + [[str(topic), None]] + topic_model.topic_aspects_[custom_labels][topic] + for topic in topics + ] + topic_names = [ + "_".join([label[0] for label in labels[:4]]) for labels in topic_names + ] + topic_names = [ + label if len(label) < 30 else label[:27] + "..." for label in topic_names + ] + topic_names = { + key: topic_names[index] + for index, key in enumerate(topic_model.topic_labels_.keys()) + } elif topic_model.custom_labels_ is not None and custom_labels: - topic_names = {key: topic_model.custom_labels_[key + topic_model._outliers] for key, _ in topic_model.topic_labels_.items()} + topic_names = { + key: topic_model.custom_labels_[key + topic_model._outliers] + for key, _ in topic_model.topic_labels_.items() + } else: - topic_names = {key: value[:40] + "..." if len(value) > 40 else value - for key, value in topic_model.topic_labels_.items()} + topic_names = { + key: value[:40] + "..." if len(value) > 40 else value + for key, value in topic_model.topic_labels_.items() + } topics_over_time["Name"] = topics_over_time.Topic.map(topic_names) - data = topics_over_time.loc[topics_over_time.Topic.isin(selected_topics), :].sort_values(["Topic", "Timestamp"]) + data = topics_over_time.loc[ + topics_over_time.Topic.isin(selected_topics), : + ].sort_values(["Topic", "Timestamp"]) # Add traces fig = go.Figure() @@ -86,12 +112,17 @@ def visualize_topics_over_time(topic_model, y = normalize(trace_data.Frequency.values.reshape(1, -1))[0] else: y = trace_data.Frequency - fig.add_trace(go.Scatter(x=trace_data.Timestamp, y=y, - mode='lines', - marker_color=colors[index % 7], - hoverinfo="text", - name=topic_name, - hovertext=[f'Topic {topic}
Words: {word}' for word in words])) + fig.add_trace( + go.Scatter( + x=trace_data.Timestamp, + y=y, + mode="lines", + marker_color=colors[index % 7], + hoverinfo="text", + name=topic_name, + hovertext=[f"Topic {topic}
Words: {word}" for word in words], + ) + ) # Styling of the visualization fig.update_xaxes(showgrid=True) @@ -99,25 +130,19 @@ def visualize_topics_over_time(topic_model, fig.update_layout( yaxis_title="Normalized Frequency" if normalize_frequency else "Frequency", title={ - 'text': f"{title}", - 'y': .95, - 'x': 0.40, - 'xanchor': 'center', - 'yanchor': 'top', - 'font': dict( - size=22, - color="Black") + "text": f"{title}", + "y": 0.95, + "x": 0.40, + "xanchor": "center", + "yanchor": "top", + "font": dict(size=22, color="Black"), }, template="simple_white", width=width, height=height, - hoverlabel=dict( - bgcolor="white", - font_size=16, - font_family="Rockwell" - ), + hoverlabel=dict(bgcolor="white", font_size=16, font_family="Rockwell"), legend=dict( title="Global Topic Representation", - ) + ), ) return fig diff --git a/bertopic/plotting/_topics_per_class.py b/bertopic/plotting/_topics_per_class.py index e6a59fd3..5bb8cef4 100644 --- a/bertopic/plotting/_topics_per_class.py +++ b/bertopic/plotting/_topics_per_class.py @@ -4,16 +4,18 @@ from sklearn.preprocessing import normalize -def visualize_topics_per_class(topic_model, - topics_per_class: pd.DataFrame, - top_n_topics: int = 10, - topics: List[int] = None, - normalize_frequency: bool = False, - custom_labels: Union[bool, str] = False, - title: str = "Topics per Class", - width: int = 1250, - height: int = 900) -> go.Figure: - """ Visualize topics per class +def visualize_topics_per_class( + topic_model, + topics_per_class: pd.DataFrame, + top_n_topics: int = 10, + topics: List[int] = None, + normalize_frequency: bool = False, + custom_labels: Union[bool, str] = False, + title: str = "Topics per Class", + width: int = 1250, + height: int = 900, +) -> go.Figure: + """Visualize topics per class. Arguments: topic_model: A fitted BERTopic instance. @@ -22,7 +24,7 @@ def visualize_topics_per_class(topic_model, top_n_topics: To visualize the most frequent topics instead of all topics: Select which topics you would like to be visualized normalize_frequency: Whether to normalize each topic's frequency individually - custom_labels: If bool, whether to use custom topic labels that were defined using + custom_labels: If bool, whether to use custom topic labels that were defined using `topic_model.set_topic_labels`. If `str`, it uses labels from other aspects, e.g., "Aspect1". title: Title of the plot. @@ -33,7 +35,6 @@ def visualize_topics_per_class(topic_model, A plotly.graph_objects.Figure including all traces Examples: - To visualize the topics per class, simply run: ```python @@ -50,7 +51,15 @@ def visualize_topics_per_class(topic_model, """ - colors = ["#E69F00", "#56B4E9", "#009E73", "#F0E442", "#D55E00", "#0072B2", "#CC79A7"] + colors = [ + "#E69F00", + "#56B4E9", + "#009E73", + "#F0E442", + "#D55E00", + "#0072B2", + "#CC79A7", + ] # Select topics based on top_n and topics args freq_df = topic_model.get_topic_freq() @@ -64,15 +73,30 @@ def visualize_topics_per_class(topic_model, # Prepare data if isinstance(custom_labels, str): - topic_names = [[[str(topic), None]] + topic_model.topic_aspects_[custom_labels][topic] for topic in topics] - topic_names = ["_".join([label[0] for label in labels[:4]]) for labels in topic_names] - topic_names = [label if len(label) < 30 else label[:27] + "..." for label in topic_names] - topic_names = {key: topic_names[index] for index, key in enumerate(topic_model.topic_labels_.keys())} + topic_names = [ + [[str(topic), None]] + topic_model.topic_aspects_[custom_labels][topic] + for topic in topics + ] + topic_names = [ + "_".join([label[0] for label in labels[:4]]) for labels in topic_names + ] + topic_names = [ + label if len(label) < 30 else label[:27] + "..." for label in topic_names + ] + topic_names = { + key: topic_names[index] + for index, key in enumerate(topic_model.topic_labels_.keys()) + } elif topic_model.custom_labels_ is not None and custom_labels: - topic_names = {key: topic_model.custom_labels_[key + topic_model._outliers] for key, _ in topic_model.topic_labels_.items()} + topic_names = { + key: topic_model.custom_labels_[key + topic_model._outliers] + for key, _ in topic_model.topic_labels_.items() + } else: - topic_names = {key: value[:40] + "..." if len(value) > 40 else value - for key, value in topic_model.topic_labels_.items()} + topic_names = { + key: value[:40] + "..." if len(value) > 40 else value + for key, value in topic_model.topic_labels_.items() + } topics_per_class["Name"] = topics_per_class.Topic.map(topic_names) data = topics_per_class.loc[topics_per_class.Topic.isin(selected_topics), :] @@ -90,14 +114,18 @@ def visualize_topics_per_class(topic_model, x = normalize(trace_data.Frequency.values.reshape(1, -1))[0] else: x = trace_data.Frequency - fig.add_trace(go.Bar(y=trace_data.Class, - x=x, - visible=visible, - marker_color=colors[index % 7], - hoverinfo="text", - name=topic_name, - orientation="h", - hovertext=[f'Topic {topic}
Words: {word}' for word in words])) + fig.add_trace( + go.Bar( + y=trace_data.Class, + x=x, + visible=visible, + marker_color=colors[index % 7], + hoverinfo="text", + name=topic_name, + orientation="h", + hovertext=[f"Topic {topic}
Words: {word}" for word in words], + ) + ) # Styling of the visualization fig.update_xaxes(showgrid=True) @@ -106,25 +134,19 @@ def visualize_topics_per_class(topic_model, xaxis_title="Normalized Frequency" if normalize_frequency else "Frequency", yaxis_title="Class", title={ - 'text': f"{title}", - 'y': .95, - 'x': 0.40, - 'xanchor': 'center', - 'yanchor': 'top', - 'font': dict( - size=22, - color="Black") + "text": f"{title}", + "y": 0.95, + "x": 0.40, + "xanchor": "center", + "yanchor": "top", + "font": dict(size=22, color="Black"), }, template="simple_white", width=width, height=height, - hoverlabel=dict( - bgcolor="white", - font_size=16, - font_family="Rockwell" - ), + hoverlabel=dict(bgcolor="white", font_size=16, font_family="Rockwell"), legend=dict( title="Global Topic Representation", - ) + ), ) return fig diff --git a/bertopic/representation/__init__.py b/bertopic/representation/__init__.py index 3fd5ce63..3c18305f 100644 --- a/bertopic/representation/__init__.py +++ b/bertopic/representation/__init__.py @@ -24,7 +24,9 @@ from bertopic.representation._zeroshot import ZeroShotClassification except ModuleNotFoundError: msg = "`pip install bertopic` without `--no-deps` \n\n" - ZeroShotClassification = NotInstalled("ZeroShotClassification", "transformers", custom_msg=msg) + ZeroShotClassification = NotInstalled( + "ZeroShotClassification", "transformers", custom_msg=msg + ) # OpenAI Generator try: @@ -64,5 +66,5 @@ "OpenAI", "LangChain", "LlamaCPP", - "VisualRepresentation" + "VisualRepresentation", ] diff --git a/bertopic/representation/_base.py b/bertopic/representation/_base.py index cf3dcf75..63feeda9 100644 --- a/bertopic/representation/_base.py +++ b/bertopic/representation/_base.py @@ -5,14 +5,16 @@ class BaseRepresentation(BaseEstimator): - """ The base representation model for fine-tuning topic representations """ - def extract_topics(self, - topic_model, - documents: pd.DataFrame, - c_tf_idf: csr_matrix, - topics: Mapping[str, List[Tuple[str, float]]] - ) -> Mapping[str, List[Tuple[str, float]]]: - """ Extract topics + """The base representation model for fine-tuning topic representations.""" + + def extract_topics( + self, + topic_model, + documents: pd.DataFrame, + c_tf_idf: csr_matrix, + topics: Mapping[str, List[Tuple[str, float]]], + ) -> Mapping[str, List[Tuple[str, float]]]: + """Extract topics. Each representation model that inherits this class will have its arguments (topic_model, documents, c_tf_idf, topics) diff --git a/bertopic/representation/_cohere.py b/bertopic/representation/_cohere.py index dd082f0e..64511daf 100644 --- a/bertopic/representation/_cohere.py +++ b/bertopic/representation/_cohere.py @@ -37,7 +37,7 @@ class Cohere(BaseRepresentation): - """ Use the Cohere API to generate topic labels based on their + """Use the Cohere API to generate topic labels based on their generative model. Find more about their models here: @@ -51,19 +51,19 @@ class Cohere(BaseRepresentation): NOTE: Use `"[KEYWORDS]"` and `"[DOCUMENTS]"` in the prompt to decide where the keywords and documents need to be inserted. - delay_in_seconds: The delay in seconds between consecutive prompts - in order to prevent RateLimitErrors. + delay_in_seconds: The delay in seconds between consecutive prompts + in order to prevent RateLimitErrors. nr_docs: The number of documents to pass to OpenAI if a prompt with the `["DOCUMENTS"]` tag is used. diversity: The diversity of documents to pass to OpenAI. - Accepts values between 0 and 1. A higher + Accepts values between 0 and 1. A higher values results in passing more diverse documents whereas lower values passes more similar documents. doc_length: The maximum length of each document. If a document is longer, it will be truncated. If None, the entire document is passed. tokenizer: The tokenizer used to calculate to split the document into segments - used to count the length of a document. - * If tokenizer is 'char', then the document is split up + used to count the length of a document. + * If tokenizer is 'char', then the document is split up into characters which are counted to adhere to `doc_length` * If tokenizer is 'whitespace', the document is split up into words separated by whitespaces. These words are counted @@ -103,16 +103,18 @@ class Cohere(BaseRepresentation): representation_model = Cohere(co, prompt=prompt) ``` """ - def __init__(self, - client, - model: str = "xlarge", - prompt: str = None, - delay_in_seconds: float = None, - nr_docs: int = 4, - diversity: float = None, - doc_length: int = None, - tokenizer: Union[str, Callable] = None - ): + + def __init__( + self, + client, + model: str = "xlarge", + prompt: str = None, + delay_in_seconds: float = None, + nr_docs: int = 4, + diversity: float = None, + doc_length: int = None, + tokenizer: Union[str, Callable] = None, + ): self.client = client self.model = model self.prompt = prompt if prompt is not None else DEFAULT_PROMPT @@ -124,13 +126,14 @@ def __init__(self, self.tokenizer = tokenizer self.prompts_ = [] - def extract_topics(self, - topic_model, - documents: pd.DataFrame, - c_tf_idf: csr_matrix, - topics: Mapping[str, List[Tuple[str, float]]] - ) -> Mapping[str, List[Tuple[str, float]]]: - """ Extract topics + def extract_topics( + self, + topic_model, + documents: pd.DataFrame, + c_tf_idf: csr_matrix, + topics: Mapping[str, List[Tuple[str, float]]], + ) -> Mapping[str, List[Tuple[str, float]]]: + """Extract topics. Arguments: topic_model: Not used @@ -142,12 +145,19 @@ def extract_topics(self, updated_topics: Updated topic representations """ # Extract the top 4 representative documents per topic - repr_docs_mappings, _, _, _ = topic_model._extract_representative_docs(c_tf_idf, documents, topics, 500, self.nr_docs, self.diversity) + repr_docs_mappings, _, _, _ = topic_model._extract_representative_docs( + c_tf_idf, documents, topics, 500, self.nr_docs, self.diversity + ) # Generate using Cohere's Language Model updated_topics = {} - for topic, docs in tqdm(repr_docs_mappings.items(), disable=not topic_model.verbose): - truncated_docs = [truncate_document(topic_model, self.doc_length, self.tokenizer, doc) for doc in docs] + for topic, docs in tqdm( + repr_docs_mappings.items(), disable=not topic_model.verbose + ): + truncated_docs = [ + truncate_document(topic_model, self.doc_length, self.tokenizer, doc) + for doc in docs + ] prompt = self._create_prompt(truncated_docs, topic, topics) self.prompts_.append(prompt) @@ -155,11 +165,13 @@ def extract_topics(self, if self.delay_in_seconds: time.sleep(self.delay_in_seconds) - request = self.client.generate(model=self.model, - prompt=prompt, - max_tokens=50, - num_generations=1, - stop_sequences=["\n"]) + request = self.client.generate( + model=self.model, + prompt=prompt, + max_tokens=50, + num_generations=1, + stop_sequences=["\n"], + ) label = request.generations[0].text.strip() updated_topics[topic] = [(label, 1)] + [("", 0) for _ in range(9)] diff --git a/bertopic/representation/_keybert.py b/bertopic/representation/_keybert.py index c75501f7..7d9d19e2 100644 --- a/bertopic/representation/_keybert.py +++ b/bertopic/representation/_keybert.py @@ -10,13 +10,15 @@ class KeyBERTInspired(BaseRepresentation): - def __init__(self, - top_n_words: int = 10, - nr_repr_docs: int = 5, - nr_samples: int = 500, - nr_candidate_words: int = 100, - random_state: int = 42): - """ Use a KeyBERT-like model to fine-tune the topic representations + def __init__( + self, + top_n_words: int = 10, + nr_repr_docs: int = 5, + nr_samples: int = 500, + nr_candidate_words: int = 100, + random_state: int = 42, + ): + """Use a KeyBERT-like model to fine-tune the topic representations. The algorithm follows KeyBERT but does some optimization in order to speed up inference. @@ -63,13 +65,14 @@ def __init__(self, self.nr_candidate_words = nr_candidate_words self.random_state = random_state - def extract_topics(self, - topic_model, - documents: pd.DataFrame, - c_tf_idf: csr_matrix, - topics: Mapping[str, List[Tuple[str, float]]] - ) -> Mapping[str, List[Tuple[str, float]]]: - """ Extract topics + def extract_topics( + self, + topic_model, + documents: pd.DataFrame, + c_tf_idf: csr_matrix, + topics: Mapping[str, List[Tuple[str, float]]], + ) -> Mapping[str, List[Tuple[str, float]]]: + """Extract topics. Arguments: topic_model: A BERTopic model @@ -81,26 +84,33 @@ def extract_topics(self, updated_topics: Updated topic representations """ # We extract the top n representative documents per class - _, representative_docs, repr_doc_indices, _ = topic_model._extract_representative_docs(c_tf_idf, documents, topics, self.nr_samples, self.nr_repr_docs) + _, representative_docs, repr_doc_indices, _ = ( + topic_model._extract_representative_docs( + c_tf_idf, documents, topics, self.nr_samples, self.nr_repr_docs + ) + ) # We extract the top n words per class topics = self._extract_candidate_words(topic_model, c_tf_idf, topics) # We calculate the similarity between word and document embeddings and create # topic embeddings from the representative document embeddings - sim_matrix, words = self._extract_embeddings(topic_model, topics, representative_docs, repr_doc_indices) + sim_matrix, words = self._extract_embeddings( + topic_model, topics, representative_docs, repr_doc_indices + ) # Find the best matching words based on the similarity matrix for each topic updated_topics = self._extract_top_words(words, topics, sim_matrix) return updated_topics - def _extract_candidate_words(self, - topic_model, - c_tf_idf: csr_matrix, - topics: Mapping[str, List[Tuple[str, float]]] - ) -> Mapping[str, List[Tuple[str, float]]]: - """ For each topic, extract candidate words based on the c-TF-IDF + def _extract_candidate_words( + self, + topic_model, + c_tf_idf: csr_matrix, + topics: Mapping[str, List[Tuple[str, float]]], + ) -> Mapping[str, List[Tuple[str, float]]]: + """For each topic, extract candidate words based on the c-TF-IDF representation. Arguments: @@ -127,23 +137,30 @@ def _extract_candidate_words(self, scores = np.take_along_axis(scores, sorted_indices, axis=1) # Get top 30 words per topic based on c-TF-IDF score - topics = {label: [(words[word_index], score) - if word_index is not None and score > 0 - else ("", 0.00001) - for word_index, score in zip(indices[index][::-1], scores[index][::-1]) - ] - for index, label in enumerate(labels)} - topics = {label: list(zip(*values[:self.nr_candidate_words]))[0] for label, values in topics.items()} + topics = { + label: [ + (words[word_index], score) + if word_index is not None and score > 0 + else ("", 0.00001) + for word_index, score in zip(indices[index][::-1], scores[index][::-1]) + ] + for index, label in enumerate(labels) + } + topics = { + label: list(zip(*values[: self.nr_candidate_words]))[0] + for label, values in topics.items() + } return topics - def _extract_embeddings(self, - topic_model, - topics: Mapping[str, List[Tuple[str, float]]], - representative_docs: List[str], - repr_doc_indices: List[List[int]] - ) -> Union[np.ndarray, List[str]]: - """ Extract the representative document embeddings and create topic embeddings. + def _extract_embeddings( + self, + topic_model, + topics: Mapping[str, List[Tuple[str, float]]], + representative_docs: List[str], + repr_doc_indices: List[List[int]], + ) -> Union[np.ndarray, List[str]]: + """Extract the representative document embeddings and create topic embeddings. Then extract word embeddings and calculate the cosine similarity between topic embeddings and the word embeddings. Topic embeddings are the average of representative document embeddings. @@ -160,22 +177,29 @@ def _extract_embeddings(self, vocab: The complete vocabulary of input documents """ # Calculate representative docs embeddings and create topic embeddings - repr_embeddings = topic_model._extract_embeddings(representative_docs, method="document", verbose=False) - topic_embeddings = [np.mean(repr_embeddings[i[0]:i[-1]+1], axis=0) for i in repr_doc_indices] + repr_embeddings = topic_model._extract_embeddings( + representative_docs, method="document", verbose=False + ) + topic_embeddings = [ + np.mean(repr_embeddings[i[0] : i[-1] + 1], axis=0) for i in repr_doc_indices + ] # Calculate word embeddings and extract best matching with updated topic_embeddings vocab = list(set([word for words in topics.values() for word in words])) - word_embeddings = topic_model._extract_embeddings(vocab, method="document", verbose=False) + word_embeddings = topic_model._extract_embeddings( + vocab, method="document", verbose=False + ) sim = cosine_similarity(topic_embeddings, word_embeddings) return sim, vocab - def _extract_top_words(self, - vocab: List[str], - topics: Mapping[str, List[Tuple[str, float]]], - sim: np.ndarray - ) -> Mapping[str, List[Tuple[str, float]]]: - """ Extract the top n words per topic based on the + def _extract_top_words( + self, + vocab: List[str], + topics: Mapping[str, List[Tuple[str, float]]], + sim: np.ndarray, + ) -> Mapping[str, List[Tuple[str, float]]]: + """Extract the top n words per topic based on the similarity matrix between topics and words. Arguments: @@ -192,7 +216,14 @@ def _extract_top_words(self, for i, topic in enumerate(labels): indices = [vocab.index(word) for word in topics[topic]] values = sim[:, indices][i] - word_indices = [indices[index] for index in np.argsort(values)[-self.top_n_words:]] - updated_topics[topic] = [(vocab[index], val) for val, index in zip(np.sort(values)[-self.top_n_words:], word_indices)][::-1] + word_indices = [ + indices[index] for index in np.argsort(values)[-self.top_n_words :] + ] + updated_topics[topic] = [ + (vocab[index], val) + for val, index in zip( + np.sort(values)[-self.top_n_words :], word_indices + ) + ][::-1] return updated_topics diff --git a/bertopic/representation/_langchain.py b/bertopic/representation/_langchain.py index b0c68439..ad92aef1 100644 --- a/bertopic/representation/_langchain.py +++ b/bertopic/representation/_langchain.py @@ -1,7 +1,7 @@ import pandas as pd from langchain.docstore.document import Document from scipy.sparse import csr_matrix -from typing import Callable, Dict, Mapping, List, Tuple, Union +from typing import Callable, Mapping, List, Tuple, Union from bertopic.representation._base import BaseRepresentation from bertopic.representation._utils import truncate_document @@ -10,7 +10,7 @@ class LangChain(BaseRepresentation): - """ Using chains in langchain to generate topic labels. + """Using chains in langchain to generate topic labels. The classic example uses `langchain.chains.question_answering.load_qa_chain`. This returns a chain that takes a list of documents and a question as input. @@ -32,21 +32,21 @@ class LangChain(BaseRepresentation): formats the representative documents within the prompt. nr_docs: The number of documents to pass to LangChain diversity: The diversity of documents to pass to LangChain. - Accepts values between 0 and 1. A higher + Accepts values between 0 and 1. A higher values results in passing more diverse documents whereas lower values passes more similar documents. doc_length: The maximum length of each document. If a document is longer, it will be truncated. If None, the entire document is passed. tokenizer: The tokenizer used to calculate to split the document into segments - used to count the length of a document. - * If tokenizer is 'char', then the document is split up + used to count the length of a document. + * If tokenizer is 'char', then the document is split up into characters which are counted to adhere to `doc_length` * If tokenizer is 'whitespace', the document is split up into words separated by whitespaces. These words are counted and truncated depending on `doc_length` * If tokenizer is 'vectorizer', then the internal CountVectorizer is used to tokenize the document. These tokens are counted - and truncated depending on `doc_length`. They are decoded with + and truncated depending on `doc_length`. They are decoded with whitespaces. * If tokenizer is a callable, then that callable is used to tokenize the document. These tokens are counted and truncated depending @@ -129,15 +129,17 @@ class LangChain(BaseRepresentation): representation_model = LangChain(chain, prompt=representation_prompt) ``` """ - def __init__(self, - chain, - prompt: str = None, - nr_docs: int = 4, - diversity: float = None, - doc_length: int = None, - tokenizer: Union[str, Callable] = None, - chain_config = None, - ): + + def __init__( + self, + chain, + prompt: str = None, + nr_docs: int = 4, + diversity: float = None, + doc_length: int = None, + tokenizer: Union[str, Callable] = None, + chain_config=None, + ): self.chain = chain self.prompt = prompt if prompt is not None else DEFAULT_PROMPT self.default_prompt_ = DEFAULT_PROMPT @@ -147,13 +149,14 @@ def __init__(self, self.doc_length = doc_length self.tokenizer = tokenizer - def extract_topics(self, - topic_model, - documents: pd.DataFrame, - c_tf_idf: csr_matrix, - topics: Mapping[str, List[Tuple[str, float]]] - ) -> Mapping[str, List[Tuple[str, int]]]: - """ Extract topics + def extract_topics( + self, + topic_model, + documents: pd.DataFrame, + c_tf_idf: csr_matrix, + topics: Mapping[str, List[Tuple[str, float]]], + ) -> Mapping[str, List[Tuple[str, int]]]: + """Extract topics. Arguments: topic_model: A BERTopic model @@ -171,7 +174,7 @@ def extract_topics(self, topics=topics, nr_samples=500, nr_repr_docs=self.nr_docs, - diversity=self.diversity + diversity=self.diversity, ) # Generate label using langchain's batch functionality @@ -179,10 +182,7 @@ def extract_topics(self, [ Document( page_content=truncate_document( - topic_model, - self.doc_length, - self.tokenizer, - doc + topic_model, self.doc_length, self.tokenizer, doc ) ) for doc in docs @@ -203,7 +203,7 @@ def extract_topics(self, {"input_documents": docs, "question": prompt} for docs, prompt in zip(chain_docs, prompts) ] - + else: inputs = [ {"input_documents": docs, "question": self.prompt} diff --git a/bertopic/representation/_llamacpp.py b/bertopic/representation/_llamacpp.py index 7060df80..fa573463 100644 --- a/bertopic/representation/_llamacpp.py +++ b/bertopic/representation/_llamacpp.py @@ -18,11 +18,11 @@ class LlamaCPP(BaseRepresentation): - """ A llama.cpp implementation to use as a representation model. + """A llama.cpp implementation to use as a representation model. Arguments: - model: Either a string pointing towards a local LLM or a - `llama_cpp.Llama` object. + model: Either a string pointing towards a local LLM or a + `llama_cpp.Llama` object. prompt: The prompt to be used in the model. If no prompt is given, `self.default_prompt_` is used instead. NOTE: Use `"[KEYWORDS]"` and `"[DOCUMENTS]"` in the prompt @@ -88,23 +88,27 @@ class LlamaCPP(BaseRepresentation): topic_model = BERTopic(representation_model=representation_model, verbose=True) ``` """ - def __init__(self, - model: Union[str, Llama], - prompt: str = None, - pipeline_kwargs: Mapping[str, Any] = {}, - nr_docs: int = 4, - diversity: float = None, - doc_length: int = None, - tokenizer: Union[str, Callable] = None - ): + + def __init__( + self, + model: Union[str, Llama], + prompt: str = None, + pipeline_kwargs: Mapping[str, Any] = {}, + nr_docs: int = 4, + diversity: float = None, + doc_length: int = None, + tokenizer: Union[str, Callable] = None, + ): if isinstance(model, str): self.model = Llama(model_path=model, n_gpu_layers=-1, stop="Q:") elif isinstance(model, Llama): self.model = model else: - raise ValueError("Make sure that the model that you" - "pass is either a string referring to a" - "local LLM or a ` llama_cpp.Llama` object.") + raise ValueError( + "Make sure that the model that you" + "pass is either a string referring to a" + "local LLM or a ` llama_cpp.Llama` object." + ) self.prompt = prompt if prompt is not None else DEFAULT_PROMPT self.default_prompt_ = DEFAULT_PROMPT self.pipeline_kwargs = pipeline_kwargs @@ -115,13 +119,14 @@ def __init__(self, self.prompts_ = [] - def extract_topics(self, - topic_model, - documents: pd.DataFrame, - c_tf_idf: csr_matrix, - topics: Mapping[str, List[Tuple[str, float]]] - ) -> Mapping[str, List[Tuple[str, float]]]: - """ Extract topic representations and return a single label + def extract_topics( + self, + topic_model, + documents: pd.DataFrame, + c_tf_idf: csr_matrix, + topics: Mapping[str, List[Tuple[str, float]]], + ) -> Mapping[str, List[Tuple[str, float]]]: + """Extract topic representations and return a single label. Arguments: topic_model: A BERTopic model @@ -134,28 +139,32 @@ def extract_topics(self, """ # Extract the top 4 representative documents per topic repr_docs_mappings, _, _, _ = topic_model._extract_representative_docs( - c_tf_idf, - documents, - topics, - 500, - self.nr_docs, - self.diversity + c_tf_idf, documents, topics, 500, self.nr_docs, self.diversity ) updated_topics = {} - for topic, docs in tqdm(repr_docs_mappings.items(), disable=not topic_model.verbose): - + for topic, docs in tqdm( + repr_docs_mappings.items(), disable=not topic_model.verbose + ): # Prepare prompt - truncated_docs = [truncate_document(topic_model, self.doc_length, self.tokenizer, doc) for doc in docs] + truncated_docs = [ + truncate_document(topic_model, self.doc_length, self.tokenizer, doc) + for doc in docs + ] prompt = self._create_prompt(truncated_docs, topic, topics) self.prompts_.append(prompt) # Extract result from generator and use that as label - topic_description = self.model(prompt, **self.pipeline_kwargs)['choices'] - topic_description = [(description["text"].replace(prompt, ""), 1) for description in topic_description] + topic_description = self.model(prompt, **self.pipeline_kwargs)["choices"] + topic_description = [ + (description["text"].replace(prompt, ""), 1) + for description in topic_description + ] if len(topic_description) < 10: - topic_description += [("", 0) for _ in range(10-len(topic_description))] + topic_description += [ + ("", 0) for _ in range(10 - len(topic_description)) + ] updated_topics[topic] = topic_description diff --git a/bertopic/representation/_mmr.py b/bertopic/representation/_mmr.py index 4d0686a9..07a8dd13 100644 --- a/bertopic/representation/_mmr.py +++ b/bertopic/representation/_mmr.py @@ -8,7 +8,7 @@ class MaximalMarginalRelevance(BaseRepresentation): - """ Calculate Maximal Marginal Relevance (MMR) + """Calculate Maximal Marginal Relevance (MMR) between candidate keywords and the document. MMR considers the similarity of keywords/keyphrases with the @@ -35,17 +35,19 @@ class MaximalMarginalRelevance(BaseRepresentation): topic_model = BERTopic(representation_model=representation_model) ``` """ + def __init__(self, diversity: float = 0.1, top_n_words: int = 10): self.diversity = diversity self.top_n_words = top_n_words - def extract_topics(self, - topic_model, - documents: pd.DataFrame, - c_tf_idf: csr_matrix, - topics: Mapping[str, List[Tuple[str, float]]] - ) -> Mapping[str, List[Tuple[str, float]]]: - """ Extract topic representations + def extract_topics( + self, + topic_model, + documents: pd.DataFrame, + c_tf_idf: csr_matrix, + topics: Mapping[str, List[Tuple[str, float]]], + ) -> Mapping[str, List[Tuple[str, float]]]: + """Extract topic representations. Arguments: topic_model: The BERTopic model @@ -56,41 +58,55 @@ def extract_topics(self, Returns: updated_topics: Updated topic representations """ - if topic_model.embedding_model is None: - warnings.warn("MaximalMarginalRelevance can only be used BERTopic was instantiated" - "with the `embedding_model` parameter.") + warnings.warn( + "MaximalMarginalRelevance can only be used BERTopic was instantiated" + "with the `embedding_model` parameter." + ) return topics updated_topics = {} for topic, topic_words in topics.items(): words = [word[0] for word in topic_words] - word_embeddings = topic_model._extract_embeddings(words, method="word", verbose=False) - topic_embedding = topic_model._extract_embeddings(" ".join(words), method="word", verbose=False).reshape(1, -1) - topic_words = mmr(topic_embedding, word_embeddings, words, self.diversity, self.top_n_words) - updated_topics[topic] = [(word, value) for word, value in topics[topic] if word in topic_words] + word_embeddings = topic_model._extract_embeddings( + words, method="word", verbose=False + ) + topic_embedding = topic_model._extract_embeddings( + " ".join(words), method="word", verbose=False + ).reshape(1, -1) + topic_words = mmr( + topic_embedding, + word_embeddings, + words, + self.diversity, + self.top_n_words, + ) + updated_topics[topic] = [ + (word, value) for word, value in topics[topic] if word in topic_words + ] return updated_topics -def mmr(doc_embedding: np.ndarray, - word_embeddings: np.ndarray, - words: List[str], - diversity: float = 0.1, - top_n: int = 10) -> List[str]: - """ Maximal Marginal Relevance +def mmr( + doc_embedding: np.ndarray, + word_embeddings: np.ndarray, + words: List[str], + diversity: float = 0.1, + top_n: int = 10, +) -> List[str]: + """Maximal Marginal Relevance. Arguments: doc_embedding: The document embeddings word_embeddings: The embeddings of the selected candidate keywords/phrases words: The selected candidate keywords/keyphrases - diversity: The diversity of the selected embeddings. + diversity: The diversity of the selected embeddings. Values between 0 and 1. top_n: The top n items to return Returns: List[str]: The selected keywords/keyphrases """ - # Extract similarity within words, and between words and the document word_doc_similarity = cosine_similarity(word_embeddings, doc_embedding) word_similarity = cosine_similarity(word_embeddings) @@ -103,10 +119,14 @@ def mmr(doc_embedding: np.ndarray, # Extract similarities within candidates and # between candidates and selected keywords/phrases candidate_similarities = word_doc_similarity[candidates_idx, :] - target_similarities = np.max(word_similarity[candidates_idx][:, keywords_idx], axis=1) + target_similarities = np.max( + word_similarity[candidates_idx][:, keywords_idx], axis=1 + ) # Calculate MMR - mmr = (1-diversity) * candidate_similarities - diversity * target_similarities.reshape(-1, 1) + mmr = ( + 1 - diversity + ) * candidate_similarities - diversity * target_similarities.reshape(-1, 1) mmr_idx = candidates_idx[np.argmax(mmr)] # Update keywords & candidates diff --git a/bertopic/representation/_openai.py b/bertopic/representation/_openai.py index 1a0c2b59..35bdf1da 100644 --- a/bertopic/representation/_openai.py +++ b/bertopic/representation/_openai.py @@ -5,7 +5,10 @@ from scipy.sparse import csr_matrix from typing import Mapping, List, Tuple, Any, Union, Callable from bertopic.representation._base import BaseRepresentation -from bertopic.representation._utils import retry_with_exponential_backoff, truncate_document +from bertopic.representation._utils import ( + retry_with_exponential_backoff, + truncate_document, +) DEFAULT_PROMPT = """ @@ -47,7 +50,7 @@ class OpenAI(BaseRepresentation): - """ Using the OpenAI API to generate topic labels based + r"""Using the OpenAI API to generate topic labels based on one of their Completion of ChatCompletion models. The default method is `openai.Completion` if `chat=False`. @@ -135,19 +138,21 @@ class OpenAI(BaseRepresentation): representation_model = OpenAI(client, model="gpt-3.5-turbo", delay_in_seconds=10, chat=True) ``` """ - def __init__(self, - client, - model: str = "text-embedding-3-small", - prompt: str = None, - generator_kwargs: Mapping[str, Any] = {}, - delay_in_seconds: float = None, - exponential_backoff: bool = False, - chat: bool = False, - nr_docs: int = 4, - diversity: float = None, - doc_length: int = None, - tokenizer: Union[str, Callable] = None - ): + + def __init__( + self, + client, + model: str = "text-embedding-3-small", + prompt: str = None, + generator_kwargs: Mapping[str, Any] = {}, + delay_in_seconds: float = None, + exponential_backoff: bool = False, + chat: bool = False, + nr_docs: int = 4, + diversity: float = None, + doc_length: int = None, + tokenizer: Union[str, Callable] = None, + ): self.client = client self.model = model @@ -175,13 +180,14 @@ def __init__(self, if not self.generator_kwargs.get("stop") and not chat: self.generator_kwargs["stop"] = "\n" - def extract_topics(self, - topic_model, - documents: pd.DataFrame, - c_tf_idf: csr_matrix, - topics: Mapping[str, List[Tuple[str, float]]] - ) -> Mapping[str, List[Tuple[str, float]]]: - """ Extract topics + def extract_topics( + self, + topic_model, + documents: pd.DataFrame, + c_tf_idf: csr_matrix, + topics: Mapping[str, List[Tuple[str, float]]], + ) -> Mapping[str, List[Tuple[str, float]]]: + """Extract topics. Arguments: topic_model: A BERTopic model @@ -193,12 +199,19 @@ def extract_topics(self, updated_topics: Updated topic representations """ # Extract the top n representative documents per topic - repr_docs_mappings, _, _, _ = topic_model._extract_representative_docs(c_tf_idf, documents, topics, 500, self.nr_docs, self.diversity) + repr_docs_mappings, _, _, _ = topic_model._extract_representative_docs( + c_tf_idf, documents, topics, 500, self.nr_docs, self.diversity + ) # Generate using OpenAI's Language Model updated_topics = {} - for topic, docs in tqdm(repr_docs_mappings.items(), disable=not topic_model.verbose): - truncated_docs = [truncate_document(topic_model, self.doc_length, self.tokenizer, doc) for doc in docs] + for topic, docs in tqdm( + repr_docs_mappings.items(), disable=not topic_model.verbose + ): + truncated_docs = [ + truncate_document(topic_model, self.doc_length, self.tokenizer, doc) + for doc in docs + ] prompt = self._create_prompt(truncated_docs, topic, topics) self.prompts_.append(prompt) @@ -209,9 +222,13 @@ def extract_topics(self, if self.chat: messages = [ {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": prompt} + {"role": "user", "content": prompt}, ] - kwargs = {"model": self.model, "messages": messages, **self.generator_kwargs} + kwargs = { + "model": self.model, + "messages": messages, + **self.generator_kwargs, + } if self.exponential_backoff: response = chat_completions_with_backoff(self.client, **kwargs) else: @@ -220,14 +237,25 @@ def extract_topics(self, # Check whether content was actually generated # Addresses #1570 for potential issues with OpenAI's content filter if hasattr(response.choices[0].message, "content"): - label = response.choices[0].message.content.strip().replace("topic: ", "") + label = ( + response.choices[0] + .message.content.strip() + .replace("topic: ", "") + ) else: label = "No label returned" else: if self.exponential_backoff: - response = completions_with_backoff(self.client, model=self.model, prompt=prompt, **self.generator_kwargs) + response = completions_with_backoff( + self.client, + model=self.model, + prompt=prompt, + **self.generator_kwargs, + ) else: - response = self.client.completions.create(model=self.model, prompt=prompt, **self.generator_kwargs) + response = self.client.completions.create( + model=self.model, prompt=prompt, **self.generator_kwargs + ) label = response.choices[0].text.strip() updated_topics[topic] = [(label, 1)] @@ -265,16 +293,12 @@ def _replace_documents(prompt, docs): def completions_with_backoff(client, **kwargs): return retry_with_exponential_backoff( client.completions.create, - errors=( - openai.RateLimitError, - ), + errors=(openai.RateLimitError,), )(**kwargs) def chat_completions_with_backoff(client, **kwargs): return retry_with_exponential_backoff( client.chat.completions.create, - errors=( - openai.RateLimitError, - ), + errors=(openai.RateLimitError,), )(**kwargs) diff --git a/bertopic/representation/_pos.py b/bertopic/representation/_pos.py index ed3d1c5f..08139b53 100644 --- a/bertopic/representation/_pos.py +++ b/bertopic/representation/_pos.py @@ -1,4 +1,3 @@ - import numpy as np import pandas as pd @@ -14,7 +13,7 @@ class PartOfSpeech(BaseRepresentation): - """ Extract Topic Keywords based on their Part-of-Speech + """Extract Topic Keywords based on their Part-of-Speech. DEFAULT_PATTERNS = [ [{'POS': 'ADJ'}, {'POS': 'NOUN'}], @@ -63,36 +62,43 @@ class PartOfSpeech(BaseRepresentation): representation_model = PartOfSpeech("en_core_web_sm", pos_patterns=pos_patterns) ``` """ - def __init__(self, - model: Union[str, Language] = "en_core_web_sm", - top_n_words: int = 10, - pos_patterns: List[str] = None): + + def __init__( + self, + model: Union[str, Language] = "en_core_web_sm", + top_n_words: int = 10, + pos_patterns: List[str] = None, + ): if isinstance(model, str): self.model = spacy.load(model) elif isinstance(model, Language): self.model = model else: - raise ValueError("Make sure that the Spacy model that you" - "pass is either a string referring to a" - "Spacy model or a Spacy nlp object.") + raise ValueError( + "Make sure that the Spacy model that you" + "pass is either a string referring to a" + "Spacy model or a Spacy nlp object." + ) self.top_n_words = top_n_words if pos_patterns is None: self.pos_patterns = [ - [{'POS': 'ADJ'}, {'POS': 'NOUN'}], - [{'POS': 'NOUN'}], [{'POS': 'ADJ'}] + [{"POS": "ADJ"}, {"POS": "NOUN"}], + [{"POS": "NOUN"}], + [{"POS": "ADJ"}], ] else: self.pos_patterns = pos_patterns - def extract_topics(self, - topic_model, - documents: pd.DataFrame, - c_tf_idf: csr_matrix, - topics: Mapping[str, List[Tuple[str, float]]] - ) -> Mapping[str, List[Tuple[str, float]]]: - """ Extract topics + def extract_topics( + self, + topic_model, + documents: pd.DataFrame, + c_tf_idf: csr_matrix, + topics: Mapping[str, List[Tuple[str, float]]], + ) -> Mapping[str, List[Tuple[str, float]]]: + """Extract topics. Arguments: topic_model: A BERTopic model @@ -114,7 +120,9 @@ def extract_topics(self, candidate_documents = [] for keyword in keywords: selection = documents.loc[documents.Topic == topic, :] - selection = selection.loc[selection.Document.str.contains(keyword), "Document"] + selection = selection.loc[ + selection.Document.str.contains(keyword), "Document" + ] if len(selection) > 0: for document in selection[:2]: candidate_documents.append(document) @@ -141,13 +149,28 @@ def extract_topics(self, updated_topics = {topic: [] for topic in topics.keys()} for topic, candidate_keywords in candidate_topics.items(): - word_indices = np.sort([words_lookup.get(keyword) for keyword in candidate_keywords if keyword in words_lookup]) + word_indices = np.sort( + [ + words_lookup.get(keyword) + for keyword in candidate_keywords + if keyword in words_lookup + ] + ) vals = topic_model.c_tf_idf_[:, word_indices][topic + topic_model._outliers] - indices = np.argsort(np.array(vals.todense().reshape(1, -1))[0])[-self.top_n_words:][::-1] - vals = np.sort(np.array(vals.todense().reshape(1, -1))[0])[-self.top_n_words:][::-1] - topic_words = [(words[word_indices[index]], val) for index, val in zip(indices, vals)] + indices = np.argsort(np.array(vals.todense().reshape(1, -1))[0])[ + -self.top_n_words : + ][::-1] + vals = np.sort(np.array(vals.todense().reshape(1, -1))[0])[ + -self.top_n_words : + ][::-1] + topic_words = [ + (words[word_indices[index]], val) for index, val in zip(indices, vals) + ] updated_topics[topic] = topic_words if len(updated_topics[topic]) < self.top_n_words: - updated_topics[topic] += [("", 0) for _ in range(self.top_n_words-len(updated_topics[topic]))] + updated_topics[topic] += [ + ("", 0) + for _ in range(self.top_n_words - len(updated_topics[topic])) + ] return updated_topics diff --git a/bertopic/representation/_textgeneration.py b/bertopic/representation/_textgeneration.py index 91cdd399..3bc3853a 100644 --- a/bertopic/representation/_textgeneration.py +++ b/bertopic/representation/_textgeneration.py @@ -15,7 +15,7 @@ class TextGeneration(BaseRepresentation): - """ Text2Text or text generation with transformers + """Text2Text or text generation with transformers. Arguments: model: A transformers pipeline that should be initialized as "text-generation" @@ -81,16 +81,18 @@ class TextGeneration(BaseRepresentation): representation_model = TextGeneration(generator) ``` """ - def __init__(self, - model: Union[str, pipeline], - prompt: str = None, - pipeline_kwargs: Mapping[str, Any] = {}, - random_state: int = 42, - nr_docs: int = 4, - diversity: float = None, - doc_length: int = None, - tokenizer: Union[str, Callable] = None - ): + + def __init__( + self, + model: Union[str, pipeline], + prompt: str = None, + pipeline_kwargs: Mapping[str, Any] = {}, + random_state: int = 42, + nr_docs: int = 4, + diversity: float = None, + doc_length: int = None, + tokenizer: Union[str, Callable] = None, + ): self.random_state = random_state set_seed(random_state) if isinstance(model, str): @@ -98,9 +100,11 @@ def __init__(self, elif isinstance(model, Pipeline): self.model = model else: - raise ValueError("Make sure that the HF model that you" - "pass is either a string referring to a" - "HF model or a `transformers.pipeline` object.") + raise ValueError( + "Make sure that the HF model that you" + "pass is either a string referring to a" + "HF model or a `transformers.pipeline` object." + ) self.prompt = prompt if prompt is not None else DEFAULT_PROMPT self.default_prompt_ = DEFAULT_PROMPT self.pipeline_kwargs = pipeline_kwargs @@ -111,13 +115,14 @@ def __init__(self, self.prompts_ = [] - def extract_topics(self, - topic_model, - documents: pd.DataFrame, - c_tf_idf: csr_matrix, - topics: Mapping[str, List[Tuple[str, float]]] - ) -> Mapping[str, List[Tuple[str, float]]]: - """ Extract topic representations and return a single label + def extract_topics( + self, + topic_model, + documents: pd.DataFrame, + c_tf_idf: csr_matrix, + topics: Mapping[str, List[Tuple[str, float]]], + ) -> Mapping[str, List[Tuple[str, float]]]: + """Extract topic representations and return a single label. Arguments: topic_model: A BERTopic model @@ -131,30 +136,38 @@ def extract_topics(self, # Extract the top 4 representative documents per topic if self.prompt != DEFAULT_PROMPT and "[DOCUMENTS]" in self.prompt: repr_docs_mappings, _, _, _ = topic_model._extract_representative_docs( - c_tf_idf, - documents, - topics, - 500, - self.nr_docs, - self.diversity + c_tf_idf, documents, topics, 500, self.nr_docs, self.diversity ) else: repr_docs_mappings = {topic: None for topic in topics.keys()} updated_topics = {} - for topic, docs in tqdm(repr_docs_mappings.items(), disable=not topic_model.verbose): - + for topic, docs in tqdm( + repr_docs_mappings.items(), disable=not topic_model.verbose + ): # Prepare prompt - truncated_docs = [truncate_document(topic_model, self.doc_length, self.tokenizer, doc) for doc in docs] if docs is not None else docs + truncated_docs = ( + [ + truncate_document(topic_model, self.doc_length, self.tokenizer, doc) + for doc in docs + ] + if docs is not None + else docs + ) prompt = self._create_prompt(truncated_docs, topic, topics) self.prompts_.append(prompt) # Extract result from generator and use that as label topic_description = self.model(prompt, **self.pipeline_kwargs) - topic_description = [(description["generated_text"].replace(prompt, ""), 1) for description in topic_description] + topic_description = [ + (description["generated_text"].replace(prompt, ""), 1) + for description in topic_description + ] if len(topic_description) < 10: - topic_description += [("", 0) for _ in range(10-len(topic_description))] + topic_description += [ + ("", 0) for _ in range(10 - len(topic_description)) + ] updated_topics[topic] = topic_description diff --git a/bertopic/representation/_utils.py b/bertopic/representation/_utils.py index 75599e18..00f157a5 100644 --- a/bertopic/representation/_utils.py +++ b/bertopic/representation/_utils.py @@ -3,7 +3,7 @@ def truncate_document(topic_model, doc_length, tokenizer, document: str): - """ Truncate a document to a certain length + """Truncate a document to a certain length. If you want to add a custom tokenizer, then it will need to have a `decode` and `encode` method. An example would be the following custom tokenizer: @@ -25,15 +25,15 @@ def decode(self, doc_chunks): doc_length: The maximum length of each document. If a document is longer, it will be truncated. If None, the entire document is passed. tokenizer: The tokenizer used to calculate to split the document into segments - used to count the length of a document. - * If tokenizer is 'char', then the document is split up + used to count the length of a document. + * If tokenizer is 'char', then the document is split up into characters which are counted to adhere to `doc_length` * If tokenizer is 'whitespace', the document is split up into words separated by whitespaces. These words are counted and truncated depending on `doc_length` * If tokenizer is 'vectorizer', then the internal CountVectorizer is used to tokenize the document. These tokens are counted - and truncated depending on `doc_length`. They are decoded with + and truncated depending on `doc_length`. They are decoded with whitespaces. * If tokenizer is a callable, then that callable is used to tokenize the document. These tokens are counted and truncated depending @@ -51,7 +51,7 @@ def decode(self, doc_chunks): elif tokenizer == "vectorizer": tokenizer = topic_model.vectorizer_model.build_tokenizer() truncated_document = " ".join(tokenizer(document)[:doc_length]) - elif hasattr(tokenizer, 'encode') and hasattr(tokenizer, 'decode'): + elif hasattr(tokenizer, "encode") and hasattr(tokenizer, "decode"): encoded_document = tokenizer.encode(document) truncated_document = tokenizer.decode(encoded_document[:doc_length]) return truncated_document @@ -67,36 +67,36 @@ def retry_with_exponential_backoff( errors: tuple = None, ): """Retry a function with exponential backoff.""" - + def wrapper(*args, **kwargs): # Initialize variables num_retries = 0 delay = initial_delay - + # Loop until a successful response or max_retries is hit or an exception is raised while True: try: return func(*args, **kwargs) - + # Retry on specific errors - except errors as e: + except errors: # Increment retries num_retries += 1 - + # Check if max retries has been reached if num_retries > max_retries: raise Exception( f"Maximum number of retries ({max_retries}) exceeded." ) - + # Increment the delay delay *= exponential_base * (1 + jitter * random.random()) - + # Sleep for the delay time.sleep(delay) - + # Raise exceptions for any errors not specified except Exception as e: raise e - - return wrapper \ No newline at end of file + + return wrapper diff --git a/bertopic/representation/_visual.py b/bertopic/representation/_visual.py index eab90a8a..897d7c9d 100644 --- a/bertopic/representation/_visual.py +++ b/bertopic/representation/_visual.py @@ -12,10 +12,10 @@ class VisualRepresentation(BaseRepresentation): - """ From a collection of representative documents, extract + """From a collection of representative documents, extract images to represent topics. These topics are represented by a - collage of images. - + collage of images. + Arguments: nr_repr_images: Number of representative images to extract nr_samples: The number of candidate documents to extract per cluster. @@ -24,8 +24,8 @@ class VisualRepresentation(BaseRepresentation): to a square. This can be visually more appealing if all input images are all almost squares. image_to_text_model: The model to caption images. - batch_size: The number of images to pass to the - `image_to_text_model`. + batch_size: The number of images to pass to the + `image_to_text_model`. Usage: @@ -44,13 +44,16 @@ class VisualRepresentation(BaseRepresentation): topic_model = BERTopic(representation_model=representation_model) ``` """ - def __init__(self, - nr_repr_images: int = 9, - nr_samples: int = 500, - image_height: Tuple[int, int] = 600, - image_squares: bool = False, - image_to_text_model: Union[str, Pipeline] = None, - batch_size: int = 32): + + def __init__( + self, + nr_repr_images: int = 9, + nr_samples: int = 500, + image_height: Tuple[int, int] = 600, + image_squares: bool = False, + image_to_text_model: Union[str, Pipeline] = None, + batch_size: int = 32, + ): self.nr_repr_images = nr_repr_images self.nr_samples = nr_samples self.image_height = image_height @@ -60,21 +63,26 @@ def __init__(self, if isinstance(image_to_text_model, Pipeline): self.image_to_text_model = image_to_text_model elif isinstance(image_to_text_model, str): - self.image_to_text_model = pipeline("image-to-text", model=image_to_text_model) + self.image_to_text_model = pipeline( + "image-to-text", model=image_to_text_model + ) elif image_to_text_model is None: self.image_to_text_model = None else: - raise ValueError("Please select a correct transformers pipeline. For example:" - "pipeline('image-to-text', model='nlpconnect/vit-gpt2-image-captioning')") + raise ValueError( + "Please select a correct transformers pipeline. For example:" + "pipeline('image-to-text', model='nlpconnect/vit-gpt2-image-captioning')" + ) self.batch_size = batch_size - def extract_topics(self, - topic_model, - documents: pd.DataFrame, - c_tf_idf: csr_matrix, - topics: Mapping[str, List[Tuple[str, float]]] - ) -> Mapping[str, List[Tuple[str, float]]]: - """ Extract topics + def extract_topics( + self, + topic_model, + documents: pd.DataFrame, + c_tf_idf: csr_matrix, + topics: Mapping[str, List[Tuple[str, float]]], + ) -> Mapping[str, List[Tuple[str, float]]]: + """Extract topics. Arguments: topic_model: A BERTopic model @@ -87,32 +95,37 @@ def extract_topics(self, """ # Extract image ids of most representative documents images = documents["Image"].values.tolist() - (_, _, _, - repr_docs_ids) = topic_model._extract_representative_docs(c_tf_idf, - documents, - topics, - nr_samples=self.nr_samples, - nr_repr_docs=self.nr_repr_images) + (_, _, _, repr_docs_ids) = topic_model._extract_representative_docs( + c_tf_idf, + documents, + topics, + nr_samples=self.nr_samples, + nr_repr_docs=self.nr_repr_images, + ) unique_topics = sorted(list(topics.keys())) # Combine representative images into a single representation representative_images = {} for topic in tqdm(unique_topics): - # Get and order represetnative images - sliced_examplars = repr_docs_ids[topic+topic_model._outliers] - sliced_examplars = [sliced_examplars[i:i + 3] for i in - range(0, len(sliced_examplars), 3)] + sliced_examplars = repr_docs_ids[topic + topic_model._outliers] + sliced_examplars = [ + sliced_examplars[i : i + 3] for i in range(0, len(sliced_examplars), 3) + ] images_to_combine = [ - [Image.open(images[index]) if isinstance(images[index], str) - else images[index] for index in sub_indices] + [ + Image.open(images[index]) + if isinstance(images[index], str) + else images[index] + for index in sub_indices + ] for sub_indices in sliced_examplars ] # Concatenate representative images - representative_image = get_concat_tile_resize(images_to_combine, - self.image_height, - self.image_squares) + representative_image = get_concat_tile_resize( + images_to_combine, self.image_height, self.image_squares + ) representative_images[topic] = representative_image # Make sure to properly close images @@ -120,13 +133,13 @@ def extract_topics(self, for image_list in images_to_combine: for image in image_list: image.close() - + return representative_images - - def _convert_image_to_text(self, - images: List[str], - verbose: bool = False) -> List[str]: - """ Convert a list of images to captions. + + def _convert_image_to_text( + self, images: List[str], verbose: bool = False + ) -> List[str]: + """Convert a list of images to captions. Arguments: images: A list of images or words to be converted to text. @@ -149,35 +162,53 @@ def _convert_image_to_text(self, documents = [output[0]["generated_text"] for output in outputs] return documents - - def image_to_text(self, documents: pd.DataFrame, embeddings: np.ndarray) -> pd.DataFrame: - """ Convert images to text """ + + def image_to_text( + self, documents: pd.DataFrame, embeddings: np.ndarray + ) -> pd.DataFrame: + """Convert images to text.""" # Create image topic embeddings topics = documents.Topic.values.tolist() images = documents.Image.values.tolist() df = pd.DataFrame(np.hstack([np.array(topics).reshape(-1, 1), embeddings])) image_topic_embeddings = df.groupby(0).mean().values - + # Extract image centroids image_centroids = {} unique_topics = sorted(list(set(topics))) for topic, topic_embedding in zip(unique_topics, image_topic_embeddings): indices = np.array([index for index, t in enumerate(topics) if t == topic]) top_n = min([self.nr_repr_images, len(indices)]) - indices = mmr(topic_embedding.reshape(1, -1), embeddings[indices], indices, top_n=top_n, diversity=0.1) + indices = mmr( + topic_embedding.reshape(1, -1), + embeddings[indices], + indices, + top_n=top_n, + diversity=0.1, + ) image_centroids[topic] = indices - + # Extract documents documents = pd.DataFrame(columns=["Document", "ID", "Topic", "Image"]) current_id = 0 for topic, image_ids in tqdm(image_centroids.items()): - selected_images = [Image.open(images[index]) if isinstance(images[index], str) else images[index] for index in image_ids] + selected_images = [ + Image.open(images[index]) + if isinstance(images[index], str) + else images[index] + for index in image_ids + ] text = self._convert_image_to_text(selected_images) - + for doc, image_id in zip(text, image_ids): - documents.loc[len(documents), :] = [doc, current_id, topic, images[image_id]] + documents.loc[len(documents), :] = [ + doc, + current_id, + topic, + images[image_id], + ] current_id += 1 - + # Properly close images if isinstance(images[image_ids[0]], str): for image in selected_images: @@ -185,15 +216,13 @@ def image_to_text(self, documents: pd.DataFrame, embeddings: np.ndarray) -> pd.D return documents - def _chunks(self, images): + def _chunks(self, images): for i in range(0, len(images), self.batch_size): - yield images[i:i + self.batch_size] + yield images[i : i + self.batch_size] def get_concat_h_multi_resize(im_list): - """ - Code adapted from: https://note.nkmk.me/en/python-pillow-concat-images/ - """ + """Code adapted from: https://note.nkmk.me/en/python-pillow-concat-images/.""" min_height = min(im.height for im in im_list) min_height = max(im.height for im in im_list) im_list_resize = [] @@ -202,7 +231,7 @@ def get_concat_h_multi_resize(im_list): im_list_resize.append(im) total_width = sum(im.width for im in im_list_resize) - dst = Image.new('RGB', (total_width, min_height), (255, 255, 255)) + dst = Image.new("RGB", (total_width, min_height), (255, 255, 255)) pos_x = 0 for im in im_list_resize: dst.paste(im, (pos_x, 0)) @@ -211,15 +240,15 @@ def get_concat_h_multi_resize(im_list): def get_concat_v_multi_resize(im_list): - """ - Code adapted from: https://note.nkmk.me/en/python-pillow-concat-images/ - """ + """Code adapted from: https://note.nkmk.me/en/python-pillow-concat-images/.""" min_width = min(im.width for im in im_list) min_width = max(im.width for im in im_list) - im_list_resize = [im.resize((min_width, int(im.height * min_width / im.width)), resample=0) - for im in im_list] + im_list_resize = [ + im.resize((min_width, int(im.height * min_width / im.width)), resample=0) + for im in im_list + ] total_height = sum(im.height for im in im_list_resize) - dst = Image.new('RGB', (min_width, total_height), (255, 255, 255)) + dst = Image.new("RGB", (min_width, total_height), (255, 255, 255)) pos_y = 0 for im in im_list_resize: dst.paste(im, (0, pos_y)) @@ -228,17 +257,17 @@ def get_concat_v_multi_resize(im_list): def get_concat_tile_resize(im_list_2d, image_height=600, image_squares=False): - """ - Code adapted from: https://note.nkmk.me/en/python-pillow-concat-images/ - """ + """Code adapted from: https://note.nkmk.me/en/python-pillow-concat-images/.""" images = [[image.copy() for image in images] for images in im_list_2d] - - # Create + + # Create if image_squares: - width = int(image_height / 3) + width = int(image_height / 3) height = int(image_height / 3) - images = [[image.resize((width, height)) for image in images] for images in im_list_2d] - + images = [ + [image.resize((width, height)) for image in images] for images in im_list_2d + ] + # Resize images based on minimum size else: min_width = min([min([img.width for img in imgs]) for imgs in im_list_2d]) @@ -246,17 +275,22 @@ def get_concat_tile_resize(im_list_2d, image_height=600, image_squares=False): for i, imgs in enumerate(images): for j, img in enumerate(imgs): if img.height > img.width: - images[i][j] = img.resize((int(img.width * min_height / img.height), min_height), resample=0) + images[i][j] = img.resize( + (int(img.width * min_height / img.height), min_height), + resample=0, + ) elif img.width > img.height: - images[i][j] = img.resize((min_width, int(img.height * min_width / img.width)), resample=0) + images[i][j] = img.resize( + (min_width, int(img.height * min_width / img.width)), resample=0 + ) else: images[i][j] = img.resize((min_width, min_width)) # Resize grid image images = [get_concat_h_multi_resize(im_list_h) for im_list_h in images] img = get_concat_v_multi_resize(images) - height_percentage = (image_height/float(img.size[1])) - adjusted_width = int((float(img.size[0])*float(height_percentage))) + height_percentage = image_height / float(img.size[1]) + adjusted_width = int((float(img.size[0]) * float(height_percentage))) img = img.resize((adjusted_width, image_height), Image.Resampling.LANCZOS) - + return img diff --git a/bertopic/representation/_zeroshot.py b/bertopic/representation/_zeroshot.py index eddd3fd8..7dff499b 100644 --- a/bertopic/representation/_zeroshot.py +++ b/bertopic/representation/_zeroshot.py @@ -7,7 +7,7 @@ class ZeroShotClassification(BaseRepresentation): - """ Zero-shot Classification on topic keywords with candidate labels + """Zero-shot Classification on topic keywords with candidate labels. Arguments: candidate_topics: A list of labels to assign to the topics if they @@ -34,31 +34,36 @@ class ZeroShotClassification(BaseRepresentation): topic_model = BERTopic(representation_model=representation_model) ``` """ - def __init__(self, - candidate_topics: List[str], - model: str = "facebook/bart-large-mnli", - pipeline_kwargs: Mapping[str, Any] = {}, - min_prob: float = 0.8 - ): + + def __init__( + self, + candidate_topics: List[str], + model: str = "facebook/bart-large-mnli", + pipeline_kwargs: Mapping[str, Any] = {}, + min_prob: float = 0.8, + ): self.candidate_topics = candidate_topics if isinstance(model, str): self.model = pipeline("zero-shot-classification", model=model) elif isinstance(model, Pipeline): self.model = model else: - raise ValueError("Make sure that the HF model that you" - "pass is either a string referring to a" - "HF model or a `transformers.pipeline` object.") + raise ValueError( + "Make sure that the HF model that you" + "pass is either a string referring to a" + "HF model or a `transformers.pipeline` object." + ) self.pipeline_kwargs = pipeline_kwargs self.min_prob = min_prob - def extract_topics(self, - topic_model, - documents: pd.DataFrame, - c_tf_idf: csr_matrix, - topics: Mapping[str, List[Tuple[str, float]]] - ) -> Mapping[str, List[Tuple[str, float]]]: - """ Extract topics + def extract_topics( + self, + topic_model, + documents: pd.DataFrame, + c_tf_idf: csr_matrix, + topics: Mapping[str, List[Tuple[str, float]]], + ) -> Mapping[str, List[Tuple[str, float]]]: + """Extract topics. Arguments: topic_model: Not used @@ -70,8 +75,12 @@ def extract_topics(self, updated_topics: Updated topic representations """ # Classify topics - topic_descriptions = [" ".join(list(zip(*topics[topic]))[0]) for topic in topics.keys()] - classifications = self.model(topic_descriptions, self.candidate_topics, **self.pipeline_kwargs) + topic_descriptions = [ + " ".join(list(zip(*topics[topic]))[0]) for topic in topics.keys() + ] + classifications = self.model( + topic_descriptions, self.candidate_topics, **self.pipeline_kwargs + ) # Extract labels updated_topics = {} @@ -81,19 +90,25 @@ def extract_topics(self, # Multi-label assignment if self.pipeline_kwargs.get("multi_label"): topic_description = [] - for label, score in zip(classification["labels"], classification["scores"]): + for label, score in zip( + classification["labels"], classification["scores"] + ): if score > self.min_prob: topic_description.append((label, score)) # Single label assignment elif classification["scores"][0] > self.min_prob: - topic_description = [(classification["labels"][0], classification["scores"][0])] + topic_description = [ + (classification["labels"][0], classification["scores"][0]) + ] # Make sure that 10 items are returned if len(topic_description) == 0: topic_description = topics[topic] elif len(topic_description) < 10: - topic_description += [("", 0) for _ in range(10-len(topic_description))] + topic_description += [ + ("", 0) for _ in range(10 - len(topic_description)) + ] updated_topics[topic] = topic_description return updated_topics diff --git a/bertopic/vectorizers/__init__.py b/bertopic/vectorizers/__init__.py index a813a197..af566558 100644 --- a/bertopic/vectorizers/__init__.py +++ b/bertopic/vectorizers/__init__.py @@ -1,7 +1,4 @@ from ._ctfidf import ClassTfidfTransformer from ._online_cv import OnlineCountVectorizer -__all__ = [ - "ClassTfidfTransformer", - "OnlineCountVectorizer" -] +__all__ = ["ClassTfidfTransformer", "OnlineCountVectorizer"] diff --git a/bertopic/vectorizers/_ctfidf.py b/bertopic/vectorizers/_ctfidf.py index f84d5e33..d2d0e3c6 100644 --- a/bertopic/vectorizers/_ctfidf.py +++ b/bertopic/vectorizers/_ctfidf.py @@ -7,8 +7,7 @@ class ClassTfidfTransformer(TfidfTransformer): - """ - A Class-based TF-IDF procedure using scikit-learns TfidfTransformer as a base. + """A Class-based TF-IDF procedure using scikit-learns TfidfTransformer as a base. ![](../algorithm/c-TF-IDF.svg) @@ -27,24 +26,25 @@ class ClassTfidfTransformer(TfidfTransformer): `log(1+((avg_nr_samples - df + 0.5) / (df+0.5)))` reduce_frequent_words: Takes the square root of the bag-of-words after normalizing the matrix. Helps to reduce the impact of words that appear too frequently. - seed_words: Specific words that will have their idf value increased by - the value of `seed_multiplier`. + seed_words: Specific words that will have their idf value increased by + the value of `seed_multiplier`. NOTE: This will only increase the value of words that have an exact match. seed_multiplier: The value with which the idf values of the words in `seed_words` are multiplied. Examples: - ```python transformer = ClassTfidfTransformer() ``` """ - def __init__(self, - bm25_weighting: bool = False, - reduce_frequent_words: bool = False, - seed_words: List[str] = None, - seed_multiplier: float = 2 - ): + + def __init__( + self, + bm25_weighting: bool = False, + reduce_frequent_words: bool = False, + seed_words: List[str] = None, + seed_multiplier: float = 2, + ): self.bm25_weighting = bm25_weighting self.reduce_frequent_words = reduce_frequent_words self.seed_words = seed_words @@ -58,7 +58,7 @@ def fit(self, X: sp.csr_matrix, multiplier: np.ndarray = None): X: A matrix of term/token counts. multiplier: A multiplier for increasing/decreasing certain IDF scores """ - X = check_array(X, accept_sparse=('csr', 'csc')) + X = check_array(X, accept_sparse=("csr", "csc")) if not sp.issparse(X): X = sp.csr_matrix(X) dtype = np.float64 @@ -74,26 +74,29 @@ def fit(self, X: sp.csr_matrix, multiplier: np.ndarray = None): # BM25-inspired weighting procedure if self.bm25_weighting: - idf = np.log(1+((avg_nr_samples - df + 0.5) / (df+0.5))) + idf = np.log(1 + ((avg_nr_samples - df + 0.5) / (df + 0.5))) # Divide the average number of samples by the word frequency # +1 is added to force values to be positive else: - idf = np.log((avg_nr_samples / df)+1) + idf = np.log((avg_nr_samples / df) + 1) # Multiplier to increase/decrease certain idf scores if multiplier is not None: idf = idf * multiplier - self._idf_diag = sp.diags(idf, offsets=0, - shape=(n_features, n_features), - format='csr', - dtype=dtype) + self._idf_diag = sp.diags( + idf, + offsets=0, + shape=(n_features, n_features), + format="csr", + dtype=dtype, + ) return self def transform(self, X: sp.csr_matrix): - """Transform a count-based matrix to c-TF-IDF + """Transform a count-based matrix to c-TF-IDF. Arguments: X (sparse matrix): A matrix of term/token counts. @@ -102,7 +105,7 @@ def transform(self, X: sp.csr_matrix): X (sparse matrix): A c-TF-IDF matrix """ if self.use_idf: - X = normalize(X, axis=1, norm='l1', copy=False) + X = normalize(X, axis=1, norm="l1", copy=False) if self.reduce_frequent_words: X.data = np.sqrt(X.data) diff --git a/bertopic/vectorizers/_online_cv.py b/bertopic/vectorizers/_online_cv.py index 5e8de94f..fedb363c 100644 --- a/bertopic/vectorizers/_online_cv.py +++ b/bertopic/vectorizers/_online_cv.py @@ -9,7 +9,7 @@ class OnlineCountVectorizer(CountVectorizer): - """ An online variant of the CountVectorizer with updating vocabulary. + """An online variant of the CountVectorizer with updating vocabulary. At each `.partial_fit`, its vocabulary is updated based on any OOV words it might find. Then, `.update_bow` can be used to track and update @@ -42,7 +42,6 @@ class OnlineCountVectorizer(CountVectorizer): X_ (scipy.sparse.csr_matrix) : The Bag-of-Words representation Examples: - ```python from bertopic.vectorizers import OnlineCountVectorizer vectorizer = OnlineCountVectorizer(stop_words="english") @@ -68,21 +67,19 @@ class OnlineCountVectorizer(CountVectorizer): References: Adapted from: https://github.com/idoshlomo/online_vectorizers """ - def __init__(self, - decay: float = None, - delete_min_df: float = None, - **kwargs): + + def __init__(self, decay: float = None, delete_min_df: float = None, **kwargs): self.decay = decay self.delete_min_df = delete_min_df super(OnlineCountVectorizer, self).__init__(**kwargs) def partial_fit(self, raw_documents: List[str]) -> None: - """ Perform a partial fit and update vocabulary with OOV tokens + """Perform a partial fit and update vocabulary with OOV tokens. Arguments: raw_documents: A list of documents """ - if not hasattr(self, 'vocabulary_'): + if not hasattr(self, "vocabulary_"): return self.fit(raw_documents) analyzer = self.build_analyzer() @@ -92,13 +89,18 @@ def partial_fit(self, raw_documents: List[str]) -> None: if oov_tokens: max_index = max(self.vocabulary_.values()) - oov_vocabulary = dict(zip(oov_tokens, list(range(max_index + 1, max_index + 1 + len(oov_tokens), 1)))) + oov_vocabulary = dict( + zip( + oov_tokens, + list(range(max_index + 1, max_index + 1 + len(oov_tokens), 1)), + ) + ) self.vocabulary_.update(oov_vocabulary) return self def update_bow(self, raw_documents: List[str]) -> csr_matrix: - """ Create or update the bag-of-words matrix + """Create or update the bag-of-words matrix. Update the bag-of-words matrix by adding the newly transformed documents. This may add empty columns if new words are found and/or @@ -119,11 +121,15 @@ def update_bow(self, raw_documents: List[str]) -> csr_matrix: X = self.transform(raw_documents) # Add empty columns if new words are found - columns = csr_matrix((self.X_.shape[0], X.shape[1] - self.X_.shape[1]), dtype=int) + columns = csr_matrix( + (self.X_.shape[0], X.shape[1] - self.X_.shape[1]), dtype=int + ) self.X_ = sparse.hstack([self.X_, columns]) # Add empty rows if new topics are found - rows = csr_matrix((X.shape[0] - self.X_.shape[0], self.X_.shape[1]), dtype=int) + rows = csr_matrix( + (X.shape[0] - self.X_.shape[0], self.X_.shape[1]), dtype=int + ) self.X_ = sparse.vstack([self.X_, rows]) # Decay of BoW matrix @@ -140,7 +146,7 @@ def update_bow(self, raw_documents: List[str]) -> csr_matrix: return self.X_ def _clean_bow(self) -> None: - """ Remove words that do not exceed `self.delete_min_df` """ + """Remove words that do not exceed `self.delete_min_df`.""" # Only keep words with a minimum frequency indices = np.where(self.X_.sum(0) >= self.delete_min_df)[1] indices_dict = {index: index for index in indices} diff --git a/pyproject.toml b/pyproject.toml index 79877153..d0c1abfe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,6 +75,7 @@ spacy = [ test = [ "pytest>=5.4.3", "pytest-cov>=2.6.1", + "ruff~=0.4.7", ] use = [ "tensorflow", @@ -95,3 +96,28 @@ Repository = "https://github.com/MaartenGr/BERTopic.git" [tool.setuptools.packages.find] include = ["bertopic*"] exclude = ["tests"] + +[tool.ruff] +target-version = "py38" + +[tool.ruff.lint] +select = [ + "E4", # Ruff Defaults + "E7", + "E9", + "F", # End Ruff Defaults, + "D" +] + +ignore = [ + "D100", # Missing docstring in public module + "D104", # Missing docstring in public package + "D205", # 1 blank line required between summary line and description + "E731", # Do not assign a lambda expression, use a def +] + +[tool.ruff.lint.per-file-ignores] +"**/tests/*" = ["D"] # Ignore all docstring errors in tests + +[tool.ruff.lint.pydocstyle] +convention = "google" \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index 0b7cbb78..610c8690 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,10 +6,9 @@ from sklearn.datasets import fetch_20newsgroups from sentence_transformers import SentenceTransformer from sklearn.cluster import KMeans, MiniBatchKMeans -from sklearn.decomposition import PCA, IncrementalPCA +from sklearn.decomposition import PCA from bertopic.vectorizers import OnlineCountVectorizer from bertopic.representation import KeyBERTInspired, MaximalMarginalRelevance -from bertopic.cluster import BaseCluster from bertopic.dimensionality import BaseDimensionalityReduction from sklearn.linear_model import LogisticRegression @@ -28,20 +27,24 @@ def document_embeddings(documents, embedding_model): @pytest.fixture(scope="session") def reduced_embeddings(document_embeddings): - reduced_embeddings = UMAP(n_neighbors=10, n_components=2, min_dist=0.0, metric='cosine').fit_transform(document_embeddings) + reduced_embeddings = UMAP( + n_neighbors=10, n_components=2, min_dist=0.0, metric="cosine" + ).fit_transform(document_embeddings) return reduced_embeddings @pytest.fixture(scope="session") def documents(): - newsgroup_docs = fetch_20newsgroups(subset='all', remove=('headers', 'footers', 'quotes'))['data'][:1000] + newsgroup_docs = fetch_20newsgroups( + subset="all", remove=("headers", "footers", "quotes") + )["data"][:1000] return newsgroup_docs @pytest.fixture(scope="session") def targets(): - data = fetch_20newsgroups(subset='all', remove=('headers', 'footers', 'quotes')) - y = data['target'][:1000] + data = fetch_20newsgroups(subset="all", remove=("headers", "footers", "quotes")) + y = data["target"][:1000] return y @@ -57,7 +60,12 @@ def base_topic_model(documents, document_embeddings, embedding_model): @pytest.fixture(scope="session") def zeroshot_topic_model(documents, document_embeddings, embedding_model): zeroshot_topic_list = ["religion", "cars", "electronics"] - model = BERTopic(embedding_model=embedding_model, calculate_probabilities=True, zeroshot_topic_list=zeroshot_topic_list, zeroshot_min_similarity=0.5) + model = BERTopic( + embedding_model=embedding_model, + calculate_probabilities=True, + zeroshot_topic_list=zeroshot_topic_list, + zeroshot_min_similarity=0.5, + ) model.umap_model.random_state = 42 model.hdbscan_model.min_cluster_size = 2 model.fit(documents, document_embeddings) @@ -66,20 +74,49 @@ def zeroshot_topic_model(documents, document_embeddings, embedding_model): @pytest.fixture(scope="session") def custom_topic_model(documents, document_embeddings, embedding_model): - umap_model = UMAP(n_neighbors=15, n_components=6, min_dist=0.0, metric='cosine', random_state=42) - hdbscan_model = HDBSCAN(min_cluster_size=3, metric='euclidean', cluster_selection_method='eom', prediction_data=True) - model = BERTopic(umap_model=umap_model, hdbscan_model=hdbscan_model, embedding_model=embedding_model, calculate_probabilities=True).fit(documents, document_embeddings) + umap_model = UMAP( + n_neighbors=15, n_components=6, min_dist=0.0, metric="cosine", random_state=42 + ) + hdbscan_model = HDBSCAN( + min_cluster_size=3, + metric="euclidean", + cluster_selection_method="eom", + prediction_data=True, + ) + model = BERTopic( + umap_model=umap_model, + hdbscan_model=hdbscan_model, + embedding_model=embedding_model, + calculate_probabilities=True, + ).fit(documents, document_embeddings) return model + @pytest.fixture(scope="session") def representation_topic_model(documents, document_embeddings, embedding_model): - umap_model = UMAP(n_neighbors=15, n_components=6, min_dist=0.0, metric='cosine', random_state=42) - hdbscan_model = HDBSCAN(min_cluster_size=3, metric='euclidean', cluster_selection_method='eom', prediction_data=True) - representation_model = {"Main": KeyBERTInspired(), "MMR": [KeyBERTInspired(top_n_words=30), MaximalMarginalRelevance()]} - model = BERTopic(umap_model=umap_model, hdbscan_model=hdbscan_model, embedding_model=embedding_model, representation_model=representation_model, - calculate_probabilities=True).fit(documents, document_embeddings) + umap_model = UMAP( + n_neighbors=15, n_components=6, min_dist=0.0, metric="cosine", random_state=42 + ) + hdbscan_model = HDBSCAN( + min_cluster_size=3, + metric="euclidean", + cluster_selection_method="eom", + prediction_data=True, + ) + representation_model = { + "Main": KeyBERTInspired(), + "MMR": [KeyBERTInspired(top_n_words=30), MaximalMarginalRelevance()], + } + model = BERTopic( + umap_model=umap_model, + hdbscan_model=hdbscan_model, + embedding_model=embedding_model, + representation_model=representation_model, + calculate_probabilities=True, + ).fit(documents, document_embeddings) return model + @pytest.fixture(scope="session") def reduced_topic_model(custom_topic_model, documents): model = copy.deepcopy(custom_topic_model) @@ -92,8 +129,7 @@ def merged_topic_model(custom_topic_model, documents): model = copy.deepcopy(custom_topic_model) # Merge once - topics_to_merge = [[1, 2], - [3, 4]] + topics_to_merge = [[1, 2], [3, 4]] model.merge_topics(documents, topics_to_merge) # Merge second time @@ -106,7 +142,11 @@ def merged_topic_model(custom_topic_model, documents): def kmeans_pca_topic_model(documents, document_embeddings): hdbscan_model = KMeans(n_clusters=15, random_state=42) dim_model = PCA(n_components=5) - model = BERTopic(hdbscan_model=hdbscan_model, umap_model=dim_model, embedding_model=embedding_model).fit(documents, document_embeddings) + model = BERTopic( + hdbscan_model=hdbscan_model, + umap_model=dim_model, + embedding_model=embedding_model, + ).fit(documents, document_embeddings) return model @@ -116,9 +156,9 @@ def supervised_topic_model(documents, document_embeddings, embedding_model, targ clf = LogisticRegression() model = BERTopic( - embedding_model=embedding_model, - umap_model=empty_dimensionality_model, - hdbscan_model=clf, + embedding_model=embedding_model, + umap_model=empty_dimensionality_model, + hdbscan_model=clf, ).fit(documents, embeddings=document_embeddings, y=targets) return model @@ -127,16 +167,24 @@ def supervised_topic_model(documents, document_embeddings, embedding_model, targ def online_topic_model(documents, document_embeddings, embedding_model): umap_model = PCA(n_components=5) cluster_model = MiniBatchKMeans(n_clusters=50, random_state=0) - vectorizer_model = OnlineCountVectorizer(stop_words="english", decay=.01) - model = BERTopic(umap_model=umap_model, hdbscan_model=cluster_model, vectorizer_model=vectorizer_model, embedding_model=embedding_model) + vectorizer_model = OnlineCountVectorizer(stop_words="english", decay=0.01) + model = BERTopic( + umap_model=umap_model, + hdbscan_model=cluster_model, + vectorizer_model=vectorizer_model, + embedding_model=embedding_model, + ) topics = [] for index in range(0, len(documents), 50): - model.partial_fit(documents[index: index+50], document_embeddings[index: index+50]) + model.partial_fit( + documents[index : index + 50], document_embeddings[index : index + 50] + ) topics.extend(model.topics_) model.topics_ = topics return model + @pytest.fixture(scope="session") def cuml_base_topic_model(documents, document_embeddings, embedding_model): from cuml.cluster import HDBSCAN as cuml_hdbscan diff --git a/tests/test_bertopic.py b/tests/test_bertopic.py index 5d4bfac8..a899680b 100644 --- a/tests/test_bertopic.py +++ b/tests/test_bertopic.py @@ -1,39 +1,48 @@ import copy import pytest from bertopic import BERTopic +import importlib.util + def cuml_available(): try: - import cuml - return True + return importlib.util.find_spec("cuml") is not None except ImportError: return False + @pytest.mark.parametrize( - 'model', + "model", [ ("base_topic_model"), - ('kmeans_pca_topic_model'), - ('custom_topic_model'), - ('merged_topic_model'), - ('reduced_topic_model'), - ('online_topic_model'), - ('supervised_topic_model'), - ('representation_topic_model'), - ('zeroshot_topic_model'), + ("kmeans_pca_topic_model"), + ("custom_topic_model"), + ("merged_topic_model"), + ("reduced_topic_model"), + ("online_topic_model"), + ("supervised_topic_model"), + ("representation_topic_model"), + ("zeroshot_topic_model"), pytest.param( - "cuml_base_topic_model", marks=pytest.mark.skipif(not cuml_available(), reason="cuML not available") + "cuml_base_topic_model", + marks=pytest.mark.skipif(not cuml_available(), reason="cuML not available"), ), - ]) + ], +) def test_full_model(model, documents, request): - """ Tests the entire pipeline in one go. This serves as a sanity check to see if the default + """Tests the entire pipeline in one go. This serves as a sanity check to see if the default settings result in a good separation of topics. NOTE: This does not cover all cases but merely combines it all together """ topic_model = copy.deepcopy(request.getfixturevalue(model)) if model == "base_topic_model": - topic_model.save("model_dir", serialization="pytorch", save_ctfidf=True, save_embedding_model="sentence-transformers/all-MiniLM-L6-v2") + topic_model.save( + "model_dir", + serialization="pytorch", + save_ctfidf=True, + save_embedding_model="sentence-transformers/all-MiniLM-L6-v2", + ) topic_model = BERTopic.load("model_dir") if model == "cuml_base_topic_model": @@ -110,7 +119,9 @@ def test_full_model(model, documents, request): assert topic != original_topic # Test updating topic labels - topic_labels = topic_model.generate_topic_labels(nr_words=3, topic_prefix=False, word_length=10, separator=", ") + topic_labels = topic_model.generate_topic_labels( + nr_words=3, topic_prefix=False, word_length=10, separator=", " + ) assert len(topic_labels) == len(set(topic_model.topics_)) # Test setting topic labels @@ -126,7 +137,9 @@ def test_full_model(model, documents, request): # Test reduction of outliers if -1 in topics: new_topics = topic_model.reduce_outliers(documents, topics, threshold=0.0) - nr_outliers_topic_model = sum([1 for topic in topic_model.topics_ if topic == -1]) + nr_outliers_topic_model = sum( + [1 for topic in topic_model.topics_ if topic == -1] + ) nr_outliers_new_topics = sum([1 for topic in new_topics if topic == -1]) if topic_model._outliers == 1: diff --git a/tests/test_other.py b/tests/test_other.py index 2868ed1e..cdd5fc68 100644 --- a/tests/test_other.py +++ b/tests/test_other.py @@ -19,4 +19,4 @@ def test_get_params(): assert not params["nr_topics"] assert params["n_gram_range"] == (1, 1) assert params["min_topic_size"] == 10 - assert params["language"] == 'english' + assert params["language"] == "english" diff --git a/tests/test_plotting/test_approximate.py b/tests/test_plotting/test_approximate.py index c9a1e6b3..2de86848 100644 --- a/tests/test_plotting/test_approximate.py +++ b/tests/test_plotting/test_approximate.py @@ -1,28 +1,45 @@ import copy import pytest + @pytest.mark.parametrize("batch_size", [50, None]) @pytest.mark.parametrize("padding", [True, False]) -@pytest.mark.parametrize('model', [('kmeans_pca_topic_model'), - ('base_topic_model'), - ('custom_topic_model'), - ('merged_topic_model'), - ('reduced_topic_model')]) +@pytest.mark.parametrize( + "model", + [ + ("kmeans_pca_topic_model"), + ("base_topic_model"), + ("custom_topic_model"), + ("merged_topic_model"), + ("reduced_topic_model"), + ], +) def test_approximate_distribution(batch_size, padding, model, documents, request): topic_model = copy.deepcopy(request.getfixturevalue(model)) - + # Calculate only on a document-level based on tokensets - topic_distr, _ = topic_model.approximate_distribution(documents, padding=padding, batch_size=batch_size) - assert topic_distr.shape[1] == len(topic_model.topic_labels_) - topic_model._outliers + topic_distr, _ = topic_model.approximate_distribution( + documents, padding=padding, batch_size=batch_size + ) + assert ( + topic_distr.shape[1] == len(topic_model.topic_labels_) - topic_model._outliers + ) # Use the distribution visualization for i in range(3): topic_model.visualize_distribution(topic_distr[i]) # Calculate distribution on a token-level - topic_distr, topic_token_distr = topic_model.approximate_distribution(documents[:100], calculate_tokens=True) - assert topic_distr.shape[1] == len(topic_model.topic_labels_) - topic_model._outliers + topic_distr, topic_token_distr = topic_model.approximate_distribution( + documents[:100], calculate_tokens=True + ) + assert ( + topic_distr.shape[1] == len(topic_model.topic_labels_) - topic_model._outliers + ) assert len(topic_token_distr) == len(documents[:100]) for token_distr in topic_token_distr: - assert token_distr.shape[1] == len(topic_model.topic_labels_) - topic_model._outliers + assert ( + token_distr.shape[1] + == len(topic_model.topic_labels_) - topic_model._outliers + ) diff --git a/tests/test_plotting/test_bar.py b/tests/test_plotting/test_bar.py index f61524c9..0ca3056c 100644 --- a/tests/test_plotting/test_bar.py +++ b/tests/test_plotting/test_bar.py @@ -2,12 +2,17 @@ import pytest -@pytest.mark.parametrize('model', [('kmeans_pca_topic_model'), - ('base_topic_model'), - ('custom_topic_model'), - ('merged_topic_model'), - ('reduced_topic_model'), - ('online_topic_model')]) +@pytest.mark.parametrize( + "model", + [ + ("kmeans_pca_topic_model"), + ("base_topic_model"), + ("custom_topic_model"), + ("merged_topic_model"), + ("reduced_topic_model"), + ("online_topic_model"), + ], +) def test_barchart(model, request): topic_model = copy.deepcopy(request.getfixturevalue(model)) fig = topic_model.visualize_barchart() @@ -23,12 +28,17 @@ def test_barchart(model, request): assert int(annotation["text"].split(" ")[-1]) != -1 -@pytest.mark.parametrize('model', [('kmeans_pca_topic_model'), - ('base_topic_model'), - ('custom_topic_model'), - ('merged_topic_model'), - ('reduced_topic_model'), - ('online_topic_model')]) +@pytest.mark.parametrize( + "model", + [ + ("kmeans_pca_topic_model"), + ("base_topic_model"), + ("custom_topic_model"), + ("merged_topic_model"), + ("reduced_topic_model"), + ("online_topic_model"), + ], +) def test_barchart_outlier(model, request): topic_model = copy.deepcopy(request.getfixturevalue(model)) topic_model.topic_sizes_[-1] = 4 diff --git a/tests/test_plotting/test_documents.py b/tests/test_plotting/test_documents.py index 7e1c9f1c..81acbe4c 100644 --- a/tests/test_plotting/test_documents.py +++ b/tests/test_plotting/test_documents.py @@ -2,16 +2,23 @@ import pytest -@pytest.mark.parametrize('model', [('kmeans_pca_topic_model'), - ('base_topic_model'), - ('custom_topic_model'), - ('merged_topic_model'), - ('reduced_topic_model')]) +@pytest.mark.parametrize( + "model", + [ + ("kmeans_pca_topic_model"), + ("base_topic_model"), + ("custom_topic_model"), + ("merged_topic_model"), + ("reduced_topic_model"), + ], +) def test_documents(model, reduced_embeddings, documents, request): topic_model = copy.deepcopy(request.getfixturevalue(model)) topics = set(topic_model.topics_) if -1 in topics: topics.remove(-1) - fig = topic_model.visualize_documents(documents, embeddings=reduced_embeddings, hide_document_hover=True) + fig = topic_model.visualize_documents( + documents, embeddings=reduced_embeddings, hide_document_hover=True + ) fig_topics = [int(data["name"].split("_")[0]) for data in fig.to_dict()["data"][1:]] assert set(fig_topics) == topics diff --git a/tests/test_plotting/test_dynamic.py b/tests/test_plotting/test_dynamic.py index 05372935..361702b1 100644 --- a/tests/test_plotting/test_dynamic.py +++ b/tests/test_plotting/test_dynamic.py @@ -2,16 +2,24 @@ import pytest -@pytest.mark.parametrize('model', [('kmeans_pca_topic_model'), - ('base_topic_model'), - ('custom_topic_model'), - ('merged_topic_model'), - ('reduced_topic_model'), - ('online_topic_model')]) +@pytest.mark.parametrize( + "model", + [ + ("kmeans_pca_topic_model"), + ("base_topic_model"), + ("custom_topic_model"), + ("merged_topic_model"), + ("reduced_topic_model"), + ("online_topic_model"), + ], +) def test_dynamic(model, documents, request): topic_model = copy.deepcopy(request.getfixturevalue(model)) timestamps = [i % 10 for i in range(len(documents))] topics_over_time = topic_model.topics_over_time(documents, timestamps) fig = topic_model.visualize_topics_over_time(topics_over_time) - assert len(fig.to_dict()["data"]) == len(set(topic_model.topics_)) - topic_model._outliers + assert ( + len(fig.to_dict()["data"]) + == len(set(topic_model.topics_)) - topic_model._outliers + ) diff --git a/tests/test_plotting/test_heatmap.py b/tests/test_plotting/test_heatmap.py index 730b3b1b..f92a4b80 100644 --- a/tests/test_plotting/test_heatmap.py +++ b/tests/test_plotting/test_heatmap.py @@ -2,11 +2,16 @@ import pytest -@pytest.mark.parametrize('model', [('kmeans_pca_topic_model'), - ('base_topic_model'), - ('custom_topic_model'), - ('merged_topic_model'), - ('reduced_topic_model')]) +@pytest.mark.parametrize( + "model", + [ + ("kmeans_pca_topic_model"), + ("base_topic_model"), + ("custom_topic_model"), + ("merged_topic_model"), + ("reduced_topic_model"), + ], +) def test_heatmap(model, request): topic_model = copy.deepcopy(request.getfixturevalue(model)) topics = set(topic_model.topics_) diff --git a/tests/test_plotting/test_term_rank.py b/tests/test_plotting/test_term_rank.py index b60ca0ab..318d7d3c 100644 --- a/tests/test_plotting/test_term_rank.py +++ b/tests/test_plotting/test_term_rank.py @@ -2,9 +2,9 @@ import pytest -@pytest.mark.parametrize('model', [('kmeans_pca_topic_model'), - ('base_topic_model'), - ('custom_topic_model')]) +@pytest.mark.parametrize( + "model", [("kmeans_pca_topic_model"), ("base_topic_model"), ("custom_topic_model")] +) def test_term_rank(model, request): topic_model = copy.deepcopy(request.getfixturevalue(model)) - fig = topic_model.visualize_term_rank() + topic_model.visualize_term_rank() diff --git a/tests/test_plotting/test_topics.py b/tests/test_plotting/test_topics.py index 9048b74b..b438e8dc 100644 --- a/tests/test_plotting/test_topics.py +++ b/tests/test_plotting/test_topics.py @@ -2,12 +2,17 @@ import pytest -@pytest.mark.parametrize('model', [('kmeans_pca_topic_model'), - ('base_topic_model'), - ('custom_topic_model'), - ('merged_topic_model'), - ('reduced_topic_model'), - ('online_topic_model')]) +@pytest.mark.parametrize( + "model", + [ + ("kmeans_pca_topic_model"), + ("base_topic_model"), + ("custom_topic_model"), + ("merged_topic_model"), + ("reduced_topic_model"), + ("online_topic_model"), + ], +) def test_topics(model, request): topic_model = copy.deepcopy(request.getfixturevalue(model)) fig = topic_model.visualize_topics() @@ -20,12 +25,18 @@ def test_topics(model, request): for step in slider["steps"]: assert int(step["label"].split(" ")[-1]) != -1 -@pytest.mark.parametrize('model', [('kmeans_pca_topic_model'), - ('base_topic_model'), - ('custom_topic_model'), - ('merged_topic_model'), - ('reduced_topic_model'), - ('online_topic_model')]) + +@pytest.mark.parametrize( + "model", + [ + ("kmeans_pca_topic_model"), + ("base_topic_model"), + ("custom_topic_model"), + ("merged_topic_model"), + ("reduced_topic_model"), + ("online_topic_model"), + ], +) def test_topics_outlier(model, request): topic_model = copy.deepcopy(request.getfixturevalue(model)) topic_model.topic_sizes_[-1] = 4 diff --git a/tests/test_reduction/test_merge.py b/tests/test_reduction/test_merge.py index 97c5a3b8..b69ee3cd 100644 --- a/tests/test_reduction/test_merge.py +++ b/tests/test_reduction/test_merge.py @@ -2,19 +2,26 @@ import pytest -@pytest.mark.parametrize('model', [('kmeans_pca_topic_model'), - ('base_topic_model'), - ('custom_topic_model'), - ('merged_topic_model'), - ('reduced_topic_model'), - ('online_topic_model')]) +@pytest.mark.parametrize( + "model", + [ + ("kmeans_pca_topic_model"), + ("base_topic_model"), + ("custom_topic_model"), + ("merged_topic_model"), + ("reduced_topic_model"), + ("online_topic_model"), + ], +) def test_merge(model, documents, request): topic_model = copy.deepcopy(request.getfixturevalue(model)) nr_topics = len(set(topic_model.topics_)) topics_to_merge = [1, 2] topic_model.merge_topics(documents, topics_to_merge) - mappings = topic_model.topic_mapper_.get_mappings(list(topic_model.hdbscan_model.labels_)) + mappings = topic_model.topic_mapper_.get_mappings( + list(topic_model.hdbscan_model.labels_) + ) mapped_labels = [mappings[label] for label in topic_model.hdbscan_model.labels_] assert nr_topics == len(set(topic_model.topics_)) + 1 @@ -26,7 +33,9 @@ def test_merge(model, documents, request): topics_to_merge = [1, 2] topic_model.merge_topics(documents, topics_to_merge) - mappings = topic_model.topic_mapper_.get_mappings(list(topic_model.hdbscan_model.labels_)) + mappings = topic_model.topic_mapper_.get_mappings( + list(topic_model.hdbscan_model.labels_) + ) mapped_labels = [mappings[label] for label in topic_model.hdbscan_model.labels_] assert nr_topics == len(set(topic_model.topics_)) + 2 diff --git a/tests/test_representation/test_get.py b/tests/test_representation/test_get.py index 429d217e..2fe48574 100644 --- a/tests/test_representation/test_get.py +++ b/tests/test_representation/test_get.py @@ -4,12 +4,17 @@ import pandas as pd -@pytest.mark.parametrize('model', [('kmeans_pca_topic_model'), - ('base_topic_model'), - ('custom_topic_model'), - ('merged_topic_model'), - ('reduced_topic_model'), - ('online_topic_model')]) +@pytest.mark.parametrize( + "model", + [ + ("kmeans_pca_topic_model"), + ("base_topic_model"), + ("custom_topic_model"), + ("merged_topic_model"), + ("reduced_topic_model"), + ("online_topic_model"), + ], +) def test_get_topic(model, request): topic_model = copy.deepcopy(request.getfixturevalue(model)) topics = [topic_model.get_topic(topic) for topic in set(topic_model.topics_)] @@ -21,12 +26,18 @@ def test_get_topic(model, request): assert len(topics) == len(topic_model.get_topic_info()) assert not unknown_topic -@pytest.mark.parametrize('model', [('kmeans_pca_topic_model'), - ('base_topic_model'), - ('custom_topic_model'), - ('merged_topic_model'), - ('reduced_topic_model'), - ('online_topic_model')]) + +@pytest.mark.parametrize( + "model", + [ + ("kmeans_pca_topic_model"), + ("base_topic_model"), + ("custom_topic_model"), + ("merged_topic_model"), + ("reduced_topic_model"), + ("online_topic_model"), + ], +) def test_get_topics(model, request): topic_model = copy.deepcopy(request.getfixturevalue(model)) topics = topic_model.get_topics() @@ -35,12 +46,17 @@ def test_get_topics(model, request): assert len(topics.keys()) == len(set(topic_model.topics_)) -@pytest.mark.parametrize('model', [('kmeans_pca_topic_model'), - ('base_topic_model'), - ('custom_topic_model'), - ('merged_topic_model'), - ('reduced_topic_model'), - ('online_topic_model')]) +@pytest.mark.parametrize( + "model", + [ + ("kmeans_pca_topic_model"), + ("base_topic_model"), + ("custom_topic_model"), + ("merged_topic_model"), + ("reduced_topic_model"), + ("online_topic_model"), + ], +) def test_get_topic_freq(model, request): topic_model = copy.deepcopy(request.getfixturevalue(model)) for topic in set(topic_model.topics_): @@ -48,7 +64,7 @@ def test_get_topic_freq(model, request): topic_freq = topic_model.get_topic_freq() unique_topics = set(topic_model.topics_) - topics_in_mapper = set(np.array(topic_model.topic_mapper_.mappings_)[: ,-1]) + topics_in_mapper = set(np.array(topic_model.topic_mapper_.mappings_)[:, -1]) assert isinstance(topic_freq, pd.DataFrame) @@ -57,10 +73,15 @@ def test_get_topic_freq(model, request): assert len(unique_topics.difference(topics_in_mapper)) == 0 -@pytest.mark.parametrize('model', [('base_topic_model'), - ('custom_topic_model'), - ('merged_topic_model'), - ('reduced_topic_model')]) +@pytest.mark.parametrize( + "model", + [ + ("base_topic_model"), + ("custom_topic_model"), + ("merged_topic_model"), + ("reduced_topic_model"), + ], +) def test_get_representative_docs(model, request): topic_model = copy.deepcopy(request.getfixturevalue(model)) all_docs = topic_model.get_representative_docs() @@ -79,12 +100,17 @@ def test_get_representative_docs(model, request): assert len(topics.difference(topics_in_mapper)) == 0 -@pytest.mark.parametrize('model', [('kmeans_pca_topic_model'), - ('base_topic_model'), - ('custom_topic_model'), - ('merged_topic_model'), - ('reduced_topic_model'), - ('online_topic_model')]) +@pytest.mark.parametrize( + "model", + [ + ("kmeans_pca_topic_model"), + ("base_topic_model"), + ("custom_topic_model"), + ("merged_topic_model"), + ("reduced_topic_model"), + ("online_topic_model"), + ], +) def test_get_topic_info(model, request): topic_model = copy.deepcopy(request.getfixturevalue(model)) info = topic_model.get_topic_info() diff --git a/tests/test_representation/test_labels.py b/tests/test_representation/test_labels.py index d762fc93..1dd74595 100644 --- a/tests/test_representation/test_labels.py +++ b/tests/test_representation/test_labels.py @@ -2,12 +2,17 @@ import pytest -@pytest.mark.parametrize('model', [('kmeans_pca_topic_model'), - ('base_topic_model'), - ('custom_topic_model'), - ('merged_topic_model'), - ('reduced_topic_model'), - ('online_topic_model')]) +@pytest.mark.parametrize( + "model", + [ + ("kmeans_pca_topic_model"), + ("base_topic_model"), + ("custom_topic_model"), + ("merged_topic_model"), + ("reduced_topic_model"), + ("online_topic_model"), + ], +) def test_generate_topic_labels(model, request): topic_model = copy.deepcopy(request.getfixturevalue(model)) labels = topic_model.generate_topic_labels(topic_prefix=False) @@ -21,12 +26,17 @@ def test_generate_topic_labels(model, request): assert all([True if len(label) < 15 else False for label in labels]) -@pytest.mark.parametrize('model', [('kmeans_pca_topic_model'), - ('base_topic_model'), - ('custom_topic_model'), - ('merged_topic_model'), - ('reduced_topic_model'), - ('online_topic_model')]) +@pytest.mark.parametrize( + "model", + [ + ("kmeans_pca_topic_model"), + ("base_topic_model"), + ("custom_topic_model"), + ("merged_topic_model"), + ("reduced_topic_model"), + ("online_topic_model"), + ], +) def test_set_labels(model, request): topic_model = copy.deepcopy(request.getfixturevalue(model)) @@ -37,20 +47,26 @@ def test_set_labels(model, request): if model != "online_topic_model": labels = {1: "My label", 2: "Another label"} topic_model.set_topic_labels(labels) - assert topic_model.custom_labels_[1+topic_model._outliers] == "My label" - assert topic_model.custom_labels_[2+topic_model._outliers] == "Another label" + assert topic_model.custom_labels_[1 + topic_model._outliers] == "My label" + assert topic_model.custom_labels_[2 + topic_model._outliers] == "Another label" labels = {1: "Change label", 3: "New label"} topic_model.set_topic_labels(labels) - assert topic_model.custom_labels_[1+topic_model._outliers] == "Change label" - assert topic_model.custom_labels_[3+topic_model._outliers] == "New label" + assert topic_model.custom_labels_[1 + topic_model._outliers] == "Change label" + assert topic_model.custom_labels_[3 + topic_model._outliers] == "New label" else: - labels = {sorted(set(topic_model.topics_))[0]: "My label", sorted(set(topic_model.topics_))[1]: "Another label"} + labels = { + sorted(set(topic_model.topics_))[0]: "My label", + sorted(set(topic_model.topics_))[1]: "Another label", + } topic_model.set_topic_labels(labels) assert topic_model.custom_labels_[0] == "My label" assert topic_model.custom_labels_[1] == "Another label" - labels = {sorted(set(topic_model.topics_))[0]: "Change label", sorted(set(topic_model.topics_))[2]: "New label"} + labels = { + sorted(set(topic_model.topics_))[0]: "Change label", + sorted(set(topic_model.topics_))[2]: "New label", + } topic_model.set_topic_labels(labels) - assert topic_model.custom_labels_[0+topic_model._outliers] == "Change label" - assert topic_model.custom_labels_[2+topic_model._outliers] == "New label" + assert topic_model.custom_labels_[0 + topic_model._outliers] == "Change label" + assert topic_model.custom_labels_[2 + topic_model._outliers] == "New label" diff --git a/tests/test_representation/test_representations.py b/tests/test_representation/test_representations.py index 5a4e99f0..98b8f4dd 100644 --- a/tests/test_representation/test_representations.py +++ b/tests/test_representation/test_representations.py @@ -5,11 +5,16 @@ from sklearn.feature_extraction.text import CountVectorizer -@pytest.mark.parametrize('model', [('kmeans_pca_topic_model'), - ('base_topic_model'), - ('custom_topic_model'), - ('merged_topic_model'), - ('reduced_topic_model')]) +@pytest.mark.parametrize( + "model", + [ + ("kmeans_pca_topic_model"), + ("base_topic_model"), + ("custom_topic_model"), + ("merged_topic_model"), + ("reduced_topic_model"), + ], +) def test_update_topics(model, documents, request): topic_model = copy.deepcopy(request.getfixturevalue(model)) old_ctfidf = topic_model.c_tf_idf_ @@ -32,18 +37,27 @@ def test_update_topics(model, documents, request): assert len(set(old_topics)) - 1 == len(set(topic_model.topics_)) -@pytest.mark.parametrize('model', [('kmeans_pca_topic_model'), - ('base_topic_model'), - ('custom_topic_model'), - ('merged_topic_model'), - ('reduced_topic_model'), - ('online_topic_model')]) +@pytest.mark.parametrize( + "model", + [ + ("kmeans_pca_topic_model"), + ("base_topic_model"), + ("custom_topic_model"), + ("merged_topic_model"), + ("reduced_topic_model"), + ("online_topic_model"), + ], +) def test_extract_topics(model, documents, request): topic_model = copy.deepcopy(request.getfixturevalue(model)) nr_topics = 5 - documents = pd.DataFrame({"Document": documents, - "ID": range(len(documents)), - "Topic": np.random.randint(-1, nr_topics-1, len(documents))}) + documents = pd.DataFrame( + { + "Document": documents, + "ID": range(len(documents)), + "Topic": np.random.randint(-1, nr_topics - 1, len(documents)), + } + ) topic_model._update_topic_size(documents) topic_model._extract_topics(documents) freq = topic_model.get_topic_freq() @@ -56,18 +70,27 @@ def test_extract_topics(model, documents, request): assert len(freq.Topic.unique()) == len(freq) -@pytest.mark.parametrize('model', [('kmeans_pca_topic_model'), - ('base_topic_model'), - ('custom_topic_model'), - ('merged_topic_model'), - ('reduced_topic_model'), - ('online_topic_model')]) +@pytest.mark.parametrize( + "model", + [ + ("kmeans_pca_topic_model"), + ("base_topic_model"), + ("custom_topic_model"), + ("merged_topic_model"), + ("reduced_topic_model"), + ("online_topic_model"), + ], +) def test_extract_topics_custom_cv(model, documents, request): topic_model = copy.deepcopy(request.getfixturevalue(model)) nr_topics = 5 - documents = pd.DataFrame({"Document": documents, - "ID": range(len(documents)), - "Topic": np.random.randint(-1, nr_topics-1, len(documents))}) + documents = pd.DataFrame( + { + "Document": documents, + "ID": range(len(documents)), + "Topic": np.random.randint(-1, nr_topics - 1, len(documents)), + } + ) cv = CountVectorizer(ngram_range=(1, 2)) topic_model.vectorizer_model = cv @@ -83,12 +106,17 @@ def test_extract_topics_custom_cv(model, documents, request): assert len(freq.Topic.unique()) == len(freq) -@pytest.mark.parametrize('model', [('kmeans_pca_topic_model'), - ('base_topic_model'), - ('custom_topic_model'), - ('merged_topic_model'), - ('reduced_topic_model'), - ('online_topic_model')]) +@pytest.mark.parametrize( + "model", + [ + ("kmeans_pca_topic_model"), + ("base_topic_model"), + ("custom_topic_model"), + ("merged_topic_model"), + ("reduced_topic_model"), + ("online_topic_model"), + ], +) @pytest.mark.parametrize("reduced_topics", [2, 4, 10]) def test_topic_reduction(model, reduced_topics, documents, request): topic_model = copy.deepcopy(request.getfixturevalue(model)) @@ -107,20 +135,25 @@ def test_topic_reduction(model, reduced_topics, documents, request): assert topic_model.topics_ != old_topics -@pytest.mark.parametrize('model', [('kmeans_pca_topic_model'), - ('base_topic_model'), - ('custom_topic_model'), - ('merged_topic_model'), - ('reduced_topic_model'), - ('online_topic_model')]) +@pytest.mark.parametrize( + "model", + [ + ("kmeans_pca_topic_model"), + ("base_topic_model"), + ("custom_topic_model"), + ("merged_topic_model"), + ("reduced_topic_model"), + ("online_topic_model"), + ], +) def test_topic_reduction_edge_cases(model, documents, request): topic_model = copy.deepcopy(request.getfixturevalue(model)) topic_model.nr_topics = 100 nr_topics = 5 topics = np.random.randint(-1, nr_topics - 1, len(documents)) - old_documents = pd.DataFrame({"Document": documents, - "ID": range(len(documents)), - "Topic": topics}) + old_documents = pd.DataFrame( + {"Document": documents, "ID": range(len(documents)), "Topic": topics} + ) topic_model._update_topic_size(old_documents) topic_model._extract_topics(old_documents) old_freq = topic_model.get_topic_freq() @@ -133,12 +166,17 @@ def test_topic_reduction_edge_cases(model, documents, request): pd.testing.assert_frame_equal(old_freq, new_freq) -@pytest.mark.parametrize('model', [('kmeans_pca_topic_model'), - ('base_topic_model'), - ('custom_topic_model'), - ('merged_topic_model'), - ('reduced_topic_model'), - ('online_topic_model')]) +@pytest.mark.parametrize( + "model", + [ + ("kmeans_pca_topic_model"), + ("base_topic_model"), + ("custom_topic_model"), + ("merged_topic_model"), + ("reduced_topic_model"), + ("online_topic_model"), + ], +) def test_find_topics(model, request): topic_model = copy.deepcopy(request.getfixturevalue(model)) similar_topics, similarity = topic_model.find_topics("car") diff --git a/tests/test_sub_models/test_cluster.py b/tests/test_sub_models/test_cluster.py index ec0612e9..6115d08e 100644 --- a/tests/test_sub_models/test_cluster.py +++ b/tests/test_sub_models/test_cluster.py @@ -1,4 +1,3 @@ - import pytest import pandas as pd @@ -10,51 +9,81 @@ @pytest.mark.parametrize("cluster_model", ["hdbscan", "kmeans"]) -@pytest.mark.parametrize("samples,features,centers", - [(200, 500, 1), - (500, 200, 1), - (200, 500, 2), - (500, 200, 2), - (200, 500, 4), - (500, 200, 4)]) +@pytest.mark.parametrize( + "samples,features,centers", + [ + (200, 500, 1), + (500, 200, 1), + (200, 500, 2), + (500, 200, 2), + (200, 500, 4), + (500, 200, 4), + ], +) def test_hdbscan_cluster_embeddings(cluster_model, samples, features, centers): - embeddings, _ = make_blobs(n_samples=samples, centers=centers, n_features=features, random_state=42) + embeddings, _ = make_blobs( + n_samples=samples, centers=centers, n_features=features, random_state=42 + ) documents = [str(i + 1) for i in range(embeddings.shape[0])] - old_df = pd.DataFrame({"Document": documents, "ID": range(len(documents)), "Topic": None}) + old_df = pd.DataFrame( + {"Document": documents, "ID": range(len(documents)), "Topic": None} + ) if cluster_model == "kmeans": cluster_model = KMeans(n_clusters=centers) else: - cluster_model = HDBSCAN(min_cluster_size=10, metric="euclidean", cluster_selection_method="eom", prediction_data=True) + cluster_model = HDBSCAN( + min_cluster_size=10, + metric="euclidean", + cluster_selection_method="eom", + prediction_data=True, + ) model = BERTopic(hdbscan_model=cluster_model) new_df, _ = model._cluster_embeddings(embeddings, old_df) assert len(new_df.Topic.unique()) == centers assert "Topic" in new_df.columns - pd.testing.assert_frame_equal(old_df.drop("Topic", axis=1), new_df.drop("Topic", axis=1)) + pd.testing.assert_frame_equal( + old_df.drop("Topic", axis=1), new_df.drop("Topic", axis=1) + ) @pytest.mark.parametrize("cluster_model", ["hdbscan", "kmeans"]) -@pytest.mark.parametrize("samples,features,centers", - [(200, 500, 1), - (500, 200, 1), - (200, 500, 2), - (500, 200, 2), - (200, 500, 4), - (500, 200, 4)]) +@pytest.mark.parametrize( + "samples,features,centers", + [ + (200, 500, 1), + (500, 200, 1), + (200, 500, 2), + (500, 200, 2), + (200, 500, 4), + (500, 200, 4), + ], +) def test_custom_hdbscan_cluster_embeddings(cluster_model, samples, features, centers): - embeddings, _ = make_blobs(n_samples=samples, centers=centers, n_features=features, random_state=42) + embeddings, _ = make_blobs( + n_samples=samples, centers=centers, n_features=features, random_state=42 + ) documents = [str(i + 1) for i in range(embeddings.shape[0])] - old_df = pd.DataFrame({"Document": documents, "ID": range(len(documents)), "Topic": None}) + old_df = pd.DataFrame( + {"Document": documents, "ID": range(len(documents)), "Topic": None} + ) if cluster_model == "kmeans": cluster_model = KMeans(n_clusters=centers) else: - cluster_model = HDBSCAN(min_cluster_size=10, metric="euclidean", cluster_selection_method="eom", prediction_data=True) + cluster_model = HDBSCAN( + min_cluster_size=10, + metric="euclidean", + cluster_selection_method="eom", + prediction_data=True, + ) model = BERTopic(hdbscan_model=cluster_model) new_df, _ = model._cluster_embeddings(embeddings, old_df) assert len(new_df.Topic.unique()) == centers assert "Topic" in new_df.columns - pd.testing.assert_frame_equal(old_df.drop("Topic", axis=1), new_df.drop("Topic", axis=1)) + pd.testing.assert_frame_equal( + old_df.drop("Topic", axis=1), new_df.drop("Topic", axis=1) + ) diff --git a/tests/test_sub_models/test_dim_reduction.py b/tests/test_sub_models/test_dim_reduction.py index 551b8822..a77020c4 100644 --- a/tests/test_sub_models/test_dim_reduction.py +++ b/tests/test_sub_models/test_dim_reduction.py @@ -1,4 +1,3 @@ - import copy import pytest import numpy as np @@ -9,21 +8,31 @@ @pytest.mark.parametrize("dim_model", [UMAP, PCA]) -@pytest.mark.parametrize("embeddings,shape,n_components", [(np.random.rand(100, 128), 100, 5), - (np.random.rand(10, 256), 10, 5), - (np.random.rand(50, 15), 50, 10)]) +@pytest.mark.parametrize( + "embeddings,shape,n_components", + [ + (np.random.rand(100, 128), 100, 5), + (np.random.rand(10, 256), 10, 5), + (np.random.rand(50, 15), 50, 10), + ], +) def test_reduce_dimensionality(dim_model, embeddings, shape, n_components): model = BERTopic(umap_model=dim_model(n_components=n_components)) umap_embeddings = model._reduce_dimensionality(embeddings) assert umap_embeddings.shape == (shape, n_components) -@pytest.mark.parametrize('model', [('kmeans_pca_topic_model'), - ('base_topic_model'), - ('custom_topic_model'), - ('merged_topic_model'), - ('reduced_topic_model'), - ('online_topic_model')]) +@pytest.mark.parametrize( + "model", + [ + ("kmeans_pca_topic_model"), + ("base_topic_model"), + ("custom_topic_model"), + ("merged_topic_model"), + ("reduced_topic_model"), + ("online_topic_model"), + ], +) def test_custom_reduce_dimensionality(model, request): embeddings = np.random.rand(500, 128) topic_model = copy.deepcopy(request.getfixturevalue(model)) diff --git a/tests/test_sub_models/test_embeddings.py b/tests/test_sub_models/test_embeddings.py index 017b3399..22f53539 100644 --- a/tests/test_sub_models/test_embeddings.py +++ b/tests/test_sub_models/test_embeddings.py @@ -5,16 +5,23 @@ from sklearn.metrics.pairwise import cosine_similarity -@pytest.mark.parametrize('model', [('kmeans_pca_topic_model'), - ('base_topic_model'), - ('custom_topic_model'), - ('merged_topic_model'), - ('reduced_topic_model'), - ('online_topic_model')]) +@pytest.mark.parametrize( + "model", + [ + ("kmeans_pca_topic_model"), + ("base_topic_model"), + ("custom_topic_model"), + ("merged_topic_model"), + ("reduced_topic_model"), + ("online_topic_model"), + ], +) def test_extract_embeddings(model, request): topic_model = copy.deepcopy(request.getfixturevalue(model)) single_embedding = topic_model._extract_embeddings("a document") - multiple_embeddings = topic_model._extract_embeddings(["something different", "another document"]) + multiple_embeddings = topic_model._extract_embeddings( + ["something different", "another document"] + ) sim_matrix = cosine_similarity(single_embedding, multiple_embeddings)[0] assert single_embedding.shape[0] == 1 @@ -31,12 +38,17 @@ def test_extract_embeddings(model, request): assert sim_matrix[1] > 0.5 -@pytest.mark.parametrize('model', [('kmeans_pca_topic_model'), - ('base_topic_model'), - ('custom_topic_model'), - ('merged_topic_model'), - ('reduced_topic_model'), - ('online_topic_model')]) +@pytest.mark.parametrize( + "model", + [ + ("kmeans_pca_topic_model"), + ("base_topic_model"), + ("custom_topic_model"), + ("merged_topic_model"), + ("reduced_topic_model"), + ("online_topic_model"), + ], +) def test_extract_embeddings_compare(model, embedding_model, request): docs = ["some document"] topic_model = copy.deepcopy(request.getfixturevalue(model)) diff --git a/tests/test_utils.py b/tests/test_utils.py index 5827c017..2974b1b6 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -2,7 +2,13 @@ import logging import numpy as np from typing import List -from bertopic._utils import check_documents_type, check_embeddings_shape, MyLogger, select_topic_representation, get_unique_distances +from bertopic._utils import ( + check_documents_type, + check_embeddings_shape, + MyLogger, + select_topic_representation, + get_unique_distances, +) from scipy.sparse import csr_matrix @@ -20,11 +26,7 @@ def test_logger(): @pytest.mark.parametrize( "docs", - [ - "A document not in an iterable", - [None], - 5 - ], + ["A document not in an iterable", [None], 5], ) def test_check_documents_type(docs): with pytest.raises(TypeError): @@ -33,16 +35,21 @@ def test_check_documents_type(docs): def test_check_embeddings_shape(): docs = ["doc_one", "doc_two"] - embeddings = np.array([[1, 2, 3], - [2, 3, 4]]) + embeddings = np.array([[1, 2, 3], [2, 3, 4]]) check_embeddings_shape(embeddings, docs) - + def test_make_unique_distances(): def check_dists(dists: List[float], noise_max: float): - unique_dists = get_unique_distances(np.array(dists, dtype=float), noise_max=noise_max) - assert len(unique_dists) == len(dists), "The number of elements must be the same" - assert len(dists) == len(np.unique(unique_dists)), "The distances must be unique" + unique_dists = get_unique_distances( + np.array(dists, dtype=float), noise_max=noise_max + ) + assert len(unique_dists) == len( + dists + ), "The number of elements must be the same" + assert len(dists) == len( + np.unique(unique_dists) + ), "The distances must be unique" check_dists([0, 0, 0.5, 0.75, 1, 1], noise_max=1e-7) @@ -52,45 +59,54 @@ def check_dists(dists: List[float], noise_max: float): # test whether the distances are sorted in ascending order when the distances are all the same check_dists([0, 0, 0, 0, 0, 0, 0], noise_max=1e-7) - + def test_select_topic_representation(): ctfidf_embeddings = np.array([[1, 1, 1]]) ctfidf_embeddings_sparse = csr_matrix( (ctfidf_embeddings.reshape(-1).tolist(), ([0, 0, 0], [0, 1, 2])), - shape=ctfidf_embeddings.shape + shape=ctfidf_embeddings.shape, ) topic_embeddings = np.array([[2, 2, 2]]) # Use topic embeddings - repr_, ctfidf_used = select_topic_representation(ctfidf_embeddings, topic_embeddings, use_ctfidf=False) + repr_, ctfidf_used = select_topic_representation( + ctfidf_embeddings, topic_embeddings, use_ctfidf=False + ) np.testing.assert_array_equal(topic_embeddings, repr_) assert not ctfidf_used # Fallback to c-TF-IDF - repr_, ctfidf_used = select_topic_representation(ctfidf_embeddings, None, use_ctfidf=False) + repr_, ctfidf_used = select_topic_representation( + ctfidf_embeddings, None, use_ctfidf=False + ) np.testing.assert_array_equal(ctfidf_embeddings, repr_) assert ctfidf_used # Use c-TF-IDF - repr_, ctfidf_used = select_topic_representation(ctfidf_embeddings, topic_embeddings, use_ctfidf=True) - np.testing.assert_array_equal( - ctfidf_embeddings, - repr_ + repr_, ctfidf_used = select_topic_representation( + ctfidf_embeddings, topic_embeddings, use_ctfidf=True ) + np.testing.assert_array_equal(ctfidf_embeddings, repr_) assert ctfidf_used # Fallback to topic embeddings - repr_, ctfidf_used = select_topic_representation(None, topic_embeddings, use_ctfidf=True) + repr_, ctfidf_used = select_topic_representation( + None, topic_embeddings, use_ctfidf=True + ) np.testing.assert_array_equal(topic_embeddings, repr_) assert not ctfidf_used # `scipy.sparse.csr_matrix` can be used as c-TF-IDF embeddings np.testing.assert_array_equal( ctfidf_embeddings, - select_topic_representation(ctfidf_embeddings_sparse, None, use_ctfidf=True, output_ndarray=True)[0] + select_topic_representation( + ctfidf_embeddings_sparse, None, use_ctfidf=True, output_ndarray=True + )[0], ) # check that `csr_matrix` is not casted to `np.ndarray` when `ctfidf_as_ndarray` is False - repr_ = select_topic_representation(ctfidf_embeddings_sparse, None, output_ndarray=False)[0] + repr_ = select_topic_representation( + ctfidf_embeddings_sparse, None, output_ndarray=False + )[0] assert isinstance(repr_, csr_matrix) diff --git a/tests/test_variations/test_class.py b/tests/test_variations/test_class.py index 910837ad..a94c108d 100644 --- a/tests/test_variations/test_class.py +++ b/tests/test_variations/test_class.py @@ -2,15 +2,28 @@ import pytest from sklearn.datasets import fetch_20newsgroups -data = fetch_20newsgroups(subset="all", remove=('headers', 'footers', 'quotes')) +data = fetch_20newsgroups(subset="all", remove=("headers", "footers", "quotes")) classes = [data["target_names"][i] for i in data["target"]][:1000] -@pytest.mark.parametrize('model', [('kmeans_pca_topic_model'), ('custom_topic_model'), ('merged_topic_model'), ('reduced_topic_model'), ('online_topic_model')]) +@pytest.mark.parametrize( + "model", + [ + ("kmeans_pca_topic_model"), + ("custom_topic_model"), + ("merged_topic_model"), + ("reduced_topic_model"), + ("online_topic_model"), + ], +) def test_class(model, documents, request): topic_model = copy.deepcopy(request.getfixturevalue(model)) - topics_per_class_global = topic_model.topics_per_class(documents, classes=classes, global_tuning=True) - topics_per_class_local = topic_model.topics_per_class(documents, classes=classes, global_tuning=False) + topics_per_class_global = topic_model.topics_per_class( + documents, classes=classes, global_tuning=True + ) + topics_per_class_local = topic_model.topics_per_class( + documents, classes=classes, global_tuning=False + ) assert topics_per_class_global.Frequency.sum() == len(documents) assert topics_per_class_local.Frequency.sum() == len(documents) diff --git a/tests/test_variations/test_dynamic.py b/tests/test_variations/test_dynamic.py index 759eacbf..5af38f3e 100644 --- a/tests/test_variations/test_dynamic.py +++ b/tests/test_variations/test_dynamic.py @@ -2,7 +2,16 @@ import pytest -@pytest.mark.parametrize('model', [('kmeans_pca_topic_model'), ('custom_topic_model'), ('merged_topic_model'), ('reduced_topic_model'), ('online_topic_model')]) +@pytest.mark.parametrize( + "model", + [ + ("kmeans_pca_topic_model"), + ("custom_topic_model"), + ("merged_topic_model"), + ("reduced_topic_model"), + ("online_topic_model"), + ], +) def test_dynamic(model, documents, request): topic_model = copy.deepcopy(request.getfixturevalue(model)) timestamps = [i % 10 for i in range(len(documents))] diff --git a/tests/test_variations/test_hierarchy.py b/tests/test_variations/test_hierarchy.py index e8625c36..cdfdaf8d 100644 --- a/tests/test_variations/test_hierarchy.py +++ b/tests/test_variations/test_hierarchy.py @@ -3,7 +3,16 @@ from scipy.cluster import hierarchy as sch -@pytest.mark.parametrize('model', [('kmeans_pca_topic_model'), ('custom_topic_model'), ('merged_topic_model'), ('reduced_topic_model'), ('online_topic_model')]) +@pytest.mark.parametrize( + "model", + [ + ("kmeans_pca_topic_model"), + ("custom_topic_model"), + ("merged_topic_model"), + ("reduced_topic_model"), + ("online_topic_model"), + ], +) def test_hierarchy(model, documents, request): topic_model = copy.deepcopy(request.getfixturevalue(model)) hierarchical_topics = topic_model.hierarchical_topics(documents) @@ -14,11 +23,22 @@ def test_hierarchy(model, documents, request): assert merged_topics == set(topic_model.topics_).difference({-1}) -@pytest.mark.parametrize('model', [('kmeans_pca_topic_model'), ('custom_topic_model'), ('merged_topic_model'), ('reduced_topic_model'), ('online_topic_model')]) +@pytest.mark.parametrize( + "model", + [ + ("kmeans_pca_topic_model"), + ("custom_topic_model"), + ("merged_topic_model"), + ("reduced_topic_model"), + ("online_topic_model"), + ], +) def test_linkage(model, documents, request): topic_model = copy.deepcopy(request.getfixturevalue(model)) - linkage_function = lambda x: sch.linkage(x, 'single', optimal_ordering=True) - hierarchical_topics = topic_model.hierarchical_topics(documents, linkage_function=linkage_function) + linkage_function = lambda x: sch.linkage(x, "single", optimal_ordering=True) + hierarchical_topics = topic_model.hierarchical_topics( + documents, linkage_function=linkage_function + ) merged_topics = set([v for vals in hierarchical_topics.Topics.values for v in vals]) tree = topic_model.get_topic_tree(hierarchical_topics) @@ -28,11 +48,22 @@ def test_linkage(model, documents, request): assert merged_topics == set(topic_model.topics_).difference({-1}) -@pytest.mark.parametrize('model', [('kmeans_pca_topic_model'), ('custom_topic_model'), ('merged_topic_model'), ('reduced_topic_model'), ('online_topic_model')]) +@pytest.mark.parametrize( + "model", + [ + ("kmeans_pca_topic_model"), + ("custom_topic_model"), + ("merged_topic_model"), + ("reduced_topic_model"), + ("online_topic_model"), + ], +) def test_tree(model, documents, request): topic_model = copy.deepcopy(request.getfixturevalue(model)) - linkage_function = lambda x: sch.linkage(x, 'single', optimal_ordering=True) - hierarchical_topics = topic_model.hierarchical_topics(documents, linkage_function=linkage_function) + linkage_function = lambda x: sch.linkage(x, "single", optimal_ordering=True) + hierarchical_topics = topic_model.hierarchical_topics( + documents, linkage_function=linkage_function + ) merged_topics = set([v for vals in hierarchical_topics.Topics.values for v in vals]) tree = topic_model.get_topic_tree(hierarchical_topics) diff --git a/tests/test_vectorizers/test_ctfidf.py b/tests/test_vectorizers/test_ctfidf.py index 703067ff..a6cedccd 100644 --- a/tests/test_vectorizers/test_ctfidf.py +++ b/tests/test_vectorizers/test_ctfidf.py @@ -1,4 +1,3 @@ - import copy import pytest import numpy as np @@ -10,19 +9,26 @@ from bertopic.vectorizers import ClassTfidfTransformer -@pytest.mark.parametrize('model', [('kmeans_pca_topic_model'), - ('base_topic_model'), - ('custom_topic_model'), - ('merged_topic_model'), - ('reduced_topic_model'), - ('online_topic_model')]) +@pytest.mark.parametrize( + "model", + [ + ("kmeans_pca_topic_model"), + ("base_topic_model"), + ("custom_topic_model"), + ("merged_topic_model"), + ("reduced_topic_model"), + ("online_topic_model"), + ], +) def test_ctfidf(model, documents, request): topic_model = copy.deepcopy(request.getfixturevalue(model)) topics = topic_model.topics_ - documents = pd.DataFrame({"Document": documents, - "ID": range(len(documents)), - "Topic": topics}) - documents_per_topic = documents.groupby(['Topic'], as_index=False).agg({'Document': ' '.join}) + documents = pd.DataFrame( + {"Document": documents, "ID": range(len(documents)), "Topic": topics} + ) + documents_per_topic = documents.groupby(["Topic"], as_index=False).agg( + {"Document": " ".join} + ) documents = topic_model._preprocess_text(documents_per_topic.Document.values) count = topic_model.vectorizer_model.fit(documents) @@ -52,21 +58,28 @@ def test_ctfidf(model, documents, request): assert np.min(X) == 0 -@pytest.mark.parametrize('model', [('kmeans_pca_topic_model'), - ('base_topic_model'), - ('custom_topic_model'), - ('merged_topic_model'), - ('reduced_topic_model'), - ('online_topic_model')]) +@pytest.mark.parametrize( + "model", + [ + ("kmeans_pca_topic_model"), + ("base_topic_model"), + ("custom_topic_model"), + ("merged_topic_model"), + ("reduced_topic_model"), + ("online_topic_model"), + ], +) def test_ctfidf_custom_cv(model, documents, request): cv = CountVectorizer(ngram_range=(1, 3), stop_words="english") topic_model = copy.deepcopy(request.getfixturevalue(model)) topic_model.vectorizer_model = cv topics = topic_model.topics_ - documents = pd.DataFrame({"Document": documents, - "ID": range(len(documents)), - "Topic": topics}) - documents_per_topic = documents.groupby(['Topic'], as_index=False).agg({'Document': ' '.join}) + documents = pd.DataFrame( + {"Document": documents, "ID": range(len(documents)), "Topic": topics} + ) + documents_per_topic = documents.groupby(["Topic"], as_index=False).agg( + {"Document": " ".join} + ) documents = topic_model._preprocess_text(documents_per_topic.Document.values) count = topic_model.vectorizer_model.fit(documents) diff --git a/tests/test_vectorizers/test_online_cv.py b/tests/test_vectorizers/test_online_cv.py index d7ab677e..3ed21813 100644 --- a/tests/test_vectorizers/test_online_cv.py +++ b/tests/test_vectorizers/test_online_cv.py @@ -1,10 +1,18 @@ - import copy import pytest from bertopic.vectorizers import OnlineCountVectorizer -@pytest.mark.parametrize('model', [('kmeans_pca_topic_model'), ('custom_topic_model'), ('merged_topic_model'), ('reduced_topic_model'), ('online_topic_model')]) +@pytest.mark.parametrize( + "model", + [ + ("kmeans_pca_topic_model"), + ("custom_topic_model"), + ("merged_topic_model"), + ("reduced_topic_model"), + ("online_topic_model"), + ], +) def test_online_cv(model, documents, request): topic_model = copy.deepcopy(request.getfixturevalue(model)) vectorizer_model = OnlineCountVectorizer(stop_words="english", ngram_range=(2, 2)) @@ -18,7 +26,7 @@ def test_online_cv(model, documents, request): assert old_topic != new_topic -@pytest.mark.parametrize('model', [('online_topic_model')]) +@pytest.mark.parametrize("model", [("online_topic_model")]) def test_clean_bow(model, request): topic_model = copy.deepcopy(request.getfixturevalue(model))