Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add setup_chat_format for adding new special tokens to model for training chat models #1242

Merged
merged 15 commits into from
Jan 18, 2024
23 changes: 23 additions & 0 deletions docs/source/sft_trainer.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,29 @@ response_template_ids = tokenizer.encode(response_template_with_context, add_spe
data_collator = DataCollatorForCompletionOnlyLM(response_template_ids, tokenizer=tokenizer)
```

### Add Special Tokens for Chat Format

Adding special tokens to a language model is crucial for training chat models. These tokens are added between the different roles in a conversation, such as the user, assistant, and system and help the model recognize the structure and flow of a conversation. This setup is essential for enabling the model to generate coherent and contextually appropriate responses in a chat environment.
The [`setup_chat_format`] function in `trl` easily sets up a model and tokenizer for conversational AI tasks. This function:
- Adds special tokens to the tokenizer, e.g. `<|im_start|>` and `<|im_end|>`, to indicate the start and end of a conversation.
- Resizes the model’s embedding layer to accommodate the new tokens.
- Sets the `chat_template` of the tokenizer, which is used to format the input data into a chat-like format. The default is `chatml` from OpenAI.
- _optionally_ you can pass `resize_to_multiple_of` to resize the embedding layer to a multiple of the `resize_to_multiple_of` argument, e.g. 64.
philschmid marked this conversation as resolved.
Show resolved Hide resolved

```python
from transformers import AutoModelForCausalLM, AutoTokenizer

# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m")
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")

# Set up the chat format with default 'chatml' format
model, tokenizer = setup_chat_format(model, tokenizer)

```

With our model and tokenizer set up, we can now fine-tune our model on a conversational dataset. Below is an example of how a dataset can be formatted for fine-tuning.

### Dataset format support

The [`SFTTrainer`] supports popular dataset formats. This allows you to pass the dataset to the trainer without any pre-processing directly. The following formats are supported:
Expand Down
44 changes: 44 additions & 0 deletions tests/slow/test_sft_slow.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments

from trl import SFTTrainer, is_peft_available
from trl.models.utils import setup_chat_format

from ..testing_utils import require_bitsandbytes, require_peft, require_torch_gpu, require_torch_multi_gpu
from .testing_constants import DEVICE_MAP_OPTIONS, GRADIENT_CHECKPOINTING_KWARGS, MODELS_TO_TEST, PACKING_OPTIONS
Expand Down Expand Up @@ -345,3 +346,46 @@ def test_sft_trainer_transformers_mp_gc_peft_qlora(self, model_name, packing, gr
trainer.train()

release_memory(model, trainer)

@parameterized.expand(list(itertools.product(MODELS_TO_TEST, PACKING_OPTIONS)))
@require_peft
@require_bitsandbytes
def test_sft_trainer_with_chat_format_qlora(self, model_name, packing):
"""
Simply tests if using setup_chat_format with a transformers model + peft + bnb config to `SFTTrainer` loads and runs the trainer
as expected.
"""
with tempfile.TemporaryDirectory() as tmp_dir:
train_dataset = load_dataset("trl-internal-testing/dolly-chatml-sft", split="train")

args = TrainingArguments(
output_dir=tmp_dir,
logging_strategy="no",
report_to="none",
per_device_train_batch_size=2,
max_steps=10,
fp16=True,
)

quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16)

model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=quantization_config)
tokenizer = AutoTokenizer.from_pretrained(model_name)

model, tokenizer = setup_chat_format(model, tokenizer)

trainer = SFTTrainer(
model,
args=args,
tokenizer=tokenizer,
train_dataset=train_dataset,
packing=packing,
max_seq_length=self.max_seq_length,
peft_config=self.peft_config,
)

self.assertTrue(isinstance(trainer.model, PeftModel))

trainer.train()

release_memory(model, trainer)
27 changes: 26 additions & 1 deletion tests/test_dataset_formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
from typing import Callable

from datasets import Dataset, load_dataset
from transformers import AutoTokenizer
from transformers import AutoModelForCausalLM, AutoTokenizer

from trl.extras.dataset_formatting import get_formatting_func_from_dataset
from trl.models.utils import ChatMlSpecialTokens, setup_chat_format


class DatasetFormattingTestCase(unittest.TestCase):
Expand Down Expand Up @@ -122,3 +123,27 @@ def test_get_formatting_func_from_dataset_with_unknown_format(self):
dataset = Dataset.from_dict({"text": "test"})
formatting_func = get_formatting_func_from_dataset(dataset, self.llama_tokenizer)
self.assertIsNone(formatting_func)


class SetupChatFormatTestCase(unittest.TestCase):
def setUp(self):
self.tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
self.model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM")

def test_setup_chat_format(self):
original_tokenizer_len = len(self.tokenizer)
modified_model, modified_tokenizer = setup_chat_format(
self.model, self.tokenizer, format="chatml", resize_to_multiple_of=64
)

_chatml = ChatMlSpecialTokens()
# Check if special tokens are correctly set
self.assertTrue(modified_tokenizer.eos_token == "<|im_end|>")
self.assertTrue(modified_tokenizer.pad_token == "<|im_end|>")
self.assertTrue(modified_tokenizer.bos_token == "<|im_start|>")
self.assertTrue(modified_tokenizer.eos_token == _chatml.eos_token)
self.assertTrue(modified_tokenizer.pad_token == _chatml.pad_token)
self.assertTrue(modified_tokenizer.bos_token == _chatml.bos_token)
self.assertTrue(len(modified_tokenizer) == original_tokenizer_len + 2)
self.assertTrue(self.model.get_input_embeddings().weight.shape[0] % 64 == 0)
self.assertTrue(self.model.get_input_embeddings().weight.shape[0] == original_tokenizer_len + 64)
1 change: 1 addition & 0 deletions trl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
AutoModelForSeq2SeqLMWithValueHead,
PreTrainedModelWrapper,
create_reference_model,
setup_chat_format,
)
from .trainer import (
DataCollatorForCompletionOnlyLM,
Expand Down
1 change: 1 addition & 0 deletions trl/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# limitations under the License.
from .modeling_base import PreTrainedModelWrapper, create_reference_model
from .modeling_value_head import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead
from .utils import setup_chat_format


SUPPORTED_ARCHITECTURES = (
Expand Down
80 changes: 80 additions & 0 deletions trl/models/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from dataclasses import dataclass
from typing import Literal, Optional, Tuple

from transformers import PreTrainedModel, PreTrainedTokenizer


# TODO: Add Abstract Base Class if more formats are added
@dataclass
class ChatMlSpecialTokens:
philschmid marked this conversation as resolved.
Show resolved Hide resolved
"""Dataclass for special tokens used in ChatML, including system, user, assistant, bos, eos, and pad tokens."""

bos_token: str = "<|im_start|>"
eos_token: str = "<|im_end|>"
pad_token: str = "<|im_end|>"

@property
def system(self):
return f"{self.bos_token}system"

@property
def user(self):
return f"{self.bos_token}user"

@property
def assistant(self):
return f"{self.bos_token}assistant"

@property
def chat_template(self):
return (
"{% for message in messages %}"
f"{{'{self.bos_token}' + message['role'] + '\n' + message['content'] + eos_token + '\n'}}"
"{% endfor %}"
"{% if add_generation_prompt %}"
f"{{ '{self.assistant}\n' }}"
"{% endif %}"
)


FORMAT_MAPPING = {"chatml": ChatMlSpecialTokens}


def setup_chat_format(
model: PreTrainedModel,
tokenizer: PreTrainedTokenizer,
format: Optional[Literal["chatml"]] = "chatml",
resize_to_multiple_of: Optional[int] = None,
) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
"""
Setup chat format by adding special tokens to the tokenizer, setting the correct format, and extending the embedding layer of the model based on the new special tokens.

Args:
model (`~transformers.PreTrainedModel`): The model to be modified.
tokenizer (`~transformers.PreTrainedTokenizer`): The tokenizer to be modified.
format (`Optional[Literal["chatml"]]`): The format to be set. Defaults to "chatml".
resize_to_multiple_of (`Optional[int]`): Number to resize the embedding layer to. Defaults to None.
Returns:
model (`~transformers.PreTrainedModel`): The modified model.
tokenizer (`~transformers.PreTrainedTokenizer`): The modified tokenizer.
"""
# check if format available and retrieve
if format not in FORMAT_MAPPING:
raise ValueError(f"Format {format} not available. Please use one of {FORMAT_MAPPING.keys()}")

chat_format = FORMAT_MAPPING[format]()

# set special tokens and them
tokenizer.eos_token = chat_format.eos_token
tokenizer.pad_token = chat_format.pad_token
tokenizer.bos_token = chat_format.bos_token
tokenizer.add_special_tokens({"additional_special_tokens": [chat_format.bos_token, chat_format.eos_token]})
# set chat format for tokenizer
tokenizer.chat_template = chat_format.chat_template

# resize embedding layer to a multiple of 64, https://x.com/karpathy/status/1621578354024677377
model.resize_token_embeddings(
len(tokenizer), pad_to_multiple_of=resize_to_multiple_of if resize_to_multiple_of is not None else None
)

return model, tokenizer
Loading