Skip to content

Commit

Permalink
Enable some Jinja extensions and add datetime capabilities (huggingfa…
Browse files Browse the repository at this point in the history
…ce#32684)

* Add new Jinja features:

- Do extension
- Break/continue in loops
- Call strftime to get current datetime in any format

* Add new Jinja features:

- Do extension
- Break/continue in loops
- Call strftime to get current datetime in any format

* Fix strftime template

* Add template strip() just to be safe

* Remove the do extension to make porting easier, and also because it's the least useful

* Rename test

* strftime -> strftime_now

* Split test

* Update test to use strftime_now

* Refactor everything out into chat_template_utils

* Refactor everything out into chat_template_utils

* Refactor everything out into chat_template_utils

* Refactor everything out into chat_template_utils

* Refactor everything out into chat_template_utils
  • Loading branch information
Rocketknight1 authored and zucchini-nlp committed Aug 30, 2024
1 parent 18f52a8 commit 01c8c31
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 92 deletions.
94 changes: 3 additions & 91 deletions src/transformers/tokenization_utils_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from collections.abc import Mapping, Sized
from contextlib import contextmanager
from dataclasses import dataclass
from functools import lru_cache
from inspect import isfunction
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union

Expand Down Expand Up @@ -65,6 +64,7 @@
requires_backends,
to_py_obj,
)
from .utils.chat_template_utils import _compile_jinja_template, _render_with_assistant_indices


if TYPE_CHECKING:
Expand Down Expand Up @@ -1791,7 +1791,7 @@ def apply_chat_template(
)

# Compilation function uses a cache to avoid recompiling the same template
compiled_template = self._compile_jinja_template(chat_template)
compiled_template = _compile_jinja_template(chat_template)

if isinstance(conversation, (list, tuple)) and (
isinstance(conversation[0], (list, tuple)) or hasattr(conversation[0], "messages")
Expand Down Expand Up @@ -1831,7 +1831,7 @@ def apply_chat_template(
# Indicates it's a Conversation object
chat = chat.messages
if return_assistant_tokens_mask:
rendered_chat, generation_indices = self._render_with_assistant_indices(
rendered_chat, generation_indices = _render_with_assistant_indices(
compiled_template=compiled_template,
messages=chat,
tools=tool_schemas,
Expand Down Expand Up @@ -1888,94 +1888,6 @@ def apply_chat_template(
else:
return rendered

def _render_with_assistant_indices(
self, compiled_template, messages, tools, documents, add_generation_prompt, **template_kwargs
):
rendered_blocks = []
generation_indices = []
with compiled_template.environment.activate_tracker(rendered_blocks, generation_indices):
for block in compiled_template.generate(
messages=messages,
tools=tools,
documents=documents,
add_generation_prompt=add_generation_prompt,
**template_kwargs,
):
rendered_blocks.append(block)
rendered_chat = "".join(rendered_blocks)
return rendered_chat, generation_indices

@lru_cache
def _compile_jinja_template(self, chat_template):
try:
import jinja2
from jinja2 import nodes
from jinja2.exceptions import TemplateError
from jinja2.ext import Extension
from jinja2.sandbox import ImmutableSandboxedEnvironment
except ImportError:
raise ImportError("apply_chat_template requires jinja2 to be installed.")

if version.parse(jinja2.__version__) < version.parse("3.1.0"):
raise ImportError(
"apply_chat_template requires jinja2>=3.1.0 to be installed. Your version is " f"{jinja2.__version__}."
)

def raise_exception(message):
raise TemplateError(message)

def tojson(x, ensure_ascii=False, indent=None, separators=None, sort_keys=False):
# We override the built-in tojson filter because Jinja's default filter escapes HTML characters
# We also expose some options like custom indents and separators
return json.dumps(x, ensure_ascii=ensure_ascii, indent=indent, separators=separators, sort_keys=sort_keys)

class AssistantTracker(Extension):
# This extension is used to track the indices of assistant-generated tokens in the rendered chat
tags = {"generation"}

def __init__(self, environment: ImmutableSandboxedEnvironment):
# The class is only initiated by jinja.
super().__init__(environment)
environment.extend(activate_tracker=self.activate_tracker)
self._rendered_blocks = None
self._generation_indices = None

def parse(self, parser: jinja2.parser.Parser) -> jinja2.nodes.CallBlock:
lineno = next(parser.stream).lineno
body = parser.parse_statements(["name:endgeneration"], drop_needle=True)
return nodes.CallBlock(self.call_method("_generation_support"), [], [], body).set_lineno(lineno)

@jinja2.pass_eval_context
def _generation_support(self, context: jinja2.nodes.EvalContext, caller: jinja2.runtime.Macro) -> str:
rv = caller()
if self.is_active():
# Only track generation indices if the tracker is active
start_index = len("".join(self._rendered_blocks))
end_index = start_index + len(rv)
self._generation_indices.append((start_index, end_index))
return rv

def is_active(self) -> bool:
return self._rendered_blocks or self._generation_indices

@contextmanager
def activate_tracker(self, rendered_blocks: List[int], generation_indices: List[int]):
try:
if self.is_active():
raise ValueError("AssistantTracker should not be reused before closed")
self._rendered_blocks = rendered_blocks
self._generation_indices = generation_indices

yield
finally:
self._rendered_blocks = None
self._generation_indices = None

jinja_env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True, extensions=[AssistantTracker])
jinja_env.filters["tojson"] = tojson
jinja_env.globals["raise_exception"] = raise_exception
return jinja_env.from_string(chat_template)

def get_chat_template(self, chat_template: Optional[str] = None, tools: Optional[List[Dict]] = None) -> str:
"""
Retrieve the chat template string used for tokenizing chat messages. This template is used
Expand Down
104 changes: 103 additions & 1 deletion src/transformers/utils/chat_template_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,22 @@
import inspect
import json
import re
from typing import Any, Callable, Dict, Optional, Tuple, Union, get_args, get_origin, get_type_hints
from contextlib import contextmanager
from datetime import datetime
from functools import lru_cache
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, get_args, get_origin, get_type_hints

from packaging import version

from .import_utils import is_jinja_available


if is_jinja_available():
import jinja2
from jinja2.ext import Extension
from jinja2.sandbox import ImmutableSandboxedEnvironment
else:
jinja2 = None


BASIC_TYPES = (int, float, str, bool, Any, type(None), ...)
Expand Down Expand Up @@ -314,3 +329,90 @@ def get_json_schema(func: Callable) -> Dict:
if return_dict is not None:
output["return"] = return_dict
return {"type": "function", "function": output}


def _render_with_assistant_indices(
compiled_template, messages, tools, documents, add_generation_prompt, **template_kwargs
):
rendered_blocks = []
generation_indices = []
with compiled_template.environment.activate_tracker(rendered_blocks, generation_indices):
for block in compiled_template.generate(
messages=messages,
tools=tools,
documents=documents,
add_generation_prompt=add_generation_prompt,
**template_kwargs,
):
rendered_blocks.append(block)
rendered_chat = "".join(rendered_blocks)
return rendered_chat, generation_indices


@lru_cache
def _compile_jinja_template(chat_template):
class AssistantTracker(Extension):
# This extension is used to track the indices of assistant-generated tokens in the rendered chat
tags = {"generation"}

def __init__(self, environment: ImmutableSandboxedEnvironment):
# The class is only initiated by jinja.
super().__init__(environment)
environment.extend(activate_tracker=self.activate_tracker)
self._rendered_blocks = None
self._generation_indices = None

def parse(self, parser: jinja2.parser.Parser) -> jinja2.nodes.CallBlock:
lineno = next(parser.stream).lineno
body = parser.parse_statements(["name:endgeneration"], drop_needle=True)
return jinja2.nodes.CallBlock(self.call_method("_generation_support"), [], [], body).set_lineno(lineno)

@jinja2.pass_eval_context
def _generation_support(self, context: jinja2.nodes.EvalContext, caller: jinja2.runtime.Macro) -> str:
rv = caller()
if self.is_active():
# Only track generation indices if the tracker is active
start_index = len("".join(self._rendered_blocks))
end_index = start_index + len(rv)
self._generation_indices.append((start_index, end_index))
return rv

def is_active(self) -> bool:
return self._rendered_blocks or self._generation_indices

@contextmanager
def activate_tracker(self, rendered_blocks: List[int], generation_indices: List[int]):
try:
if self.is_active():
raise ValueError("AssistantTracker should not be reused before closed")
self._rendered_blocks = rendered_blocks
self._generation_indices = generation_indices

yield
finally:
self._rendered_blocks = None
self._generation_indices = None

if version.parse(jinja2.__version__) < version.parse("3.1.0"):
raise ImportError(
"apply_chat_template requires jinja2>=3.1.0 to be installed. Your version is " f"{jinja2.__version__}."
)

def raise_exception(message):
raise jinja2.exceptions.TemplateError(message)

def tojson(x, ensure_ascii=False, indent=None, separators=None, sort_keys=False):
# We override the built-in tojson filter because Jinja's default filter escapes HTML characters
# We also expose some options like custom indents and separators
return json.dumps(x, ensure_ascii=ensure_ascii, indent=indent, separators=separators, sort_keys=sort_keys)

def strftime_now(format):
return datetime.now().strftime(format)

jinja_env = ImmutableSandboxedEnvironment(
trim_blocks=True, lstrip_blocks=True, extensions=[AssistantTracker, jinja2.ext.loopcontrols]
)
jinja_env.filters["tojson"] = tojson
jinja_env.globals["raise_exception"] = raise_exception
jinja_env.globals["strftime_now"] = strftime_now
return jinja_env.from_string(chat_template)
45 changes: 45 additions & 0 deletions tests/test_tokenization_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1153,6 +1153,51 @@ def test_chat_template_batched(self):
dummy_conversations, chat_template=dummy_template, tokenize=True
) # Check that no error raised

@require_jinja
def test_jinja_loopcontrols(self):
break_template = """
{%- for message in messages %}
{{- message.role + " " + message.content }}
{%- if loop.first %}
{%- break %}
{%- endif %}
{%- endfor %}""".strip()

dummy_conversation = [
{"role": "system", "content": "1"},
{"role": "user", "content": "2"},
{"role": "assistant", "content": "3"},
]

tokenizers = self.get_tokenizers()
for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"):
break_output = tokenizer.apply_chat_template(
dummy_conversation, chat_template=break_template, tokenize=False
)
self.assertEqual(break_output, "system 1") # Loop should break after first iter

@require_jinja
def test_jinja_strftime(self):
strftime_template = """{{- strftime_now("%Y-%m-%d") }}""".strip()

dummy_conversation = [
{"role": "system", "content": "1"},
{"role": "user", "content": "2"},
{"role": "assistant", "content": "3"},
]

tokenizers = self.get_tokenizers()
for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"):
strftime_output = tokenizer.apply_chat_template(
dummy_conversation, chat_template=strftime_template, tokenize=False
)

# Assert that we get a date formatted as expected
self.assertEqual(len(strftime_output), 10)
self.assertEqual(len(strftime_output.split("-")), 3)

@require_jinja
def test_chat_template_return_assistant_tokens_mask(self):
dummy_template = (
Expand Down

0 comments on commit 01c8c31

Please sign in to comment.