Skip to content

Commit

Permalink
v0.14.1 - ChatGPT support and improved Prompting (#1057)
Browse files Browse the repository at this point in the history
  • Loading branch information
MaartenGr authored Mar 2, 2023
1 parent 5e63dac commit d665d3f
Show file tree
Hide file tree
Showing 7 changed files with 308 additions and 67 deletions.
2 changes: 1 addition & 1 deletion bertopic/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from bertopic._bertopic import BERTopic

__version__ = "0.14.0"
__version__ = "0.14.1"

__all__ = [
"BERTopic",
Expand Down
35 changes: 33 additions & 2 deletions bertopic/_bertopic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2005,6 +2005,8 @@ def reduce_outliers(self,
def visualize_topics(self,
topics: List[int] = None,
top_n_topics: int = None,
custom_labels: bool = False,
title: str = "<b>Intertopic Distance Map</b>",
width: int = 650,
height: int = 650) -> go.Figure:
""" Visualize topics, their sizes, and their corresponding words
Expand All @@ -2015,6 +2017,9 @@ def visualize_topics(self,
Arguments:
topics: A selection of topics to visualize
top_n_topics: Only select the top n most frequent topics
custom_labels: Whether to use custom topic labels that were defined using
`topic_model.set_topic_labels`.
title: Title of the plot.
width: The width of the figure.
height: The height of the figure.
Expand All @@ -2037,6 +2042,8 @@ def visualize_topics(self,
return plotting.visualize_topics(self,
topics=topics,
top_n_topics=top_n_topics,
custom_labels=custom_labels,
title=title,
width=width,
height=height)

Expand All @@ -2049,6 +2056,7 @@ def visualize_documents(self,
hide_annotations: bool = False,
hide_document_hover: bool = False,
custom_labels: bool = False,
title: str = "<b>Documents and Topics</b>",
width: int = 1200,
height: int = 750) -> go.Figure:
""" Visualize documents and their topics in 2D
Expand All @@ -2071,6 +2079,7 @@ def visualize_documents(self,
specific points. Helps to speed up generation of visualization.
custom_labels: Whether to use custom topic labels that were defined using
`topic_model.set_topic_labels`.
title: Title of the plot.
width: The width of the figure.
height: The height of the figure.
Expand Down Expand Up @@ -2129,6 +2138,7 @@ def visualize_documents(self,
hide_annotations=hide_annotations,
hide_document_hover=hide_document_hover,
custom_labels=custom_labels,
title=title,
width=width,
height=height)

Expand All @@ -2143,6 +2153,7 @@ def visualize_hierarchical_documents(self,
hide_document_hover: bool = True,
nr_levels: int = 10,
custom_labels: bool = False,
title: str = "<b>Hierarchical Documents and Topics</b>",
width: int = 1200,
height: int = 750) -> go.Figure:
""" Visualize documents and their topics in 2D at different levels of hierarchy
Expand Down Expand Up @@ -2174,6 +2185,7 @@ def visualize_hierarchical_documents(self,
`topic_model.set_topic_labels`.
NOTE: Custom labels are only generated for the original
un-merged topics.
title: Title of the plot.
width: The width of the figure.
height: The height of the figure.
Expand Down Expand Up @@ -2235,13 +2247,15 @@ def visualize_hierarchical_documents(self,
hide_document_hover=hide_document_hover,
nr_levels=nr_levels,
custom_labels=custom_labels,
title=title,
width=width,
height=height)

def visualize_term_rank(self,
topics: List[int] = None,
log_scale: bool = False,
custom_labels: bool = False,
title: str = "<b>Term score decline per Topic</b>",
width: int = 800,
height: int = 500) -> go.Figure:
""" Visualize the ranks of all terms across all topics
Expand All @@ -2257,6 +2271,7 @@ def visualize_term_rank(self,
log_scale: Whether to represent the ranking on a log scale
custom_labels: Whether to use custom topic labels that were defined using
`topic_model.set_topic_labels`.
title: Title of the plot.
width: The width of the figure.
height: The height of the figure.
Expand Down Expand Up @@ -2292,6 +2307,7 @@ def visualize_term_rank(self,
topics=topics,
log_scale=log_scale,
custom_labels=custom_labels,
title=title,
width=width,
height=height)

Expand All @@ -2301,6 +2317,7 @@ def visualize_topics_over_time(self,
topics: List[int] = None,
normalize_frequency: bool = False,
custom_labels: bool = False,
title: str = "<b>Topics over Time</b>",
width: int = 1250,
height: int = 450) -> go.Figure:
""" Visualize topics over time
Expand All @@ -2313,6 +2330,7 @@ def visualize_topics_over_time(self,
normalize_frequency: Whether to normalize each topic's frequency individually
custom_labels: Whether to use custom topic labels that were defined using
`topic_model.set_topic_labels`.
title: Title of the plot.
width: The width of the figure.
height: The height of the figure.
Expand Down Expand Up @@ -2342,6 +2360,7 @@ def visualize_topics_over_time(self,
topics=topics,
normalize_frequency=normalize_frequency,
custom_labels=custom_labels,
title=title,
width=width,
height=height)

Expand All @@ -2351,6 +2370,7 @@ def visualize_topics_per_class(self,
topics: List[int] = None,
normalize_frequency: bool = False,
custom_labels: bool = False,
title: str = "<b>Topics per Class</b>",
width: int = 1250,
height: int = 900) -> go.Figure:
""" Visualize topics per class
Expand All @@ -2363,6 +2383,7 @@ def visualize_topics_per_class(self,
normalize_frequency: Whether to normalize each topic's frequency individually
custom_labels: Whether to use custom topic labels that were defined using
`topic_model.set_topic_labels`.
title: Title of the plot.
width: The width of the figure.
height: The height of the figure.
Expand Down Expand Up @@ -2392,13 +2413,15 @@ def visualize_topics_per_class(self,
topics=topics,
normalize_frequency=normalize_frequency,
custom_labels=custom_labels,
title=title,
width=width,
height=height)

def visualize_distribution(self,
probabilities: np.ndarray,
min_probability: float = 0.015,
custom_labels: bool = False,
title: str = "<b>Topic Probability Distribution</b>",
width: int = 800,
height: int = 600) -> go.Figure:
""" Visualize the distribution of topic probabilities
Expand All @@ -2409,6 +2432,7 @@ def visualize_distribution(self,
All others are ignored.
custom_labels: Whether to use custom topic labels that were defined using
`topic_model.set_topic_labels`.
title: Title of the plot.
width: The width of the figure.
height: The height of the figure.
Expand All @@ -2433,6 +2457,7 @@ def visualize_distribution(self,
probabilities=probabilities,
min_probability=min_probability,
custom_labels=custom_labels,
title=title,
width=width,
height=height)

Expand Down Expand Up @@ -2492,6 +2517,7 @@ def visualize_hierarchy(self,
topics: List[int] = None,
top_n_topics: int = None,
custom_labels: bool = False,
title: str = "<b>Hierarchical Clustering</b>",
width: int = 1000,
height: int = 600,
hierarchical_topics: pd.DataFrame = None,
Expand All @@ -2514,6 +2540,7 @@ def visualize_hierarchy(self,
`topic_model.set_topic_labels`.
NOTE: Custom labels are only generated for the original
un-merged topics.
title: Title of the plot.
width: The width of the figure. Only works if orientation is set to 'left'
height: The height of the figure. Only works if orientation is set to 'bottom'
hierarchical_topics: A dataframe that contains a hierarchy of topics
Expand Down Expand Up @@ -2570,6 +2597,7 @@ def visualize_hierarchy(self,
topics=topics,
top_n_topics=top_n_topics,
custom_labels=custom_labels,
title=title,
width=width,
height=height,
hierarchical_topics=hierarchical_topics,
Expand All @@ -2583,6 +2611,7 @@ def visualize_heatmap(self,
top_n_topics: int = None,
n_clusters: int = None,
custom_labels: bool = False,
title: str = "<b>Similarity Matrix</b>",
width: int = 800,
height: int = 800) -> go.Figure:
""" Visualize a heatmap of the topic's similarity matrix
Expand All @@ -2597,6 +2626,7 @@ def visualize_heatmap(self,
matrix by those clusters.
custom_labels: Whether to use custom topic labels that were defined using
`topic_model.set_topic_labels`.
title: Title of the plot.
width: The width of the figure.
height: The height of the figure.
Expand Down Expand Up @@ -2625,6 +2655,7 @@ def visualize_heatmap(self,
top_n_topics=top_n_topics,
n_clusters=n_clusters,
custom_labels=custom_labels,
title=title,
width=width,
height=height)

Expand Down Expand Up @@ -3333,9 +3364,9 @@ def _map_probabilities(self,

# Map array of probabilities (probability for assigned topic per document)
if probabilities is not None:
if len(probabilities.shape) == 2 and self.get_topic(-1):
if len(probabilities.shape) == 2:
mapped_probabilities = np.zeros((probabilities.shape[0],
len(set(mappings.values())) - 1))
len(set(mappings.values())) - self._outliers))
for from_topic, to_topic in mappings.items():
if to_topic != -1 and from_topic != -1:
mapped_probabilities[:, to_topic] += probabilities[:, from_topic]
Expand Down
62 changes: 39 additions & 23 deletions bertopic/representation/_cohere.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import numpy as np
import time
import pandas as pd
from scipy.sparse import csr_matrix
from typing import Mapping, List, Tuple, Union
from sklearn.metrics.pairwise import cosine_similarity
from typing import Mapping, List, Tuple
from bertopic.representation._base import BaseRepresentation


Expand All @@ -28,7 +27,11 @@
Keywords: deliver weeks product shipping long delivery received arrived arrive week
Topic name: Shipping and delivery issues
---
"""
Topic:
Sample texts from this topic:
[DOCUMENTS]
Keywords: [KEYWORDS]
Topic name:"""


class Cohere(BaseRepresentation):
Expand All @@ -46,6 +49,8 @@ class Cohere(BaseRepresentation):
NOTE: Use `"[KEYWORDS]"` and `"[DOCUMENTS]"` in the prompt
to decide where the keywords and documents need to be
inserted.
delay_in_seconds: The delay in seconds between consecutive prompts
in order to prevent RateLimitErrors.
Usage:
Expand Down Expand Up @@ -79,11 +84,13 @@ def __init__(self,
client,
model: str = "xlarge",
prompt: str = None,
delay_in_seconds: float = None,
):
self.client = client
self.model = model
self.prompt = prompt if prompt is not None else DEFAULT_PROMPT
self.default_prompt_ = DEFAULT_PROMPT
self.delay_in_seconds = delay_in_seconds

def extract_topics(self,
topic_model,
Expand All @@ -109,6 +116,11 @@ def extract_topics(self,
updated_topics = {}
for topic, docs in repr_docs_mappings.items():
prompt = self._create_prompt(docs, topic, topics)

# Delay
if self.delay_in_seconds:
time.sleep(self.delay_in_seconds)

request = self.client.generate(model=self.model,
prompt=prompt,
max_tokens=50,
Expand All @@ -118,26 +130,30 @@ def extract_topics(self,
updated_topics[topic] = [(label, 1)] + [("", 0) for _ in range(9)]

return updated_topics

def _create_prompt(self, docs, topic, topics):
keywords = list(zip(*topics[topic]))[0]

# Use a prompt that leverages either keywords or documents in
# a custom location
prompt = ""
if "[KEYWORDS]" in self.prompt:
prompt += self.prompt.replace("[KEYWORDS]", keywords)
if "[DOCUMENTS]" in self.prompt:
to_replace = ""
for doc in docs:
to_replace += f"- {doc[:255]}\n"
prompt += self.prompt.replace("[DOCUMENTS]", to_replace)

# Use the default prompt
if "[KEYWORDS]" and "[DOCUMENTS]" not in self.prompt:
prompt = self.prompt + 'Topic:\nSample texts from this topic:\n'
for doc in docs:
prompt += f"- {doc[:255]}\n"
prompt += "Keywords: " + " ".join(keywords)
prompt += "\nTopic name:"
# Use the Default Chat Prompt
if self.prompt == self.prompt == DEFAULT_PROMPT:
prompt = self.prompt.replace("[KEYWORDS]", " ".join(keywords))
prompt = self._replace_documents(prompt, docs)

# Use a custom prompt that leverages keywords, documents or both using
# custom tags, namely [KEYWORDS] and [DOCUMENTS] respectively
else:
prompt = self.prompt
if "[KEYWORDS]" in prompt:
prompt = prompt.replace("[KEYWORDS]", " ".join(keywords))
if "[DOCUMENTS]" in prompt:
prompt = self._replace_documents(prompt, docs)

return prompt

@staticmethod
def _replace_documents(prompt, docs):
to_replace = ""
for doc in docs:
to_replace += f"- {doc[:255]}\n"
prompt = prompt.replace("[DOCUMENTS]", to_replace)
return prompt
Loading

0 comments on commit d665d3f

Please sign in to comment.