Skip to content

Integration of Jinja2 Templating #875

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jan 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions docs/templates.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Templates

This document provides a comprehensive guide to the integration of Jinja2 templating into the `llama-cpp-python` project, with a focus on enhancing the chat functionality of the `llama-2` model.

## Introduction

- Brief explanation of the `llama-cpp-python` project's need for a templating system.
- Overview of the `llama-2` model's interaction with templating.

## Jinja2 Dependency Integration

- Rationale for choosing Jinja2 as the templating engine.
- Compatibility with Hugging Face's `transformers`.
- Desire for advanced templating features and simplicity.
- Detailed steps for adding `jinja2` to `pyproject.toml` for dependency management.

## Template Management Refactor

- Summary of the refactor and the motivation behind it.
- Description of the new chat handler selection logic:
1. Preference for a user-specified `chat_handler`.
2. Fallback to a user-specified `chat_format`.
3. Defaulting to a chat format from a `.gguf` file if available.
4. Utilizing the `llama2` default chat format as the final fallback.
- Ensuring backward compatibility throughout the refactor.

## Implementation Details

- In-depth look at the new `AutoChatFormatter` class.
- Example code snippets showing how to utilize the Jinja2 environment and templates.
- Guidance on how to provide custom templates or use defaults.

## Testing and Validation

- Outline of the testing strategy to ensure seamless integration.
- Steps for validating backward compatibility with existing implementations.

## Benefits and Impact

- Analysis of the expected benefits, including consistency, performance gains, and improved developer experience.
- Discussion of the potential impact on current users and contributors.

## Future Work

- Exploration of how templating can evolve within the project.
- Consideration of additional features or optimizations for the templating engine.
- Mechanisms for community feedback on the templating system.

## Conclusion

- Final thoughts on the integration of Jinja2 templating.
- Call to action for community involvement and feedback.
138 changes: 138 additions & 0 deletions llama_cpp/llama_jinja_format.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
"""
llama_cpp/llama_jinja_format.py
"""
import dataclasses
from typing import Any, Callable, Dict, List, Optional, Protocol, Union

import jinja2
from jinja2 import Template

# NOTE: We sacrifice readability for usability.
# It will fail to work as expected if we attempt to format it in a readable way.
llama2_template = """{% for message in messages %}{% if message['role'] == 'user' %}[INST] {{ message['content'] }} [/INST]\n{% elif message['role'] == 'assistant' %}{{ message['content'] }}\n{% elif message['role'] == 'system' %}<<SYS>> {{ message['content'] }} <</SYS>>\n{% endif %}{% endfor %}"""


class MetaSingleton(type):
"""
Metaclass for implementing the Singleton pattern.
"""

_instances = {}

def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
cls._instances[cls] = super(MetaSingleton, cls).__call__(*args, **kwargs)
return cls._instances[cls]


class Singleton(object, metaclass=MetaSingleton):
"""
Base class for implementing the Singleton pattern.
"""

def __init__(self):
super(Singleton, self).__init__()


@dataclasses.dataclass
class ChatFormatterResponse:
prompt: str
stop: Optional[Union[str, List[str]]] = None


# Base Chat Formatter Protocol
class ChatFormatterInterface(Protocol):
def __init__(self, template: Optional[object] = None):
...

def __call__(
self,
messages: List[Dict[str, str]],
**kwargs,
) -> ChatFormatterResponse:
...

@property
def template(self) -> str:
...


class AutoChatFormatter(ChatFormatterInterface):
def __init__(
self,
template: Optional[str] = None,
template_class: Optional[Template] = None,
):
if template is not None:
self._template = template
else:
self._template = llama2_template # default template

self._environment = jinja2.Environment(
loader=jinja2.BaseLoader(),
trim_blocks=True,
lstrip_blocks=True,
).from_string(
self._template,
template_class=template_class,
)

def __call__(
self,
messages: List[Dict[str, str]],
**kwargs: Any,
) -> ChatFormatterResponse:
formatted_sequence = self._environment.render(messages=messages, **kwargs)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

During my exploration on existing chat_template, I found out that, usually, they are using functions, such as raise_exception.

It looks like their might be some elegant solutions to define such method, leveraging the Jinja env (see https://stackoverflow.com/a/29262304).

Otherwise, I guess you can heavily inspire yourself from HF's transformers implementation (c.f. the usage guide: https://huggingface.co/docs/transformers/main/chat_templating) of AutoTokenizer.from_pretrained("xxx/model-name").apply_chat_template(chat, tokenize=False)

Examples of chat_templates:

Copy link

@lopagela lopagela Dec 2, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And here is a nice entry-point line in transformers to follow to see how they are rendering this jinja template (I basically did a Ctrl + F to find it): https://github.com/huggingface/transformers/blob/74a3cebfa51b539bfcfa79b33686cc090b7074e8/src/transformers/tokenization_utils_base.py#L1600

return ChatFormatterResponse(prompt=formatted_sequence)

@property
def template(self) -> str:
return self._template


class FormatterNotFoundException(Exception):
pass


class ChatFormatterFactory(Singleton):
_chat_formatters: Dict[str, Callable[[], ChatFormatterInterface]] = {}

def register_formatter(
self,
name: str,
formatter_callable: Callable[[], ChatFormatterInterface],
overwrite=False,
):
if not overwrite and name in self._chat_formatters:
raise ValueError(
f"Formatter with name '{name}' is already registered. Use `overwrite=True` to overwrite it."
)
self._chat_formatters[name] = formatter_callable

def unregister_formatter(self, name: str):
if name in self._chat_formatters:
del self._chat_formatters[name]
else:
raise ValueError(f"No formatter registered under the name '{name}'.")

def get_formatter_by_name(self, name: str) -> ChatFormatterInterface:
try:
formatter_callable = self._chat_formatters[name]
return formatter_callable()
except KeyError:
raise FormatterNotFoundException(
f"Invalid chat format: {name} (valid formats: {list(self._chat_formatters.keys())})"
)


# Define a chat format class
class Llama2Formatter(AutoChatFormatter):
def __init__(self):
super().__init__(llama2_template)


# With the Singleton pattern applied, regardless of where or how many times
# ChatFormatterFactory() is called, it will always return the same instance
# of the factory, ensuring that the factory's state is consistent throughout
# the application.
ChatFormatterFactory().register_formatter("llama-2", Llama2Formatter)
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,13 @@ license = { text = "MIT" }
authors = [
{ name = "Andrei Betlen", email = "abetlen@gmail.com" },
]
# mkdocs-martiral requires "jinja2~=3.0"
# transformers requires "jinja2>=2.11.3"
dependencies = [
"typing-extensions>=4.5.0",
"numpy>=1.20.0",
"diskcache>=5.6.1",
"jinja2>=2.11.3",
]
requires-python = ">=3.8"
classifiers = [
Expand Down Expand Up @@ -71,4 +74,3 @@ Changelog = "https://llama-cpp-python.readthedocs.io/en/latest/changelog/"

[tool.pytest.ini_options]
addopts = "--ignore=vendor"

50 changes: 50 additions & 0 deletions tests/test_llama_chat_format.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from typing import List

import pytest

from llama_cpp import ChatCompletionMessage
from llama_cpp.llama_jinja_format import Llama2Formatter


@pytest.fixture
def sequence_of_messages() -> List[ChatCompletionMessage]:
return [
ChatCompletionMessage(role="system", content="Welcome to CodeHelp Bot!"),
ChatCompletionMessage(
role="user", content="Hi there! I need some help with Python."
),
ChatCompletionMessage(
role="assistant", content="Of course! What do you need help with in Python?"
),
ChatCompletionMessage(
role="user",
content="I'm trying to write a function to find the factorial of a number, but I'm stuck.",
),
ChatCompletionMessage(
role="assistant",
content="I can help with that! Would you like a recursive or iterative solution?",
),
ChatCompletionMessage(
role="user", content="Let's go with a recursive solution."
),
]


def test_llama2_formatter(sequence_of_messages):
expected_prompt = (
"<<SYS>> Welcome to CodeHelp Bot! <</SYS>>\n"
"[INST] Hi there! I need some help with Python. [/INST]\n"
"Of course! What do you need help with in Python?\n"
"[INST] I'm trying to write a function to find the factorial of a number, but I'm stuck. [/INST]\n"
"I can help with that! Would you like a recursive or iterative solution?\n"
"[INST] Let's go with a recursive solution. [/INST]\n"
)

llama2_formatter_instance = Llama2Formatter()
formatter_response = llama2_formatter_instance(sequence_of_messages)
assert (
expected_prompt == formatter_response.prompt
), "The formatted prompt does not match the expected output."


# Optionally, include a test for the 'stop' if it's part of the functionality.