Skip to content

Commit

Permalink
Update incorrect data processing in DataCollatorForChatML (#2172)
Browse files Browse the repository at this point in the history
* Update incorrect data processing in DataCollatorForChatML

Fix the extra BOS token and the absence of an EOS token in the returned input_ids, and potentially the absence of a target string in the returned labels.

* Update trl/trainer/utils.py

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

* style

* move comment

* add test for DataCollatorForChatML

* update comment with more details

* update assert reports and comments, and adds verification that the last token of input_ids should be EOS token

* new line at the end of file for code quality

* Update tests/test_utils.py

* Update tests/test_utils.py

* Update tests/test_utils.py

* update tests

* fix test

* Update tests/test_utils.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* Update tests/test_utils.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

* formatting

* fix typo

* simplify

* Revert "simplify"

This reverts commit 7e4006c.

* tokenize full messages

* dont add eos

* eos is in the last token

* simplify DataCollatorForChatML

* Update tests/test_utils.py

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
  • Loading branch information
5 people committed Oct 10, 2024
1 parent 00b537e commit 22567cd
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 49 deletions.
77 changes: 76 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@
import unittest

import torch
from datasets import load_dataset
from transformers import AutoTokenizer
from transformers.testing_utils import require_peft
from transformers.utils import is_peft_available

from trl.trainer.model_config import ModelConfig
from trl.trainer.utils import decode_and_strip_padding, get_peft_config, pad
from trl.trainer.utils import DataCollatorForChatML, decode_and_strip_padding, get_peft_config, pad


if is_peft_available():
Expand Down Expand Up @@ -126,3 +127,77 @@ def test_example_without_padding(self):
inputs = self.tokenizer(["Hello", "Hello"], padding=False, return_tensors="pt")
decoded = decode_and_strip_padding(inputs["input_ids"], self.tokenizer)
self.assertEqual(decoded, ["Hello", "Hello"])


class TestDataCollatorForChatML(unittest.TestCase):
def setUp(self):
# Initialize the tokenizer
self.tokenizer = AutoTokenizer.from_pretrained("codellama/CodeLlama-7b-Instruct-hf")
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token

# Define token IDs
self.bos_token_id = self.tokenizer.bos_token_id if self.tokenizer.bos_token_id is not None else 1
self.eos_token_id = self.tokenizer.eos_token_id if self.tokenizer.eos_token_id is not None else 2
# Token ID for "true", the last assistant's response in the example:
self.ignore_index = -100
self.max_length = 1024
self.messages_key = "messages"

# Example input
dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling", split="train")
self.examples = dataset.to_list()

# Initialize the data collator
self.collator = DataCollatorForChatML(
tokenizer=self.tokenizer,
max_length=self.max_length,
ignore_index=self.ignore_index,
)

def test_data_collator_for_chatml(self):
# Process the data
data = self.collator(self.examples)

# Decode input_ids and labels for verification
input_ids = data["input_ids"][0].tolist()
labels = data["labels"][0].tolist()
prompt_only = data["prompts"][0].tolist()

# Verify that input_ids start with optional padding tokens and a single BOS token and there are no extra ones
first_non_pad = next(token for token in input_ids if token != self.tokenizer.pad_token_id)
self.assertEqual(
first_non_pad, self.bos_token_id, "The first non-padding token of input_ids should be BOS token."
)
self.assertEqual(input_ids.count(self.bos_token_id), 1, "There should be exactly one BOS token in input_ids.")

# Verify that the assistant's response token is present in input_ids and not in the prompt_only
last_assistant_response = self.examples[0][self.messages_key][-1]["content"]
last_assistant_response_tokens = self.tokenizer.encode(last_assistant_response, add_special_tokens=False)
response_in_input_ids = all(token in input_ids for token in last_assistant_response_tokens)
self.assertTrue(response_in_input_ids, "The assistant's response should be present in input_ids.")

# Check if the last assistant's response tokens are not in prompt_only
response_in_prompt = all(token in prompt_only for token in last_assistant_response_tokens)
self.assertFalse(response_in_prompt, "The assistant's response should not be present in prompt_only.")

# Verify that EOS token is at the end of input_ids
self.assertEqual(input_ids[-1], self.eos_token_id, "The last token of input_ids should be EOS token.")

# Verify that the labels preserved the target string (last_assistant_response)
last_assistant_response = self.examples[0][self.messages_key][-1]["content"]
last_assistant_response_tokens = self.tokenizer.encode(last_assistant_response, add_special_tokens=False)

# Find the start and end of the last assistant's response in the labels
response_start = next(i for i, label in enumerate(labels) if label != self.ignore_index)
response_end = next(i for i in range(len(labels) - 1, -1, -1) if labels[i] != self.ignore_index)

actual_response = labels[response_start : response_end - 1]
self.assertEqual(
actual_response,
last_assistant_response_tokens,
"The labels should preserve the last assistant's response tokens.",
)

# Verify that EOS token is at the end of labels
self.assertEqual(labels[-1], self.eos_token_id, "The last token of labels should be EOS token.")
99 changes: 51 additions & 48 deletions trl/trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ class DataCollatorForChatML:
tokenizer: PreTrainedTokenizerBase
ignore_index: int = -100
max_length: int = None
prompt_key: str = "prompt"
messages_key: str = "messages"

def __post_init__(self):
Expand All @@ -250,67 +251,69 @@ def __post_init__(self):
self.max_length = min(self.tokenizer.model_max_length, 1024)

def __call__(self, examples: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
prompts = []
completions = []

for example in examples:
messages = example[self.messages_key]
formatted_chat = self.tokenizer.apply_chat_template(messages, tokenize=False)

# Split the formatted chat into prompt and completion
assistant_messages = [msg for msg in messages if msg["role"] == "assistant"]
last_assistant_message = assistant_messages[-1]["content"]
prompt = formatted_chat.rsplit(last_assistant_message, 1)[0]
completion = last_assistant_message

prompts.append(prompt)
completions.append(completion)

# Tokenize prompts and completions
tokenized_prompts = self.tokenizer(
prompts, truncation=True, max_length=self.max_length, padding=False, return_tensors=None
)
tokenized_completions = self.tokenizer(
completions, truncation=True, max_length=self.max_length, padding=False, return_tensors=None
)

# Combine prompts and completions
input_ids = []
attention_mask = []
prompts_input_ids = []
prompt_attention_mask = []
labels = []

for prompt, completion in zip(tokenized_prompts["input_ids"], tokenized_completions["input_ids"]):
combined_input_ids = prompt + completion
combined_attention_mask = [1] * len(combined_input_ids)
for example in examples:
formatted_prompt = example.get(self.prompt_key, None)
if formatted_prompt is None:
prompt = example[self.messages_key][:-1]
formatted_prompt = self.tokenizer.apply_chat_template(
prompt, tokenize=False, add_generation_prompt=True
)

# Create labels for one-token ahead task, masking the prompt
combined_labels = [self.ignore_index] * len(prompt) + completion[:-1]
combined_labels.append(self.tokenizer.eos_token_id) # Add EOS token as final target
if "input_ids" not in example:
message = example[self.messages_key]
formatted_message = self.tokenizer.apply_chat_template(
message, tokenize=False, add_generation_prompt=True
)
tokenized_message = self.tokenizer(
formatted_message,
truncation=True,
max_length=self.max_length,
padding=False,
return_tensors=None,
add_special_tokens=False,
)
input_ids.append(tokenized_message["input_ids"])
attention_mask.append(tokenized_message["attention_mask"])
else:
input_ids.append(example["input_ids"])
attention_mask.append(example["attention_mask"])

tokenized_prompt = self.tokenizer(
formatted_prompt,
truncation=True,
max_length=len(input_ids[-1]),
padding=False,
return_tensors=None,
add_special_tokens=False,
)

input_ids.append(combined_input_ids)
attention_mask.append(combined_attention_mask)
labels.append(combined_labels)
prompts_input_ids.append(tokenized_prompt["input_ids"])
prompt_attention_mask.append(tokenized_prompt["attention_mask"])

# first convert to list of tensors
input_ids = [torch.tensor(ids) for ids in input_ids]
attention_mask = [torch.tensor(mask) for mask in attention_mask]
labels = [torch.tensor(label) for label in labels]
# Create the labels that will have all but the completion tokens of the example["input_ids"] set to ignore_index
label = [self.ignore_index] * len(input_ids[-1])
completion_start_idx = len(tokenized_prompt["input_ids"])
label[completion_start_idx:] = input_ids[-1][completion_start_idx:]
labels.append(label)

# pad the input_ids, attention_mask and labels to the same length across the batch
# convert to list of tensors and pad
input_ids = [torch.tensor(ids, dtype=torch.long) for ids in input_ids]
attention_mask = [torch.tensor(mask, dtype=torch.long) for mask in attention_mask]
labels = [torch.tensor(label, dtype=torch.long) for label in labels]
input_ids = pad(input_ids, padding_side="left", padding_value=self.tokenizer.pad_token_id)
attention_mask = pad(attention_mask, padding_side="left", padding_value=0)
labels = pad(labels, padding_side="left", padding_value=self.ignore_index)

# pad the tokenized_prompts on the left to the same length convert to tensor first
prompts_input_ids = [torch.tensor(ids) for ids in tokenized_prompts["input_ids"]]
prompts_input_ids = [torch.tensor(ids, dtype=torch.long) for ids in prompts_input_ids]
prompt_attention_mask = [torch.tensor(mask, dtype=torch.long) for mask in prompt_attention_mask]
prompts_input_ids = pad(prompts_input_ids, padding_side="left", padding_value=self.tokenizer.pad_token_id)

# prompt attention mask
prompt_attention_mask = pad(
[torch.tensor([1] * len(ids)) for ids in tokenized_prompts["input_ids"]],
padding_side="left",
padding_value=0,
)
prompt_attention_mask = pad(prompt_attention_mask, padding_side="left", padding_value=0)

return {
"input_ids": input_ids,
Expand Down

0 comments on commit 22567cd

Please sign in to comment.