Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TextClassification, UMAP, DBSCAN and TextClustering tasks #948

Merged
merged 27 commits into from
Sep 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
1cfb3e8
Redirect import of task
plaguss Sep 5, 2024
0edea70
Add icon for text classification
plaguss Sep 5, 2024
edc015a
Add text classification task
plaguss Sep 5, 2024
7d53ee4
Add tests for text classification
plaguss Sep 5, 2024
526f789
Continue with this problematic thing until we merge it in one of the PRs
plaguss Sep 6, 2024
ce9f5df
Port itertools.batched function for python<3.12
plaguss Sep 6, 2024
28280d1
Make more generic the template for text classification
plaguss Sep 6, 2024
d93a338
Add tests for the extra flexibility in the template
plaguss Sep 6, 2024
4828e9f
Merge and solve conflicts
plaguss Sep 6, 2024
85c700e
Fix condition to determine the backend for the structured output
plaguss Sep 7, 2024
52df044
Simplify condition for json schema in structured output
plaguss Sep 7, 2024
2684b6a
Add folder for clustering related steps
plaguss Sep 9, 2024
7fd75c3
Fix default structured output for inference endpoints
plaguss Sep 9, 2024
b562b42
Added examples to the docstrings
plaguss Sep 9, 2024
399a083
Add icon for clustering steps/tasks
plaguss Sep 9, 2024
997eb21
Add umap step
plaguss Sep 9, 2024
f5a7ad7
Add dbscan step
plaguss Sep 9, 2024
c77b6c9
Redirect import of steps
plaguss Sep 9, 2024
e499556
Add text clustering task
plaguss Sep 9, 2024
a4c0332
Merge branch 'develop' of https://github.com/argilla-io/distilabel in…
plaguss Sep 11, 2024
702970e
Set default value for repo_id to avoid potential errors when loading …
plaguss Sep 11, 2024
1c4597a
Change example dataset in docstrings as that has more information
plaguss Sep 11, 2024
c070130
Add unit tests for clustering steps
plaguss Sep 11, 2024
4a55288
Remove extra log message unnecesary
plaguss Sep 11, 2024
c39e13a
Add tests for text clustering process
plaguss Sep 11, 2024
555a8dc
Update pyproject with dependencies of text_clustering
plaguss Sep 11, 2024
c0cbe15
Set internal variables to None on unload to clean up
plaguss Sep 16, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,11 @@ vllm = [
sentence-transformers = ["sentence-transformers >= 3.0.0"]
faiss-cpu = ["faiss-cpu >= 1.8.0"]
faiss-gpu = ["faiss-gpu >= 1.7.2"]
text-clustering = [
"umap-learn >= 0.5.6",
"scikit-learn >= 1.4.1",
"matplotlib >= 3.8.3" # For the figure (even though it's optional)
]

# minhash
minhash = ["datasketch >= 1.6.5", "nltk>3.8.1"]
Expand Down
2 changes: 1 addition & 1 deletion scripts/install_dependencies.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ python_version=$(python -c "import sys; print(sys.version_info[:2])")

python -m pip install uv

uv pip install --system -e ".[anthropic,argilla,cohere,groq,hf-inference-endpoints,hf-transformers,litellm,llama-cpp,ollama,openai,outlines,vertexai,mistralai,instructor,sentence-transformers,faiss-cpu,minhash]"
uv pip install --system -e ".[anthropic,argilla,cohere,groq,hf-inference-endpoints,hf-transformers,litellm,llama-cpp,ollama,openai,outlines,vertexai,mistralai,instructor,sentence-transformers,faiss-cpu,minhash,text-clustering]"

if [ "${python_version}" != "(3, 12)" ]; then
uv pip install --system -e .[ray]
Expand Down
6 changes: 6 additions & 0 deletions src/distilabel/steps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
StepInput,
StepResources,
)
from distilabel.steps.clustering.dbscan import DBSCAN
from distilabel.steps.clustering.text_clustering import TextClustering
from distilabel.steps.clustering.umap import UMAP
from distilabel.steps.columns.combine import CombineOutputs
from distilabel.steps.columns.expand import ExpandColumns
from distilabel.steps.columns.group import CombineColumns, GroupColumns
Expand Down Expand Up @@ -67,6 +70,9 @@
"GroupColumns",
"KeepColumns",
"MergeColumns",
"DBSCAN",
"UMAP",
"TextClustering",
"step",
"DeitaFiltering",
"EmbeddingGeneration",
Expand Down
14 changes: 14 additions & 0 deletions src/distilabel/steps/clustering/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

177 changes: 177 additions & 0 deletions src/distilabel/steps/clustering/dbscan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import importlib.util
from typing import TYPE_CHECKING, Any, List, Optional

import numpy as np
from pydantic import Field, PrivateAttr

from distilabel.mixins.runtime_parameters import RuntimeParameter
from distilabel.steps import (
GlobalStep,
StepInput,
)

if TYPE_CHECKING:
from sklearn.cluster import DBSCAN as _DBSCAN

from distilabel.steps.typing import StepOutput


class DBSCAN(GlobalStep):
plaguss marked this conversation as resolved.
Show resolved Hide resolved
r"""DBSCAN (Density-Based Spatial Clustering of Applications with Noise) finds core
samples in regions of high density and expands clusters from them. This algorithm
is good for data which contains clusters of similar density.

This is a `GlobalStep` that clusters the embeddings using the DBSCAN algorithm
from `sklearn`. Visit `TextClustering` step for an example of use.
The trained model is saved as an artifact when creating a distiset
and pushing it to the Hugging Face Hub.

Input columns:
- projection (`List[float]`): Vector representation of the text to cluster,
normally the output from the `UMAP` step.

Output columns:
- cluster_label (`int`): Integer representing the label of a given cluster. -1
means it wasn't clustered.

Categories:
- clustering
- text-classification

References:
- [`DBSCAN demo of sklearn`](https://scikit-learn.org/stable/auto_examples/cluster/plot_dbscan.html#demo-of-dbscan-clustering-algorithm)
- [`sklearn dbscan`](https://scikit-learn.org/stable/modules/clustering.html#dbscan)

Attributes:
- eps: The maximum distance between two samples for one to be considered as in the
neighborhood of the other. This is not a maximum bound on the distances of
points within a cluster. This is the most important DBSCAN parameter to
choose appropriately for your data set and distance function.
- min_samples: The number of samples (or total weight) in a neighborhood for a point
to be considered as a core point. This includes the point itself. If `min_samples`
is set to a higher value, DBSCAN will find denser clusters, whereas if it is set
to a lower value, the found clusters will be more sparse.
- metric: The metric to use when calculating distance between instances in a feature
array. If metric is a string or callable, it must be one of the options allowed
by `sklearn.metrics.pairwise_distances` for its metric parameter.
- n_jobs: The number of parallel jobs to run.

Runtime parameters:
- `eps`: The maximum distance between two samples for one to be considered as in the
neighborhood of the other. This is not a maximum bound on the distances of
points within a cluster. This is the most important DBSCAN parameter to
choose appropriately for your data set and distance function.
- `min_samples`: The number of samples (or total weight) in a neighborhood for a point
to be considered as a core point. This includes the point itself. If `min_samples`
is set to a higher value, DBSCAN will find denser clusters, whereas if it is set
to a lower value, the found clusters will be more sparse.
- `metric`: The metric to use when calculating distance between instances in a feature
array. If metric is a string or callable, it must be one of the options allowed
by `sklearn.metrics.pairwise_distances` for its metric parameter.
- `n_jobs`: The number of parallel jobs to run.
"""

eps: Optional[RuntimeParameter[float]] = Field(
default=0.3,
description=(
"The maximum distance between two samples for one to be considered "
"as in the neighborhood of the other. This is not a maximum bound "
"on the distances of points within a cluster. This is the most "
"important DBSCAN parameter to choose appropriately for your data set "
"and distance function."
),
)
min_samples: Optional[RuntimeParameter[int]] = Field(
default=30,
description=(
"The number of samples (or total weight) in a neighborhood for a point to "
"be considered as a core point. This includes the point itself. If "
"`min_samples` is set to a higher value, DBSCAN will find denser clusters, "
"whereas if it is set to a lower value, the found clusters will be more "
"sparse."
),
)
metric: Optional[RuntimeParameter[str]] = Field(
default="euclidean",
description=(
"The metric to use when calculating distance between instances in a "
"feature array. If metric is a string or callable, it must be one of "
"the options allowed by `sklearn.metrics.pairwise_distances` for "
"its metric parameter."
),
)
n_jobs: Optional[RuntimeParameter[int]] = Field(
default=8, description="The number of parallel jobs to run."
)

_clusterer: Optional["_DBSCAN"] = PrivateAttr(None)

def load(self) -> None:
super().load()
if importlib.util.find_spec("sklearn") is None:
raise ImportError(
"`sklearn` package is not installed. Please install it using `pip install scikit-learn`."
)
from sklearn.cluster import DBSCAN as _DBSCAN

self._clusterer = _DBSCAN(
eps=self.eps,
min_samples=self.min_samples,
metric=self.metric,
n_jobs=self.n_jobs,
)

def unload(self) -> None:
self._clusterer = None

@property
def inputs(self) -> List[str]:
return ["projection"]

@property
def outputs(self) -> List[str]:
return ["cluster_label"]

def _save_model(self, model: Any) -> None:
import joblib

def save_model(path):
with open(str(path / "DBSCAN.joblib"), "wb") as f:
joblib.dump(model, f)

self.save_artifact(
name="DBSCAN_model",
write_function=lambda path: save_model(path),
metadata={
"eps": self.eps,
"min_samples": self.min_samples,
"metric": self.metric,
},
)

def process(self, inputs: StepInput) -> "StepOutput": # type: ignore
projections = np.array([input["projection"] for input in inputs])

self._logger.info("🏋️‍♀️ Start training DBSCAN...")
fitted_clusterer = self._clusterer.fit(projections)
cluster_labels = fitted_clusterer.labels_
# Sets the cluster labels for each input, -1 means it wasn't clustered
for input, cluster_label in zip(inputs, cluster_labels):
input["cluster_label"] = cluster_label
self._logger.info(f"DBSCAN labels assigned: {len(set(cluster_labels))}")
self._save_model(fitted_clusterer)
yield inputs
Loading
Loading