From ab4176b331de137d7561cd09e5f8b8119e9bc42e Mon Sep 17 00:00:00 2001 From: Matt Date: Wed, 14 Aug 2024 14:58:14 +0100 Subject: [PATCH 01/14] Add new Jinja features: - Do extension - Break/continue in loops - Call strftime to get current datetime in any format --- src/transformers/tokenization_utils_base.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index dc24662e51da19..eddbe691201638 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -27,6 +27,7 @@ from collections.abc import Mapping, Sized from contextlib import contextmanager from dataclasses import dataclass +from datetime import datetime from functools import lru_cache from inspect import isfunction from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union @@ -1929,6 +1930,9 @@ def tojson(x, ensure_ascii=False, indent=None, separators=None, sort_keys=False) # We also expose some options like custom indents and separators return json.dumps(x, ensure_ascii=ensure_ascii, indent=indent, separators=separators, sort_keys=sort_keys) + def strftime(format): + return datetime.now().strftime(format) + class AssistantTracker(Extension): # This extension is used to track the indices of assistant-generated tokens in the rendered chat tags = {"generation"} @@ -1971,9 +1975,12 @@ def activate_tracker(self, rendered_blocks: List[int], generation_indices: List[ self._rendered_blocks = None self._generation_indices = None - jinja_env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True, extensions=[AssistantTracker]) + jinja_env = ImmutableSandboxedEnvironment( + trim_blocks=True, lstrip_blocks=True, extensions=[AssistantTracker, jinja2.ext.do, jinja2.ext.loopcontrols] + ) jinja_env.filters["tojson"] = tojson jinja_env.globals["raise_exception"] = raise_exception + jinja_env.globals["strftime"] = strftime return jinja_env.from_string(chat_template) def get_chat_template(self, chat_template: Optional[str] = None, tools: Optional[List[Dict]] = None) -> str: From 8d64c3ae7dd0f7c675bbce7ccf796c9210d38798 Mon Sep 17 00:00:00 2001 From: Matt Date: Wed, 14 Aug 2024 15:14:29 +0100 Subject: [PATCH 02/14] Add new Jinja features: - Do extension - Break/continue in loops - Call strftime to get current datetime in any format --- tests/test_tokenization_common.py | 33 +++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/tests/test_tokenization_common.py b/tests/test_tokenization_common.py index d8ff702cbe1209..052328bfe6c954 100644 --- a/tests/test_tokenization_common.py +++ b/tests/test_tokenization_common.py @@ -1153,6 +1153,39 @@ def test_chat_template_batched(self): dummy_conversations, chat_template=dummy_template, tokenize=True ) # Check that no error raised + @require_jinja + def test_jinja_extensions(self): + break_template = """ + {%- for message in messages %} + {{- message.role + " " + message.content }} + {%- if loop.first %} + {%- break %} + {%- endif %} + {%- endfor %}""" + + dummy_conversation = [ + {"role": "system", "content": "1"}, + {"role": "user", "content": "2"}, + {"role": "assistant", "content": "3"}, + ] + + strftime_template = """{{ strftime("%Y-%m-%d") }} """ + + tokenizers = self.get_tokenizers() + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): + break_output = tokenizer.apply_chat_template( + dummy_conversation, chat_template=break_template, tokenize=False + ) + self.assertEqual(break_output, "system 1") # Loop should break after first iter + + strftime_output = tokenizer.apply_chat_template( + dummy_conversation, chat_template=strftime_template, tokenize=False + ) + # Assert that we get a date formatted as expected + self.assertEqual(len(strftime_output), 10) + self.assertEqual(len(strftime_output.split("-")), 3) + @require_jinja def test_chat_template_return_assistant_tokens_mask(self): dummy_template = ( From a3d3fbadab497db2803c674fca6848eeca57f5fe Mon Sep 17 00:00:00 2001 From: Matt Date: Wed, 14 Aug 2024 15:18:59 +0100 Subject: [PATCH 03/14] Fix strftime template --- tests/test_tokenization_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_tokenization_common.py b/tests/test_tokenization_common.py index 052328bfe6c954..6d60250f751153 100644 --- a/tests/test_tokenization_common.py +++ b/tests/test_tokenization_common.py @@ -1169,7 +1169,7 @@ def test_jinja_extensions(self): {"role": "assistant", "content": "3"}, ] - strftime_template = """{{ strftime("%Y-%m-%d") }} """ + strftime_template = """{{- strftime("%Y-%m-%d") }}""" tokenizers = self.get_tokenizers() for tokenizer in tokenizers: From b9d9fb4aa8baaf0c868eac93b962db03f0e06942 Mon Sep 17 00:00:00 2001 From: Matt Date: Wed, 14 Aug 2024 15:19:31 +0100 Subject: [PATCH 04/14] Add template strip() just to be safe --- tests/test_tokenization_common.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_tokenization_common.py b/tests/test_tokenization_common.py index 6d60250f751153..3dc9d118789642 100644 --- a/tests/test_tokenization_common.py +++ b/tests/test_tokenization_common.py @@ -1161,7 +1161,7 @@ def test_jinja_extensions(self): {%- if loop.first %} {%- break %} {%- endif %} - {%- endfor %}""" + {%- endfor %}""".strip() dummy_conversation = [ {"role": "system", "content": "1"}, @@ -1169,7 +1169,7 @@ def test_jinja_extensions(self): {"role": "assistant", "content": "3"}, ] - strftime_template = """{{- strftime("%Y-%m-%d") }}""" + strftime_template = """{{- strftime("%Y-%m-%d") }}""".strip() tokenizers = self.get_tokenizers() for tokenizer in tokenizers: From df2fa1e978db2d1a645a98ae18287c6cd01ebadb Mon Sep 17 00:00:00 2001 From: Matt Date: Wed, 14 Aug 2024 15:29:38 +0100 Subject: [PATCH 05/14] Remove the do extension to make porting easier, and also because it's the least useful --- src/transformers/tokenization_utils_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index eddbe691201638..15f05039513ef1 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -1976,7 +1976,7 @@ def activate_tracker(self, rendered_blocks: List[int], generation_indices: List[ self._generation_indices = None jinja_env = ImmutableSandboxedEnvironment( - trim_blocks=True, lstrip_blocks=True, extensions=[AssistantTracker, jinja2.ext.do, jinja2.ext.loopcontrols] + trim_blocks=True, lstrip_blocks=True, extensions=[AssistantTracker, jinja2.ext.loopcontrols] ) jinja_env.filters["tojson"] = tojson jinja_env.globals["raise_exception"] = raise_exception From adf0e55fd23efa8fd3ebed7cb316dec47e0a9fb4 Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 15 Aug 2024 13:29:53 +0100 Subject: [PATCH 06/14] Rename test --- tests/test_tokenization_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_tokenization_common.py b/tests/test_tokenization_common.py index 3dc9d118789642..f8b4482048d74e 100644 --- a/tests/test_tokenization_common.py +++ b/tests/test_tokenization_common.py @@ -1154,7 +1154,7 @@ def test_chat_template_batched(self): ) # Check that no error raised @require_jinja - def test_jinja_extensions(self): + def test_jinja_extensions_are_enabled(self): break_template = """ {%- for message in messages %} {{- message.role + " " + message.content }} From 9e1c9b1391be7fc232b2684b3ddaa1392460cee8 Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 15 Aug 2024 13:30:27 +0100 Subject: [PATCH 07/14] strftime -> strftime_now --- src/transformers/tokenization_utils_base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 15f05039513ef1..e2ef9345cbd661 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -1930,7 +1930,7 @@ def tojson(x, ensure_ascii=False, indent=None, separators=None, sort_keys=False) # We also expose some options like custom indents and separators return json.dumps(x, ensure_ascii=ensure_ascii, indent=indent, separators=separators, sort_keys=sort_keys) - def strftime(format): + def strftime_now(format): return datetime.now().strftime(format) class AssistantTracker(Extension): @@ -1980,7 +1980,7 @@ def activate_tracker(self, rendered_blocks: List[int], generation_indices: List[ ) jinja_env.filters["tojson"] = tojson jinja_env.globals["raise_exception"] = raise_exception - jinja_env.globals["strftime"] = strftime + jinja_env.globals["strftime_now"] = strftime_now return jinja_env.from_string(chat_template) def get_chat_template(self, chat_template: Optional[str] = None, tools: Optional[List[Dict]] = None) -> str: From 66119c01aef3a07b13bb6cc30211319584a09db7 Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 15 Aug 2024 13:34:01 +0100 Subject: [PATCH 08/14] Split test --- tests/test_tokenization_common.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/tests/test_tokenization_common.py b/tests/test_tokenization_common.py index f8b4482048d74e..bb4be88e031586 100644 --- a/tests/test_tokenization_common.py +++ b/tests/test_tokenization_common.py @@ -1154,7 +1154,7 @@ def test_chat_template_batched(self): ) # Check that no error raised @require_jinja - def test_jinja_extensions_are_enabled(self): + def test_jinja_loopcontrols(self): break_template = """ {%- for message in messages %} {{- message.role + " " + message.content }} @@ -1169,8 +1169,6 @@ def test_jinja_extensions_are_enabled(self): {"role": "assistant", "content": "3"}, ] - strftime_template = """{{- strftime("%Y-%m-%d") }}""".strip() - tokenizers = self.get_tokenizers() for tokenizer in tokenizers: with self.subTest(f"{tokenizer.__class__.__name__}"): @@ -1179,9 +1177,23 @@ def test_jinja_extensions_are_enabled(self): ) self.assertEqual(break_output, "system 1") # Loop should break after first iter + @require_jinja + def test_jinja_strftime(self): + strftime_template = """{{- strftime("%Y-%m-%d") }}""".strip() + + dummy_conversation = [ + {"role": "system", "content": "1"}, + {"role": "user", "content": "2"}, + {"role": "assistant", "content": "3"}, + ] + + tokenizers = self.get_tokenizers() + for tokenizer in tokenizers: + with self.subTest(f"{tokenizer.__class__.__name__}"): strftime_output = tokenizer.apply_chat_template( dummy_conversation, chat_template=strftime_template, tokenize=False ) + # Assert that we get a date formatted as expected self.assertEqual(len(strftime_output), 10) self.assertEqual(len(strftime_output.split("-")), 3) From 97505afb1928d3b3d13dab391075108e5cee2bb1 Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 15 Aug 2024 13:43:55 +0100 Subject: [PATCH 09/14] Update test to use strftime_now --- tests/test_tokenization_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_tokenization_common.py b/tests/test_tokenization_common.py index bb4be88e031586..f1bcfe3929be47 100644 --- a/tests/test_tokenization_common.py +++ b/tests/test_tokenization_common.py @@ -1179,7 +1179,7 @@ def test_jinja_loopcontrols(self): @require_jinja def test_jinja_strftime(self): - strftime_template = """{{- strftime("%Y-%m-%d") }}""".strip() + strftime_template = """{{- strftime_now("%Y-%m-%d") }}""".strip() dummy_conversation = [ {"role": "system", "content": "1"}, From 56ff91cfa3bddd3061546ecfd4e8a3783f28bab5 Mon Sep 17 00:00:00 2001 From: Matt Date: Wed, 21 Aug 2024 16:57:37 +0100 Subject: [PATCH 10/14] Refactor everything out into chat_template_utils --- src/transformers/tokenization_utils_base.py | 101 +----------------- src/transformers/utils/chat_template_utils.py | 100 ++++++++++++++++- 2 files changed, 102 insertions(+), 99 deletions(-) diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index e2ef9345cbd661..ca61fcf6764215 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -27,8 +27,6 @@ from collections.abc import Mapping, Sized from contextlib import contextmanager from dataclasses import dataclass -from datetime import datetime -from functools import lru_cache from inspect import isfunction from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union @@ -66,6 +64,7 @@ requires_backends, to_py_obj, ) +from .utils.chat_template_utils import _compile_jinja_template, _render_with_assistant_indices if TYPE_CHECKING: @@ -1792,7 +1791,7 @@ def apply_chat_template( ) # Compilation function uses a cache to avoid recompiling the same template - compiled_template = self._compile_jinja_template(chat_template) + compiled_template = _compile_jinja_template(chat_template) if isinstance(conversation, (list, tuple)) and ( isinstance(conversation[0], (list, tuple)) or hasattr(conversation[0], "messages") @@ -1832,7 +1831,7 @@ def apply_chat_template( # Indicates it's a Conversation object chat = chat.messages if return_assistant_tokens_mask: - rendered_chat, generation_indices = self._render_with_assistant_indices( + rendered_chat, generation_indices = _render_with_assistant_indices( compiled_template=compiled_template, messages=chat, tools=tool_schemas, @@ -1889,100 +1888,6 @@ def apply_chat_template( else: return rendered - def _render_with_assistant_indices( - self, compiled_template, messages, tools, documents, add_generation_prompt, **template_kwargs - ): - rendered_blocks = [] - generation_indices = [] - with compiled_template.environment.activate_tracker(rendered_blocks, generation_indices): - for block in compiled_template.generate( - messages=messages, - tools=tools, - documents=documents, - add_generation_prompt=add_generation_prompt, - **template_kwargs, - ): - rendered_blocks.append(block) - rendered_chat = "".join(rendered_blocks) - return rendered_chat, generation_indices - - @lru_cache - def _compile_jinja_template(self, chat_template): - try: - import jinja2 - from jinja2 import nodes - from jinja2.exceptions import TemplateError - from jinja2.ext import Extension - from jinja2.sandbox import ImmutableSandboxedEnvironment - except ImportError: - raise ImportError("apply_chat_template requires jinja2 to be installed.") - - if version.parse(jinja2.__version__) < version.parse("3.1.0"): - raise ImportError( - "apply_chat_template requires jinja2>=3.1.0 to be installed. Your version is " f"{jinja2.__version__}." - ) - - def raise_exception(message): - raise TemplateError(message) - - def tojson(x, ensure_ascii=False, indent=None, separators=None, sort_keys=False): - # We override the built-in tojson filter because Jinja's default filter escapes HTML characters - # We also expose some options like custom indents and separators - return json.dumps(x, ensure_ascii=ensure_ascii, indent=indent, separators=separators, sort_keys=sort_keys) - - def strftime_now(format): - return datetime.now().strftime(format) - - class AssistantTracker(Extension): - # This extension is used to track the indices of assistant-generated tokens in the rendered chat - tags = {"generation"} - - def __init__(self, environment: ImmutableSandboxedEnvironment): - # The class is only initiated by jinja. - super().__init__(environment) - environment.extend(activate_tracker=self.activate_tracker) - self._rendered_blocks = None - self._generation_indices = None - - def parse(self, parser: jinja2.parser.Parser) -> jinja2.nodes.CallBlock: - lineno = next(parser.stream).lineno - body = parser.parse_statements(["name:endgeneration"], drop_needle=True) - return nodes.CallBlock(self.call_method("_generation_support"), [], [], body).set_lineno(lineno) - - @jinja2.pass_eval_context - def _generation_support(self, context: jinja2.nodes.EvalContext, caller: jinja2.runtime.Macro) -> str: - rv = caller() - if self.is_active(): - # Only track generation indices if the tracker is active - start_index = len("".join(self._rendered_blocks)) - end_index = start_index + len(rv) - self._generation_indices.append((start_index, end_index)) - return rv - - def is_active(self) -> bool: - return self._rendered_blocks or self._generation_indices - - @contextmanager - def activate_tracker(self, rendered_blocks: List[int], generation_indices: List[int]): - try: - if self.is_active(): - raise ValueError("AssistantTracker should not be reused before closed") - self._rendered_blocks = rendered_blocks - self._generation_indices = generation_indices - - yield - finally: - self._rendered_blocks = None - self._generation_indices = None - - jinja_env = ImmutableSandboxedEnvironment( - trim_blocks=True, lstrip_blocks=True, extensions=[AssistantTracker, jinja2.ext.loopcontrols] - ) - jinja_env.filters["tojson"] = tojson - jinja_env.globals["raise_exception"] = raise_exception - jinja_env.globals["strftime_now"] = strftime_now - return jinja_env.from_string(chat_template) - def get_chat_template(self, chat_template: Optional[str] = None, tools: Optional[List[Dict]] = None) -> str: """ Retrieve the chat template string used for tokenizing chat messages. This template is used diff --git a/src/transformers/utils/chat_template_utils.py b/src/transformers/utils/chat_template_utils.py index 078a307b1c33d6..603bb151264605 100644 --- a/src/transformers/utils/chat_template_utils.py +++ b/src/transformers/utils/chat_template_utils.py @@ -15,7 +15,17 @@ import inspect import json import re -from typing import Any, Callable, Dict, Optional, Tuple, Union, get_args, get_origin, get_type_hints +from contextlib import contextmanager +from datetime import datetime +from functools import lru_cache +from typing import Any, Callable, Dict, List, Optional, Tuple, Union, get_args, get_origin, get_type_hints + +import jinja2 +from jinja2 import nodes +from jinja2.exceptions import TemplateError +from jinja2.ext import Extension +from jinja2.sandbox import ImmutableSandboxedEnvironment +from packaging import version BASIC_TYPES = (int, float, str, bool, Any, type(None), ...) @@ -314,3 +324,91 @@ def get_json_schema(func: Callable) -> Dict: if return_dict is not None: output["return"] = return_dict return {"type": "function", "function": output} + + +class AssistantTracker(Extension): + # This extension is used to track the indices of assistant-generated tokens in the rendered chat + tags = {"generation"} + + def __init__(self, environment: ImmutableSandboxedEnvironment): + # The class is only initiated by jinja. + super().__init__(environment) + environment.extend(activate_tracker=self.activate_tracker) + self._rendered_blocks = None + self._generation_indices = None + + def parse(self, parser: jinja2.parser.Parser) -> jinja2.nodes.CallBlock: + lineno = next(parser.stream).lineno + body = parser.parse_statements(["name:endgeneration"], drop_needle=True) + return nodes.CallBlock(self.call_method("_generation_support"), [], [], body).set_lineno(lineno) + + @jinja2.pass_eval_context + def _generation_support(self, context: jinja2.nodes.EvalContext, caller: jinja2.runtime.Macro) -> str: + rv = caller() + if self.is_active(): + # Only track generation indices if the tracker is active + start_index = len("".join(self._rendered_blocks)) + end_index = start_index + len(rv) + self._generation_indices.append((start_index, end_index)) + return rv + + def is_active(self) -> bool: + return self._rendered_blocks or self._generation_indices + + @contextmanager + def activate_tracker(self, rendered_blocks: List[int], generation_indices: List[int]): + try: + if self.is_active(): + raise ValueError("AssistantTracker should not be reused before closed") + self._rendered_blocks = rendered_blocks + self._generation_indices = generation_indices + + yield + finally: + self._rendered_blocks = None + self._generation_indices = None + + +def _render_with_assistant_indices( + compiled_template, messages, tools, documents, add_generation_prompt, **template_kwargs +): + rendered_blocks = [] + generation_indices = [] + with compiled_template.environment.activate_tracker(rendered_blocks, generation_indices): + for block in compiled_template.generate( + messages=messages, + tools=tools, + documents=documents, + add_generation_prompt=add_generation_prompt, + **template_kwargs, + ): + rendered_blocks.append(block) + rendered_chat = "".join(rendered_blocks) + return rendered_chat, generation_indices + + +@lru_cache +def _compile_jinja_template(chat_template): + if version.parse(jinja2.__version__) < version.parse("3.1.0"): + raise ImportError( + "apply_chat_template requires jinja2>=3.1.0 to be installed. Your version is " f"{jinja2.__version__}." + ) + + def raise_exception(message): + raise TemplateError(message) + + def tojson(x, ensure_ascii=False, indent=None, separators=None, sort_keys=False): + # We override the built-in tojson filter because Jinja's default filter escapes HTML characters + # We also expose some options like custom indents and separators + return json.dumps(x, ensure_ascii=ensure_ascii, indent=indent, separators=separators, sort_keys=sort_keys) + + def strftime_now(format): + return datetime.now().strftime(format) + + jinja_env = ImmutableSandboxedEnvironment( + trim_blocks=True, lstrip_blocks=True, extensions=[AssistantTracker, jinja2.ext.loopcontrols] + ) + jinja_env.filters["tojson"] = tojson + jinja_env.globals["raise_exception"] = raise_exception + jinja_env.globals["strftime_now"] = strftime_now + return jinja_env.from_string(chat_template) From 34f0cbe9adcf2fbabd9a171df210363d1163770c Mon Sep 17 00:00:00 2001 From: Matt Date: Wed, 21 Aug 2024 17:09:48 +0100 Subject: [PATCH 11/14] Refactor everything out into chat_template_utils --- src/transformers/utils/chat_template_utils.py | 106 +++++++++--------- 1 file changed, 56 insertions(+), 50 deletions(-) diff --git a/src/transformers/utils/chat_template_utils.py b/src/transformers/utils/chat_template_utils.py index 603bb151264605..463dd1ff214769 100644 --- a/src/transformers/utils/chat_template_utils.py +++ b/src/transformers/utils/chat_template_utils.py @@ -20,13 +20,20 @@ from functools import lru_cache from typing import Any, Callable, Dict, List, Optional, Tuple, Union, get_args, get_origin, get_type_hints -import jinja2 -from jinja2 import nodes -from jinja2.exceptions import TemplateError -from jinja2.ext import Extension -from jinja2.sandbox import ImmutableSandboxedEnvironment from packaging import version +from .import_utils import is_jinja_available + + +if is_jinja_available(): + import jinja2 + # from jinja2 import nodes + # from jinja2.exceptions import TemplateError + # from jinja2.ext import Extension + # from jinja2.sandbox import ImmutableSandboxedEnvironment +else: + jinja2 = None + BASIC_TYPES = (int, float, str, bool, Any, type(None), ...) # Extracts the initial segment of the docstring, containing the function description @@ -326,49 +333,6 @@ def get_json_schema(func: Callable) -> Dict: return {"type": "function", "function": output} -class AssistantTracker(Extension): - # This extension is used to track the indices of assistant-generated tokens in the rendered chat - tags = {"generation"} - - def __init__(self, environment: ImmutableSandboxedEnvironment): - # The class is only initiated by jinja. - super().__init__(environment) - environment.extend(activate_tracker=self.activate_tracker) - self._rendered_blocks = None - self._generation_indices = None - - def parse(self, parser: jinja2.parser.Parser) -> jinja2.nodes.CallBlock: - lineno = next(parser.stream).lineno - body = parser.parse_statements(["name:endgeneration"], drop_needle=True) - return nodes.CallBlock(self.call_method("_generation_support"), [], [], body).set_lineno(lineno) - - @jinja2.pass_eval_context - def _generation_support(self, context: jinja2.nodes.EvalContext, caller: jinja2.runtime.Macro) -> str: - rv = caller() - if self.is_active(): - # Only track generation indices if the tracker is active - start_index = len("".join(self._rendered_blocks)) - end_index = start_index + len(rv) - self._generation_indices.append((start_index, end_index)) - return rv - - def is_active(self) -> bool: - return self._rendered_blocks or self._generation_indices - - @contextmanager - def activate_tracker(self, rendered_blocks: List[int], generation_indices: List[int]): - try: - if self.is_active(): - raise ValueError("AssistantTracker should not be reused before closed") - self._rendered_blocks = rendered_blocks - self._generation_indices = generation_indices - - yield - finally: - self._rendered_blocks = None - self._generation_indices = None - - def _render_with_assistant_indices( compiled_template, messages, tools, documents, add_generation_prompt, **template_kwargs ): @@ -389,13 +353,55 @@ def _render_with_assistant_indices( @lru_cache def _compile_jinja_template(chat_template): + class AssistantTracker(jinja2.ext.Extension): + # This extension is used to track the indices of assistant-generated tokens in the rendered chat + tags = {"generation"} + + def __init__(self, environment: jinja2.sandbox.ImmutableSandboxedEnvironment): + # The class is only initiated by jinja. + super().__init__(environment) + environment.extend(activate_tracker=self.activate_tracker) + self._rendered_blocks = None + self._generation_indices = None + + def parse(self, parser: jinja2.parser.Parser) -> jinja2.nodes.CallBlock: + lineno = next(parser.stream).lineno + body = parser.parse_statements(["name:endgeneration"], drop_needle=True) + return jinja2.nodes.CallBlock(self.call_method("_generation_support"), [], [], body).set_lineno(lineno) + + @jinja2.pass_eval_context + def _generation_support(self, context: jinja2.nodes.EvalContext, caller: jinja2.runtime.Macro) -> str: + rv = caller() + if self.is_active(): + # Only track generation indices if the tracker is active + start_index = len("".join(self._rendered_blocks)) + end_index = start_index + len(rv) + self._generation_indices.append((start_index, end_index)) + return rv + + def is_active(self) -> bool: + return self._rendered_blocks or self._generation_indices + + @contextmanager + def activate_tracker(self, rendered_blocks: List[int], generation_indices: List[int]): + try: + if self.is_active(): + raise ValueError("AssistantTracker should not be reused before closed") + self._rendered_blocks = rendered_blocks + self._generation_indices = generation_indices + + yield + finally: + self._rendered_blocks = None + self._generation_indices = None + if version.parse(jinja2.__version__) < version.parse("3.1.0"): raise ImportError( "apply_chat_template requires jinja2>=3.1.0 to be installed. Your version is " f"{jinja2.__version__}." ) def raise_exception(message): - raise TemplateError(message) + raise jinja2.exceptions.TemplateError(message) def tojson(x, ensure_ascii=False, indent=None, separators=None, sort_keys=False): # We override the built-in tojson filter because Jinja's default filter escapes HTML characters @@ -405,7 +411,7 @@ def tojson(x, ensure_ascii=False, indent=None, separators=None, sort_keys=False) def strftime_now(format): return datetime.now().strftime(format) - jinja_env = ImmutableSandboxedEnvironment( + jinja_env = jinja2.sandbox.ImmutableSandboxedEnvironment( trim_blocks=True, lstrip_blocks=True, extensions=[AssistantTracker, jinja2.ext.loopcontrols] ) jinja_env.filters["tojson"] = tojson From d7edd37bfb828976bf1b1c150827fb36fa073dd4 Mon Sep 17 00:00:00 2001 From: Matt Date: Wed, 21 Aug 2024 17:18:10 +0100 Subject: [PATCH 12/14] Refactor everything out into chat_template_utils --- src/transformers/utils/chat_template_utils.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/transformers/utils/chat_template_utils.py b/src/transformers/utils/chat_template_utils.py index 463dd1ff214769..4f0f6b652c9c66 100644 --- a/src/transformers/utils/chat_template_utils.py +++ b/src/transformers/utils/chat_template_utils.py @@ -27,12 +27,10 @@ if is_jinja_available(): import jinja2 - # from jinja2 import nodes - # from jinja2.exceptions import TemplateError - # from jinja2.ext import Extension - # from jinja2.sandbox import ImmutableSandboxedEnvironment + from jinja2.ext import Extension else: jinja2 = None + Extension = None BASIC_TYPES = (int, float, str, bool, Any, type(None), ...) From c4abca43a2c63d69fac9df33b24959ec6eeca3d7 Mon Sep 17 00:00:00 2001 From: Matt Date: Wed, 21 Aug 2024 17:18:37 +0100 Subject: [PATCH 13/14] Refactor everything out into chat_template_utils --- src/transformers/utils/chat_template_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/utils/chat_template_utils.py b/src/transformers/utils/chat_template_utils.py index 4f0f6b652c9c66..4ef9be5193ca5c 100644 --- a/src/transformers/utils/chat_template_utils.py +++ b/src/transformers/utils/chat_template_utils.py @@ -27,7 +27,7 @@ if is_jinja_available(): import jinja2 - from jinja2.ext import Extension + from jinja2.ext import Extension # jinja2.ext.Extension fails because the module is not correctly initialized else: jinja2 = None Extension = None From 9240137897096698f5292c7dd38d0651c8a33dc8 Mon Sep 17 00:00:00 2001 From: Matt Date: Wed, 21 Aug 2024 17:24:02 +0100 Subject: [PATCH 14/14] Refactor everything out into chat_template_utils --- src/transformers/utils/chat_template_utils.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/transformers/utils/chat_template_utils.py b/src/transformers/utils/chat_template_utils.py index 4ef9be5193ca5c..aabaf4a3666506 100644 --- a/src/transformers/utils/chat_template_utils.py +++ b/src/transformers/utils/chat_template_utils.py @@ -27,10 +27,10 @@ if is_jinja_available(): import jinja2 - from jinja2.ext import Extension # jinja2.ext.Extension fails because the module is not correctly initialized + from jinja2.ext import Extension + from jinja2.sandbox import ImmutableSandboxedEnvironment else: jinja2 = None - Extension = None BASIC_TYPES = (int, float, str, bool, Any, type(None), ...) @@ -351,11 +351,11 @@ def _render_with_assistant_indices( @lru_cache def _compile_jinja_template(chat_template): - class AssistantTracker(jinja2.ext.Extension): + class AssistantTracker(Extension): # This extension is used to track the indices of assistant-generated tokens in the rendered chat tags = {"generation"} - def __init__(self, environment: jinja2.sandbox.ImmutableSandboxedEnvironment): + def __init__(self, environment: ImmutableSandboxedEnvironment): # The class is only initiated by jinja. super().__init__(environment) environment.extend(activate_tracker=self.activate_tracker) @@ -409,7 +409,7 @@ def tojson(x, ensure_ascii=False, indent=None, separators=None, sort_keys=False) def strftime_now(format): return datetime.now().strftime(format) - jinja_env = jinja2.sandbox.ImmutableSandboxedEnvironment( + jinja_env = ImmutableSandboxedEnvironment( trim_blocks=True, lstrip_blocks=True, extensions=[AssistantTracker, jinja2.ext.loopcontrols] ) jinja_env.filters["tojson"] = tojson