Skip to content

Commit

Permalink
Make chat templates part of ProcessorMixin (#30744)
Browse files Browse the repository at this point in the history
* Let's try moving chat templates out of IDEFICS and into the generic ProcessorMixin

* Chat templates should not be mandatory

* Chat templates should not be mandatory

* Not all classes will have default chat templates

* stash commit

* Add chat template docstring

* Clean up docstring

* Add chat templates to LLaVA/LLaVA-next

* Docstring fixup

* Quick IDEFICS2 fixup

* Remove some old references to the Conversation class

* make fixup
  • Loading branch information
Rocketknight1 authored and itazap committed Jun 18, 2024
1 parent 4c3d517 commit feb348d
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 54 deletions.
54 changes: 5 additions & 49 deletions src/transformers/models/idefics2/processing_idefics2.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
Processor class for IDEFICS2.
"""

from typing import TYPE_CHECKING, Dict, List, Optional, Union
from typing import TYPE_CHECKING, List, Optional, Union

from ...feature_extraction_utils import BatchFeature
from ...image_utils import ImageInput, is_valid_image, load_image
Expand Down Expand Up @@ -56,13 +56,15 @@ class Idefics2Processor(ProcessorMixin):
The length of the image sequence i.e. the number of <image> tokens per image in the input.
This parameter is used to build the string from the input prompt and image tokens and should match the
config.perceiver_config.resampler_n_latents value for the model used.
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
in a chat into a tokenizable string.
"""

attributes = ["image_processor", "tokenizer"]
image_processor_class = "Idefics2ImageProcessor"
tokenizer_class = "AutoTokenizer"

def __init__(self, image_processor, tokenizer=None, image_seq_len: int = 64, **kwargs):
def __init__(self, image_processor, tokenizer=None, image_seq_len: int = 64, chat_template: str = None, **kwargs):
if image_processor is None:
raise ValueError("You need to specify an `image_processor`.")
if tokenizer is None:
Expand All @@ -78,10 +80,7 @@ def __init__(self, image_processor, tokenizer=None, image_seq_len: int = 64, **k
}
tokenizer.add_special_tokens(tokens_to_add)

# Stores a Jinja template that formats chat histories into tokenizable strings
self.chat_template = kwargs.pop("chat_template", None)

super().__init__(image_processor, tokenizer)
super().__init__(image_processor, tokenizer, chat_template=chat_template)

def _extract_images_from_prompts(self, prompts):
prompt_images = []
Expand Down Expand Up @@ -252,49 +251,6 @@ def model_input_names(self):
image_processor_input_names = self.image_processor.model_input_names
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))

def apply_chat_template(
self,
conversation: Union[List[Dict[str, str]]],
chat_template: Optional[str] = None,
tokenize: bool = False,
**kwargs,
) -> str:
"""
Overrides the tokenizer's `apply_chat_template` method to apply the IDEFICS2 chat template by default
if no chat template is provided.
By default, the output isn't tokenized. This is because the IDEFICS2 chat template is designed to insert
the image token <image> into the sequence according to the message, but does not handle expanding the image
tokens to the sequence length or adding the surrounding tokens e.g. <fake_image_token>.
Args:
conversation (`Union[List[Dict, str, str]]`):
The conversation to format.
chat_template (`Optional[str]`, *optional*):
The Jinja template to use for formatting the conversation. If not provided, the default chat template
is used.
tokenize (`bool`, *optional*, defaults to `False`):
Whether to tokenize the output or not.
**kwargs:
Additional keyword arguments for the tokenizer's `apply_chat_template` method.
"""

if chat_template is None:
if self.chat_template is not None:
chat_template = self.chat_template
else:
logger.warning_once(
"No chat template is set for this processor, falling back to a default class-level template. This is "
"very error-prone, because models are often trained with templates different from the class default! "
"Default chat templates are a legacy feature and will be removed in Transformers v4.43, at which "
"point any code depending on them will stop working. We recommend setting a valid chat template before "
"then to ensure that this model continues working without issues."
)
chat_template = self.default_chat_template
return self.tokenizer.apply_chat_template(
conversation, chat_template=chat_template, tokenize=tokenize, **kwargs
)

@property
def default_chat_template(self):
"""
Expand Down
6 changes: 4 additions & 2 deletions src/transformers/models/llava/processing_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,16 @@ class LlavaProcessor(ProcessorMixin):
The image processor is a required input.
tokenizer ([`LlamaTokenizerFast`], *optional*):
The tokenizer is a required input.
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
in a chat into a tokenizable string.
"""

attributes = ["image_processor", "tokenizer"]
image_processor_class = "AutoImageProcessor"
tokenizer_class = "AutoTokenizer"

def __init__(self, image_processor=None, tokenizer=None):
super().__init__(image_processor, tokenizer)
def __init__(self, image_processor=None, tokenizer=None, chat_template=None):
super().__init__(image_processor, tokenizer, chat_template=chat_template)

def __call__(
self,
Expand Down
6 changes: 4 additions & 2 deletions src/transformers/models/llava_next/processing_llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,16 @@ class LlavaNextProcessor(ProcessorMixin):
The image processor is a required input.
tokenizer ([`LlamaTokenizerFast`], *optional*):
The tokenizer is a required input.
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
in a chat into a tokenizable string.
"""

attributes = ["image_processor", "tokenizer"]
image_processor_class = "AutoImageProcessor"
tokenizer_class = "AutoTokenizer"

def __init__(self, image_processor=None, tokenizer=None):
super().__init__(image_processor, tokenizer)
def __init__(self, image_processor=None, tokenizer=None, chat_template=None):
super().__init__(image_processor, tokenizer, chat_template=chat_template)

def __call__(
self,
Expand Down
52 changes: 51 additions & 1 deletion src/transformers/processing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import os
import warnings
from pathlib import Path
from typing import Any, Dict, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union

from .dynamic_module_utils import custom_object_save
from .tokenization_utils_base import PreTrainedTokenizerBase
Expand Down Expand Up @@ -60,13 +60,18 @@ class ProcessorMixin(PushToHubMixin):
"""

attributes = ["feature_extractor", "tokenizer"]
optional_attributes = ["chat_template"]
# Names need to be attr_class for attr in attributes
feature_extractor_class = None
tokenizer_class = None
_auto_class = None

# args have to match the attributes class attribute
def __init__(self, *args, **kwargs):
# First, extract optional attributes from kwargs if present
# Optional attributes can never be positional arguments
for optional_attribute in self.optional_attributes:
setattr(self, optional_attribute, kwargs.pop(optional_attribute, None))
# Sanitize args and kwargs
for key in kwargs:
if key not in self.attributes:
Expand Down Expand Up @@ -522,6 +527,51 @@ def model_input_names(self):
first_attribute = getattr(self, self.attributes[0])
return getattr(first_attribute, "model_input_names", None)

def apply_chat_template(
self,
conversation: Union[List[Dict[str, str]]],
chat_template: Optional[str] = None,
tokenize: bool = False,
**kwargs,
) -> str:
"""
Similar to the `apply_chat_template` method on tokenizers, this method applies a Jinja template to input
conversations to turn them into a single tokenizable string.
Args:
conversation (`List[Dict, str, str]`):
The conversation to format.
chat_template (`Optional[str]`, *optional*):
The Jinja template to use for formatting the conversation. If not provided, the default chat template
is used.
tokenize (`bool`, *optional*, defaults to `False`):
Whether to tokenize the output or not.
**kwargs:
Additional keyword arguments
"""

if chat_template is None:
if self.chat_template is not None:
chat_template = self.chat_template
elif getattr(self, "default_chat_template", None) is not None:
logger.warning_once(
"No chat template is set for this processor, falling back to a default class-level template. This is "
"very error-prone, because models are often trained with templates different from the class default! "
"Default chat templates are a legacy feature and will be removed in Transformers v4.43, at which "
"point any code depending on them will stop working. We recommend setting a valid chat template before "
"then to ensure that this model continues working without issues."
)
chat_template = self.default_chat_template
else:
raise ValueError(
"No chat template is set for this processor. Please either set the `chat_template` attribute, "
"or provide a chat template as an argument. See "
"https://huggingface.co/docs/transformers/main/en/chat_templating for more information."
)
return self.tokenizer.apply_chat_template(
conversation, chat_template=chat_template, tokenize=tokenize, **kwargs
)


ProcessorMixin.push_to_hub = copy_func(ProcessorMixin.push_to_hub)
if ProcessorMixin.push_to_hub.__doc__ is not None:
Expand Down

0 comments on commit feb348d

Please sign in to comment.