Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Llm generate upgrade #1034

Open
wants to merge 34 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
caea13f
Update GenerateOutput type
plaguss Oct 11, 2024
f7d7a0e
Add draft function to compute number of tokens given a tokenizer
plaguss Oct 11, 2024
dbcfafa
Refactor llm generation to return generations and statistics
plaguss Oct 11, 2024
a0bf204
Move statistics from the LLM to distilabel_metadata row
plaguss Oct 14, 2024
a51ce59
Update tests and LLM outputs to run with generations and statistics a…
plaguss Oct 14, 2024
fe5d4c5
Openai computed tokens
plaguss Oct 14, 2024
394984f
First version of async llms with statistics
plaguss Oct 14, 2024
9003fa0
Return generations with list of strings and token count from _raw_res…
plaguss Oct 15, 2024
4cf0e2f
Passing tests for inference endpoints
plaguss Oct 15, 2024
1b6d15c
Testing vLLM with statistics
plaguss Oct 16, 2024
8923880
Refactor statistics module to utils and output preparation to avoid c…
plaguss Oct 16, 2024
608e8b6
Refactor to remove code duplication
plaguss Oct 16, 2024
8c35af5
Fix async llms not returning properly the generations grouped by num_…
plaguss Oct 17, 2024
6d19de7
Fix async llms not processing multiple generations
plaguss Oct 17, 2024
9746d75
Fix vllm sorting mechanism and add mocked generate method to the test…
plaguss Oct 18, 2024
f108670
Checkpoint
plaguss Oct 22, 2024
c8063a4
Fix tests from merge responses and group generations
plaguss Oct 23, 2024
6f6769a
Move import to guarded type hint
plaguss Oct 23, 2024
4971f26
Fix tests to work with statistics
plaguss Oct 23, 2024
74f81ad
Return void list in case of no generations
plaguss Oct 24, 2024
8ff6e13
Update function to allow flatten inner list in values of dicts, and a…
plaguss Oct 24, 2024
d8f2a8b
Fix dummy magpie llm
plaguss Oct 24, 2024
70898da
Update tests for magpie
plaguss Oct 24, 2024
314e171
Create statistics entry in distilabel_metadata with the name of the step
plaguss Oct 24, 2024
e3e81d9
Update magpie code to work with the new llm.generate behaviour
plaguss Oct 24, 2024
241d899
Update tests with the llm generate output format
plaguss Oct 24, 2024
a657859
Merge branch 'develop' of https://github.com/argilla-io/distilabel in…
plaguss Oct 24, 2024
40e408d
Fix pending tests
plaguss Oct 24, 2024
6c2e1fd
Fix test failing with vllm version upgrade
plaguss Oct 24, 2024
d28a798
Another fix including tokenizer for our llm to work and to avoid outl…
plaguss Oct 24, 2024
1bc28ba
Fix dummy offline batch generation
plaguss Oct 25, 2024
13f42b0
Merge and fix conflict
plaguss Oct 28, 2024
edbea28
Compute tokens using the tokenizer if available
plaguss Oct 28, 2024
e97f901
Update docs to include references to the new outputs of the LLMs incl…
plaguss Oct 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 50 additions & 10 deletions docs/sections/how_to_guides/basic/llm/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,45 @@ LLM subclasses are designed to be used within a [Task][distilabel.steps.tasks.Ta
```python
from distilabel.models import InferenceEndpointsLLM

llm = InferenceEndpointsLLM(model="meta-llama/Meta-Llama-3.1-70B-Instruct")
llm = InferenceEndpointsLLM(
model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
tokenizer_id="meta-llama/Meta-Llama-3.1-70B-Instruct"
)
llm.load()

llm.generate_outputs(
inputs=[
[{"role": "user", "content": "What's the capital of Spain?"}],
],
)
# "The capital of Spain is Madrid."
# [
# {
# "generations": [
# "The capital of Spain is Madrid."
# ],
# "statistics": {
# "input_tokens": [
# 43
# ],
# "output_tokens": [
# 8
# ]
# }
# }
# ]
```

!!! NOTE
!!! Note
Always call the `LLM.load` or `Task.load` method when using LLMs standalone or as part of a `Task`. If using a `Pipeline`, this is done automatically in `Pipeline.run()`.

!!! Tip "New in version 1.5.0"
Since version `1.5.0` the LLM output is a list of dictionaries (one per item in the `inputs`),
each containing `generations`, that reports the text returned by the `LLM`, and a `statistics` field that will store statistics related to the `LLM` generation. Initially, this will include
`input_tokens` and `output_tokens` when available, which will be obtained via the API when available, or if a tokenizer is available for the model used, using the tokenizer for the model.
This data will be moved by the corresponding `Task` during the pipeline processing and moved to `distilabel_metadata` so we can operate on this data if we want, like for example computing the number of tokens per dataset.

To access to the previous result one just has to access to the generations in the resulting dictionary: `result[0]["generations"]`.

### Offline Batch Generation

By default, all `LLM`s will generate text in a synchronous manner i.e. send inputs using `generate_outputs` method that will get blocked until outputs are generated. There are some `LLM`s (such as [OpenAILLM][distilabel.models.llms.openai.OpenAILLM]) that implements what we denote as _offline batch generation_, which allows to send the inputs to the LLM-as-a-service which will generate the outputs asynchronously and give us a job id that we can use later to check the status and retrieve the generated outputs when they are ready. LLM-as-a-service platforms offers this feature as a way to save costs in exchange of waiting for the outputs to be generated.
Expand Down Expand Up @@ -56,7 +81,8 @@ llm.generate_outputs( # (4)
[{"role": "user", "content": "What's the capital of Spain?"}],
],
)
# "The capital of Spain is Madrid."
# [{'generations': ['The capital of Spain is Madrid.'],
# 'statistics': {'input_tokens': [13], 'output_tokens': [7]}}]
```

1. At first the `jobs_ids` attribute is `None`.
Expand All @@ -81,7 +107,8 @@ llm.generate_outputs(
[{"role": "user", "content": "What's the capital of Spain?"}],
],
)
# "The capital of Spain is Madrid."
# [{'generations': ['The capital of Spain is Madrid.'],
# 'statistics': {'input_tokens': [13], 'output_tokens': [7]}}]
```

### Within a Task
Expand All @@ -92,20 +119,30 @@ Pass the LLM as an argument to the [`Task`][distilabel.steps.tasks.Task], and th
from distilabel.models import OpenAILLM
from distilabel.steps.tasks import TextGeneration

llm = OpenAILLM(model="gpt-4")
llm = OpenAILLM(model="gpt-4o-mini")
task = TextGeneration(name="text_generation", llm=llm)

task.load()

next(task.process(inputs=[{"instruction": "What's the capital of Spain?"}]))
# [{'instruction': "What's the capital of Spain?", "generation": "The capital of Spain is Madrid."}]
# [{'instruction': "What's the capital of Spain?",
# 'generation': 'The capital of Spain is Madrid.',
# 'distilabel_metadata': {'raw_output_text_generation': 'The capital of Spain is Madrid.',
# 'raw_input_text_generation': [{'role': 'user',
# 'content': "What's the capital of Spain?"}],
# 'statistics_text_generation': {'input_tokens': 13, 'output_tokens': 7}},
# 'model_name': 'gpt-4o-mini'}]
```

!!! Note
As mentioned in *Working with LLMs* section, the generation of an LLM is automatically moved to `distilabel_metadata` to avoid interference with the common workflow, so the addition of the `statistics` it's an extra component available for the user, but nothing has to be changed in the
defined pipelines.

### Runtime Parameters

LLMs can have runtime parameters, such as `generation_kwargs`, provided via the `Pipeline.run()` method using the `params` argument.

!!! NOTE
!!! Note
Runtime parameters can differ between LLM subclasses, caused by the different functionalities offered by the LLM providers.

```python
Expand All @@ -122,7 +159,7 @@ with Pipeline(name="text-generation-pipeline") as pipeline:

text_generation = TextGeneration(
name="text_generation",
llm=OpenAILLM(model="gpt-4"),
llm=OpenAILLM(model="gpt-4o-mini"),
)

load_dataset >> text_generation
Expand Down Expand Up @@ -200,9 +237,12 @@ To create custom LLMs, subclass either [`LLM`][distilabel.models.llms.LLM] for s

`generate` and `agenerate` keyword arguments (but `input` and `num_generations`) are considered as `RuntimeParameter`s, so a value can be passed to them via the `parameters` argument of the `Pipeline.run` method.

!!! NOTE
!!! Note
To have the arguments of the `generate` and `agenerate` coerced to the expected types, the `validate_call` decorator is used, which will automatically coerce the arguments to the expected types, and raise an error if the types are not correct. This is specially useful when providing a value for an argument of `generate` or `agenerate` from the CLI, since the CLI will always provide the arguments as strings.

!!! Warning
Additional LLMs created in `distilabel` will have to take into account how the `statistics` are generated to properly include them in the LLM output.

## Available LLMs

[Our LLM gallery](../../../../components-gallery/llms/index.md) shows a list of the available LLMs that can be used within the `distilabel` library.
43 changes: 29 additions & 14 deletions docs/sections/how_to_guides/basic/task/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,26 +21,35 @@ task.load()

next(task.process([{"instruction": "What's the capital of Spain?"}]))
# [
# {
# 'instruction': "What's the capital of Spain?",
# 'generation': 'The capital of Spain is Madrid.',
# 'distilabel_metadata': {
# 'raw_output_text-generation': 'The capital of Spain is Madrid.',
# 'raw_input_text-generation': [
# {'role': 'user', 'content': "What's the capital of Spain?"}
# ]
# },
# 'model_name': 'meta-llama/Meta-Llama-3-70B-Instruct'
# }
# {
# "instruction": "What's the capital of Spain?",
# "generation": "The capital of Spain is Madrid.",
# "distilabel_metadata": {
# "raw_output_text-generation": "The capital of Spain is Madrid.",
# "raw_input_text-generation": [
# {
# "role": "user",
# "content": "What's the capital of Spain?"
# }
# ],
# "statistics_text-generation": { # (1)
# "input_tokens": 18,
# "output_tokens": 8
# }
# },
# "model_name": "meta-llama/Meta-Llama-3.1-8B-Instruct"
# }
# ]
```

!!! NOTE
1. The `LLMs` will not only return the text but also a `statistics_{STEP_NAME}` field that will contain statistics related to the generation. If available, at least the input and output tokens will be returned.

!!! Note
The `Step.load()` always needs to be executed when being used as a standalone. Within a pipeline, this will be done automatically during pipeline execution.

As shown above, the [`TextGeneration`][distilabel.steps.tasks.TextGeneration] task adds a `generation` based on the `instruction`.

!!! Tip
!!! Tip "New in version 1.2.0"
Since version `1.2.0`, we provide some metadata about the LLM call through `distilabel_metadata`. This can be disabled by setting the `add_raw_output` attribute to `False` when creating the task.

Additionally, since version `1.4.0`, the formatted input can also be included, which can be helpful when testing
Expand All @@ -57,9 +66,12 @@ As shown above, the [`TextGeneration`][distilabel.steps.tasks.TextGeneration] ta
)
```

!!! Tip "New in version 1.5.0"
Since version `1.5.0` `distilabel_metadata` includes a new `statistics` field out of the box. The generation from the LLM will not only contain the text, but also statistics associated with the text if available, like the input and output tokens. This field will be generated with `statistic_{STEP_NAME}` to avoid collisions between different steps in the pipeline, similar to how `raw_output_{STEP_NAME}` works.

### Task.print

!!! Info
!!! Info "New in version 1.4.0"
New since version `1.4.0`, [`Task.print`][distilabel.steps.tasks.base._Task.print] `Task.print` method.

The `Tasks` include a handy method to show what the prompt formatted for an `LLM` would look like, let's see an example with [`UltraFeedback`][distilabel.steps.tasks.ultrafeedback.UltraFeedback], but it applies to any other `Task`.
Expand Down Expand Up @@ -271,3 +283,6 @@ We can define a custom step by creating a new subclass of the [`Task`][distilabe
# Format the `LLM` output here
return {"output_field": output}
```

!!! Warning
Most `Tasks` reuse the `Task.process` method to process the generations, but if a new `Task` defines a custom `process` method, like happens for example with [`Magpie`][distilabel.steps.tasks.magpie.base.Magpie], one hast to deal with the `statistics` returned by the `LLM`.
29 changes: 22 additions & 7 deletions src/distilabel/models/llms/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,19 @@
from distilabel.mixins.runtime_parameters import RuntimeParameter
from distilabel.models.llms.base import AsyncLLM
from distilabel.models.llms.typing import GenerateOutput
from distilabel.models.llms.utils import prepare_output
from distilabel.steps.tasks.typing import (
FormattedInput,
InstructorStructuredOutputType,
)

if TYPE_CHECKING:
from typing import BaseModel

from anthropic import AsyncAnthropic
from anthropic.types import Message

from distilabel.llms.typing import LLMStatistics


_ANTHROPIC_API_KEY_ENV_VAR_NAME = "ANTHROPIC_API_KEY"
Expand Down Expand Up @@ -260,17 +266,26 @@ async def agenerate( # type: ignore
if structured_output:
kwargs = self._prepare_kwargs(kwargs, structured_output)

generations = []

completion = await self._aclient.messages.create(**kwargs) # type: ignore
completion: Union["Message", "BaseModel"] = await self._aclient.messages.create(
**kwargs
) # type: ignore
if structured_output:
generations.append(completion.model_dump_json())
return generations
# raw_response = completion._raw_response
return prepare_output(
[completion.model_dump_json()],
**self._get_llm_statistics(completion._raw_response),
)

if (content := completion.content[0].text) is None:
self._logger.warning(
f"Received no response using Anthropic client (model: '{self.model}')."
f" Finish reason was: {completion.stop_reason}"
)
generations.append(content)
return generations
return prepare_output([content], **self._get_llm_statistics(completion))

@staticmethod
def _get_llm_statistics(completion: "Message") -> "LLMStatistics":
return {
"input_tokens": [completion.usage.input_tokens],
"output_tokens": [completion.usage.output_tokens],
}
52 changes: 45 additions & 7 deletions src/distilabel/models/llms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import time
from abc import ABC, abstractmethod
from functools import cached_property
from itertools import islice
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union

from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
Expand All @@ -33,7 +34,6 @@
RuntimeParametersMixin,
)
from distilabel.utils.docstring import parse_google_docstring
from distilabel.utils.itertools import grouper
from distilabel.utils.notebook import in_notebook
from distilabel.utils.serialization import _Serializable

Expand Down Expand Up @@ -459,18 +459,16 @@ async def _agenerate(
)
for input in inputs
]
return await asyncio.gather(*tasks)
result = await asyncio.gather(*tasks)
return result

tasks = [
asyncio.create_task(self.agenerate(input=input, **kwargs))
for input in inputs
for _ in range(num_generations)
]
outputs = [outputs[0] for outputs in await asyncio.gather(*tasks)]
return [
list(group)
for group in grouper(outputs, n=num_generations, incomplete="ignore")
]
outputs = await asyncio.gather(*tasks)
return merge_responses(outputs, n=num_generations)

def generate(
self,
Expand Down Expand Up @@ -590,3 +588,43 @@ def _prepare_kwargs(
},
)
return arguments


def merge_responses(
responses: List[Dict[str, Any]], n: int = 1
) -> List[Dict[str, Any]]:
"""Helper function to group the responses from `LLM.agenerate` method according
to the number of generations requested.

Args:
responses: the responses from the `LLM.agenerate` method.
n: number of responses to group together. Defaults to 1.

Returns:
List of merged responses, where each merged response contains n generations
and their corresponding statistics.
"""
if not responses:
return []

def chunks(lst, n):
"""Yield successive n-sized chunks from lst."""
for i in range(0, len(lst), n):
yield list(islice(lst, i, i + n))

# Split responses into groups of size n
grouped_responses = list(chunks(responses, n))

result = []
for group in grouped_responses:
first = group[0]
merged = {
"generations": sum((r["generations"] for r in group), []),
"statistics": {
key: sum((r["statistics"][key] for r in group), [])
for key in first["statistics"]
},
}
result.append(merged)

return result
Loading