Skip to content

Commit

Permalink
Add chat support to text generation pipeline (#28945)
Browse files Browse the repository at this point in the history
* Add chat support to text generation pipeline

* Better handling of single elements

* Deprecate ConversationalPipeline

* stash commit

* Add missing add_special_tokens kwarg

* Update chat templating docs to refer to TextGenerationPipeline instead of ConversationalPipeline

* Add ✨TF✨ tests

* @require_tf

* Add type hint

* Add specific deprecation version

* Remove unnecessary do_sample

* Remove todo - the discrepancy has been resolved

* Update src/transformers/tokenization_utils_base.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/pipelines/text_generation.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
  • Loading branch information
2 people authored and Ita Zaporozhets committed May 14, 2024
1 parent b8d5896 commit ab01c8d
Show file tree
Hide file tree
Showing 5 changed files with 181 additions and 35 deletions.
29 changes: 14 additions & 15 deletions docs/source/en/chat_templating.md
Original file line number Diff line number Diff line change
Expand Up @@ -121,31 +121,30 @@ Arr, 'twas easy after all!

## Is there an automated pipeline for chat?

Yes, there is: [`ConversationalPipeline`]. This pipeline is designed to make it easy to use chat models. Let's try
the `Zephyr` example again, but this time using the pipeline:
Yes, there is! Our text generation pipelines support chat inputs, which makes it easy to use chat models. In the past,
we used to use a dedicated "ConversationalPipeline" class, but this has now been deprecated and its functionality
has been merged into the [`TextGenerationPipeline`]. Let's try the `Zephyr` example again, but this time using
a pipeline:

```python
from transformers import pipeline

pipe = pipeline("conversational", "HuggingFaceH4/zephyr-7b-beta")
pipe = pipeline("text-generation", "HuggingFaceH4/zephyr-7b-beta")
messages = [
{
"role": "system",
"content": "You are a friendly chatbot who always responds in the style of a pirate",
},
{"role": "user", "content": "How many helicopters can a human eat in one sitting?"},
]
print(pipe(messages))
print(pipe(messages, max_new_tokens=128)[0]['generated_text'][-1]) # Print the assistant's response
```

```text
Conversation id: 76d886a0-74bd-454e-9804-0467041a63dc
system: You are a friendly chatbot who always responds in the style of a pirate
user: How many helicopters can a human eat in one sitting?
assistant: Matey, I'm afraid I must inform ye that humans cannot eat helicopters. Helicopters are not food, they are flying machines. Food is meant to be eaten, like a hearty plate o' grog, a savory bowl o' stew, or a delicious loaf o' bread. But helicopters, they be for transportin' and movin' around, not for eatin'. So, I'd say none, me hearties. None at all.
{'role': 'assistant', 'content': "Matey, I'm afraid I must inform ye that humans cannot eat helicopters. Helicopters are not food, they are flying machines. Food is meant to be eaten, like a hearty plate o' grog, a savory bowl o' stew, or a delicious loaf o' bread. But helicopters, they be for transportin' and movin' around, not for eatin'. So, I'd say none, me hearties. None at all."}
```

[`ConversationalPipeline`] will take care of all the details of tokenization and calling `apply_chat_template` for you -
The pipeline will take care of all the details of tokenization and calling `apply_chat_template` for you -
once the model has a chat template, all you need to do is initialize the pipeline and pass it the list of messages!

## What are "generation prompts"?
Expand Down Expand Up @@ -191,7 +190,7 @@ Can I ask a question?<|im_end|>
Note that this time, we've added the tokens that indicate the start of a bot response. This ensures that when the model
generates text it will write a bot response instead of doing something unexpected, like continuing the user's
message. Remember, chat models are still just language models - they're trained to continue text, and chat is just a
special kind of text to them! You need to guide them with the appropriate control tokens so they know what they're
special kind of text to them! You need to guide them with appropriate control tokens, so they know what they're
supposed to be doing.

Not all models require generation prompts. Some models, like BlenderBot and LLaMA, don't have any
Expand Down Expand Up @@ -340,8 +339,8 @@ tokenizer.chat_template = template # Set the new template
tokenizer.push_to_hub("model_name") # Upload your new template to the Hub!
```

The method [`~PreTrainedTokenizer.apply_chat_template`] which uses your chat template is called by the [`ConversationalPipeline`] class, so
once you set the correct chat template, your model will automatically become compatible with [`ConversationalPipeline`].
The method [`~PreTrainedTokenizer.apply_chat_template`] which uses your chat template is called by the [`TextGenerationPipeline`] class, so
once you set the correct chat template, your model will automatically become compatible with [`TextGenerationPipeline`].

<Tip>
If you're fine-tuning a model for chat, in addition to setting a chat template, you should probably add any new chat
Expand All @@ -356,7 +355,7 @@ template. This will ensure that text generation tools can correctly figure out w

Before the introduction of chat templates, chat handling was hardcoded at the model class level. For backwards
compatibility, we have retained this class-specific handling as default templates, also set at the class level. If a
model does not have a chat template set, but there is a default template for its model class, the `ConversationalPipeline`
model does not have a chat template set, but there is a default template for its model class, the `TextGenerationPipeline`
class and methods like `apply_chat_template` will use the class template instead. You can find out what the default
template for your tokenizer is by checking the `tokenizer.default_chat_template` attribute.

Expand Down Expand Up @@ -407,7 +406,7 @@ I'm doing great!<|im_end|>
```

The "user", "system" and "assistant" roles are the standard for chat, and we recommend using them when it makes sense,
particularly if you want your model to operate well with [`ConversationalPipeline`]. However, you are not limited
particularly if you want your model to operate well with [`TextGenerationPipeline`]. However, you are not limited
to these roles - templating is extremely flexible, and any string can be a role.

### I want to add some chat templates! How should I get started?
Expand All @@ -418,7 +417,7 @@ not the model owner - if you're using a model with an empty chat template, or on
template, please open a [pull request](https://huggingface.co/docs/hub/repositories-pull-requests-discussions) to the model repository so that this attribute can be set properly!

Once the attribute is set, that's it, you're done! `tokenizer.apply_chat_template` will now work correctly for that
model, which means it is also automatically supported in places like `ConversationalPipeline`!
model, which means it is also automatically supported in places like `TextGenerationPipeline`!

By ensuring that models have this attribute, we can make sure that the whole community gets to use the full power of
open-source models. Formatting mismatches have been haunting the field and silently harming performance for too long -
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/pipelines/conversational.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import uuid
import warnings
from typing import Any, Dict, List, Union

from ..utils import add_end_docstrings, is_tf_available, is_torch_available, logging
Expand Down Expand Up @@ -232,6 +233,10 @@ class ConversationalPipeline(Pipeline):
"""

def __init__(self, *args, **kwargs):
warnings.warn(
"`ConversationalPipeline` is now deprecated, and the functionality has been moved to the standard `text-generation` pipeline, which now accepts lists of message dicts as well as strings. This class will be removed in v4.42.",
DeprecationWarning,
)
super().__init__(*args, **kwargs)
if self.tokenizer.pad_token_id is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
Expand Down
58 changes: 47 additions & 11 deletions src/transformers/pipelines/text_generation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import enum
import warnings
from typing import Dict

from ..utils import add_end_docstrings, is_tf_available, is_torch_available
from .base import Pipeline, build_pipeline_init_args
Expand All @@ -20,11 +21,24 @@ class ReturnType(enum.Enum):
FULL_TEXT = 2


class Chat:
"""This class is intended to just be used internally in this pipeline and not exposed to users. We convert chats
to this format because the rest of the pipeline code tends to assume that lists of messages are
actually a batch of samples rather than messages in the same conversation."""

def __init__(self, messages: Dict):
for message in messages:
if not ("role" in message and "content" in message):
raise ValueError("When passing chat dicts as input, each dict must have a 'role' and 'content' key.")
self.messages = messages


@add_end_docstrings(build_pipeline_init_args(has_tokenizer=True))
class TextGenerationPipeline(Pipeline):
"""
Language generation pipeline using any `ModelWithLMHead`. This pipeline predicts the words that will follow a
specified text prompt.
specified text prompt. It can also accept one or more chats. Each chat takes the form of a list of dicts,
where each dict contains "role" and "content" keys.
Example:
Expand Down Expand Up @@ -216,7 +230,15 @@ def __call__(self, text_inputs, **kwargs):
- **generated_token_ids** (`torch.Tensor` or `tf.Tensor`, present when `return_tensors=True`) -- The token
ids of the generated text.
"""
return super().__call__(text_inputs, **kwargs)
if isinstance(text_inputs, (list, tuple)) and isinstance(text_inputs[0], (list, tuple, dict)):
# We have one or more prompts in list-of-dicts format, so this is chat mode
if isinstance(text_inputs[0], dict):
return super().__call__(Chat(text_inputs), **kwargs)
else:
chats = [Chat(chat) for chat in text_inputs] # 🐈 🐈 🐈
return super().__call__(chats, **kwargs)
else:
return super().__call__(text_inputs, **kwargs)

def preprocess(
self,
Expand All @@ -229,14 +251,25 @@ def preprocess(
max_length=None,
**generate_kwargs,
):
inputs = self.tokenizer(
prefix + prompt_text,
return_tensors=self.framework,
truncation=truncation,
padding=padding,
max_length=max_length,
add_special_tokens=add_special_tokens,
)
if isinstance(prompt_text, Chat):
inputs = self.tokenizer.apply_chat_template(
prompt_text.messages,
truncation=truncation,
padding=padding,
max_length=max_length,
add_generation_prompt=True,
return_dict=True,
return_tensors=self.framework,
)
else:
inputs = self.tokenizer(
prefix + prompt_text,
truncation=truncation,
padding=padding,
max_length=max_length,
add_special_tokens=add_special_tokens,
return_tensors=self.framework,
)
inputs["prompt_text"] = prompt_text

if handle_long_generation == "hole":
Expand Down Expand Up @@ -331,7 +364,10 @@ def postprocess(self, model_outputs, return_type=ReturnType.FULL_TEXT, clean_up_

all_text = text[prompt_length:]
if return_type == ReturnType.FULL_TEXT:
all_text = prompt_text + all_text
if isinstance(prompt_text, str):
all_text = prompt_text + all_text
elif isinstance(prompt_text, Chat):
all_text = prompt_text.messages + [{"role": "assistant", "content": all_text}]

record = {"generated_text": all_text}
records.append(record)
Expand Down
32 changes: 23 additions & 9 deletions src/transformers/tokenization_utils_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1685,6 +1685,7 @@ def apply_chat_template(
truncation: bool = False,
max_length: Optional[int] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
return_dict: bool = False,
**tokenizer_kwargs,
) -> Union[str, List[int]]:
"""
Expand Down Expand Up @@ -1718,6 +1719,8 @@ def apply_chat_template(
- `'pt'`: Return PyTorch `torch.Tensor` objects.
- `'np'`: Return NumPy `np.ndarray` objects.
- `'jax'`: Return JAX `jnp.ndarray` objects.
return_dict (`bool`, *optional*, defaults to `False`):
Whether to return a dictionary with named outputs. Has no effect if tokenize is `False`.
**tokenizer_kwargs: Additional kwargs to pass to the tokenizer.
Returns:
Expand Down Expand Up @@ -1746,15 +1749,26 @@ def apply_chat_template(
if padding is True:
padding = "max_length" # There's only one sequence here, so "longest" makes no sense
if tokenize:
return self.encode(
rendered,
add_special_tokens=False,
padding=padding,
truncation=truncation,
max_length=max_length,
return_tensors=return_tensors,
**tokenizer_kwargs,
)
if return_dict:
return self(
rendered,
padding=padding,
truncation=truncation,
max_length=max_length,
add_special_tokens=False,
return_tensors=return_tensors,
**tokenizer_kwargs,
)
else:
return self.encode(
rendered,
padding=padding,
truncation=truncation,
max_length=max_length,
add_special_tokens=False,
return_tensors=return_tensors,
**tokenizer_kwargs,
)
else:
return rendered

Expand Down
92 changes: 92 additions & 0 deletions tests/pipelines/test_pipelines_text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,52 @@ def test_small_model_pt(self):
],
)

@require_torch
def test_small_chat_model_pt(self):
text_generator = pipeline(
task="text-generation", model="rocketknight1/tiny-gpt2-with-chatml-template", framework="pt"
)
# Using `do_sample=False` to force deterministic output
chat1 = [
{"role": "system", "content": "This is a system message."},
{"role": "user", "content": "This is a test"},
{"role": "assistant", "content": "This is a reply"},
]
chat2 = [
{"role": "system", "content": "This is a system message."},
{"role": "user", "content": "This is a second test"},
{"role": "assistant", "content": "This is a reply"},
]
outputs = text_generator(chat1, do_sample=False, max_new_tokens=10)
expected_chat1 = chat1 + [
{
"role": "assistant",
"content": " factors factors factors factors factors factors factors factors factors factors",
}
]
self.assertEqual(
outputs,
[
{"generated_text": expected_chat1},
],
)

outputs = text_generator([chat1, chat2], do_sample=False, max_new_tokens=10)
expected_chat2 = chat2 + [
{
"role": "assistant",
"content": " factors factors factors factors factors factors factors factors factors factors",
}
]

self.assertEqual(
outputs,
[
[{"generated_text": expected_chat1}],
[{"generated_text": expected_chat2}],
],
)

@require_tf
def test_small_model_tf(self):
text_generator = pipeline(task="text-generation", model="sshleifer/tiny-ctrl", framework="tf")
Expand Down Expand Up @@ -172,6 +218,52 @@ def test_small_model_tf(self):
],
)

@require_tf
def test_small_chat_model_tf(self):
text_generator = pipeline(
task="text-generation", model="rocketknight1/tiny-gpt2-with-chatml-template", framework="tf"
)
# Using `do_sample=False` to force deterministic output
chat1 = [
{"role": "system", "content": "This is a system message."},
{"role": "user", "content": "This is a test"},
{"role": "assistant", "content": "This is a reply"},
]
chat2 = [
{"role": "system", "content": "This is a system message."},
{"role": "user", "content": "This is a second test"},
{"role": "assistant", "content": "This is a reply"},
]
outputs = text_generator(chat1, do_sample=False, max_new_tokens=10)
expected_chat1 = chat1 + [
{
"role": "assistant",
"content": " factors factors factors factors factors factors factors factors factors factors",
}
]
self.assertEqual(
outputs,
[
{"generated_text": expected_chat1},
],
)

outputs = text_generator([chat1, chat2], do_sample=False, max_new_tokens=10)
expected_chat2 = chat2 + [
{
"role": "assistant",
"content": " factors factors factors factors factors factors factors factors factors factors",
}
]

self.assertEqual(
outputs,
[
[{"generated_text": expected_chat1}],
[{"generated_text": expected_chat2}],
],
)

def get_test_pipeline(self, model, tokenizer, processor):
text_generator = TextGenerationPipeline(model=model, tokenizer=tokenizer)
return text_generator, ["This is a test", "Another test"]
Expand Down

0 comments on commit ab01c8d

Please sign in to comment.