Skip to content

Commit

Permalink
Linting
Browse files Browse the repository at this point in the history
  • Loading branch information
Maarten Grootendorst committed Feb 24, 2025
1 parent 231fb6c commit 851e03d
Show file tree
Hide file tree
Showing 10 changed files with 37 additions and 41 deletions.
28 changes: 10 additions & 18 deletions bertopic/_bertopic.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
# Models
try:
from hdbscan import HDBSCAN

HAS_HDBSCAN = True
except (ImportError, ModuleNotFoundError):
HAS_HDBSCAN = False
Expand Down Expand Up @@ -150,7 +151,7 @@ def __init__(
zeroshot_min_similarity: float = 0.7,
embedding_model=None,
umap_model=None,
hdbscan_model = None,
hdbscan_model=None,
vectorizer_model: CountVectorizer = None,
ctfidf_model: TfidfTransformer = None,
representation_model: BaseRepresentation = None,
Expand Down Expand Up @@ -258,6 +259,7 @@ def __init__(
else:
try:
from umap import UMAP

self.umap_model = UMAP(
n_neighbors=15,
n_components=5,
Expand All @@ -282,12 +284,9 @@ def __init__(
)
else:
self.hdbscan_model = SK_HDBSCAN(
min_cluster_size=self.min_topic_size,
metric="euclidean",
cluster_selection_method="eom",
n_jobs=-1
min_cluster_size=self.min_topic_size, metric="euclidean", cluster_selection_method="eom", n_jobs=-1
)

# Public attributes
self.topics_ = None
self.probabilities_ = None
Expand Down Expand Up @@ -708,9 +707,7 @@ def partial_fit(
# Checks
check_embeddings_shape(embeddings, documents)
if not hasattr(self.hdbscan_model, "partial_fit"):
raise ValueError(
"In order to use `.partial_fit`, the cluster model should have " "a `.partial_fit` function."
)
raise ValueError("In order to use `.partial_fit`, the cluster model should have a `.partial_fit` function.")

# Prepare documents
if isinstance(documents, str):
Expand Down Expand Up @@ -1548,7 +1545,7 @@ def update_topics(

if top_n_words > 100:
logger.warning(
"Note that extracting more than 100 words from a sparse " "can slow down computation quite a bit."
"Note that extracting more than 100 words from a sparse can slow down computation quite a bit."
)
self.top_n_words = top_n_words
self.vectorizer_model = vectorizer_model or CountVectorizer(ngram_range=n_gram_range)
Expand Down Expand Up @@ -2031,7 +2028,7 @@ def set_topic_labels(self, topic_labels: Union[List[str], Mapping[int, str]]) ->
custom_labels = topic_labels
else:
raise ValueError(
"Make sure that `topic_labels` contains the same number " "of labels as there are topics."
"Make sure that `topic_labels` contains the same number of labels as there are topics."
)

self.custom_labels_ = custom_labels
Expand Down Expand Up @@ -2148,9 +2145,7 @@ def merge_topics(
for topic in topic_group:
mapping[topic] = topic_group[0]
else:
raise ValueError(
"Make sure that `topics_to_merge` is either" "a list of topics or a list of list of topics."
)
raise ValueError("Make sure that `topics_to_merge` is eithera list of topics or a list of list of topics.")

# Track mappings and sizes of topics for merging topic embeddings
mappings = defaultdict(list)
Expand Down Expand Up @@ -4507,10 +4502,7 @@ def _auto_reduce_topics(self, documents: pd.DataFrame, use_ctfidf: bool = False)
).fit_predict(norm_data[self._outliers :])
else:
predictions = SK_HDBSCAN(
min_cluster_size=2,
metric="euclidean",
cluster_selection_method="eom",
n_jobs=-1
min_cluster_size=2, metric="euclidean", cluster_selection_method="eom", n_jobs=-1
).fit_predict(norm_data[self._outliers :])

# Map similar topics
Expand Down
6 changes: 3 additions & 3 deletions bertopic/_save_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,12 +471,12 @@ def get_package_versions():
hdbscan_version = version("hdbscan")
except (ImportError, ModuleNotFoundError):
hdbscan_version = None

try:
from umap import __version__ as umap_version
except (ImportError, ModuleNotFoundError):
umap_version = None

try:
from sentence_transformers import __version__ as sbert_version
except (ImportError, ModuleNotFoundError):
Expand All @@ -486,7 +486,7 @@ def get_package_versions():
from numba import __version__ as numba_version
except (ImportError, ModuleNotFoundError):
numba_version = None

try:
from transformers import __version__ as transformers_version
except (ImportError, ModuleNotFoundError):
Expand Down
9 changes: 3 additions & 6 deletions bertopic/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,7 @@ def check_is_fitted(topic_model):
Raises:
ValueError: If the matches were not found.
"""
msg = (
"This %(name)s instance is not fitted yet. Call 'fit' with "
"appropriate arguments before using this estimator."
)
msg = "This %(name)s instance is not fitted yet. Call 'fit' with appropriate arguments before using this estimator."

if topic_model.topics_ is None:
raise ValueError(msg % {"name": type(topic_model).__name__})
Expand Down Expand Up @@ -131,11 +128,11 @@ def validate_distance_matrix(X, n_samples):
# check it has correct size
n = s[0]
if n != (n_samples * (n_samples - 1) / 2):
raise ValueError("The condensed distance matrix must have " "shape (n*(n-1)/2,).")
raise ValueError("The condensed distance matrix must have shape (n*(n-1)/2,).")
elif len(s) == 2:
# check it has correct size
if (s[0] != n_samples) or (s[1] != n_samples):
raise ValueError("The distance matrix must be of shape " "(n, n) where n is the number of samples.")
raise ValueError("The distance matrix must be of shape (n, n) where n is the number of samples.")
# force zero diagonal and convert to condensed
np.fill_diagonal(X, 0)
X = squareform(X)
Expand Down
4 changes: 2 additions & 2 deletions bertopic/cluster/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def hdbscan_delegator(model, func: str, embeddings: np.ndarray = None):
try:
import hdbscan
except (ImportError, ModuleNotFoundError):
hdbscan = type('hdbscan', (), {'HDBSCAN': None})()
hdbscan = type("hdbscan", (), {"HDBSCAN": None})()

# Approximate predict
if func == "approximate_predict":
Expand Down Expand Up @@ -69,7 +69,7 @@ def is_supported_hdbscan(model):
try:
import hdbscan
except (ImportError, ModuleNotFoundError):
hdbscan = type('hdbscan', (), {'HDBSCAN': None})()
hdbscan = type("hdbscan", (), {"HDBSCAN": None})()

if isinstance(model, hdbscan.HDBSCAN):
return True
Expand Down
2 changes: 1 addition & 1 deletion bertopic/plotting/_approximate_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def visualize_approximate_distribution(
df = pd.DataFrame(topic_token_distribution).T

df.columns = [f"{token}_{i}" for i, token in enumerate(tokens)]
df.columns = [f"{token}{' '*i}" for i, token in enumerate(tokens)]
df.columns = [f"{token}{' ' * i}" for i, token in enumerate(tokens)]
df.index = list(topic_model.topic_labels_.values())[topic_model._outliers :]
df = df.loc[(df.sum(axis=1) != 0), :]

Expand Down
1 change: 1 addition & 0 deletions bertopic/plotting/_datamap.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def visualize_document_datamap(
if reduced_embeddings is None:
try:
from umap import UMAP

umap_model = UMAP(n_neighbors=15, n_components=2, min_dist=0.15, metric="cosine").fit(embeddings_to_reduce)
embeddings_2d = umap_model.embedding_
except (ImportError, ModuleNotFoundError):
Expand Down
1 change: 1 addition & 0 deletions bertopic/plotting/_documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def visualize_documents(
if reduced_embeddings is None:
try:
from umap import UMAP

umap_model = UMAP(n_neighbors=10, n_components=2, min_dist=0.0, metric="cosine").fit(embeddings_to_reduce)
embeddings_2d = umap_model.embedding_
except (ImportError, ModuleNotFoundError):
Expand Down
2 changes: 1 addition & 1 deletion bertopic/plotting/_heatmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def visualize_heatmap(
sorted_topics = topics
if n_clusters:
if n_clusters >= len(set(topics)):
raise ValueError("Make sure to set `n_clusters` lower than " "the total number of unique topics.")
raise ValueError("Make sure to set `n_clusters` lower than the total number of unique topics.")

distance_matrix = cosine_similarity(embeddings[topics])
Z = linkage(distance_matrix, "ward")
Expand Down
19 changes: 10 additions & 9 deletions bertopic/plotting/_hierarchical_documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def visualize_hierarchical_documents(
if reduced_embeddings is None:
try:
from umap import UMAP

umap_model = UMAP(n_neighbors=10, n_components=2, min_dist=0.0, metric="cosine").fit(embeddings_to_reduce)
embeddings_2d = umap_model.embedding_
except (ImportError, ModuleNotFoundError):
Expand Down Expand Up @@ -204,8 +205,8 @@ def visualize_hierarchical_documents(
mappings[i] = False

# Create new column
df[f"level_{index+1}"] = df.topic.map(mapping)
df[f"level_{index+1}"] = df[f"level_{index+1}"].astype(int)
df[f"level_{index + 1}"] = df.topic.map(mapping)
df[f"level_{index + 1}"] = df[f"level_{index + 1}"].astype(int)

# Prepare topic names of original and merged topics
trace_names = []
Expand Down Expand Up @@ -247,12 +248,12 @@ def visualize_hierarchical_documents(
if topic_model._outliers:
traces.append(
go.Scattergl(
x=df.loc[(df[f"level_{level+1}"] == -1), "x"],
y=df.loc[df[f"level_{level+1}"] == -1, "y"],
x=df.loc[(df[f"level_{level + 1}"] == -1), "x"],
y=df.loc[df[f"level_{level + 1}"] == -1, "y"],
mode="markers+text",
name="other",
hoverinfo="text",
hovertext=df.loc[(df[f"level_{level+1}"] == -1), "doc"] if not hide_document_hover else None,
hovertext=df.loc[(df[f"level_{level + 1}"] == -1), "doc"] if not hide_document_hover else None,
showlegend=False,
marker=dict(color="#CFD8DC", size=5, opacity=0.5),
)
Expand All @@ -261,16 +262,16 @@ def visualize_hierarchical_documents(
# Selected topics
if topics:
selection = df.loc[(df.topic.isin(topics)), :]
unique_topics = sorted([int(topic) for topic in selection[f"level_{level+1}"].unique()])
unique_topics = sorted([int(topic) for topic in selection[f"level_{level + 1}"].unique()])
else:
unique_topics = sorted([int(topic) for topic in df[f"level_{level+1}"].unique()])
unique_topics = sorted([int(topic) for topic in df[f"level_{level + 1}"].unique()])

for topic in unique_topics:
if topic != -1:
if topics:
selection = df.loc[(df[f"level_{level+1}"] == topic) & (df.topic.isin(topics)), :]
selection = df.loc[(df[f"level_{level + 1}"] == topic) & (df.topic.isin(topics)), :]
else:
selection = df.loc[df[f"level_{level+1}"] == topic, :]
selection = df.loc[df[f"level_{level + 1}"] == topic, :]

if not hide_annotations:
selection.loc[len(selection), :] = None
Expand Down
6 changes: 5 additions & 1 deletion bertopic/plotting/_topics.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import numpy as np
import pandas as pd

try:
from umap import UMAP

HAS_UMAP = True
except (ImportError, ModuleNotFoundError):
HAS_UMAP = False
Expand Down Expand Up @@ -93,7 +95,9 @@ def visualize_topics(
if HAS_UMAP:
if c_tfidf_used:
embeddings = MinMaxScaler().fit_transform(embeddings)
embeddings = UMAP(n_neighbors=2, n_components=2, metric="hellinger", random_state=42).fit_transform(embeddings)
embeddings = UMAP(n_neighbors=2, n_components=2, metric="hellinger", random_state=42).fit_transform(
embeddings
)
else:
embeddings = UMAP(n_neighbors=2, n_components=2, metric="cosine", random_state=42).fit_transform(embeddings)
else:
Expand Down

0 comments on commit 851e03d

Please sign in to comment.