Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable some Jinja extensions and add datetime capabilities #32684

Merged
merged 14 commits into from
Aug 23, 2024
Merged
94 changes: 3 additions & 91 deletions src/transformers/tokenization_utils_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from collections.abc import Mapping, Sized
from contextlib import contextmanager
from dataclasses import dataclass
from functools import lru_cache
from inspect import isfunction
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union

Expand Down Expand Up @@ -65,6 +64,7 @@
requires_backends,
to_py_obj,
)
from .utils.chat_template_utils import _compile_jinja_template, _render_with_assistant_indices


if TYPE_CHECKING:
Expand Down Expand Up @@ -1791,7 +1791,7 @@ def apply_chat_template(
)

# Compilation function uses a cache to avoid recompiling the same template
compiled_template = self._compile_jinja_template(chat_template)
compiled_template = _compile_jinja_template(chat_template)

if isinstance(conversation, (list, tuple)) and (
isinstance(conversation[0], (list, tuple)) or hasattr(conversation[0], "messages")
Expand Down Expand Up @@ -1831,7 +1831,7 @@ def apply_chat_template(
# Indicates it's a Conversation object
chat = chat.messages
if return_assistant_tokens_mask:
rendered_chat, generation_indices = self._render_with_assistant_indices(
rendered_chat, generation_indices = _render_with_assistant_indices(
compiled_template=compiled_template,
messages=chat,
tools=tool_schemas,
Expand Down Expand Up @@ -1888,94 +1888,6 @@ def apply_chat_template(
else:
return rendered

def _render_with_assistant_indices(
self, compiled_template, messages, tools, documents, add_generation_prompt, **template_kwargs
):
rendered_blocks = []
generation_indices = []
with compiled_template.environment.activate_tracker(rendered_blocks, generation_indices):
for block in compiled_template.generate(
messages=messages,
tools=tools,
documents=documents,
add_generation_prompt=add_generation_prompt,
**template_kwargs,
):
rendered_blocks.append(block)
rendered_chat = "".join(rendered_blocks)
return rendered_chat, generation_indices

@lru_cache
def _compile_jinja_template(self, chat_template):
try:
import jinja2
from jinja2 import nodes
from jinja2.exceptions import TemplateError
from jinja2.ext import Extension
from jinja2.sandbox import ImmutableSandboxedEnvironment
except ImportError:
raise ImportError("apply_chat_template requires jinja2 to be installed.")

if version.parse(jinja2.__version__) < version.parse("3.1.0"):
raise ImportError(
"apply_chat_template requires jinja2>=3.1.0 to be installed. Your version is " f"{jinja2.__version__}."
)

def raise_exception(message):
raise TemplateError(message)

def tojson(x, ensure_ascii=False, indent=None, separators=None, sort_keys=False):
# We override the built-in tojson filter because Jinja's default filter escapes HTML characters
# We also expose some options like custom indents and separators
return json.dumps(x, ensure_ascii=ensure_ascii, indent=indent, separators=separators, sort_keys=sort_keys)

class AssistantTracker(Extension):
# This extension is used to track the indices of assistant-generated tokens in the rendered chat
tags = {"generation"}

def __init__(self, environment: ImmutableSandboxedEnvironment):
# The class is only initiated by jinja.
super().__init__(environment)
environment.extend(activate_tracker=self.activate_tracker)
self._rendered_blocks = None
self._generation_indices = None

def parse(self, parser: jinja2.parser.Parser) -> jinja2.nodes.CallBlock:
lineno = next(parser.stream).lineno
body = parser.parse_statements(["name:endgeneration"], drop_needle=True)
return nodes.CallBlock(self.call_method("_generation_support"), [], [], body).set_lineno(lineno)

@jinja2.pass_eval_context
def _generation_support(self, context: jinja2.nodes.EvalContext, caller: jinja2.runtime.Macro) -> str:
rv = caller()
if self.is_active():
# Only track generation indices if the tracker is active
start_index = len("".join(self._rendered_blocks))
end_index = start_index + len(rv)
self._generation_indices.append((start_index, end_index))
return rv

def is_active(self) -> bool:
return self._rendered_blocks or self._generation_indices

@contextmanager
def activate_tracker(self, rendered_blocks: List[int], generation_indices: List[int]):
try:
if self.is_active():
raise ValueError("AssistantTracker should not be reused before closed")
self._rendered_blocks = rendered_blocks
self._generation_indices = generation_indices

yield
finally:
self._rendered_blocks = None
self._generation_indices = None

jinja_env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True, extensions=[AssistantTracker])
jinja_env.filters["tojson"] = tojson
jinja_env.globals["raise_exception"] = raise_exception
return jinja_env.from_string(chat_template)

def get_chat_template(self, chat_template: Optional[str] = None, tools: Optional[List[Dict]] = None) -> str:
"""
Retrieve the chat template string used for tokenizing chat messages. This template is used
Expand Down
104 changes: 103 additions & 1 deletion src/transformers/utils/chat_template_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,22 @@
import inspect
import json
import re
from typing import Any, Callable, Dict, Optional, Tuple, Union, get_args, get_origin, get_type_hints
from contextlib import contextmanager
from datetime import datetime
from functools import lru_cache
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, get_args, get_origin, get_type_hints

from packaging import version

from .import_utils import is_jinja_available


if is_jinja_available():
import jinja2
from jinja2.ext import Extension
from jinja2.sandbox import ImmutableSandboxedEnvironment
else:
jinja2 = None


BASIC_TYPES = (int, float, str, bool, Any, type(None), ...)
Expand Down Expand Up @@ -314,3 +329,90 @@ def get_json_schema(func: Callable) -> Dict:
if return_dict is not None:
output["return"] = return_dict
return {"type": "function", "function": output}


def _render_with_assistant_indices(
compiled_template, messages, tools, documents, add_generation_prompt, **template_kwargs
):
rendered_blocks = []
generation_indices = []
with compiled_template.environment.activate_tracker(rendered_blocks, generation_indices):
for block in compiled_template.generate(
messages=messages,
tools=tools,
documents=documents,
add_generation_prompt=add_generation_prompt,
**template_kwargs,
):
rendered_blocks.append(block)
rendered_chat = "".join(rendered_blocks)
return rendered_chat, generation_indices


@lru_cache
def _compile_jinja_template(chat_template):
class AssistantTracker(Extension):
# This extension is used to track the indices of assistant-generated tokens in the rendered chat
tags = {"generation"}

def __init__(self, environment: ImmutableSandboxedEnvironment):
# The class is only initiated by jinja.
super().__init__(environment)
environment.extend(activate_tracker=self.activate_tracker)
self._rendered_blocks = None
self._generation_indices = None

def parse(self, parser: jinja2.parser.Parser) -> jinja2.nodes.CallBlock:
lineno = next(parser.stream).lineno
body = parser.parse_statements(["name:endgeneration"], drop_needle=True)
return jinja2.nodes.CallBlock(self.call_method("_generation_support"), [], [], body).set_lineno(lineno)

@jinja2.pass_eval_context
def _generation_support(self, context: jinja2.nodes.EvalContext, caller: jinja2.runtime.Macro) -> str:
rv = caller()
if self.is_active():
# Only track generation indices if the tracker is active
start_index = len("".join(self._rendered_blocks))
end_index = start_index + len(rv)
self._generation_indices.append((start_index, end_index))
return rv

def is_active(self) -> bool:
return self._rendered_blocks or self._generation_indices

@contextmanager
def activate_tracker(self, rendered_blocks: List[int], generation_indices: List[int]):
try:
if self.is_active():
raise ValueError("AssistantTracker should not be reused before closed")
self._rendered_blocks = rendered_blocks
self._generation_indices = generation_indices

yield
finally:
self._rendered_blocks = None
self._generation_indices = None

if version.parse(jinja2.__version__) < version.parse("3.1.0"):
raise ImportError(
"apply_chat_template requires jinja2>=3.1.0 to be installed. Your version is " f"{jinja2.__version__}."
)

def raise_exception(message):
raise jinja2.exceptions.TemplateError(message)

def tojson(x, ensure_ascii=False, indent=None, separators=None, sort_keys=False):
# We override the built-in tojson filter because Jinja's default filter escapes HTML characters
# We also expose some options like custom indents and separators
return json.dumps(x, ensure_ascii=ensure_ascii, indent=indent, separators=separators, sort_keys=sort_keys)

def strftime_now(format):
return datetime.now().strftime(format)

jinja_env = ImmutableSandboxedEnvironment(
trim_blocks=True, lstrip_blocks=True, extensions=[AssistantTracker, jinja2.ext.loopcontrols]
)
jinja_env.filters["tojson"] = tojson
jinja_env.globals["raise_exception"] = raise_exception
jinja_env.globals["strftime_now"] = strftime_now
return jinja_env.from_string(chat_template)
45 changes: 45 additions & 0 deletions tests/test_tokenization_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1153,6 +1153,51 @@ def test_chat_template_batched(self):
dummy_conversations, chat_template=dummy_template, tokenize=True
) # Check that no error raised

@require_jinja
def test_jinja_loopcontrols(self):
break_template = """
{%- for message in messages %}
{{- message.role + " " + message.content }}
{%- if loop.first %}
{%- break %}
{%- endif %}
{%- endfor %}""".strip()

dummy_conversation = [
{"role": "system", "content": "1"},
{"role": "user", "content": "2"},
{"role": "assistant", "content": "3"},
]

tokenizers = self.get_tokenizers()
for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"):
break_output = tokenizer.apply_chat_template(
dummy_conversation, chat_template=break_template, tokenize=False
)
self.assertEqual(break_output, "system 1") # Loop should break after first iter

@require_jinja
def test_jinja_strftime(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice - far clearer and cleaner tests!

strftime_template = """{{- strftime_now("%Y-%m-%d") }}""".strip()

dummy_conversation = [
{"role": "system", "content": "1"},
{"role": "user", "content": "2"},
{"role": "assistant", "content": "3"},
]

tokenizers = self.get_tokenizers()
for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"):
strftime_output = tokenizer.apply_chat_template(
dummy_conversation, chat_template=strftime_template, tokenize=False
)

# Assert that we get a date formatted as expected
self.assertEqual(len(strftime_output), 10)
self.assertEqual(len(strftime_output.split("-")), 3)

@require_jinja
def test_chat_template_return_assistant_tokens_mask(self):
dummy_template = (
Expand Down
Loading