diff --git a/docs/source/en/chat_templating.md b/docs/source/en/chat_templating.md index 87f95e1ebd1986..94048f88acaa47 100644 --- a/docs/source/en/chat_templating.md +++ b/docs/source/en/chat_templating.md @@ -121,13 +121,15 @@ Arr, 'twas easy after all! ## Is there an automated pipeline for chat? -Yes, there is: [`ConversationalPipeline`]. This pipeline is designed to make it easy to use chat models. Let's try -the `Zephyr` example again, but this time using the pipeline: +Yes, there is! Our text generation pipelines support chat inputs, which makes it easy to use chat models. In the past, +we used to use a dedicated "ConversationalPipeline" class, but this has now been deprecated and its functionality +has been merged into the [`TextGenerationPipeline`]. Let's try the `Zephyr` example again, but this time using +a pipeline: ```python from transformers import pipeline -pipe = pipeline("conversational", "HuggingFaceH4/zephyr-7b-beta") +pipe = pipeline("text-generation", "HuggingFaceH4/zephyr-7b-beta") messages = [ { "role": "system", @@ -135,17 +137,14 @@ messages = [ }, {"role": "user", "content": "How many helicopters can a human eat in one sitting?"}, ] -print(pipe(messages)) +print(pipe(messages, max_new_tokens=128)[0]['generated_text'][-1]) # Print the assistant's response ``` ```text -Conversation id: 76d886a0-74bd-454e-9804-0467041a63dc -system: You are a friendly chatbot who always responds in the style of a pirate -user: How many helicopters can a human eat in one sitting? -assistant: Matey, I'm afraid I must inform ye that humans cannot eat helicopters. Helicopters are not food, they are flying machines. Food is meant to be eaten, like a hearty plate o' grog, a savory bowl o' stew, or a delicious loaf o' bread. But helicopters, they be for transportin' and movin' around, not for eatin'. So, I'd say none, me hearties. None at all. +{'role': 'assistant', 'content': "Matey, I'm afraid I must inform ye that humans cannot eat helicopters. Helicopters are not food, they are flying machines. Food is meant to be eaten, like a hearty plate o' grog, a savory bowl o' stew, or a delicious loaf o' bread. But helicopters, they be for transportin' and movin' around, not for eatin'. So, I'd say none, me hearties. None at all."} ``` -[`ConversationalPipeline`] will take care of all the details of tokenization and calling `apply_chat_template` for you - +The pipeline will take care of all the details of tokenization and calling `apply_chat_template` for you - once the model has a chat template, all you need to do is initialize the pipeline and pass it the list of messages! ## What are "generation prompts"? @@ -191,7 +190,7 @@ Can I ask a question?<|im_end|> Note that this time, we've added the tokens that indicate the start of a bot response. This ensures that when the model generates text it will write a bot response instead of doing something unexpected, like continuing the user's message. Remember, chat models are still just language models - they're trained to continue text, and chat is just a -special kind of text to them! You need to guide them with the appropriate control tokens so they know what they're +special kind of text to them! You need to guide them with appropriate control tokens, so they know what they're supposed to be doing. Not all models require generation prompts. Some models, like BlenderBot and LLaMA, don't have any @@ -340,8 +339,8 @@ tokenizer.chat_template = template # Set the new template tokenizer.push_to_hub("model_name") # Upload your new template to the Hub! ``` -The method [`~PreTrainedTokenizer.apply_chat_template`] which uses your chat template is called by the [`ConversationalPipeline`] class, so -once you set the correct chat template, your model will automatically become compatible with [`ConversationalPipeline`]. +The method [`~PreTrainedTokenizer.apply_chat_template`] which uses your chat template is called by the [`TextGenerationPipeline`] class, so +once you set the correct chat template, your model will automatically become compatible with [`TextGenerationPipeline`]. If you're fine-tuning a model for chat, in addition to setting a chat template, you should probably add any new chat @@ -356,7 +355,7 @@ template. This will ensure that text generation tools can correctly figure out w Before the introduction of chat templates, chat handling was hardcoded at the model class level. For backwards compatibility, we have retained this class-specific handling as default templates, also set at the class level. If a -model does not have a chat template set, but there is a default template for its model class, the `ConversationalPipeline` +model does not have a chat template set, but there is a default template for its model class, the `TextGenerationPipeline` class and methods like `apply_chat_template` will use the class template instead. You can find out what the default template for your tokenizer is by checking the `tokenizer.default_chat_template` attribute. @@ -407,7 +406,7 @@ I'm doing great!<|im_end|> ``` The "user", "system" and "assistant" roles are the standard for chat, and we recommend using them when it makes sense, -particularly if you want your model to operate well with [`ConversationalPipeline`]. However, you are not limited +particularly if you want your model to operate well with [`TextGenerationPipeline`]. However, you are not limited to these roles - templating is extremely flexible, and any string can be a role. ### I want to add some chat templates! How should I get started? @@ -418,7 +417,7 @@ not the model owner - if you're using a model with an empty chat template, or on template, please open a [pull request](https://huggingface.co/docs/hub/repositories-pull-requests-discussions) to the model repository so that this attribute can be set properly! Once the attribute is set, that's it, you're done! `tokenizer.apply_chat_template` will now work correctly for that -model, which means it is also automatically supported in places like `ConversationalPipeline`! +model, which means it is also automatically supported in places like `TextGenerationPipeline`! By ensuring that models have this attribute, we can make sure that the whole community gets to use the full power of open-source models. Formatting mismatches have been haunting the field and silently harming performance for too long - diff --git a/src/transformers/pipelines/conversational.py b/src/transformers/pipelines/conversational.py index 3d42363f198357..ca091074effb51 100644 --- a/src/transformers/pipelines/conversational.py +++ b/src/transformers/pipelines/conversational.py @@ -1,4 +1,5 @@ import uuid +import warnings from typing import Any, Dict, List, Union from ..utils import add_end_docstrings, is_tf_available, is_torch_available, logging @@ -232,6 +233,10 @@ class ConversationalPipeline(Pipeline): """ def __init__(self, *args, **kwargs): + warnings.warn( + "`ConversationalPipeline` is now deprecated, and the functionality has been moved to the standard `text-generation` pipeline, which now accepts lists of message dicts as well as strings. This class will be removed in v4.42.", + DeprecationWarning, + ) super().__init__(*args, **kwargs) if self.tokenizer.pad_token_id is None: self.tokenizer.pad_token = self.tokenizer.eos_token diff --git a/src/transformers/pipelines/text_generation.py b/src/transformers/pipelines/text_generation.py index ce7e180601f97e..df460a9334b1ca 100644 --- a/src/transformers/pipelines/text_generation.py +++ b/src/transformers/pipelines/text_generation.py @@ -1,5 +1,6 @@ import enum import warnings +from typing import Dict from ..utils import add_end_docstrings, is_tf_available, is_torch_available from .base import Pipeline, build_pipeline_init_args @@ -20,11 +21,24 @@ class ReturnType(enum.Enum): FULL_TEXT = 2 +class Chat: + """This class is intended to just be used internally in this pipeline and not exposed to users. We convert chats + to this format because the rest of the pipeline code tends to assume that lists of messages are + actually a batch of samples rather than messages in the same conversation.""" + + def __init__(self, messages: Dict): + for message in messages: + if not ("role" in message and "content" in message): + raise ValueError("When passing chat dicts as input, each dict must have a 'role' and 'content' key.") + self.messages = messages + + @add_end_docstrings(build_pipeline_init_args(has_tokenizer=True)) class TextGenerationPipeline(Pipeline): """ Language generation pipeline using any `ModelWithLMHead`. This pipeline predicts the words that will follow a - specified text prompt. + specified text prompt. It can also accept one or more chats. Each chat takes the form of a list of dicts, + where each dict contains "role" and "content" keys. Example: @@ -216,7 +230,15 @@ def __call__(self, text_inputs, **kwargs): - **generated_token_ids** (`torch.Tensor` or `tf.Tensor`, present when `return_tensors=True`) -- The token ids of the generated text. """ - return super().__call__(text_inputs, **kwargs) + if isinstance(text_inputs, (list, tuple)) and isinstance(text_inputs[0], (list, tuple, dict)): + # We have one or more prompts in list-of-dicts format, so this is chat mode + if isinstance(text_inputs[0], dict): + return super().__call__(Chat(text_inputs), **kwargs) + else: + chats = [Chat(chat) for chat in text_inputs] # 🐈 🐈 🐈 + return super().__call__(chats, **kwargs) + else: + return super().__call__(text_inputs, **kwargs) def preprocess( self, @@ -229,14 +251,25 @@ def preprocess( max_length=None, **generate_kwargs, ): - inputs = self.tokenizer( - prefix + prompt_text, - return_tensors=self.framework, - truncation=truncation, - padding=padding, - max_length=max_length, - add_special_tokens=add_special_tokens, - ) + if isinstance(prompt_text, Chat): + inputs = self.tokenizer.apply_chat_template( + prompt_text.messages, + truncation=truncation, + padding=padding, + max_length=max_length, + add_generation_prompt=True, + return_dict=True, + return_tensors=self.framework, + ) + else: + inputs = self.tokenizer( + prefix + prompt_text, + truncation=truncation, + padding=padding, + max_length=max_length, + add_special_tokens=add_special_tokens, + return_tensors=self.framework, + ) inputs["prompt_text"] = prompt_text if handle_long_generation == "hole": @@ -331,7 +364,10 @@ def postprocess(self, model_outputs, return_type=ReturnType.FULL_TEXT, clean_up_ all_text = text[prompt_length:] if return_type == ReturnType.FULL_TEXT: - all_text = prompt_text + all_text + if isinstance(prompt_text, str): + all_text = prompt_text + all_text + elif isinstance(prompt_text, Chat): + all_text = prompt_text.messages + [{"role": "assistant", "content": all_text}] record = {"generated_text": all_text} records.append(record) diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 3afa244740f178..3184ed6a150c3e 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -1685,6 +1685,7 @@ def apply_chat_template( truncation: bool = False, max_length: Optional[int] = None, return_tensors: Optional[Union[str, TensorType]] = None, + return_dict: bool = False, **tokenizer_kwargs, ) -> Union[str, List[int]]: """ @@ -1718,6 +1719,8 @@ def apply_chat_template( - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return NumPy `np.ndarray` objects. - `'jax'`: Return JAX `jnp.ndarray` objects. + return_dict (`bool`, *optional*, defaults to `False`): + Whether to return a dictionary with named outputs. Has no effect if tokenize is `False`. **tokenizer_kwargs: Additional kwargs to pass to the tokenizer. Returns: @@ -1746,15 +1749,26 @@ def apply_chat_template( if padding is True: padding = "max_length" # There's only one sequence here, so "longest" makes no sense if tokenize: - return self.encode( - rendered, - add_special_tokens=False, - padding=padding, - truncation=truncation, - max_length=max_length, - return_tensors=return_tensors, - **tokenizer_kwargs, - ) + if return_dict: + return self( + rendered, + padding=padding, + truncation=truncation, + max_length=max_length, + add_special_tokens=False, + return_tensors=return_tensors, + **tokenizer_kwargs, + ) + else: + return self.encode( + rendered, + padding=padding, + truncation=truncation, + max_length=max_length, + add_special_tokens=False, + return_tensors=return_tensors, + **tokenizer_kwargs, + ) else: return rendered diff --git a/tests/pipelines/test_pipelines_text_generation.py b/tests/pipelines/test_pipelines_text_generation.py index 0500e3b0353c4a..766f2a462a1930 100644 --- a/tests/pipelines/test_pipelines_text_generation.py +++ b/tests/pipelines/test_pipelines_text_generation.py @@ -131,6 +131,52 @@ def test_small_model_pt(self): ], ) + @require_torch + def test_small_chat_model_pt(self): + text_generator = pipeline( + task="text-generation", model="rocketknight1/tiny-gpt2-with-chatml-template", framework="pt" + ) + # Using `do_sample=False` to force deterministic output + chat1 = [ + {"role": "system", "content": "This is a system message."}, + {"role": "user", "content": "This is a test"}, + {"role": "assistant", "content": "This is a reply"}, + ] + chat2 = [ + {"role": "system", "content": "This is a system message."}, + {"role": "user", "content": "This is a second test"}, + {"role": "assistant", "content": "This is a reply"}, + ] + outputs = text_generator(chat1, do_sample=False, max_new_tokens=10) + expected_chat1 = chat1 + [ + { + "role": "assistant", + "content": " factors factors factors factors factors factors factors factors factors factors", + } + ] + self.assertEqual( + outputs, + [ + {"generated_text": expected_chat1}, + ], + ) + + outputs = text_generator([chat1, chat2], do_sample=False, max_new_tokens=10) + expected_chat2 = chat2 + [ + { + "role": "assistant", + "content": " factors factors factors factors factors factors factors factors factors factors", + } + ] + + self.assertEqual( + outputs, + [ + [{"generated_text": expected_chat1}], + [{"generated_text": expected_chat2}], + ], + ) + @require_tf def test_small_model_tf(self): text_generator = pipeline(task="text-generation", model="sshleifer/tiny-ctrl", framework="tf") @@ -172,6 +218,52 @@ def test_small_model_tf(self): ], ) + @require_tf + def test_small_chat_model_tf(self): + text_generator = pipeline( + task="text-generation", model="rocketknight1/tiny-gpt2-with-chatml-template", framework="tf" + ) + # Using `do_sample=False` to force deterministic output + chat1 = [ + {"role": "system", "content": "This is a system message."}, + {"role": "user", "content": "This is a test"}, + {"role": "assistant", "content": "This is a reply"}, + ] + chat2 = [ + {"role": "system", "content": "This is a system message."}, + {"role": "user", "content": "This is a second test"}, + {"role": "assistant", "content": "This is a reply"}, + ] + outputs = text_generator(chat1, do_sample=False, max_new_tokens=10) + expected_chat1 = chat1 + [ + { + "role": "assistant", + "content": " factors factors factors factors factors factors factors factors factors factors", + } + ] + self.assertEqual( + outputs, + [ + {"generated_text": expected_chat1}, + ], + ) + + outputs = text_generator([chat1, chat2], do_sample=False, max_new_tokens=10) + expected_chat2 = chat2 + [ + { + "role": "assistant", + "content": " factors factors factors factors factors factors factors factors factors factors", + } + ] + + self.assertEqual( + outputs, + [ + [{"generated_text": expected_chat1}], + [{"generated_text": expected_chat2}], + ], + ) + def get_test_pipeline(self, model, tokenizer, processor): text_generator = TextGenerationPipeline(model=model, tokenizer=tokenizer) return text_generator, ["This is a test", "Another test"]