Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix hierarchy viz and handle any form of distance matrix #1173

Merged
merged 3 commits into from
Apr 15, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,4 @@ venv.bak/
.idea
.idea/
.vscode
.DS_Store
31 changes: 18 additions & 13 deletions bertopic/_bertopic.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,10 @@
from bertopic.backend._utils import select_backend
from bertopic.representation import BaseRepresentation
from bertopic.cluster._utils import hdbscan_delegator, is_supported_hdbscan
from bertopic._utils import MyLogger, check_documents_type, check_embeddings_shape, check_is_fitted
from bertopic._utils import (
MyLogger, check_documents_type, check_embeddings_shape,
check_is_fitted, validate_distance_matrix
)
from bertopic.representation._mmr import mmr

# Visualization
Expand Down Expand Up @@ -829,7 +832,12 @@ def hierarchical_topics(self,
linkage_function: The linkage function to use. Default is:
`lambda x: sch.linkage(x, 'ward', optimal_ordering=True)`
distance_function: The distance function to use on the c-TF-IDF matrix. Default is:
`lambda x: 1 - cosine_similarity(x)`
`lambda x: 1 - cosine_similarity(x)`.
You can pass any function that returns either a square matrix of
shape (n_samples, n_samples) with zeros on the diagonal and
non-negative values or condensed distance matrix of shape
(n_samples * (n_samples - 1) / 2,) containing the upper
triangular of the distance matrix.

Returns:
hierarchical_topics: A dataframe that contains a hierarchy of topics
Expand Down Expand Up @@ -866,10 +874,7 @@ def hierarchical_topics(self,
# Calculate distance
embeddings = self.c_tf_idf_[self._outliers:]
X = distance_function(embeddings)

# Make sure it is the 1-D condensed distance matrix with zeros on the diagonal
np.fill_diagonal(X, 0)
X = squareform(X)
X = validate_distance_matrix(X, embeddings.shape[0])

# Use the 1-D condensed distance matrix as an input instead of the raw distance matrix
Z = linkage_function(X)
Expand Down Expand Up @@ -2153,7 +2158,7 @@ def visualize_hierarchical_documents(self,
hide_annotations: bool = False,
hide_document_hover: bool = True,
nr_levels: int = 10,
level_scale: str = 'linear',
level_scale: str = 'linear',
custom_labels: bool = False,
title: str = "<b>Hierarchical Documents and Topics</b>",
width: int = 1200,
Expand Down Expand Up @@ -2253,7 +2258,7 @@ def visualize_hierarchical_documents(self,
hide_annotations=hide_annotations,
hide_document_hover=hide_document_hover,
nr_levels=nr_levels,
level_scale=level_scale,
level_scale=level_scale,
custom_labels=custom_labels,
title=title,
width=width,
Expand Down Expand Up @@ -3001,10 +3006,10 @@ def _save_representative_docs(self, documents: pd.DataFrame):
Updates:
self.representative_docs_: Populate each topic with 3 representative docs
"""
repr_docs, _, _= self._extract_representative_docs(self.c_tf_idf_,
documents,
self.topic_representations_,
nr_samples=500,
repr_docs, _, _= self._extract_representative_docs(self.c_tf_idf_,
documents,
self.topic_representations_,
nr_samples=500,
nr_repr_docs=3)
self.representative_docs_ = repr_docs

Expand Down Expand Up @@ -3062,7 +3067,7 @@ def _extract_representative_docs(self,
docs = mmr(c_tf_idf[index], ctfidf, selected_docs, nr_docs, diversity=diversity)
repr_docs.extend(docs)

# Extract top n most representative documents
# Extract top n most representative documents
else:
indices = np.argpartition(sim_matrix.reshape(1, -1)[0],
-nr_docs)[-nr_docs:]
Expand Down
52 changes: 52 additions & 0 deletions bertopic/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
from collections.abc import Iterable
from scipy.sparse import csr_matrix
from scipy.spatial.distance import squareform


class MyLogger:
Expand Down Expand Up @@ -92,3 +93,54 @@ def __getattr__(self, *args, **kwargs):

def __call__(self, *args, **kwargs):
raise ModuleNotFoundError(self.msg)

def validate_distance_matrix(X, n_samples):
""" Validate the distance matrix and convert it to a condensed distance matrix
if necessary.

A valid distance matrix is either a square matrix of shape (n_samples, n_samples)
with zeros on the diagonal and non-negative values or condensed distance matrix
of shape (n_samples * (n_samples - 1) / 2,) containing the upper triangular of the
distance matrix.

Arguments:
X: Distance matrix to validate.
n_samples: Number of samples in the dataset.

Returns:
X: Validated distance matrix.

Raises:
ValueError: If the distance matrix is not valid.
"""
# Make sure it is the 1-D condensed distance matrix with zeros on the diagonal
s = X.shape
if len(s) == 1:
# check it has correct size
n = s[0]
if n != (n_samples * (n_samples -1) / 2):
raise ValueError("The condensed distance matrix must have "
"shape (n*(n-1)/2,).")
elashrry marked this conversation as resolved.
Show resolved Hide resolved
elif len(s) == 2:
# check it is square
if s[0] != s[1]:
raise ValueError("The distance matrix must be square.")
# check it has correct size
if s[0] != n_samples:
raise ValueError("The distance matrix must be of shape "
"(n, n) where n is the number of documents.")
elashrry marked this conversation as resolved.
Show resolved Hide resolved
# force zero diagonal and convert to condensed
np.fill_diagonal(X, 0)
X = squareform(X)
else:
raise ValueError("The distance matrix must be either a 1-D condensed "
"distance matrix of shape (n*(n-1)/2,) or a "
"2-D square distance matrix of shape (n, n)."
"where n is the number of documents."
"Got a distance matrix of shape %s" % str(s))
elashrry marked this conversation as resolved.
Show resolved Hide resolved

# Make sure its entries are non-negative
if np.any(X < 0):
raise ValueError("Distance matrix cannot contain negative values.")
Comment on lines +139 to +141
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure whether this is necessary as this issue is handled and raised within Scipy already. Perhaps remove?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the squareform from SciPy doesn't raise error if the entries are non-negative.

X = np.array([1, -2, 3,])
squareform(X)

Output:

array([[ 0,  1, -2],
       [ 1,  0,  3],
       [-2,  3,  0]])

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's true but the linkage function will give the error but it's alright to keep it here.


return X
25 changes: 18 additions & 7 deletions bertopic/plotting/_hierarchy.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import plotly.graph_objects as go
import plotly.figure_factory as ff

from .._utils import validate_distance_matrix
elashrry marked this conversation as resolved.
Show resolved Hide resolved

def visualize_hierarchy(topic_model,
orientation: str = "left",
Expand Down Expand Up @@ -50,7 +51,12 @@ def visualize_hierarchy(topic_model,
NOTE: Make sure to use the same `linkage_function` as used
in `topic_model.hierarchical_topics`.
distance_function: The distance function to use on the c-TF-IDF matrix. Default is:
`lambda x: 1 - cosine_similarity(x)`
`lambda x: 1 - cosine_similarity(x)`.
You can pass any function that returns either a square matrix of
shape (n_samples, n_samples) with zeros on the diagonal and
non-negative values or condensed distance matrix of shape
(n_samples * (n_samples - 1) / 2,) containing the upper
triangular of the distance matrix.
NOTE: Make sure to use the same `distance_function` as used
in `topic_model.hierarchical_topics`.
color_threshold: Value at which the separation of clusters will be made which
Expand Down Expand Up @@ -122,10 +128,13 @@ def visualize_hierarchy(topic_model,
else:
annotations = None

# wrap distance function to validate input and return a condensed distance matrix
distance_function_viz = lambda x: validate_distance_matrix(
distance_function(x), embeddings.shape[0])
# Create dendogram
fig = ff.create_dendrogram(embeddings,
orientation=orientation,
distfun=distance_function,
distfun=distance_function_viz,
linkagefun=linkage_function,
hovertext=annotations,
color_threshold=color_threshold)
Expand Down Expand Up @@ -213,7 +222,12 @@ def _get_annotations(topic_model,
NOTE: Make sure to use the same `linkage_function` as used
in `topic_model.hierarchical_topics`.
distance_function: The distance function to use on the c-TF-IDF matrix. Default is:
`lambda x: 1 - cosine_similarity(x)`
`lambda x: 1 - cosine_similarity(x)`.
You can pass any function that returns either a square matrix of
shape (n_samples, n_samples) with zeros on the diagonal and
non-negative values or condensed distance matrix of shape
(n_samples * (n_samples - 1) / 2,) containing the upper
triangular of the distance matrix.
NOTE: Make sure to use the same `distance_function` as used
in `topic_model.hierarchical_topics`.
orientation: The orientation of the figure.
Expand All @@ -230,10 +244,7 @@ def _get_annotations(topic_model,

# Calculate distance
X = distance_function(embeddings)

# Make sure it is the 1-D condensed distance matrix with zeros on the diagonal
np.fill_diagonal(X, 0)
X = squareform(X)
X = validate_distance_matrix(X, embeddings.shape[0])

# Calculate linkage and generate dendrogram
Z = linkage_function(X)
Expand Down
Loading