From e7c4172b04ef21900cca814c0486b69013936652 Mon Sep 17 00:00:00 2001 From: Matt Date: Fri, 9 Feb 2024 17:27:45 +0000 Subject: [PATCH 01/14] Add chat support to text generation pipeline --- src/transformers/pipelines/text_generation.py | 57 +++++++++++++++---- src/transformers/tokenization_utils_base.py | 31 +++++++--- .../test_pipelines_text_generation.py | 46 +++++++++++++++ 3 files changed, 114 insertions(+), 20 deletions(-) diff --git a/src/transformers/pipelines/text_generation.py b/src/transformers/pipelines/text_generation.py index 839395d7fe0528..16e9d412900f0f 100644 --- a/src/transformers/pipelines/text_generation.py +++ b/src/transformers/pipelines/text_generation.py @@ -20,11 +20,24 @@ class ReturnType(enum.Enum): FULL_TEXT = 2 +class Chat: + """This class is intended to just be used internally in this pipeline and not exposed to users. We convert chats + to this format because the rest of the pipeline code tends to assume that lists of messages are + actually a batch of samples rather than messages in the same conversation.""" + + def __init__(self, messages): + for message in messages: + if not ("role" in message and "content" in message): + raise ValueError("When passing chat dicts as input, each chat must have a 'role' and 'content' key.") + self.messages = messages + + @add_end_docstrings(build_pipeline_init_args(has_tokenizer=True)) class TextGenerationPipeline(Pipeline): """ Language generation pipeline using any `ModelWithLMHead`. This pipeline predicts the words that will follow a - specified text prompt. + specified text prompt. It can also accept one or more chats. Each chat takes the form of a list of dicts, + where each dict contains "role" and "content" keys. Example: @@ -216,7 +229,15 @@ def __call__(self, text_inputs, **kwargs): - **generated_token_ids** (`torch.Tensor` or `tf.Tensor`, present when `return_tensors=True`) -- The token ids of the generated text. """ - return super().__call__(text_inputs, **kwargs) + if isinstance(text_inputs, (list, tuple)) and isinstance(text_inputs[0], (list, tuple, dict)): + # We have one or more prompts in list-of-dicts format, so this is chat mode + chats = text_inputs + if isinstance(chats[0], dict): + chats = [chats] + chats = [Chat(chat) for chat in chats] # 🐈 🐈 🐈 + return super().__call__(chats, **kwargs) + else: + return super().__call__(text_inputs, **kwargs) def preprocess( self, @@ -229,14 +250,25 @@ def preprocess( max_length=None, **generate_kwargs, ): - inputs = self.tokenizer( - prefix + prompt_text, - return_tensors=self.framework, - truncation=truncation, - padding=padding, - max_length=max_length, - add_special_tokens=add_special_tokens, - ) + if isinstance(prompt_text, Chat): + inputs = self.tokenizer.apply_chat_template( + prompt_text.messages, + padding=padding, + add_generation_prompt=True, + return_tensors=self.framework, + max_length=max_length, + truncation=truncation, + return_dict=True, + ) + else: + inputs = self.tokenizer( + prefix + prompt_text, + return_tensors=self.framework, + truncation=truncation, + padding=padding, + max_length=max_length, + add_special_tokens=add_special_tokens, + ) inputs["prompt_text"] = prompt_text if handle_long_generation == "hole": @@ -331,7 +363,10 @@ def postprocess(self, model_outputs, return_type=ReturnType.FULL_TEXT, clean_up_ all_text = text[prompt_length:] if return_type == ReturnType.FULL_TEXT: - all_text = prompt_text + all_text + if isinstance(prompt_text, str): + all_text = prompt_text + all_text + elif isinstance(prompt_text, Chat): + all_text = prompt_text.messages + [{"role": "assistant", "content": all_text}] record = {"generated_text": all_text} records.append(record) diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index d389af676fd0c8..13f422c1e07f8a 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -1685,6 +1685,7 @@ def apply_chat_template( truncation: bool = False, max_length: Optional[int] = None, return_tensors: Optional[Union[str, TensorType]] = None, + return_dict: bool = False, **tokenizer_kwargs, ) -> Union[str, List[int]]: """ @@ -1718,6 +1719,8 @@ def apply_chat_template( - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return NumPy `np.ndarray` objects. - `'jax'`: Return JAX `jnp.ndarray` objects. + return_dict (`bool`, *optional*, defaults to `False`): + Whether to return a dictionary with named outputs. Has no effect if tokenize is `False`. **tokenizer_kwargs: Additional kwargs to pass to the tokenizer. Returns: @@ -1746,15 +1749,25 @@ def apply_chat_template( if padding is True: padding = "max_length" # There's only one sequence here, so "longest" makes no sense if tokenize: - return self.encode( - rendered, - add_special_tokens=False, - padding=padding, - truncation=truncation, - max_length=max_length, - return_tensors=return_tensors, - **tokenizer_kwargs, - ) + if return_dict: + return self( + rendered, + padding=padding, + truncation=truncation, + max_length=max_length, + return_tensors=return_tensors, + **tokenizer_kwargs, + ) + else: + return self.encode( + rendered, + add_special_tokens=False, + padding=padding, + truncation=truncation, + max_length=max_length, + return_tensors=return_tensors, + **tokenizer_kwargs, + ) else: return rendered diff --git a/tests/pipelines/test_pipelines_text_generation.py b/tests/pipelines/test_pipelines_text_generation.py index 0500e3b0353c4a..0de1d94ccda5a2 100644 --- a/tests/pipelines/test_pipelines_text_generation.py +++ b/tests/pipelines/test_pipelines_text_generation.py @@ -131,6 +131,52 @@ def test_small_model_pt(self): ], ) + @require_torch + def test_small_chat_model_pt(self): + text_generator = pipeline( + task="text-generation", model="rocketknight1/tiny-gpt2-with-chatml-template", framework="pt" + ) + # Using `do_sample=False` to force deterministic output + chat1 = [ + {"role": "system", "content": "This is a system message."}, + {"role": "user", "content": "This is a test"}, + {"role": "assistant", "content": "This is a reply"}, + ] + chat2 = [ + {"role": "system", "content": "This is a system message."}, + {"role": "user", "content": "This is a second test"}, + {"role": "assistant", "content": "This is a reply"}, + ] + outputs = text_generator(chat1, do_sample=False, max_new_tokens=10) + expected_chat1 = chat1 + [ + { + "role": "assistant", + "content": " factors factors factors factors factors factors factors factors factors factors", + } + ] + self.assertEqual( + outputs, + [ + [{"generated_text": expected_chat1}], + ], + ) + + outputs = text_generator([chat1, chat2], do_sample=False, max_new_tokens=10) + expected_chat2 = chat2 + [ + { + "role": "assistant", + "content": " factors factors factors factors factors factors factors factors factors factors", + } + ] + + self.assertEqual( + outputs, + [ + [{"generated_text": expected_chat1}], + [{"generated_text": expected_chat2}], + ], + ) + @require_tf def test_small_model_tf(self): text_generator = pipeline(task="text-generation", model="sshleifer/tiny-ctrl", framework="tf") From a2e190ab963bb6a7baa4a41e8c2a999db18c795a Mon Sep 17 00:00:00 2001 From: Matt Date: Fri, 9 Feb 2024 18:29:49 +0000 Subject: [PATCH 02/14] Better handling of single elements --- src/transformers/pipelines/text_generation.py | 10 +++++----- tests/pipelines/test_pipelines_text_generation.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/transformers/pipelines/text_generation.py b/src/transformers/pipelines/text_generation.py index 16e9d412900f0f..0ce1f54f81c415 100644 --- a/src/transformers/pipelines/text_generation.py +++ b/src/transformers/pipelines/text_generation.py @@ -231,11 +231,11 @@ def __call__(self, text_inputs, **kwargs): """ if isinstance(text_inputs, (list, tuple)) and isinstance(text_inputs[0], (list, tuple, dict)): # We have one or more prompts in list-of-dicts format, so this is chat mode - chats = text_inputs - if isinstance(chats[0], dict): - chats = [chats] - chats = [Chat(chat) for chat in chats] # 🐈 🐈 🐈 - return super().__call__(chats, **kwargs) + if isinstance(text_inputs[0], dict): + return super().__call__(Chat(text_inputs), **kwargs) + else: + chats = [Chat(chat) for chat in text_inputs] # 🐈 🐈 🐈 + return super().__call__(chats, **kwargs) else: return super().__call__(text_inputs, **kwargs) diff --git a/tests/pipelines/test_pipelines_text_generation.py b/tests/pipelines/test_pipelines_text_generation.py index 0de1d94ccda5a2..411affa1c5370e 100644 --- a/tests/pipelines/test_pipelines_text_generation.py +++ b/tests/pipelines/test_pipelines_text_generation.py @@ -157,7 +157,7 @@ def test_small_chat_model_pt(self): self.assertEqual( outputs, [ - [{"generated_text": expected_chat1}], + {"generated_text": expected_chat1}, ], ) From a5e1ccc86c0a343c23689acfa0a73a2df8449a23 Mon Sep 17 00:00:00 2001 From: Matt Date: Fri, 9 Feb 2024 18:32:56 +0000 Subject: [PATCH 03/14] Deprecate ConversationalPipeline --- src/transformers/pipelines/conversational.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/transformers/pipelines/conversational.py b/src/transformers/pipelines/conversational.py index 3d42363f198357..9f7870864170ca 100644 --- a/src/transformers/pipelines/conversational.py +++ b/src/transformers/pipelines/conversational.py @@ -1,4 +1,5 @@ import uuid +import warnings from typing import Any, Dict, List, Union from ..utils import add_end_docstrings, is_tf_available, is_torch_available, logging @@ -232,6 +233,10 @@ class ConversationalPipeline(Pipeline): """ def __init__(self, *args, **kwargs): + warnings.warn( + "`ConversationalPipeline` is now deprecated, and the functionality has been moved to the standard `text-generation` pipeline, which now accepts lists of message dicts as well as strings. This class will be removed in a future release.", + DeprecationWarning, + ) super().__init__(*args, **kwargs) if self.tokenizer.pad_token_id is None: self.tokenizer.pad_token = self.tokenizer.eos_token From 4444a1ee8d5fe2d01b9aadaf8b152aacf705e82a Mon Sep 17 00:00:00 2001 From: Matt Date: Mon, 12 Feb 2024 15:30:42 +0000 Subject: [PATCH 04/14] stash commit --- docs/source/en/chat_templating.md | 8 ++++---- src/transformers/pipelines/text_generation.py | 1 + 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/docs/source/en/chat_templating.md b/docs/source/en/chat_templating.md index e0ffd9ad1589f3..cb79b2ced0792e 100644 --- a/docs/source/en/chat_templating.md +++ b/docs/source/en/chat_templating.md @@ -121,13 +121,13 @@ Arr, 'twas easy after all! ## Is there an automated pipeline for chat? -Yes, there is: [`ConversationalPipeline`]. This pipeline is designed to make it easy to use chat models. Let's try +Yes, there is! Our text generation pipelines support chat inputs, which makes it easy to use chat models. Let's try the `Zephyr` example again, but this time using the pipeline: ```python from transformers import pipeline -pipe = pipeline("conversational", "HuggingFaceH4/zephyr-7b-beta") +pipe = pipeline("text-generation", "HuggingFaceH4/zephyr-7b-beta") messages = [ { "role": "system", @@ -135,7 +135,7 @@ messages = [ }, {"role": "user", "content": "How many helicopters can a human eat in one sitting?"}, ] -print(pipe(messages)) +print(pipe(messages, max_new_tokens=128, do_sample=False)[0]['generated_text'][-1]) # Print the assistant's response ``` ```text @@ -145,7 +145,7 @@ user: How many helicopters can a human eat in one sitting? assistant: Matey, I'm afraid I must inform ye that humans cannot eat helicopters. Helicopters are not food, they are flying machines. Food is meant to be eaten, like a hearty plate o' grog, a savory bowl o' stew, or a delicious loaf o' bread. But helicopters, they be for transportin' and movin' around, not for eatin'. So, I'd say none, me hearties. None at all. ``` -[`ConversationalPipeline`] will take care of all the details of tokenization and calling `apply_chat_template` for you - +The pipeline will take care of all the details of tokenization and calling `apply_chat_template` for you - once the model has a chat template, all you need to do is initialize the pipeline and pass it the list of messages! ## What are "generation prompts"? diff --git a/src/transformers/pipelines/text_generation.py b/src/transformers/pipelines/text_generation.py index 0ce1f54f81c415..6175f6187a3172 100644 --- a/src/transformers/pipelines/text_generation.py +++ b/src/transformers/pipelines/text_generation.py @@ -80,6 +80,7 @@ class TextGenerationPipeline(Pipeline): """ def __init__(self, *args, **kwargs): + # TODO Matt: Find out why ConversationalPipeline and TextGeneration give slighly different (but valid) results super().__init__(*args, **kwargs) self.check_model_type( TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES if self.framework == "tf" else MODEL_FOR_CAUSAL_LM_MAPPING_NAMES From 6772615213b5d6eed885831d32e3b5f25438fa1c Mon Sep 17 00:00:00 2001 From: Matt Date: Mon, 12 Feb 2024 18:04:50 +0000 Subject: [PATCH 05/14] Add missing add_special_tokens kwarg --- src/transformers/tokenization_utils_base.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 13f422c1e07f8a..384145d84a580e 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -1756,6 +1756,7 @@ def apply_chat_template( truncation=truncation, max_length=max_length, return_tensors=return_tensors, + add_special_tokens=False, **tokenizer_kwargs, ) else: From de3d88abe41a030d02ca96187c7803f74bef28df Mon Sep 17 00:00:00 2001 From: Matt Date: Mon, 12 Feb 2024 18:13:04 +0000 Subject: [PATCH 06/14] Update chat templating docs to refer to TextGenerationPipeline instead of ConversationalPipeline --- docs/source/en/chat_templating.md | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/docs/source/en/chat_templating.md b/docs/source/en/chat_templating.md index cb79b2ced0792e..30fce8405e700f 100644 --- a/docs/source/en/chat_templating.md +++ b/docs/source/en/chat_templating.md @@ -121,8 +121,10 @@ Arr, 'twas easy after all! ## Is there an automated pipeline for chat? -Yes, there is! Our text generation pipelines support chat inputs, which makes it easy to use chat models. Let's try -the `Zephyr` example again, but this time using the pipeline: +Yes, there is! Our text generation pipelines support chat inputs, which makes it easy to use chat models. In the past, +we used to use a dedicated "ConversationalPipeline" class, but this has now been deprecated and its functionality +has been merged into the [`TextGenerationPipeline`]. Let's try the `Zephyr` example again, but this time using +a pipeline: ```python from transformers import pipeline @@ -139,10 +141,7 @@ print(pipe(messages, max_new_tokens=128, do_sample=False)[0]['generated_text'][- ``` ```text -Conversation id: 76d886a0-74bd-454e-9804-0467041a63dc -system: You are a friendly chatbot who always responds in the style of a pirate -user: How many helicopters can a human eat in one sitting? -assistant: Matey, I'm afraid I must inform ye that humans cannot eat helicopters. Helicopters are not food, they are flying machines. Food is meant to be eaten, like a hearty plate o' grog, a savory bowl o' stew, or a delicious loaf o' bread. But helicopters, they be for transportin' and movin' around, not for eatin'. So, I'd say none, me hearties. None at all. +{'role': 'assistant', 'content': "Matey, I'm afraid I must inform ye that humans cannot eat helicopters. Helicopters are not food, they are flying machines. Food is meant to be eaten, like a hearty plate o' grog, a savory bowl o' stew, or a delicious loaf o' bread. But helicopters, they be for transportin' and movin' around, not for eatin'. So, I'd say none, me hearties. None at all."} ``` The pipeline will take care of all the details of tokenization and calling `apply_chat_template` for you - @@ -191,7 +190,7 @@ Can I ask a question?<|im_end|> Note that this time, we've added the tokens that indicate the start of a bot response. This ensures that when the model generates text it will write a bot response instead of doing something unexpected, like continuing the user's message. Remember, chat models are still just language models - they're trained to continue text, and chat is just a -special kind of text to them! You need to guide them with the appropriate control tokens so they know what they're +special kind of text to them! You need to guide them with appropriate control tokens, so they know what they're supposed to be doing. Not all models require generation prompts. Some models, like BlenderBot and LLaMA, don't have any @@ -340,8 +339,8 @@ tokenizer.chat_template = template # Set the new template tokenizer.push_to_hub("model_name") # Upload your new template to the Hub! ``` -The method [`~PreTrainedTokenizer.apply_chat_template`] which uses your chat template is called by the [`ConversationalPipeline`] class, so -once you set the correct chat template, your model will automatically become compatible with [`ConversationalPipeline`]. +The method [`~PreTrainedTokenizer.apply_chat_template`] which uses your chat template is called by the [`TextGenerationPipeline`] class, so +once you set the correct chat template, your model will automatically become compatible with [`TextGenerationPipeline`]. If you're fine-tuning a model for chat, in addition to setting a chat template, you should probably add any new chat @@ -356,7 +355,7 @@ template. This will ensure that text generation tools can correctly figure out w Before the introduction of chat templates, chat handling was hardcoded at the model class level. For backwards compatibility, we have retained this class-specific handling as default templates, also set at the class level. If a -model does not have a chat template set, but there is a default template for its model class, the `ConversationalPipeline` +model does not have a chat template set, but there is a default template for its model class, the `TextGenerationPipeline` class and methods like `apply_chat_template` will use the class template instead. You can find out what the default template for your tokenizer is by checking the `tokenizer.default_chat_template` attribute. @@ -407,7 +406,7 @@ I'm doing great!<|im_end|> ``` The "user", "system" and "assistant" roles are the standard for chat, and we recommend using them when it makes sense, -particularly if you want your model to operate well with [`ConversationalPipeline`]. However, you are not limited +particularly if you want your model to operate well with [`TextGenerationPipeline`]. However, you are not limited to these roles - templating is extremely flexible, and any string can be a role. ### I want to add some chat templates! How should I get started? @@ -418,7 +417,7 @@ not the model owner - if you're using a model with an empty chat template, or on template, please open a [pull request](https://huggingface.co/docs/hub/repositories-pull-requests-discussions) to the model repository so that this attribute can be set properly! Once the attribute is set, that's it, you're done! `tokenizer.apply_chat_template` will now work correctly for that -model, which means it is also automatically supported in places like `ConversationalPipeline`! +model, which means it is also automatically supported in places like `TextGenerationPipeline`! By ensuring that models have this attribute, we can make sure that the whole community gets to use the full power of open-source models. Formatting mismatches have been haunting the field and silently harming performance for too long - From 7eb468db62bfb42654abe1590ace4ceebf7392b5 Mon Sep 17 00:00:00 2001 From: Matt Date: Wed, 14 Feb 2024 16:43:04 +0000 Subject: [PATCH 07/14] =?UTF-8?q?Add=20=E2=9C=A8TF=E2=9C=A8=20tests?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../test_pipelines_text_generation.py | 46 +++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/tests/pipelines/test_pipelines_text_generation.py b/tests/pipelines/test_pipelines_text_generation.py index 411affa1c5370e..edd0ae41aa335c 100644 --- a/tests/pipelines/test_pipelines_text_generation.py +++ b/tests/pipelines/test_pipelines_text_generation.py @@ -218,6 +218,52 @@ def test_small_model_tf(self): ], ) + @require_torch + def test_small_chat_model_tf(self): + text_generator = pipeline( + task="text-generation", model="rocketknight1/tiny-gpt2-with-chatml-template", framework="tf" + ) + # Using `do_sample=False` to force deterministic output + chat1 = [ + {"role": "system", "content": "This is a system message."}, + {"role": "user", "content": "This is a test"}, + {"role": "assistant", "content": "This is a reply"}, + ] + chat2 = [ + {"role": "system", "content": "This is a system message."}, + {"role": "user", "content": "This is a second test"}, + {"role": "assistant", "content": "This is a reply"}, + ] + outputs = text_generator(chat1, do_sample=False, max_new_tokens=10) + expected_chat1 = chat1 + [ + { + "role": "assistant", + "content": " factors factors factors factors factors factors factors factors factors factors", + } + ] + self.assertEqual( + outputs, + [ + {"generated_text": expected_chat1}, + ], + ) + + outputs = text_generator([chat1, chat2], do_sample=False, max_new_tokens=10) + expected_chat2 = chat2 + [ + { + "role": "assistant", + "content": " factors factors factors factors factors factors factors factors factors factors", + } + ] + + self.assertEqual( + outputs, + [ + [{"generated_text": expected_chat1}], + [{"generated_text": expected_chat2}], + ], + ) + def get_test_pipeline(self, model, tokenizer, processor): text_generator = TextGenerationPipeline(model=model, tokenizer=tokenizer) return text_generator, ["This is a test", "Another test"] From 6fae42d19d818f2c4d8ac3ecef3acebdf7ea1ac7 Mon Sep 17 00:00:00 2001 From: Matt Date: Wed, 14 Feb 2024 16:59:01 +0000 Subject: [PATCH 08/14] @require_tf --- tests/pipelines/test_pipelines_text_generation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pipelines/test_pipelines_text_generation.py b/tests/pipelines/test_pipelines_text_generation.py index edd0ae41aa335c..766f2a462a1930 100644 --- a/tests/pipelines/test_pipelines_text_generation.py +++ b/tests/pipelines/test_pipelines_text_generation.py @@ -218,7 +218,7 @@ def test_small_model_tf(self): ], ) - @require_torch + @require_tf def test_small_chat_model_tf(self): text_generator = pipeline( task="text-generation", model="rocketknight1/tiny-gpt2-with-chatml-template", framework="tf" From 3164035847224b24f352864befb41038f86ec118 Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 15 Feb 2024 14:21:51 +0000 Subject: [PATCH 09/14] Add type hint --- src/transformers/pipelines/text_generation.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/transformers/pipelines/text_generation.py b/src/transformers/pipelines/text_generation.py index 6175f6187a3172..997ea93bb3e34d 100644 --- a/src/transformers/pipelines/text_generation.py +++ b/src/transformers/pipelines/text_generation.py @@ -1,5 +1,6 @@ import enum import warnings +from typing import Dict from ..utils import add_end_docstrings, is_tf_available, is_torch_available from .base import Pipeline, build_pipeline_init_args @@ -25,10 +26,10 @@ class Chat: to this format because the rest of the pipeline code tends to assume that lists of messages are actually a batch of samples rather than messages in the same conversation.""" - def __init__(self, messages): + def __init__(self, messages: Dict): for message in messages: if not ("role" in message and "content" in message): - raise ValueError("When passing chat dicts as input, each chat must have a 'role' and 'content' key.") + raise ValueError("When passing chat dicts as input, each dict must have a 'role' and 'content' key.") self.messages = messages From 01fc1a67af29530c179e361699d9988015ec1bf1 Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 15 Feb 2024 14:47:57 +0000 Subject: [PATCH 10/14] Add specific deprecation version --- src/transformers/pipelines/conversational.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/pipelines/conversational.py b/src/transformers/pipelines/conversational.py index 9f7870864170ca..ca091074effb51 100644 --- a/src/transformers/pipelines/conversational.py +++ b/src/transformers/pipelines/conversational.py @@ -234,7 +234,7 @@ class ConversationalPipeline(Pipeline): def __init__(self, *args, **kwargs): warnings.warn( - "`ConversationalPipeline` is now deprecated, and the functionality has been moved to the standard `text-generation` pipeline, which now accepts lists of message dicts as well as strings. This class will be removed in a future release.", + "`ConversationalPipeline` is now deprecated, and the functionality has been moved to the standard `text-generation` pipeline, which now accepts lists of message dicts as well as strings. This class will be removed in v4.42.", DeprecationWarning, ) super().__init__(*args, **kwargs) From 1b3f53f2ea110a7bfad452e5fbe4b7883477c13f Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 15 Feb 2024 15:01:22 +0000 Subject: [PATCH 11/14] Remove unnecessary do_sample --- docs/source/en/chat_templating.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/chat_templating.md b/docs/source/en/chat_templating.md index 30fce8405e700f..23c3107580c074 100644 --- a/docs/source/en/chat_templating.md +++ b/docs/source/en/chat_templating.md @@ -137,7 +137,7 @@ messages = [ }, {"role": "user", "content": "How many helicopters can a human eat in one sitting?"}, ] -print(pipe(messages, max_new_tokens=128, do_sample=False)[0]['generated_text'][-1]) # Print the assistant's response +print(pipe(messages, max_new_tokens=128)[0]['generated_text'][-1]) # Print the assistant's response ``` ```text From bbd8cfc899935062b244bfdce5c357d8d1cc0068 Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 15 Feb 2024 15:03:53 +0000 Subject: [PATCH 12/14] Remove todo - the discrepancy has been resolved --- src/transformers/pipelines/text_generation.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/pipelines/text_generation.py b/src/transformers/pipelines/text_generation.py index 997ea93bb3e34d..5b4ee8830d7383 100644 --- a/src/transformers/pipelines/text_generation.py +++ b/src/transformers/pipelines/text_generation.py @@ -81,7 +81,6 @@ class TextGenerationPipeline(Pipeline): """ def __init__(self, *args, **kwargs): - # TODO Matt: Find out why ConversationalPipeline and TextGeneration give slighly different (but valid) results super().__init__(*args, **kwargs) self.check_model_type( TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES if self.framework == "tf" else MODEL_FOR_CAUSAL_LM_MAPPING_NAMES From ecea9b511f53fcf134c4bf3eca6fe8ba502f105f Mon Sep 17 00:00:00 2001 From: Matt Date: Fri, 16 Feb 2024 13:48:53 +0000 Subject: [PATCH 13/14] Update src/transformers/tokenization_utils_base.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- src/transformers/tokenization_utils_base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 384145d84a580e..118c6e11077285 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -1755,17 +1755,17 @@ def apply_chat_template( padding=padding, truncation=truncation, max_length=max_length, - return_tensors=return_tensors, add_special_tokens=False, + return_tensors=return_tensors, **tokenizer_kwargs, ) else: return self.encode( rendered, - add_special_tokens=False, padding=padding, truncation=truncation, max_length=max_length, + add_special_tokens=False, return_tensors=return_tensors, **tokenizer_kwargs, ) From f985755486269c57774a55200910979a0d931cf5 Mon Sep 17 00:00:00 2001 From: Matt Date: Fri, 16 Feb 2024 13:49:01 +0000 Subject: [PATCH 14/14] Update src/transformers/pipelines/text_generation.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- src/transformers/pipelines/text_generation.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/pipelines/text_generation.py b/src/transformers/pipelines/text_generation.py index 5b4ee8830d7383..8e454655c788cf 100644 --- a/src/transformers/pipelines/text_generation.py +++ b/src/transformers/pipelines/text_generation.py @@ -254,21 +254,21 @@ def preprocess( if isinstance(prompt_text, Chat): inputs = self.tokenizer.apply_chat_template( prompt_text.messages, + truncation=truncation, padding=padding, - add_generation_prompt=True, - return_tensors=self.framework, max_length=max_length, - truncation=truncation, + add_generation_prompt=True, return_dict=True, + return_tensors=self.framework, ) else: inputs = self.tokenizer( prefix + prompt_text, - return_tensors=self.framework, truncation=truncation, padding=padding, max_length=max_length, add_special_tokens=add_special_tokens, + return_tensors=self.framework, ) inputs["prompt_text"] = prompt_text