diff --git a/src/transformers/models/idefics2/processing_idefics2.py b/src/transformers/models/idefics2/processing_idefics2.py index e9f9f923373623..4edb1813b8e0d2 100644 --- a/src/transformers/models/idefics2/processing_idefics2.py +++ b/src/transformers/models/idefics2/processing_idefics2.py @@ -16,7 +16,7 @@ Processor class for IDEFICS2. """ -from typing import TYPE_CHECKING, Dict, List, Optional, Union +from typing import TYPE_CHECKING, List, Optional, Union from ...feature_extraction_utils import BatchFeature from ...image_utils import ImageInput, is_valid_image, load_image @@ -56,13 +56,15 @@ class Idefics2Processor(ProcessorMixin): The length of the image sequence i.e. the number of tokens per image in the input. This parameter is used to build the string from the input prompt and image tokens and should match the config.perceiver_config.resampler_n_latents value for the model used. + chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages + in a chat into a tokenizable string. """ attributes = ["image_processor", "tokenizer"] image_processor_class = "Idefics2ImageProcessor" tokenizer_class = "AutoTokenizer" - def __init__(self, image_processor, tokenizer=None, image_seq_len: int = 64, **kwargs): + def __init__(self, image_processor, tokenizer=None, image_seq_len: int = 64, chat_template: str = None, **kwargs): if image_processor is None: raise ValueError("You need to specify an `image_processor`.") if tokenizer is None: @@ -78,10 +80,7 @@ def __init__(self, image_processor, tokenizer=None, image_seq_len: int = 64, **k } tokenizer.add_special_tokens(tokens_to_add) - # Stores a Jinja template that formats chat histories into tokenizable strings - self.chat_template = kwargs.pop("chat_template", None) - - super().__init__(image_processor, tokenizer) + super().__init__(image_processor, tokenizer, chat_template=chat_template) def _extract_images_from_prompts(self, prompts): prompt_images = [] @@ -252,49 +251,6 @@ def model_input_names(self): image_processor_input_names = self.image_processor.model_input_names return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) - def apply_chat_template( - self, - conversation: Union[List[Dict[str, str]]], - chat_template: Optional[str] = None, - tokenize: bool = False, - **kwargs, - ) -> str: - """ - Overrides the tokenizer's `apply_chat_template` method to apply the IDEFICS2 chat template by default - if no chat template is provided. - - By default, the output isn't tokenized. This is because the IDEFICS2 chat template is designed to insert - the image token into the sequence according to the message, but does not handle expanding the image - tokens to the sequence length or adding the surrounding tokens e.g. . - - Args: - conversation (`Union[List[Dict, str, str]]`): - The conversation to format. - chat_template (`Optional[str]`, *optional*): - The Jinja template to use for formatting the conversation. If not provided, the default chat template - is used. - tokenize (`bool`, *optional*, defaults to `False`): - Whether to tokenize the output or not. - **kwargs: - Additional keyword arguments for the tokenizer's `apply_chat_template` method. - """ - - if chat_template is None: - if self.chat_template is not None: - chat_template = self.chat_template - else: - logger.warning_once( - "No chat template is set for this processor, falling back to a default class-level template. This is " - "very error-prone, because models are often trained with templates different from the class default! " - "Default chat templates are a legacy feature and will be removed in Transformers v4.43, at which " - "point any code depending on them will stop working. We recommend setting a valid chat template before " - "then to ensure that this model continues working without issues." - ) - chat_template = self.default_chat_template - return self.tokenizer.apply_chat_template( - conversation, chat_template=chat_template, tokenize=tokenize, **kwargs - ) - @property def default_chat_template(self): """ diff --git a/src/transformers/models/llava/processing_llava.py b/src/transformers/models/llava/processing_llava.py index 7016cd50096977..96d38c53c947af 100644 --- a/src/transformers/models/llava/processing_llava.py +++ b/src/transformers/models/llava/processing_llava.py @@ -37,14 +37,16 @@ class LlavaProcessor(ProcessorMixin): The image processor is a required input. tokenizer ([`LlamaTokenizerFast`], *optional*): The tokenizer is a required input. + chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages + in a chat into a tokenizable string. """ attributes = ["image_processor", "tokenizer"] image_processor_class = "AutoImageProcessor" tokenizer_class = "AutoTokenizer" - def __init__(self, image_processor=None, tokenizer=None): - super().__init__(image_processor, tokenizer) + def __init__(self, image_processor=None, tokenizer=None, chat_template=None): + super().__init__(image_processor, tokenizer, chat_template=chat_template) def __call__( self, diff --git a/src/transformers/models/llava_next/processing_llava_next.py b/src/transformers/models/llava_next/processing_llava_next.py index 91cd544ab6484e..6c2ca2f9028409 100644 --- a/src/transformers/models/llava_next/processing_llava_next.py +++ b/src/transformers/models/llava_next/processing_llava_next.py @@ -37,14 +37,16 @@ class LlavaNextProcessor(ProcessorMixin): The image processor is a required input. tokenizer ([`LlamaTokenizerFast`], *optional*): The tokenizer is a required input. + chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages + in a chat into a tokenizable string. """ attributes = ["image_processor", "tokenizer"] image_processor_class = "AutoImageProcessor" tokenizer_class = "AutoTokenizer" - def __init__(self, image_processor=None, tokenizer=None): - super().__init__(image_processor, tokenizer) + def __init__(self, image_processor=None, tokenizer=None, chat_template=None): + super().__init__(image_processor, tokenizer, chat_template=chat_template) def __call__( self, diff --git a/src/transformers/processing_utils.py b/src/transformers/processing_utils.py index d76fa4dccccfee..a21d265b9d1bda 100644 --- a/src/transformers/processing_utils.py +++ b/src/transformers/processing_utils.py @@ -22,7 +22,7 @@ import os import warnings from pathlib import Path -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union from .dynamic_module_utils import custom_object_save from .tokenization_utils_base import PreTrainedTokenizerBase @@ -60,6 +60,7 @@ class ProcessorMixin(PushToHubMixin): """ attributes = ["feature_extractor", "tokenizer"] + optional_attributes = ["chat_template"] # Names need to be attr_class for attr in attributes feature_extractor_class = None tokenizer_class = None @@ -67,6 +68,10 @@ class ProcessorMixin(PushToHubMixin): # args have to match the attributes class attribute def __init__(self, *args, **kwargs): + # First, extract optional attributes from kwargs if present + # Optional attributes can never be positional arguments + for optional_attribute in self.optional_attributes: + setattr(self, optional_attribute, kwargs.pop(optional_attribute, None)) # Sanitize args and kwargs for key in kwargs: if key not in self.attributes: @@ -522,6 +527,51 @@ def model_input_names(self): first_attribute = getattr(self, self.attributes[0]) return getattr(first_attribute, "model_input_names", None) + def apply_chat_template( + self, + conversation: Union[List[Dict[str, str]]], + chat_template: Optional[str] = None, + tokenize: bool = False, + **kwargs, + ) -> str: + """ + Similar to the `apply_chat_template` method on tokenizers, this method applies a Jinja template to input + conversations to turn them into a single tokenizable string. + + Args: + conversation (`List[Dict, str, str]`): + The conversation to format. + chat_template (`Optional[str]`, *optional*): + The Jinja template to use for formatting the conversation. If not provided, the default chat template + is used. + tokenize (`bool`, *optional*, defaults to `False`): + Whether to tokenize the output or not. + **kwargs: + Additional keyword arguments + """ + + if chat_template is None: + if self.chat_template is not None: + chat_template = self.chat_template + elif getattr(self, "default_chat_template", None) is not None: + logger.warning_once( + "No chat template is set for this processor, falling back to a default class-level template. This is " + "very error-prone, because models are often trained with templates different from the class default! " + "Default chat templates are a legacy feature and will be removed in Transformers v4.43, at which " + "point any code depending on them will stop working. We recommend setting a valid chat template before " + "then to ensure that this model continues working without issues." + ) + chat_template = self.default_chat_template + else: + raise ValueError( + "No chat template is set for this processor. Please either set the `chat_template` attribute, " + "or provide a chat template as an argument. See " + "https://huggingface.co/docs/transformers/main/en/chat_templating for more information." + ) + return self.tokenizer.apply_chat_template( + conversation, chat_template=chat_template, tokenize=tokenize, **kwargs + ) + ProcessorMixin.push_to_hub = copy_func(ProcessorMixin.push_to_hub) if ProcessorMixin.push_to_hub.__doc__ is not None: