diff --git a/pyproject.toml b/pyproject.toml index d5f4e795ed..01f5e903b5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -95,6 +95,11 @@ vllm = [ sentence-transformers = ["sentence-transformers >= 3.0.0"] faiss-cpu = ["faiss-cpu >= 1.8.0"] faiss-gpu = ["faiss-gpu >= 1.7.2"] +text-clustering = [ + "umap-learn >= 0.5.6", + "scikit-learn >= 1.4.1", + "matplotlib >= 3.8.3" # For the figure (even though it's optional) +] # minhash minhash = ["datasketch >= 1.6.5", "nltk>3.8.1"] diff --git a/scripts/install_dependencies.sh b/scripts/install_dependencies.sh index 3f7669deec..767f6e6dd0 100755 --- a/scripts/install_dependencies.sh +++ b/scripts/install_dependencies.sh @@ -6,7 +6,7 @@ python_version=$(python -c "import sys; print(sys.version_info[:2])") python -m pip install uv -uv pip install --system -e ".[anthropic,argilla,cohere,groq,hf-inference-endpoints,hf-transformers,litellm,llama-cpp,ollama,openai,outlines,vertexai,mistralai,instructor,sentence-transformers,faiss-cpu,minhash]" +uv pip install --system -e ".[anthropic,argilla,cohere,groq,hf-inference-endpoints,hf-transformers,litellm,llama-cpp,ollama,openai,outlines,vertexai,mistralai,instructor,sentence-transformers,faiss-cpu,minhash,text-clustering]" if [ "${python_version}" != "(3, 12)" ]; then uv pip install --system -e .[ray] diff --git a/src/distilabel/steps/__init__.py b/src/distilabel/steps/__init__.py index 0d0f33d9a6..79c10a268e 100644 --- a/src/distilabel/steps/__init__.py +++ b/src/distilabel/steps/__init__.py @@ -21,6 +21,9 @@ StepInput, StepResources, ) +from distilabel.steps.clustering.dbscan import DBSCAN +from distilabel.steps.clustering.text_clustering import TextClustering +from distilabel.steps.clustering.umap import UMAP from distilabel.steps.columns.combine import CombineOutputs from distilabel.steps.columns.expand import ExpandColumns from distilabel.steps.columns.group import CombineColumns, GroupColumns @@ -67,6 +70,9 @@ "GroupColumns", "KeepColumns", "MergeColumns", + "DBSCAN", + "UMAP", + "TextClustering", "step", "DeitaFiltering", "EmbeddingGeneration", diff --git a/src/distilabel/steps/clustering/__init__.py b/src/distilabel/steps/clustering/__init__.py new file mode 100644 index 0000000000..20ce00bda7 --- /dev/null +++ b/src/distilabel/steps/clustering/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/src/distilabel/steps/clustering/dbscan.py b/src/distilabel/steps/clustering/dbscan.py new file mode 100644 index 0000000000..03ac5dcb3e --- /dev/null +++ b/src/distilabel/steps/clustering/dbscan.py @@ -0,0 +1,177 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib.util +from typing import TYPE_CHECKING, Any, List, Optional + +import numpy as np +from pydantic import Field, PrivateAttr + +from distilabel.mixins.runtime_parameters import RuntimeParameter +from distilabel.steps import ( + GlobalStep, + StepInput, +) + +if TYPE_CHECKING: + from sklearn.cluster import DBSCAN as _DBSCAN + + from distilabel.steps.typing import StepOutput + + +class DBSCAN(GlobalStep): + r"""DBSCAN (Density-Based Spatial Clustering of Applications with Noise) finds core + samples in regions of high density and expands clusters from them. This algorithm + is good for data which contains clusters of similar density. + + This is a `GlobalStep` that clusters the embeddings using the DBSCAN algorithm + from `sklearn`. Visit `TextClustering` step for an example of use. + The trained model is saved as an artifact when creating a distiset + and pushing it to the Hugging Face Hub. + + Input columns: + - projection (`List[float]`): Vector representation of the text to cluster, + normally the output from the `UMAP` step. + + Output columns: + - cluster_label (`int`): Integer representing the label of a given cluster. -1 + means it wasn't clustered. + + Categories: + - clustering + - text-classification + + References: + - [`DBSCAN demo of sklearn`](https://scikit-learn.org/stable/auto_examples/cluster/plot_dbscan.html#demo-of-dbscan-clustering-algorithm) + - [`sklearn dbscan`](https://scikit-learn.org/stable/modules/clustering.html#dbscan) + + Attributes: + - eps: The maximum distance between two samples for one to be considered as in the + neighborhood of the other. This is not a maximum bound on the distances of + points within a cluster. This is the most important DBSCAN parameter to + choose appropriately for your data set and distance function. + - min_samples: The number of samples (or total weight) in a neighborhood for a point + to be considered as a core point. This includes the point itself. If `min_samples` + is set to a higher value, DBSCAN will find denser clusters, whereas if it is set + to a lower value, the found clusters will be more sparse. + - metric: The metric to use when calculating distance between instances in a feature + array. If metric is a string or callable, it must be one of the options allowed + by `sklearn.metrics.pairwise_distances` for its metric parameter. + - n_jobs: The number of parallel jobs to run. + + Runtime parameters: + - `eps`: The maximum distance between two samples for one to be considered as in the + neighborhood of the other. This is not a maximum bound on the distances of + points within a cluster. This is the most important DBSCAN parameter to + choose appropriately for your data set and distance function. + - `min_samples`: The number of samples (or total weight) in a neighborhood for a point + to be considered as a core point. This includes the point itself. If `min_samples` + is set to a higher value, DBSCAN will find denser clusters, whereas if it is set + to a lower value, the found clusters will be more sparse. + - `metric`: The metric to use when calculating distance between instances in a feature + array. If metric is a string or callable, it must be one of the options allowed + by `sklearn.metrics.pairwise_distances` for its metric parameter. + - `n_jobs`: The number of parallel jobs to run. + """ + + eps: Optional[RuntimeParameter[float]] = Field( + default=0.3, + description=( + "The maximum distance between two samples for one to be considered " + "as in the neighborhood of the other. This is not a maximum bound " + "on the distances of points within a cluster. This is the most " + "important DBSCAN parameter to choose appropriately for your data set " + "and distance function." + ), + ) + min_samples: Optional[RuntimeParameter[int]] = Field( + default=30, + description=( + "The number of samples (or total weight) in a neighborhood for a point to " + "be considered as a core point. This includes the point itself. If " + "`min_samples` is set to a higher value, DBSCAN will find denser clusters, " + "whereas if it is set to a lower value, the found clusters will be more " + "sparse." + ), + ) + metric: Optional[RuntimeParameter[str]] = Field( + default="euclidean", + description=( + "The metric to use when calculating distance between instances in a " + "feature array. If metric is a string or callable, it must be one of " + "the options allowed by `sklearn.metrics.pairwise_distances` for " + "its metric parameter." + ), + ) + n_jobs: Optional[RuntimeParameter[int]] = Field( + default=8, description="The number of parallel jobs to run." + ) + + _clusterer: Optional["_DBSCAN"] = PrivateAttr(None) + + def load(self) -> None: + super().load() + if importlib.util.find_spec("sklearn") is None: + raise ImportError( + "`sklearn` package is not installed. Please install it using `pip install scikit-learn`." + ) + from sklearn.cluster import DBSCAN as _DBSCAN + + self._clusterer = _DBSCAN( + eps=self.eps, + min_samples=self.min_samples, + metric=self.metric, + n_jobs=self.n_jobs, + ) + + def unload(self) -> None: + self._clusterer = None + + @property + def inputs(self) -> List[str]: + return ["projection"] + + @property + def outputs(self) -> List[str]: + return ["cluster_label"] + + def _save_model(self, model: Any) -> None: + import joblib + + def save_model(path): + with open(str(path / "DBSCAN.joblib"), "wb") as f: + joblib.dump(model, f) + + self.save_artifact( + name="DBSCAN_model", + write_function=lambda path: save_model(path), + metadata={ + "eps": self.eps, + "min_samples": self.min_samples, + "metric": self.metric, + }, + ) + + def process(self, inputs: StepInput) -> "StepOutput": # type: ignore + projections = np.array([input["projection"] for input in inputs]) + + self._logger.info("🏋️‍♀️ Start training DBSCAN...") + fitted_clusterer = self._clusterer.fit(projections) + cluster_labels = fitted_clusterer.labels_ + # Sets the cluster labels for each input, -1 means it wasn't clustered + for input, cluster_label in zip(inputs, cluster_labels): + input["cluster_label"] = cluster_label + self._logger.info(f"DBSCAN labels assigned: {len(set(cluster_labels))}") + self._save_model(fitted_clusterer) + yield inputs diff --git a/src/distilabel/steps/clustering/text_clustering.py b/src/distilabel/steps/clustering/text_clustering.py new file mode 100644 index 0000000000..4bf583c167 --- /dev/null +++ b/src/distilabel/steps/clustering/text_clustering.py @@ -0,0 +1,327 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib.util +import json +from collections import defaultdict +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union + +import numpy as np +import pandas as pd +from pydantic import Field + +from distilabel.mixins.runtime_parameters import RuntimeParameter +from distilabel.steps import StepInput +from distilabel.steps.tasks import TextClassification +from distilabel.steps.tasks.base import GlobalTask +from distilabel.utils.itertools import batched + +if TYPE_CHECKING: + from distilabel.steps.typing import StepOutput + + +class TextClustering(TextClassification, GlobalTask): + """Task that clusters a set of texts and generates summary labels for each cluster. + + This is a `GlobalTask` that inherits from `TextClassification`, this means that all + the attributes from that class are available here. Also, in this case we deal + with all the inputs at once, instead of using batches. The `input_batch_size` is + used here to send the examples to the LLM in batches (a subtle difference with the + more common `Task` definitions). + The task looks in each cluster for a given number of representative examples (the number + is set by the `samples_per_cluster` attribute), and sends them to the LLM to get a label/s + that represent the cluster. The labels are then assigned to each text in the cluster. + The clusters and projections used in the step, are assumed to be obtained from the `UMAP` + + `DBSCAN` steps, but could be generated for similar steps, as long as they represent the + same concepts. + This step runs a pipeline like the one in this repository: + https://github.com/huggingface/text-clustering + + Input columns: + - text (`str`): The reference text we want to obtain labels for. + - projection (`List[float]`): Vector representation of the text to cluster, + normally the output from the `UMAP` step. + - cluster_label (`int`): Integer representing the label of a given cluster. -1 + means it wasn't clustered. + + Output columns: + - summary_label (`str`): The label or list of labels for the text. + - model_name (`str`): The name of the model used to generate the label/s. + + Categories: + - clustering + - text-classification + + References: + - [`text-clustering repository`](https://github.com/huggingface/text-clustering) + + Attributes: + - savefig: Whether to generate and save a figure with the clustering of the texts. + - samples_per_cluster: The number of examples to use in the LLM as a sample of the cluster. + + Examples: + Generate labels for a set of texts using clustering: + + ```python + from distilabel.llms import InferenceEndpointsLLM + from distilabel.steps import UMAP, DBSCAN, TextClustering + from distilabel.pipeline import Pipeline + + ds_name = "argilla-warehouse/personahub-fineweb-edu-4-clustering-100k" + + with Pipeline(name="Text clustering dataset") as pipeline: + batch_size = 500 + + ds = load_dataset(ds_name, split="train").select(range(10000)) + loader = make_generator_step(ds, batch_size=batch_size, repo_id=ds_name) + + umap = UMAP(n_components=2, metric="cosine") + dbscan = DBSCAN(eps=0.3, min_samples=30) + + text_clustering = TextClustering( + llm=InferenceEndpointsLLM( + model_id="meta-llama/Meta-Llama-3.1-70B-Instruct", + tokenizer_id="meta-llama/Meta-Llama-3.1-70B-Instruct", + ), + n=3, # 3 labels per example + query_title="Examples of Personas", + samples_per_cluster=10, + context=( + "Describe the main themes, topics, or categories that could describe the " + "following types of personas. All the examples of personas must share " + "the same set of labels." + ), + default_label="None", + savefig=True, + input_batch_size=8, + input_mappings={"text": "persona"}, + use_default_structured_output=True, + ) + + loader >> umap >> dbscan >> text_clustering + ``` + """ + + savefig: Optional[RuntimeParameter[bool]] = Field( + default=True, + description="Whether to generate and save a figure with the clustering of the texts.", + ) + samples_per_cluster: int = Field( + default=10, + description="The number of examples to use in the LLM as a sample of the cluster.", + ) + + @property + def inputs(self) -> List[str]: + """The input for the task are the same as those for `TextClassification` plus + the `projection` and `cluster_label` columns (which can be obtained from + UMAP + DBSCAN steps). + """ + return super().inputs + ["projection", "cluster_label"] + + @property + def outputs(self) -> List[str]: + """The output for the task is the `summary_label` and the `model_name`.""" + return ["summary_label", "model_name"] + + def load(self) -> None: + super().load() + if self.savefig and (importlib.util.find_spec("matplotlib") is None): + raise ImportError( + "`matplotlib` package is not installed. Please install it using `pip install matplotlib`." + ) + + def _save_figure( + self, + data: pd.DataFrame, + cluster_centers: Dict[str, Tuple[float, float]], + cluster_summaries: Dict[int, str], + ) -> None: + """Saves the figure starting from the dataframe, using matplotlib. + + Args: + data: pd.DataFrame with the columns 'X', 'Y' and 'labels' representing + the projections and the label of each text respectively. + cluster_centers: Dictionary mapping from each label the center of a cluster, + to help with the placement of the annotations. + cluster_summaries: The summaries of the clusters, obtained from the LLM. + """ + import matplotlib.pyplot as plt + + fig, ax = plt.subplots(figsize=(12, 8), dpi=300) + unique_labels = data["labels"].unique() + # Map of colors for each label (-1 is black) + colormap = dict( + zip(unique_labels, plt.cm.Spectral(np.linspace(0, 1, len(unique_labels)))) + ) + colormap[-1] = np.array([0, 0, 0, 0]) + data["color"] = data["labels"].map(colormap) + + data.plot( + kind="scatter", + x="X", + y="Y", + c="color", + s=0.75, + alpha=0.8, + linewidth=0.4, + ax=ax, + colorbar=False, + ) + + for label in cluster_summaries.keys(): + if label == -1: + continue + summary = str(cluster_summaries[label]) # These are obtained from the LLM + position = cluster_centers[label] + t = ax.text( + position[0], + position[1], + summary, + horizontalalignment="center", + verticalalignment="center", + fontsize=4, + ) + t.set_bbox( + { + "facecolor": "white", + "alpha": 0.9, + "linewidth": 0, + "boxstyle": "square,pad=0.1", + } + ) + + ax.set_axis_off() + # Save the plot as an artifact of the step + self.save_artifact( + name="Text clusters", + write_function=lambda path: fig.savefig(path / "figure_clustering.png"), + metadata={"type": "image", "library": "matplotlib"}, + ) + plt.close() + + def _create_figure( + self, + inputs: StepInput, + label2docs: Dict[int, List[str]], + cluster_summaries: Dict[int, str], + ) -> None: + """Creates a figure of the clustered texts and save it as an artifact. + + Args: + inputs: The inputs of the step, as we will extract information from them again. + label2docs: Map from each label to the list of documents (texts) that belong to that cluster. + cluster_summaries: The summaries of the clusters, obtained from the LLM. + labels: The labels of the clusters (integers representing each predicted class). + """ + self._logger.info("🖼️ Creating figure for the clusters...") + + labels = [] + projections = [] + id2cluster = {} + for i, input in enumerate(inputs): + label = input["cluster_label"] + id2cluster[i] = label + labels.append(label) + projections.append(input["projection"]) + + projections = np.array(projections) + + # Contains the placement of the cluster centers in the figure + cluster_centers: Dict[str, Tuple[float, float]] = {} + for label in label2docs.keys(): + x = np.mean([projections[doc, 0] for doc in label2docs[label]]) + y = np.mean([projections[doc, 1] for doc in label2docs[label]]) + cluster_centers[label] = (x, y) + + df = pd.DataFrame( + data={ + "X": projections[:, 0], + "Y": projections[:, 1], + "labels": labels, + } + ) + + self._save_figure( + df, cluster_centers=cluster_centers, cluster_summaries=cluster_summaries + ) + + def _prepare_input_texts( + self, + inputs: StepInput, + label2docs: Dict[int, List[int]], + unique_labels: List[int], + ) -> List[Dict[str, Union[str, int]]]: + """Prepares a batch of inputs to send to the LLM, with the examples of each cluster. + + Args: + inputs: Inputs from the step. + label2docs: Map from each label to the list of documents (texts) that + belong to that cluster. + unique_labels: The unique labels of the clusters. + + Returns: + The input texts to send to the LLM, with the examples of each cluster + prepared to be used in the prompt, and an additional key to store the + labels (that will be needed to find the data after the batches are + returned from the LLM). + """ + input_texts = [] + for label in range(unique_labels): # The label -1 is implicitly excluded + # Get the ids but remove possible duplicates, which could happen with bigger probability + # the bigger the number of examples requested, and the smaller the subset of examples + ids = set( + np.random.choice(label2docs[label], size=self.samples_per_cluster) + ) # Grab the number of examples + examples = [inputs[i]["text"] for i in ids] + input_text = { + "text": "\n\n".join( + [f"Example {i}:\n{t}" for i, t in enumerate(examples, start=1)] + ), + "__LABEL": label, + } + input_texts.append(input_text) + return input_texts + + def process(self, inputs: StepInput) -> "StepOutput": + labels = [input["cluster_label"] for input in inputs] + # -1 because -1 is the label for the unclassified + unique_labels = len(set(labels)) - 1 + # This will be the output of the LLM, the set of labels for each cluster + cluster_summaries: Dict[int, str] = {-1: self.default_label} + + # Map from label to list of documents, will use them to select examples from each cluster + label2docs = defaultdict(list) + for i, label in enumerate(labels): + label2docs[label].append(i) + + input_texts = self._prepare_input_texts(inputs, label2docs, unique_labels) + + # Send the texts in batches to the LLM, and get the labels for each cluster + for i, batched_inputs in enumerate(batched(input_texts, self.input_batch_size)): + self._logger.info(f"📦 Processing internal batch of inputs {i}...") + results = super().process(batched_inputs) + for result in next(results): # Extract the elements from the generator + cluster_summaries[result["__LABEL"]] = result["labels"] + + # Assign the labels to each text + for input in inputs: + input["summary_label"] = json.dumps( + cluster_summaries[input["cluster_label"]] + ) + + if self.savefig: + self._create_figure(inputs, label2docs, cluster_summaries) + + yield inputs diff --git a/src/distilabel/steps/clustering/umap.py b/src/distilabel/steps/clustering/umap.py new file mode 100644 index 0000000000..daeb37486d --- /dev/null +++ b/src/distilabel/steps/clustering/umap.py @@ -0,0 +1,164 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib.util +from typing import TYPE_CHECKING, Any, List, Optional + +import numpy as np +from pydantic import Field, PrivateAttr + +from distilabel.mixins.runtime_parameters import RuntimeParameter +from distilabel.steps import ( + GlobalStep, + StepInput, +) + +if TYPE_CHECKING: + from umap import UMAP as _UMAP + + from distilabel.steps.typing import StepOutput + + +class UMAP(GlobalStep): + r"""UMAP is a general purpose manifold learning and dimension reduction algorithm. + + This is a `GlobalStep` that reduces the dimensionality of the embeddings using. Visit + the `TextClustering` step for an example of use. The trained model is saved as an artifact + when creating a distiset and pushing it to the Hugging Face Hub. + + Input columns: + - embedding (`List[float]`): The original embeddings we want to reduce the dimension. + + Output columns: + - projection (`List[float]`): Embedding reduced to the number of components specified, + the size of the new embeddings will be determined by the `n_components`. + + Categories: + - clustering + - text-classification + + References: + - [`UMAP repository`](https://github.com/lmcinnes/umap/tree/master) + - [`UMAP documentation`](https://umap-learn.readthedocs.io/en/latest/) + + Attributes: + - n_components: The dimension of the space to embed into. This defaults to 2 to + provide easy visualization (that's probably what you want), but can + reasonably be set to any integer value in the range 2 to 100. + - metric: The metric to use to compute distances in high dimensional space. + Visit UMAP's documentation for more information. Defaults to `euclidean`. + - n_jobs: The number of parallel jobs to run. Defaults to `8`. + - random_state: The random state to use for the UMAP algorithm. + + Runtime parameters: + - `n_components`: The dimension of the space to embed into. This defaults to 2 to + provide easy visualization (that's probably what you want), but can + reasonably be set to any integer value in the range 2 to 100. + - `metric`: The metric to use to compute distances in high dimensional space. + Visit UMAP's documentation for more information. Defaults to `euclidean`. + - `n_jobs`: The number of parallel jobs to run. Defaults to `8`. + - `random_state`: The random state to use for the UMAP algorithm. + + Citations: + ``` + @misc{mcinnes2020umapuniformmanifoldapproximation, + title={UMAP: Uniform Manifold Approximation and Projection for Dimension Reduction}, + author={Leland McInnes and John Healy and James Melville}, + year={2020}, + eprint={1802.03426}, + archivePrefix={arXiv}, + primaryClass={stat.ML}, + url={https://arxiv.org/abs/1802.03426}, + } + ``` + """ + + n_components: Optional[RuntimeParameter[int]] = Field( + default=2, + description=( + "The dimension of the space to embed into. This defaults to 2 to " + "provide easy visualization, but can reasonably be set to any " + "integer value in the range 2 to 100." + ), + ) + metric: Optional[RuntimeParameter[str]] = Field( + default="euclidean", + description=( + "The metric to use to compute distances in high dimensional space. " + "Visit UMAP's documentation for more information." + ), + ) + n_jobs: Optional[RuntimeParameter[int]] = Field( + default=8, description="The number of parallel jobs to run." + ) + random_state: Optional[RuntimeParameter[int]] = Field( + default=None, description="The random state to use for the UMAP algorithm." + ) + + _umap: Optional["_UMAP"] = PrivateAttr(None) + + def load(self) -> None: + super().load() + if importlib.util.find_spec("umap") is None: + raise ImportError( + "`umap` package is not installed. Please install it using `pip install umap-learn`." + ) + from umap import UMAP as _UMAP + + self._umap = _UMAP( + n_components=self.n_components, + metric=self.metric, + n_jobs=self.n_jobs, + random_state=self.random_state, + ) + + def unload(self) -> None: + self._umap = None + + @property + def inputs(self) -> List[str]: + return ["embedding"] + + @property + def outputs(self) -> List[str]: + return ["projection"] + + def _save_model(self, model: Any) -> None: + import joblib + + def save_model(path): + with open(str(path / "UMAP.joblib"), "wb") as f: + joblib.dump(model, f) + + self.save_artifact( + name="UMAP_model", + write_function=lambda path: save_model(path), + metadata={ + "n_components": self.n_components, + "metric": self.metric, + }, + ) + + def process(self, inputs: StepInput) -> "StepOutput": # type: ignore + # Shape of the embeddings is (n_samples, n_features) + embeddings = np.array([input["embedding"] for input in inputs]) + + self._logger.info("🏋️‍♀️ Start UMAP training...") + mapper = self._umap.fit(embeddings) + # Shape of the projection will be (n_samples, n_components) + for input, projection in zip(inputs, mapper.embedding_): + input["projection"] = projection + + self._save_model(mapper) + yield inputs diff --git a/src/distilabel/steps/generators/utils.py b/src/distilabel/steps/generators/utils.py index 27455119bd..49d27748b4 100644 --- a/src/distilabel/steps/generators/utils.py +++ b/src/distilabel/steps/generators/utils.py @@ -32,7 +32,7 @@ def make_generator_step( input_mappings: Optional[Dict[str, str]] = None, output_mappings: Optional[Dict[str, str]] = None, resources: StepResources = StepResources(), - repo_id: str = "placeholder", + repo_id: Optional[str] = "default_name", ) -> "GeneratorStep": """Helper method to create a `GeneratorStep` from a dataset, to simplify @@ -46,7 +46,6 @@ def make_generator_step( repo_id: The repository ID to use in the `LoadDataFromHub` step. This shouldn't be necessary, but in case of error, the dataset will try to be loaded using `load_dataset` internally. If that case happens, the `repo_id` will be used. - Defaults to `"placeholder"`. Raises: ValueError: If the format is different from the ones supported. diff --git a/src/distilabel/steps/tasks/__init__.py b/src/distilabel/steps/tasks/__init__.py index 7bd96c3ce0..eb90c6dbad 100644 --- a/src/distilabel/steps/tasks/__init__.py +++ b/src/distilabel/steps/tasks/__init__.py @@ -43,6 +43,7 @@ from distilabel.steps.tasks.self_instruct import SelfInstruct from distilabel.steps.tasks.sentence_transformers import GenerateSentencePair from distilabel.steps.tasks.structured_generation import StructuredGeneration +from distilabel.steps.tasks.text_classification import TextClassification from distilabel.steps.tasks.text_generation import ChatGeneration, TextGeneration from distilabel.steps.tasks.typing import ChatItem, ChatType from distilabel.steps.tasks.ultrafeedback import UltraFeedback @@ -75,6 +76,7 @@ "SelfInstruct", "GenerateSentencePair", "StructuredGeneration", + "TextClassification", "ChatGeneration", "TextGeneration", "ChatItem", diff --git a/src/distilabel/steps/tasks/base.py b/src/distilabel/steps/tasks/base.py index dcf704a807..a0afb74c32 100644 --- a/src/distilabel/steps/tasks/base.py +++ b/src/distilabel/steps/tasks/base.py @@ -235,10 +235,11 @@ def check_dependency(module_name: str) -> None: dependency = "outlines" structured_output = {"schema": schema} + if isinstance(self.llm, InferenceEndpointsLLM): + structured_output.update({"format": "json"}) # To determine instructor or outlines format - if not ( - isinstance(self.llm, AsyncLLM) - and not isinstance(self.llm, InferenceEndpointsLLM) + elif isinstance(self.llm, AsyncLLM) and not isinstance( + self.llm, InferenceEndpointsLLM ): dependency = "instructor" structured_output.update({"format": "json"}) diff --git a/src/distilabel/steps/tasks/text_classification.py b/src/distilabel/steps/tasks/text_classification.py new file mode 100644 index 0000000000..5d04b3b2db --- /dev/null +++ b/src/distilabel/steps/tasks/text_classification.py @@ -0,0 +1,378 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from textwrap import indent +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union + +import orjson +from jinja2 import Template +from pydantic import BaseModel, Field, PositiveInt, PrivateAttr +from typing_extensions import override + +from distilabel.steps.tasks import Task + +if TYPE_CHECKING: + from distilabel.steps.tasks.typing import ChatType + + +TEXT_CLASSIFICATION_TEMPLATE: str = """\ +# Instruction +Please classify the {{ query_title.lower() }} by assigning the most appropriate labels. +Do not explain your reasoning or provide any additional commentary. +If the text is ambiguous or lacks sufficient information for classification, respond with "{{ default_label }}". +{{ labels_message }}{{ context}} +{{ available_labels }} +{{ examples }} + +## {{ query_title }} +``` +{{ text }} +``` + +## Output Format +Now, please give me the labels in JSON format, do not include any other text in your response: +``` +{ + "labels": {{ labels_format }} +} +``` +""".rstrip() + + +class TextClassification(Task): + r"""Classifies text into one or more categories or labels. + + This task can be used for text classification problems, where the goal is to assign + one or multiple labels to a given text. + It uses structured generation as per the reference paper by default, + it can help to generate more concise labels. See section 4.1 in the reference. + + Input columns: + - text (`str`): The reference text we want to obtain labels for. + + Output columns: + - labels (`Union[str, List[str]]`): The label or list of labels for the text. + - model_name (`str`): The name of the model used to generate the label/s. + + Categories: + - text-classification + + References: + - [`Let Me Speak Freely? A Study on the Impact of Format Restrictions on Performance of Large Language Models`](https://arxiv.org/abs/2408.02442) + + Attributes: + system_prompt: A prompt to display to the user before the task starts. Contains a default + message to make the model behave like a classifier specialist. + n: Number of labels to generate If only 1 is required, corresponds to a label + classification problem, if >1 it will intend return the "n" labels most representative + for the text. Defaults to 1. + context: Context to use when generating the labels. By default contains a generic message, + but can be used to customize the context for the task. + examples: List of examples to help the model understand the task, few shots. + available_labels: List of available labels to choose from when classifying the text, or + a dictionary with the labels and their descriptions. + default_label: Default label to use when the text is ambiguous or lacks sufficient information for + classification. Can be a list in case of multiple labels (n>1). + + Examples: + Assigning a sentiment to a text: + + ```python + from distilabel.steps.tasks import TextClassification + from distilabel.llms.huggingface import InferenceEndpointsLLM + + llm = InferenceEndpointsLLM( + model_id="meta-llama/Meta-Llama-3.1-70B-Instruct", + tokenizer_id="meta-llama/Meta-Llama-3.1-70B-Instruct", + ) + + text_classification = TextClassification( + llm=llm, + context="You are an AI system specialized in assigning sentiment to movies.", + available_labels=["positive", "negative"], + ) + + text_classification.load() + + result = next( + text_classification.process( + [{"text": "This was a masterpiece. Not completely faithful to the books, but enthralling from beginning to end. Might be my favorite of the three."}] + ) + ) + # result + # [{'text': 'This was a masterpiece. Not completely faithful to the books, but enthralling from beginning to end. Might be my favorite of the three.', + # 'labels': 'positive', + # 'distilabel_metadata': {'raw_output_text_classification_0': '{\n "labels": "positive"\n}', + # 'raw_input_text_classification_0': [{'role': 'system', + # 'content': 'You are an AI system specialized in generating labels to classify pieces of text. Your sole purpose is to analyze the given text and provide appropriate classification labels.'}, + # {'role': 'user', + # 'content': '# Instruction\nPlease classify the user query by assigning the most appropriate labels.\nDo not explain your reasoning or provide any additional commentary.\nIf the text is ambiguous or lacks sufficient information for classification, respond with "Unclassified".\nProvide the label that best describes the text.\nYou are an AI system specialized in assigning sentiment to movie the user queries.\n## Labeling the user input\nUse the available labels to classify the user query. Analyze the context of each label specifically:\navailable_labels = [\n "positive", # The text shows positive sentiment\n "negative", # The text shows negative sentiment\n]\n\n\n## User Query\n```\nThis was a masterpiece. Not completely faithful to the books, but enthralling from beginning to end. Might be my favorite of the three.\n```\n\n## Output Format\nNow, please give me the labels in JSON format, do not include any other text in your response:\n```\n{\n "labels": "label"\n}\n```'}]}, + # 'model_name': 'meta-llama/Meta-Llama-3.1-70B-Instruct'}] + ``` + + Assigning predefined labels with specified descriptions: + + ```python + from distilabel.steps.tasks import TextClassification + + text_classification = TextClassification( + llm=llm, + n=1, + context="Determine the intent of the text.", + available_labels={ + "complaint": "A statement expressing dissatisfaction or annoyance about a product, service, or experience. It's a negative expression of discontent, often with the intention of seeking a resolution or compensation.", + "inquiry": "A question or request for information about a product, service, or situation. It's a neutral or curious expression seeking clarification or details.", + "feedback": "A statement providing evaluation, opinion, or suggestion about a product, service, or experience. It can be positive, negative, or neutral, and is often intended to help improve or inform.", + "praise": "A statement expressing admiration, approval, or appreciation for a product, service, or experience. It's a positive expression of satisfaction or delight, often with the intention of encouraging or recommending." + }, + query_title="Customer Query", + ) + + text_classification.load() + + result = next( + text_classification.process( + [{"text": "Can you tell me more about your return policy?"}] + ) + ) + # result + # [{'text': 'Can you tell me more about your return policy?', + # 'labels': 'inquiry', + # 'distilabel_metadata': {'raw_output_text_classification_0': '{\n "labels": "inquiry"\n}', + # 'raw_input_text_classification_0': [{'role': 'system', + # 'content': 'You are an AI system specialized in generating labels to classify pieces of text. Your sole purpose is to analyze the given text and provide appropriate classification labels.'}, + # {'role': 'user', + # 'content': '# Instruction\nPlease classify the customer query by assigning the most appropriate labels.\nDo not explain your reasoning or provide any additional commentary.\nIf the text is ambiguous or lacks sufficient information for classification, respond with "Unclassified".\nProvide the label that best describes the text.\nDetermine the intent of the text.\n## Labeling the user input\nUse the available labels to classify the user query. Analyze the context of each label specifically:\navailable_labels = [\n "complaint", # A statement expressing dissatisfaction or annoyance about a product, service, or experience. It\'s a negative expression of discontent, often with the intention of seeking a resolution or compensation.\n "inquiry", # A question or request for information about a product, service, or situation. It\'s a neutral or curious expression seeking clarification or details.\n "feedback", # A statement providing evaluation, opinion, or suggestion about a product, service, or experience. It can be positive, negative, or neutral, and is often intended to help improve or inform.\n "praise", # A statement expressing admiration, approval, or appreciation for a product, service, or experience. It\'s a positive expression of satisfaction or delight, often with the intention of encouraging or recommending.\n]\n\n\n## Customer Query\n```\nCan you tell me more about your return policy?\n```\n\n## Output Format\nNow, please give me the labels in JSON format, do not include any other text in your response:\n```\n{\n "labels": "label"\n}\n```'}]}, + # 'model_name': 'meta-llama/Meta-Llama-3.1-70B-Instruct'}] + ``` + + Free multi label classification without predefined labels: + + ```python + from distilabel.steps.tasks import TextClassification + + text_classification = TextClassification( + llm=llm, + n=3, + context=( + "Describe the main themes, topics, or categories that could describe the " + "following type of persona." + ), + query_title="Example of Persona", + ) + + text_classification.load() + + result = next( + text_classification.process( + [{"text": "A historian or curator of Mexican-American history and culture focused on the cultural, social, and historical impact of the Mexican presence in the United States."}] + ) + ) + # result + # [{'text': 'A historian or curator of Mexican-American history and culture focused on the cultural, social, and historical impact of the Mexican presence in the United States.', + # 'labels': ['Historical Researcher', + # 'Cultural Specialist', + # 'Ethnic Studies Expert'], + # 'distilabel_metadata': {'raw_output_text_classification_0': '{\n "labels": ["Historical Researcher", "Cultural Specialist", "Ethnic Studies Expert"]\n}', + # 'raw_input_text_classification_0': [{'role': 'system', + # 'content': 'You are an AI system specialized in generating labels to classify pieces of text. Your sole purpose is to analyze the given text and provide appropriate classification labels.'}, + # {'role': 'user', + # 'content': '# Instruction\nPlease classify the example of persona by assigning the most appropriate labels.\nDo not explain your reasoning or provide any additional commentary.\nIf the text is ambiguous or lacks sufficient information for classification, respond with "Unclassified".\nProvide a list of 3 labels that best describe the text.\nDescribe the main themes, topics, or categories that could describe the following type of persona.\nUse clear, widely understood terms for labels.Avoid overly specific or obscure labels unless the text demands it.\n\n\n## Example of Persona\n```\nA historian or curator of Mexican-American history and culture focused on the cultural, social, and historical impact of the Mexican presence in the United States.\n```\n\n## Output Format\nNow, please give me the labels in JSON format, do not include any other text in your response:\n```\n{\n "labels": ["label_0", "label_1", "label_2"]\n}\n```'}]}, + # 'model_name': 'meta-llama/Meta-Llama-3.1-70B-Instruct'}] + ``` + """ + + system_prompt: Optional[str] = ( + "You are an AI system specialized in generating labels to classify pieces of text. " + "Your sole purpose is to analyze the given text and provide appropriate classification labels." + ) + n: PositiveInt = Field( + default=1, + description="Number of labels to generate. Defaults to 1.", + ) + context: Optional[str] = Field( + default="Generate concise, relevant labels that accurately represent the text's main themes, topics, or categories.", + description="Context to use when generating the labels.", + ) + examples: Optional[List[str]] = Field( + default=None, + description="List of examples to help the model understand the task, few shots.", + ) + available_labels: Optional[Union[List[str], Dict[str, str]]] = Field( + default=None, + description=( + "List of available labels to choose from when classifying the text, or " + "a dictionary with the labels and their descriptions." + ), + ) + default_label: Optional[Union[str, List[str]]] = Field( + default="Unclassified", + description=( + "Default label to use when the text is ambiguous or lacks sufficient information for " + "classification. Can be a list in case of multiple labels (n>1)." + ), + ) + query_title: str = Field( + default="User Query", + description="Title of the query used to show the example/s to classify.", + ) + use_default_structured_output: bool = True + + _template: Optional[Template] = PrivateAttr(default=None) + + def load(self) -> None: + super().load() + self._template = Template(TEXT_CLASSIFICATION_TEMPLATE) + self._labels_format: str = ( + '"label"' + if self.n == 1 + else "[" + ", ".join([f'"label_{i}"' for i in range(self.n)]) + "]" + ) + self._labels_message: str = ( + "Provide the label that best describes the text." + if self.n == 1 + else f"Provide a list of {self.n} labels that best describe the text." + ) + self._available_labels_message: str = self._get_available_labels_message() + self._examples: str = self._get_examples_message() + + def _get_available_labels_message(self) -> str: + """Prepares the message to display depending on the available labels (if any), + and whether the labels have a specific context. + """ + if self.available_labels is None: + return ( + "Use clear, widely understood terms for labels." + "Avoid overly specific or obscure labels unless the text demands it." + ) + + msg = ( + "## Labeling the user input\n" + "Use the available labels to classify the user query{label_context}:\n" + "available_labels = {available_labels}" + ) + if isinstance(self.available_labels, list): + specific_msg = ( + "[\n" + + indent( + "".join([f'"{label}",\n' for label in self.available_labels]), + prefix=" " * 4, + ) + + "]" + ) + return msg.format(label_context="", available_labels=specific_msg) + + elif isinstance(self.available_labels, dict): + specific_msg = "" + for label, description in self.available_labels.items(): + specific_msg += indent( + f'"{label}", # {description}' + "\n", prefix=" " * 4 + ) + + specific_msg = "[\n" + specific_msg + "]" + return msg.format( + label_context=". Analyze the context of each label specifically", + available_labels=specific_msg, + ) + + def _get_examples_message(self) -> str: + """Prepares the message to display depending on the examples provided.""" + if self.examples is None: + return "" + + examples_msg = "\n".join([f"- {ex}" for ex in self.examples]) + + return ( + "\n## Examples\n" + "Here are some examples to help you understand the task:\n" + f"{examples_msg}" + ) + + @property + def inputs(self) -> List[str]: + """The input for the task is the `instruction`.""" + return ["text"] + + @property + def outputs(self) -> List[str]: + """The output for the task is the `generation` and the `model_name`.""" + return ["labels", "model_name"] + + def format_input(self, input: Dict[str, Any]) -> "ChatType": + """The input is formatted as a `ChatType` assuming that the instruction + is the first interaction from the user within a conversation.""" + messages = [ + { + "role": "user", + "content": self._template.render( # type: ignore + context=f"\n{self.context}", + labels_message=self._labels_message, + available_labels=self._available_labels_message, + examples=self._examples, + default_label=self.default_label, + labels_format=self._labels_format, + query_title=self.query_title, + text=input["text"], + ), + }, + ] + if self.system_prompt: + messages.insert(0, {"role": "system", "content": self.system_prompt}) + return messages + + def format_output( + self, output: Union[str, None], input: Union[Dict[str, Any], None] = None + ) -> Dict[str, Any]: + """The output is formatted as a dictionary with the `generation`. The `model_name` + will be automatically included within the `process` method of `Task`.""" + return self._format_structured_output(output) + + @override + def get_structured_output(self) -> Dict[str, Any]: + """Creates the json schema to be passed to the LLM, to enforce generating + a dictionary with the output which can be directly parsed as a python dictionary. + + Returns: + JSON Schema of the response to enforce. + """ + if self.n > 1: + + class MultiLabelSchema(BaseModel): + labels: List[str] + + return MultiLabelSchema.model_json_schema() + + class SingleLabelSchema(BaseModel): + labels: str + + return SingleLabelSchema.model_json_schema() + + def _format_structured_output( + self, output: str + ) -> Dict[str, Union[str, List[str]]]: + """Parses the structured response, which should correspond to a dictionary + with the `labels`, and either a string or a list of strings with the labels. + + Args: + output: The output from the `LLM`. + + Returns: + Formatted output. + """ + try: + return orjson.loads(output) + except orjson.JSONDecodeError: + if self.n > 1: + return {"labels": [None for _ in range(self.n)]} + return {"labels": None} diff --git a/src/distilabel/utils/itertools.py b/src/distilabel/utils/itertools.py index 2555f3b262..34accced2b 100644 --- a/src/distilabel/utils/itertools.py +++ b/src/distilabel/utils/itertools.py @@ -12,11 +12,26 @@ # See the License for the specific language governing permissions and # limitations under the License. +import sys from itertools import zip_longest from typing import Any, Iterable, Literal, Tuple, TypeVar T = TypeVar("T") +# https://docs.python.org/3/library/itertools.html#itertools.batched +if sys.version_info >= (3, 12): + from itertools import batched +else: + from itertools import islice + + def batched(iterable: Iterable[T], n: int) -> Iterable[T]: + # batched('ABCDEFG', 3) → ABC DEF G + if n < 1: + raise ValueError("n must be at least one") + iterator = iter(iterable) + while batch := tuple(islice(iterator, n)): + yield batch + # Copy pasted from https://docs.python.org/3/library/itertools.html#itertools-recipes # Just added the type hints and use `if`s instead of `match` diff --git a/src/distilabel/utils/mkdocs/components_gallery.py b/src/distilabel/utils/mkdocs/components_gallery.py index a7dba7e7da..3798b6f90a 100644 --- a/src/distilabel/utils/mkdocs/components_gallery.py +++ b/src/distilabel/utils/mkdocs/components_gallery.py @@ -87,6 +87,8 @@ "text-generation": ":material-text-box-edit:", "text-manipulation": ":material-receipt-text-edit:", "columns": ":material-table-column:", + "text-classification": ":material-label:", + "clustering": ":material-scatter-plot:", } diff --git a/tests/unit/steps/clustering/__init__.py b/tests/unit/steps/clustering/__init__.py new file mode 100644 index 0000000000..20ce00bda7 --- /dev/null +++ b/tests/unit/steps/clustering/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/tests/unit/steps/clustering/test_dbscan.py b/tests/unit/steps/clustering/test_dbscan.py new file mode 100644 index 0000000000..d4f62a3fae --- /dev/null +++ b/tests/unit/steps/clustering/test_dbscan.py @@ -0,0 +1,39 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from distilabel.steps.clustering.dbscan import DBSCAN + + +class TestDBSCAN: + def test_process(self) -> None: + step = DBSCAN(n_jobs=1, eps=0.5, min_samples=5) + step.load() + + results = next( + step.process( + inputs=[ + {"projection": [0.1, -0.4]}, + {"projection": [-0.3, 0.9]}, + {"projection": [0.6, 0.2]}, + {"projection": [-0.2, -0.6]}, + {"projection": [0.9, 0.1]}, + {"projection": [0.4, -0.7]}, + {"projection": [-0.5, 0.3]}, + {"projection": [0.7, 0.5]}, + {"projection": [-0.1, -0.9]}, + ] + ) + ) + assert all(result["cluster_label"] == -1 for result in results) diff --git a/tests/unit/steps/clustering/test_text_clustering.py b/tests/unit/steps/clustering/test_text_clustering.py new file mode 100644 index 0000000000..4b2da96d40 --- /dev/null +++ b/tests/unit/steps/clustering/test_text_clustering.py @@ -0,0 +1,75 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from typing import TYPE_CHECKING + +import pytest + +from distilabel.steps.clustering.text_clustering import TextClustering +from tests.unit.conftest import DummyAsyncLLM + +if TYPE_CHECKING: + from distilabel.llms.typing import GenerateOutput + from distilabel.steps.tasks.typing import FormattedInput + + +class ClusteringLLM(DummyAsyncLLM): + n: int = 1 + + async def agenerate( # type: ignore + self, input: "FormattedInput", num_generations: int = 1 + ) -> "GenerateOutput": + if self.n == 1: + return [json.dumps({"labels": "label"}) for _ in range(num_generations)] + return [ + json.dumps({"labels": ["label" for _ in range(self.n)]}) + for _ in range(self.n) + ] + + +class TestTextClustering: + @pytest.mark.parametrize("n", [1, 3]) + def test_process(self, n: int) -> None: + step = TextClustering( + llm=ClusteringLLM(n=n), + n=n, + samples_per_cluster=2, + savefig=False, + ) + step.load() + + results = next( + step.process( + inputs=[ + {"projection": [0.1, -0.4], "cluster_label": -1, "text": "hello"}, + {"projection": [-0.3, 0.9], "cluster_label": -1, "text": "hello"}, + {"projection": [0.6, 0.2], "cluster_label": 0, "text": "hello"}, + {"projection": [-0.2, -0.6], "cluster_label": 0, "text": "hello"}, + {"projection": [0.9, 0.1], "cluster_label": 0, "text": "hello"}, + {"projection": [0.4, -0.7], "cluster_label": 1, "text": "hello"}, + {"projection": [-0.5, 0.3], "cluster_label": 1, "text": "hello"}, + {"projection": [0.7, 0.5], "cluster_label": 2, "text": "hello"}, + {"projection": [-0.1, -0.9], "cluster_label": 2, "text": "hello"}, + ] + ) + ) + for r in results: + if r["cluster_label"] == -1: + assert r["summary_label"] == json.dumps("Unclassified") + else: + if n == 1: + assert r["summary_label"] == json.dumps("label") + else: + assert r["summary_label"] == json.dumps(["label"] * n) diff --git a/tests/unit/steps/clustering/test_umap.py b/tests/unit/steps/clustering/test_umap.py new file mode 100644 index 0000000000..3ab252fd24 --- /dev/null +++ b/tests/unit/steps/clustering/test_umap.py @@ -0,0 +1,42 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + +from distilabel.steps.clustering.umap import UMAP + + +class TestUMAP: + def test_process(self) -> None: + n_components = 2 + step = UMAP(n_jobs=1, n_components=n_components) + step.load() + + results = next( + step.process( + inputs=[ + {"embedding": [0.1, -0.4, 0.7, 0.2]}, + {"embedding": [-0.3, 0.9, 0.1, -0.5]}, + {"embedding": [0.6, 0.2, -0.1, 0.8]}, + {"embedding": [-0.2, -0.6, 0.4, 0.3]}, + {"embedding": [0.9, 0.1, -0.3, -0.2]}, + {"embedding": [0.4, -0.7, 0.6, 0.1]}, + {"embedding": [-0.5, 0.3, -0.2, 0.9]}, + {"embedding": [0.7, 0.5, -0.4, -0.1]}, + {"embedding": [-0.1, -0.9, 0.8, 0.6]}, + ] + ) + ) + assert all(isinstance(result["projection"], np.ndarray) for result in results) + assert all(len(result["projection"]) == n_components for result in results) diff --git a/tests/unit/steps/tasks/test_text_classification.py b/tests/unit/steps/tasks/test_text_classification.py new file mode 100644 index 0000000000..e5af171b33 --- /dev/null +++ b/tests/unit/steps/tasks/test_text_classification.py @@ -0,0 +1,140 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from typing import TYPE_CHECKING, Dict, List, Optional, Union + +import pytest + +from distilabel.steps.tasks.text_classification import TextClassification +from tests.unit.conftest import DummyAsyncLLM + +if TYPE_CHECKING: + from distilabel.llms.typing import GenerateOutput + from distilabel.steps.tasks.typing import FormattedInput + + +class TextClassificationLLM(DummyAsyncLLM): + n: int = 1 + + async def agenerate( # type: ignore + self, input: "FormattedInput", num_generations: int = 1 + ) -> "GenerateOutput": + if self.n == 1: + return [json.dumps({"labels": "label"}) for _ in range(num_generations)] + return [ + json.dumps({"labels": [f"label_{i}" for i in range(self.n)]}) + for _ in range(num_generations) + ] + + +class TestTextClassification: + @pytest.mark.parametrize( + "n, context, examples, available_labels, default_label, query_title", + [ + (1, "context", None, None, "Unclassified", "User Query"), + (1, "", ["example"], ["label1", "label2"], "default", "User Query"), + ( + 1, + "", + ["example"], + {"label1": "explanation 1", "label2": "explanation 2"}, + "default", + "User Query", + ), + ( + 3, + "", + ["example", "other example"], + None, + "default", + "User Query", + ), + ], + ) + def test_format_input( + self, + n: int, + context: str, + examples: Optional[List[str]], + available_labels: Optional[Union[List[str], Dict[str, str]]], + default_label: Optional[Union[str, List[str]]], + query_title: str, + ) -> None: + task = TextClassification( + llm=DummyAsyncLLM(), + n=n, + context=context, + examples=examples, + available_labels=available_labels, + default_label=default_label, + query_title=query_title, + ) + task.load() + + result = task.format_input({"text": "SAMPLE_TEXT"}) + content = result[1]["content"] + + assert f'respond with "{default_label}"' in content + assert "## User Query\n```\nSAMPLE_TEXT\n```" in content + assert f'respond with "{default_label}"' in content + if n == 1: + assert "Provide the label that best describes the text." in content + assert '```\n{\n "labels": "label"\n}\n```' in content + else: + assert ( + f"Provide a list of {n} labels that best describe the text." in content + ) + assert ( + '```\n{\n "labels": ["label_0", "label_1", "label_2"]\n}\n```' + in content + ) + if available_labels: + if isinstance(available_labels, list): + assert 'Use the available labels to classify the user query:\navailable_labels = [\n "label1",\n "label2"\n]' + if isinstance(available_labels, dict): + assert 'Use the available labels to classify the user query:\navailable_labels = [\n "label1", # explanation 1\n "label2", # explanation 2\n]' + + if examples: + assert ( + "## Examples\nHere are some examples to help you understand the task:\n- example\n" + in content + ) + else: + assert "## Examples" not in content + assert ( + f"Please classify the {query_title.lower()} by assigning the most appropriate labels." + in content + ) + assert f"## {query_title}" in content + + @pytest.mark.parametrize( + "n, expected", + [ + (1, json.dumps({"labels": "label"})), + (3, json.dumps({"labels": ["label_0", "label_1", "label_2"]})), + ], + ) + def test_process(self, n: int, expected: str) -> None: + task = TextClassification( + llm=TextClassificationLLM(n=n), n=n, use_default_structured_output=True + ) + task.load() + result = next(task.process([{"text": "SAMPLE_TEXT"}])) + assert result[0]["text"] == "SAMPLE_TEXT" + assert result[0]["labels"] == json.loads(expected)["labels"] + assert ( + result[0]["distilabel_metadata"]["raw_output_text_classification_0"] + == expected + )