forked from huggingface/trl
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add
setup_chat_format
for adding new special tokens to model for tr…
…aining chat models (huggingface#1242) * first draft * 64 * sourabs suggestion * wip tests * make style happy * add check * docstring * fix docstring * Update tests/test_model_utils.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * move tests * add todo for abstract class * make style happy * add slow tests and imports * add documentation * sft_trainer.mdx aktualisieren Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> --------- Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
- Loading branch information
1 parent
4fc5203
commit 34c47ed
Showing
6 changed files
with
175 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
from dataclasses import dataclass | ||
from typing import Literal, Optional, Tuple | ||
|
||
from transformers import PreTrainedModel, PreTrainedTokenizer | ||
|
||
|
||
# TODO: Add Abstract Base Class if more formats are added | ||
@dataclass | ||
class ChatMlSpecialTokens: | ||
"""Dataclass for special tokens used in ChatML, including system, user, assistant, bos, eos, and pad tokens.""" | ||
|
||
bos_token: str = "<|im_start|>" | ||
eos_token: str = "<|im_end|>" | ||
pad_token: str = "<|im_end|>" | ||
|
||
@property | ||
def system(self): | ||
return f"{self.bos_token}system" | ||
|
||
@property | ||
def user(self): | ||
return f"{self.bos_token}user" | ||
|
||
@property | ||
def assistant(self): | ||
return f"{self.bos_token}assistant" | ||
|
||
@property | ||
def chat_template(self): | ||
return ( | ||
"{% for message in messages %}" | ||
f"{{'{self.bos_token}' + message['role'] + '\n' + message['content'] + eos_token + '\n'}}" | ||
"{% endfor %}" | ||
"{% if add_generation_prompt %}" | ||
f"{{ '{self.assistant}\n' }}" | ||
"{% endif %}" | ||
) | ||
|
||
|
||
FORMAT_MAPPING = {"chatml": ChatMlSpecialTokens} | ||
|
||
|
||
def setup_chat_format( | ||
model: PreTrainedModel, | ||
tokenizer: PreTrainedTokenizer, | ||
format: Optional[Literal["chatml"]] = "chatml", | ||
resize_to_multiple_of: Optional[int] = None, | ||
) -> Tuple[PreTrainedModel, PreTrainedTokenizer]: | ||
""" | ||
Setup chat format by adding special tokens to the tokenizer, setting the correct format, and extending the embedding layer of the model based on the new special tokens. | ||
Args: | ||
model (`~transformers.PreTrainedModel`): The model to be modified. | ||
tokenizer (`~transformers.PreTrainedTokenizer`): The tokenizer to be modified. | ||
format (`Optional[Literal["chatml"]]`): The format to be set. Defaults to "chatml". | ||
resize_to_multiple_of (`Optional[int]`): Number to resize the embedding layer to. Defaults to None. | ||
Returns: | ||
model (`~transformers.PreTrainedModel`): The modified model. | ||
tokenizer (`~transformers.PreTrainedTokenizer`): The modified tokenizer. | ||
""" | ||
# check if format available and retrieve | ||
if format not in FORMAT_MAPPING: | ||
raise ValueError(f"Format {format} not available. Please use one of {FORMAT_MAPPING.keys()}") | ||
|
||
chat_format = FORMAT_MAPPING[format]() | ||
|
||
# set special tokens and them | ||
tokenizer.eos_token = chat_format.eos_token | ||
tokenizer.pad_token = chat_format.pad_token | ||
tokenizer.bos_token = chat_format.bos_token | ||
tokenizer.add_special_tokens({"additional_special_tokens": [chat_format.bos_token, chat_format.eos_token]}) | ||
# set chat format for tokenizer | ||
tokenizer.chat_template = chat_format.chat_template | ||
|
||
# resize embedding layer to a multiple of 64, https://x.com/karpathy/status/1621578354024677377 | ||
model.resize_token_embeddings( | ||
len(tokenizer), pad_to_multiple_of=resize_to_multiple_of if resize_to_multiple_of is not None else None | ||
) | ||
|
||
return model, tokenizer |