Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sklearn Pipeline Embedder #791

Merged
merged 7 commits into from
Nov 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions bertopic/backend/_sklearn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from bertopic.backend import BaseEmbedder
from sklearn.utils.validation import check_is_fitted, NotFittedError


class SklearnEmbedder(BaseEmbedder):
""" Scikit-Learn based embedding model

This component allows the usage of scikit-learn pipelines for generating document and
word embeddings.

Arguments:
pipe: A scikit-learn pipeline that can `.transform()` text.

Examples:

Scikit-Learn is very flexible and it allows for many representations.
A relatively simple pipeline is shown below.

```python
from sklearn.pipeline import make_pipeline
from sklearn.decomposition import TruncatedSVD
from sklearn.feature_extraction.text import TfidfVectorizer

from bertopic.backend import SklearnEmbedder

pipe = make_pipeline(
TfidfVectorizer(),
TruncatedSVD(100)
)

sklearn_embedder = SklearnEmbedder(pipe)
topic_model = BERTopic(embedding_model=sklearn_embedder)
```

This pipeline first constructs a sparse representation based on TF/idf and then
makes it dense by applying SVD. Alternatively, you might also construct something
more elaborate. As long as you construct a scikit-learn compatible pipeline, you
should be able to pass it to Bertopic.

!!! Warning
One caveat to be aware of is that scikit-learns base `Pipeline` class does not
support the `.partial_fit()`-API. If you have a pipeline that theoretically should
be able to support online learning then you might want to explore
the [scikit-partial](https://github.com/koaning/scikit-partial) project.
"""
def __init__(self, pipe):
super().__init__()
self.pipe = pipe

def embed(self, documents, verbose=False):
""" Embed a list of n documents/words into an n-dimensional
matrix of embeddings

Arguments:
documents: A list of documents or words to be embedded
verbose: No-op variable that's kept around to keep the API consistent. If you want to get feedback on training times, you should use the sklearn API.

Returns:
Document/words embeddings with shape (n, m) with `n` documents/words
that each have an embeddings size of `m`
"""
try:
check_is_fitted(self.pipe)
embeddings = self.pipe.transform(documents)
except NotFittedError:
embeddings = self.pipe.fit_transform(documents)

return embeddings
5 changes: 5 additions & 0 deletions bertopic/backend/_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from ._base import BaseEmbedder
from ._sentencetransformers import SentenceTransformerBackend
from ._hftransformers import HFTransformerBackend
from ._sklearn import SklearnEmbedder
from transformers.pipelines import Pipeline
from sklearn.pipeline import Pipeline as ScikitPipeline

languages = ['afrikaans', 'albanian', 'amharic', 'arabic', 'armenian', 'assamese',
'azerbaijani', 'basque', 'belarusian', 'bengali', 'bengali romanize',
Expand Down Expand Up @@ -33,6 +35,9 @@ def select_backend(embedding_model,
if isinstance(embedding_model, BaseEmbedder):
return embedding_model

if isinstance(embedding_model, ScikitPipeline):
return SklearnEmbedder(embedding_model)

# Flair word embeddings
if "flair" in str(type(embedding_model)):
from bertopic.backend._flair import FlairBackend
Expand Down
41 changes: 40 additions & 1 deletion docs/getting_started/embeddings/embeddings.md
Original file line number Diff line number Diff line change
Expand Up @@ -230,4 +230,43 @@ topics, probs = topic_model.fit_transform(docs, embeddings)

Here, you will probably notice that creating the embeddings is quite fast whereas `fit_transform` is quite slow.
This is to be expected as reducing the dimensionality of a large sparse matrix takes some time. The inverse of using
transformer embeddings is true: creating the embeddings is slow whereas `fit_transform` is quite fast.
transformer embeddings is true: creating the embeddings is slow whereas `fit_transform` is quite fast.

#### **Scikit-Learn Embeddings**
Scikit-Learn is a framework for more than just machine learning.
It offers many preprocessing tools, some of which can be used to create representations
for text. Many of these tools are relatively lightweight and don't require a GPU.
While the representations may be less expressive as many BERT models, the fact that
it runs much faster can make it a relevant candidate to consider.

If you have a scikit-learn compatible pipeline that you'd like to use to embed
text then you can also pass this to BERTopic.

```python
from sklearn.pipeline import make_pipeline
from sklearn.decomposition import TruncatedSVD
from sklearn.feature_extraction.text import TfidfVectorizer

pipe = make_pipeline(
TfidfVectorizer(),
TruncatedSVD(100)
)

topic_model = BERTopic(embedding_model=pipe)
```

Internally, this uses the `SklearnEmbedder` that ensures the scikit-learn
pipeline is compatible.

```python
from bertopic.backend import SklearnEmbedder

sklearn_embedder = SklearnEmbedder(pipe)
topic_model = BERTopic(embedding_model=sklearn_embedder)
```

!!! Warning
One caveat to be aware of is that scikit-learns base `Pipeline` class does not
support the `.partial_fit()`-API. If you have a pipeline that theoretically should
be able to support online learning then you might want to explore
the [scikit-partial](https://github.com/koaning/scikit-partial) project.
6 changes: 3 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
]

docs_packages = [
"mkdocs>=1.1",
"mkdocs-material>=4.6.3",
"mkdocstrings>=0.8.0",
"mkdocs==1.1",
"mkdocs-material==4.6.3",
"mkdocstrings==0.8.0",
]

base_packages = [
Expand Down