diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index dc24662e51da19..ca61fcf6764215 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -27,7 +27,6 @@ from collections.abc import Mapping, Sized from contextlib import contextmanager from dataclasses import dataclass -from functools import lru_cache from inspect import isfunction from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union @@ -65,6 +64,7 @@ requires_backends, to_py_obj, ) +from .utils.chat_template_utils import _compile_jinja_template, _render_with_assistant_indices if TYPE_CHECKING: @@ -1791,7 +1791,7 @@ def apply_chat_template( ) # Compilation function uses a cache to avoid recompiling the same template - compiled_template = self._compile_jinja_template(chat_template) + compiled_template = _compile_jinja_template(chat_template) if isinstance(conversation, (list, tuple)) and ( isinstance(conversation[0], (list, tuple)) or hasattr(conversation[0], "messages") @@ -1831,7 +1831,7 @@ def apply_chat_template( # Indicates it's a Conversation object chat = chat.messages if return_assistant_tokens_mask: - rendered_chat, generation_indices = self._render_with_assistant_indices( + rendered_chat, generation_indices = _render_with_assistant_indices( compiled_template=compiled_template, messages=chat, tools=tool_schemas, @@ -1888,94 +1888,6 @@ def apply_chat_template( else: return rendered - def _render_with_assistant_indices( - self, compiled_template, messages, tools, documents, add_generation_prompt, **template_kwargs - ): - rendered_blocks = [] - generation_indices = [] - with compiled_template.environment.activate_tracker(rendered_blocks, generation_indices): - for block in compiled_template.generate( - messages=messages, - tools=tools, - documents=documents, - add_generation_prompt=add_generation_prompt, - **template_kwargs, - ): - rendered_blocks.append(block) - rendered_chat = "".join(rendered_blocks) - return rendered_chat, generation_indices - - @lru_cache - def _compile_jinja_template(self, chat_template): - try: - import jinja2 - from jinja2 import nodes - from jinja2.exceptions import TemplateError - from jinja2.ext import Extension - from jinja2.sandbox import ImmutableSandboxedEnvironment - except ImportError: - raise ImportError("apply_chat_template requires jinja2 to be installed.") - - if version.parse(jinja2.__version__) < version.parse("3.1.0"): - raise ImportError( - "apply_chat_template requires jinja2>=3.1.0 to be installed. Your version is " f"{jinja2.__version__}." - ) - - def raise_exception(message): - raise TemplateError(message) - - def tojson(x, ensure_ascii=False, indent=None, separators=None, sort_keys=False): - # We override the built-in tojson filter because Jinja's default filter escapes HTML characters - # We also expose some options like custom indents and separators - return json.dumps(x, ensure_ascii=ensure_ascii, indent=indent, separators=separators, sort_keys=sort_keys) - - class AssistantTracker(Extension): - # This extension is used to track the indices of assistant-generated tokens in the rendered chat - tags = {"generation"} - - def __init__(self, environment: ImmutableSandboxedEnvironment): - # The class is only initiated by jinja. - super().__init__(environment) - environment.extend(activate_tracker=self.activate_tracker) - self._rendered_blocks = None - self._generation_indices = None - - def parse(self, parser: jinja2.parser.Parser) -> jinja2.nodes.CallBlock: - lineno = next(parser.stream).lineno - body = parser.parse_statements(["name:endgeneration"], drop_needle=True) - return nodes.CallBlock(self.call_method("_generation_support"), [], [], body).set_lineno(lineno) - - @jinja2.pass_eval_context - def _generation_support(self, context: jinja2.nodes.EvalContext, caller: jinja2.runtime.Macro) -> str: - rv = caller() - if self.is_active(): - # Only track generation indices if the tracker is active - start_index = len("".join(self._rendered_blocks)) - end_index = start_index + len(rv) - self._generation_indices.append((start_index, end_index)) - return rv - - def is_active(self) -> bool: - return self._rendered_blocks or self._generation_indices - - @contextmanager - def activate_tracker(self, rendered_blocks: List[int], generation_indices: List[int]): - try: - if self.is_active(): - raise ValueError("AssistantTracker should not be reused before closed") - self._rendered_blocks = rendered_blocks - self._generation_indices = generation_indices - - yield - finally: - self._rendered_blocks = None - self._generation_indices = None - - jinja_env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True, extensions=[AssistantTracker]) - jinja_env.filters["tojson"] = tojson - jinja_env.globals["raise_exception"] = raise_exception - return jinja_env.from_string(chat_template) - def get_chat_template(self, chat_template: Optional[str] = None, tools: Optional[List[Dict]] = None) -> str: """ Retrieve the chat template string used for tokenizing chat messages. This template is used diff --git a/src/transformers/utils/chat_template_utils.py b/src/transformers/utils/chat_template_utils.py index 078a307b1c33d6..aabaf4a3666506 100644 --- a/src/transformers/utils/chat_template_utils.py +++ b/src/transformers/utils/chat_template_utils.py @@ -15,7 +15,22 @@ import inspect import json import re -from typing import Any, Callable, Dict, Optional, Tuple, Union, get_args, get_origin, get_type_hints +from contextlib import contextmanager +from datetime import datetime +from functools import lru_cache +from typing import Any, Callable, Dict, List, Optional, Tuple, Union, get_args, get_origin, get_type_hints + +from packaging import version + +from .import_utils import is_jinja_available + + +if is_jinja_available(): + import jinja2 + from jinja2.ext import Extension + from jinja2.sandbox import ImmutableSandboxedEnvironment +else: + jinja2 = None BASIC_TYPES = (int, float, str, bool, Any, type(None), ...) @@ -314,3 +329,90 @@ def get_json_schema(func: Callable) -> Dict: if return_dict is not None: output["return"] = return_dict return {"type": "function", "function": output} + + +def _render_with_assistant_indices( + compiled_template, messages, tools, documents, add_generation_prompt, **template_kwargs +): + rendered_blocks = [] + generation_indices = [] + with compiled_template.environment.activate_tracker(rendered_blocks, generation_indices): + for block in compiled_template.generate( + messages=messages, + tools=tools, + documents=documents, + add_generation_prompt=add_generation_prompt, + **template_kwargs, + ): + rendered_blocks.append(block) + rendered_chat = "".join(rendered_blocks) + return rendered_chat, generation_indices + + +@lru_cache +def _compile_jinja_template(chat_template): + class AssistantTracker(Extension): + # This extension is used to track the indices of assistant-generated tokens in the rendered chat + tags = {"generation"} + + def __init__(self, environment: ImmutableSandboxedEnvironment): + # The class is only initiated by jinja. + super().__init__(environment) + environment.extend(activate_tracker=self.activate_tracker) + self._rendered_blocks = None + self._generation_indices = None + + def parse(self, parser: jinja2.parser.Parser) -> jinja2.nodes.CallBlock: + lineno = next(parser.stream).lineno + body = parser.parse_statements(["name:endgeneration"], drop_needle=True) + return jinja2.nodes.CallBlock(self.call_method("_generation_support"), [], [], body).set_lineno(lineno) + + @jinja2.pass_eval_context + def _generation_support(self, context: jinja2.nodes.EvalContext, caller: jinja2.runtime.Macro) -> str: + rv = caller() + if self.is_active(): + # Only track generation indices if the tracker is active + start_index = len("".join(self._rendered_blocks)) + end_index = start_index + len(rv) + self._generation_indices.append((start_index, end_index)) + return rv + + def is_active(self) -> bool: + return self._rendered_blocks or self._generation_indices + + @contextmanager + def activate_tracker(self, rendered_blocks: List[int], generation_indices: List[int]): + try: + if self.is_active(): + raise ValueError("AssistantTracker should not be reused before closed") + self._rendered_blocks = rendered_blocks + self._generation_indices = generation_indices + + yield + finally: + self._rendered_blocks = None + self._generation_indices = None + + if version.parse(jinja2.__version__) < version.parse("3.1.0"): + raise ImportError( + "apply_chat_template requires jinja2>=3.1.0 to be installed. Your version is " f"{jinja2.__version__}." + ) + + def raise_exception(message): + raise jinja2.exceptions.TemplateError(message) + + def tojson(x, ensure_ascii=False, indent=None, separators=None, sort_keys=False): + # We override the built-in tojson filter because Jinja's default filter escapes HTML characters + # We also expose some options like custom indents and separators + return json.dumps(x, ensure_ascii=ensure_ascii, indent=indent, separators=separators, sort_keys=sort_keys) + + def strftime_now(format): + return datetime.now().strftime(format) + + jinja_env = ImmutableSandboxedEnvironment( + trim_blocks=True, lstrip_blocks=True, extensions=[AssistantTracker, jinja2.ext.loopcontrols] + ) + jinja_env.filters["tojson"] = tojson + jinja_env.globals["raise_exception"] = raise_exception + jinja_env.globals["strftime_now"] = strftime_now + return jinja_env.from_string(chat_template) diff --git a/tests/test_tokenization_common.py b/tests/test_tokenization_common.py index d8ff702cbe1209..f1bcfe3929be47 100644 --- a/tests/test_tokenization_common.py +++ b/tests/test_tokenization_common.py @@ -1153,6 +1153,51 @@ def test_chat_template_batched(self): dummy_conversations, chat_template=dummy_template, tokenize=True ) # Check that no error raised + @require_jinja + def test_jinja_loopcontrols(self): + break_template = """ + {%- for message in messages %} + {{- message.role + " " + message.content }} + {%- if loop.first %} + {%- break %} + {%- endif %} + {%- endfor %}""".strip() + + dummy_conversation = [ + {"role": "system", "content": "1"}, + {"role": "user", "content": "2"}, + {"role": "assistant", "content": "3"}, + ] + + tokenizers = self.get_tokenizers() + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): + break_output = tokenizer.apply_chat_template( + dummy_conversation, chat_template=break_template, tokenize=False + ) + self.assertEqual(break_output, "system 1") # Loop should break after first iter + + @require_jinja + def test_jinja_strftime(self): + strftime_template = """{{- strftime_now("%Y-%m-%d") }}""".strip() + + dummy_conversation = [ + {"role": "system", "content": "1"}, + {"role": "user", "content": "2"}, + {"role": "assistant", "content": "3"}, + ] + + tokenizers = self.get_tokenizers() + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): + strftime_output = tokenizer.apply_chat_template( + dummy_conversation, chat_template=strftime_template, tokenize=False + ) + + # Assert that we get a date formatted as expected + self.assertEqual(len(strftime_output), 10) + self.assertEqual(len(strftime_output.split("-")), 3) + @require_jinja def test_chat_template_return_assistant_tokens_mask(self): dummy_template = (