-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Integration of Jinja2 Templating #875
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
df97e8e
feat: Add support for jinja templating
teleprint-me d9544d1
Merge branch 'main' into jinja2-templates
teleprint-me 9c11d17
fix: Refactor chat formatter and update interface for jinja templates
teleprint-me 101f5f2
Merge branch 'main' into jinja2-templates
teleprint-me 7ebbd8d
Merge branch 'abetlen:main' into jinja2-templates
teleprint-me db909e6
Merge branch 'abetlen:main' into jinja2-templates
teleprint-me 72b7e1f
Add outline for Jinja2 templating integration documentation
teleprint-me a42042a
Add jinja2 as a dependency with version range for Hugging Face transf…
teleprint-me d03eb84
Update jinja2 version constraint for mkdocs-material compatibility
teleprint-me e5d18ce
Fix attribute name in AutoChatFormatter
teleprint-me 7c30c2e
Merge branch 'abetlen:main' into jinja2-templates
teleprint-me caae414
Merge branch 'abetlen:main' into jinja2-templates
teleprint-me 49dcd51
Merge branch 'main' into jinja2-templates
teleprint-me File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
# Templates | ||
|
||
This document provides a comprehensive guide to the integration of Jinja2 templating into the `llama-cpp-python` project, with a focus on enhancing the chat functionality of the `llama-2` model. | ||
|
||
## Introduction | ||
|
||
- Brief explanation of the `llama-cpp-python` project's need for a templating system. | ||
- Overview of the `llama-2` model's interaction with templating. | ||
|
||
## Jinja2 Dependency Integration | ||
|
||
- Rationale for choosing Jinja2 as the templating engine. | ||
- Compatibility with Hugging Face's `transformers`. | ||
- Desire for advanced templating features and simplicity. | ||
- Detailed steps for adding `jinja2` to `pyproject.toml` for dependency management. | ||
|
||
## Template Management Refactor | ||
|
||
- Summary of the refactor and the motivation behind it. | ||
- Description of the new chat handler selection logic: | ||
1. Preference for a user-specified `chat_handler`. | ||
2. Fallback to a user-specified `chat_format`. | ||
3. Defaulting to a chat format from a `.gguf` file if available. | ||
4. Utilizing the `llama2` default chat format as the final fallback. | ||
- Ensuring backward compatibility throughout the refactor. | ||
|
||
## Implementation Details | ||
|
||
- In-depth look at the new `AutoChatFormatter` class. | ||
- Example code snippets showing how to utilize the Jinja2 environment and templates. | ||
- Guidance on how to provide custom templates or use defaults. | ||
|
||
## Testing and Validation | ||
|
||
- Outline of the testing strategy to ensure seamless integration. | ||
- Steps for validating backward compatibility with existing implementations. | ||
|
||
## Benefits and Impact | ||
|
||
- Analysis of the expected benefits, including consistency, performance gains, and improved developer experience. | ||
- Discussion of the potential impact on current users and contributors. | ||
|
||
## Future Work | ||
|
||
- Exploration of how templating can evolve within the project. | ||
- Consideration of additional features or optimizations for the templating engine. | ||
- Mechanisms for community feedback on the templating system. | ||
|
||
## Conclusion | ||
|
||
- Final thoughts on the integration of Jinja2 templating. | ||
- Call to action for community involvement and feedback. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
""" | ||
llama_cpp/llama_jinja_format.py | ||
""" | ||
import dataclasses | ||
from typing import Any, Callable, Dict, List, Optional, Protocol, Union | ||
|
||
import jinja2 | ||
from jinja2 import Template | ||
|
||
# NOTE: We sacrifice readability for usability. | ||
# It will fail to work as expected if we attempt to format it in a readable way. | ||
llama2_template = """{% for message in messages %}{% if message['role'] == 'user' %}[INST] {{ message['content'] }} [/INST]\n{% elif message['role'] == 'assistant' %}{{ message['content'] }}\n{% elif message['role'] == 'system' %}<<SYS>> {{ message['content'] }} <</SYS>>\n{% endif %}{% endfor %}""" | ||
|
||
|
||
class MetaSingleton(type): | ||
""" | ||
Metaclass for implementing the Singleton pattern. | ||
""" | ||
|
||
_instances = {} | ||
|
||
def __call__(cls, *args, **kwargs): | ||
if cls not in cls._instances: | ||
cls._instances[cls] = super(MetaSingleton, cls).__call__(*args, **kwargs) | ||
return cls._instances[cls] | ||
|
||
|
||
class Singleton(object, metaclass=MetaSingleton): | ||
""" | ||
Base class for implementing the Singleton pattern. | ||
""" | ||
|
||
def __init__(self): | ||
super(Singleton, self).__init__() | ||
|
||
|
||
@dataclasses.dataclass | ||
class ChatFormatterResponse: | ||
prompt: str | ||
stop: Optional[Union[str, List[str]]] = None | ||
|
||
|
||
# Base Chat Formatter Protocol | ||
class ChatFormatterInterface(Protocol): | ||
def __init__(self, template: Optional[object] = None): | ||
... | ||
|
||
def __call__( | ||
self, | ||
messages: List[Dict[str, str]], | ||
**kwargs, | ||
) -> ChatFormatterResponse: | ||
... | ||
|
||
@property | ||
def template(self) -> str: | ||
... | ||
|
||
|
||
class AutoChatFormatter(ChatFormatterInterface): | ||
def __init__( | ||
self, | ||
template: Optional[str] = None, | ||
template_class: Optional[Template] = None, | ||
): | ||
if template is not None: | ||
self._template = template | ||
else: | ||
self._template = llama2_template # default template | ||
|
||
self._environment = jinja2.Environment( | ||
loader=jinja2.BaseLoader(), | ||
trim_blocks=True, | ||
lstrip_blocks=True, | ||
).from_string( | ||
self._template, | ||
template_class=template_class, | ||
) | ||
|
||
def __call__( | ||
self, | ||
messages: List[Dict[str, str]], | ||
**kwargs: Any, | ||
) -> ChatFormatterResponse: | ||
formatted_sequence = self._environment.render(messages=messages, **kwargs) | ||
return ChatFormatterResponse(prompt=formatted_sequence) | ||
|
||
@property | ||
def template(self) -> str: | ||
return self._template | ||
|
||
|
||
class FormatterNotFoundException(Exception): | ||
pass | ||
|
||
|
||
class ChatFormatterFactory(Singleton): | ||
_chat_formatters: Dict[str, Callable[[], ChatFormatterInterface]] = {} | ||
|
||
def register_formatter( | ||
self, | ||
name: str, | ||
formatter_callable: Callable[[], ChatFormatterInterface], | ||
overwrite=False, | ||
): | ||
if not overwrite and name in self._chat_formatters: | ||
raise ValueError( | ||
f"Formatter with name '{name}' is already registered. Use `overwrite=True` to overwrite it." | ||
) | ||
self._chat_formatters[name] = formatter_callable | ||
|
||
def unregister_formatter(self, name: str): | ||
if name in self._chat_formatters: | ||
del self._chat_formatters[name] | ||
else: | ||
raise ValueError(f"No formatter registered under the name '{name}'.") | ||
|
||
def get_formatter_by_name(self, name: str) -> ChatFormatterInterface: | ||
try: | ||
formatter_callable = self._chat_formatters[name] | ||
return formatter_callable() | ||
except KeyError: | ||
raise FormatterNotFoundException( | ||
f"Invalid chat format: {name} (valid formats: {list(self._chat_formatters.keys())})" | ||
) | ||
|
||
|
||
# Define a chat format class | ||
class Llama2Formatter(AutoChatFormatter): | ||
def __init__(self): | ||
super().__init__(llama2_template) | ||
|
||
|
||
# With the Singleton pattern applied, regardless of where or how many times | ||
# ChatFormatterFactory() is called, it will always return the same instance | ||
# of the factory, ensuring that the factory's state is consistent throughout | ||
# the application. | ||
ChatFormatterFactory().register_formatter("llama-2", Llama2Formatter) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
from typing import List | ||
|
||
import pytest | ||
|
||
from llama_cpp import ChatCompletionMessage | ||
from llama_cpp.llama_jinja_format import Llama2Formatter | ||
|
||
|
||
@pytest.fixture | ||
def sequence_of_messages() -> List[ChatCompletionMessage]: | ||
return [ | ||
ChatCompletionMessage(role="system", content="Welcome to CodeHelp Bot!"), | ||
ChatCompletionMessage( | ||
role="user", content="Hi there! I need some help with Python." | ||
), | ||
ChatCompletionMessage( | ||
role="assistant", content="Of course! What do you need help with in Python?" | ||
), | ||
ChatCompletionMessage( | ||
role="user", | ||
content="I'm trying to write a function to find the factorial of a number, but I'm stuck.", | ||
), | ||
ChatCompletionMessage( | ||
role="assistant", | ||
content="I can help with that! Would you like a recursive or iterative solution?", | ||
), | ||
ChatCompletionMessage( | ||
role="user", content="Let's go with a recursive solution." | ||
), | ||
] | ||
|
||
|
||
def test_llama2_formatter(sequence_of_messages): | ||
expected_prompt = ( | ||
"<<SYS>> Welcome to CodeHelp Bot! <</SYS>>\n" | ||
"[INST] Hi there! I need some help with Python. [/INST]\n" | ||
"Of course! What do you need help with in Python?\n" | ||
"[INST] I'm trying to write a function to find the factorial of a number, but I'm stuck. [/INST]\n" | ||
"I can help with that! Would you like a recursive or iterative solution?\n" | ||
"[INST] Let's go with a recursive solution. [/INST]\n" | ||
) | ||
|
||
llama2_formatter_instance = Llama2Formatter() | ||
formatter_response = llama2_formatter_instance(sequence_of_messages) | ||
assert ( | ||
expected_prompt == formatter_response.prompt | ||
), "The formatted prompt does not match the expected output." | ||
|
||
|
||
# Optionally, include a test for the 'stop' if it's part of the functionality. |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
During my exploration on existing
chat_template
, I found out that, usually, they are using functions, such asraise_exception
.It looks like their might be some elegant solutions to define such method, leveraging the Jinja env (see https://stackoverflow.com/a/29262304).
Otherwise, I guess you can heavily inspire yourself from HF's
transformers
implementation (c.f. the usage guide: https://huggingface.co/docs/transformers/main/chat_templating) ofAutoTokenizer.from_pretrained("xxx/model-name").apply_chat_template(chat, tokenize=False)
Examples of
chat_templates
:There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And here is a nice entry-point line in
transformers
to follow to see how they are rendering this jinja template (I basically did aCtrl + F
to find it): https://github.com/huggingface/transformers/blob/74a3cebfa51b539bfcfa79b33686cc090b7074e8/src/transformers/tokenization_utils_base.py#L1600