Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bugfix] Fix DataCollatorForChatML unexpected generation prompt #2450

Merged
merged 10 commits into from
Dec 11, 2024
75 changes: 45 additions & 30 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,54 +205,69 @@ def setUp(self):
ignore_index=self.ignore_index,
)

# See https://github.com/huggingface/trl/pull/2287#discussion_r1856594421
@unittest.skip("This test must be updated.")
def test_data_collator_for_chatml(self):
# Process the data
data = self.collator(self.examples)

# Verify basic shapes and types
self.assertIn("input_ids", data)
self.assertIn("attention_mask", data)
self.assertIn("labels", data)
self.assertIn("prompts", data)
self.assertIn("prompt_attention_mask", data)

# Decode input_ids and labels for verification
input_ids = data["input_ids"][0].tolist()
labels = data["labels"][0].tolist()
prompt_only = data["prompts"][0].tolist()

# Verify that input_ids start with optional padding tokens and a single BOS token and there are no extra ones
first_non_pad = next(token for token in input_ids if token != self.tokenizer.pad_token_id)
self.assertEqual(
first_non_pad, self.bos_token_id, "The first non-padding token of input_ids should be BOS token."
)
self.assertEqual(input_ids.count(self.bos_token_id), 1, "There should be exactly one BOS token in input_ids.")

# Verify that the assistant's response token is present in input_ids and not in the prompt_only
last_assistant_response = self.examples[0][self.messages_key][-1]["content"]
last_assistant_response_tokens = self.tokenizer.encode(last_assistant_response, add_special_tokens=False)
response_in_input_ids = all(token in input_ids for token in last_assistant_response_tokens)
self.assertTrue(response_in_input_ids, "The assistant's response should be present in input_ids.")
# Get the last assistant's response for comparison
last_message = self.examples[0][self.messages_key][-1]
self.assertEqual(last_message["role"], "assistant", "Last message should be from assistant")
last_assistant_response = last_message["content"]

# Check if the last assistant's response tokens are not in prompt_only
response_in_prompt = all(token in prompt_only for token in last_assistant_response_tokens)
self.assertFalse(response_in_prompt, "The assistant's response should not be present in prompt_only.")
# Verify that input_ids contain both prompt and response
decoded_input = self.tokenizer.decode(input_ids)
self.assertIn(last_assistant_response, decoded_input, "Input should contain assistant's response")

# Verify that EOS token is at the end of input_ids
self.assertEqual(input_ids[-1], self.eos_token_id, "The last token of input_ids should be EOS token.")
# Verify that prompts only contain the conversation up to the last response
decoded_prompt = self.tokenizer.decode(prompt_only)
self.assertNotIn(last_assistant_response, decoded_prompt, "Prompt should not contain assistant's response")

# Verify that the labels preserved the target string (last_assistant_response)
last_assistant_response = self.examples[0][self.messages_key][-1]["content"]
last_assistant_response_tokens = self.tokenizer.encode(last_assistant_response, add_special_tokens=False)
# Verify labels are -100 for non-assistant parts
prompt_length = len(prompt_only)
self.assertTrue(
all(label == self.ignore_index for label in labels[:prompt_length]),
"Labels should be ignore_index for prompt tokens",
)

# Find the start and end of the last assistant's response in the labels
response_start = next(i for i, label in enumerate(labels) if label != self.ignore_index)
response_end = next(i for i in range(len(labels) - 1, -1, -1) if labels[i] != self.ignore_index)
# Verify labels match assistant response after prompt
# Add a filter to remove any trailing tokens after the first <|im_end|>
last_assistant_response_with_end = last_assistant_response + self.tokenizer.eos_token
last_assistant_response_tokens = self.tokenizer.encode(
last_assistant_response_with_end, add_special_tokens=False
)

actual_response = labels[response_start : response_end - 1]
response_labels = []
for label in labels[prompt_length:]:
if label == self.ignore_index:
continue
response_labels.append(label)
if label == self.tokenizer.convert_tokens_to_ids("<|im_end|>"):
break
self.assertEqual(
actual_response,
response_labels,
last_assistant_response_tokens,
"The labels should preserve the last assistant's response tokens.",
"Labels should match assistant response tokens",
)

# Verify that EOS token is at the end of labels
self.assertEqual(labels[-1], self.eos_token_id, "The last token of labels should be EOS token.")
# Verify there isn't an extra generation prompt at the end
generation_prompt = "<|im_start|>assistant\n"
kashif marked this conversation as resolved.
Show resolved Hide resolved
last_occurrence = decoded_input.rfind(generation_prompt)
if last_occurrence != -1:
# Check that there isn't another occurrence after the expected one
next_occurrence = decoded_input.find(generation_prompt, last_occurrence + len(generation_prompt))
self.assertEqual(next_occurrence, -1, "Found an extra generation prompt at the end of the input")
kashif marked this conversation as resolved.
Show resolved Hide resolved


class TestBatchGeneration(unittest.TestCase):
Expand Down
2 changes: 1 addition & 1 deletion trl/trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def __call__(self, examples: list[dict[str, Any]]) -> dict[str, torch.Tensor]:
if "input_ids" not in example:
message = example[self.messages_key]
formatted_message = self.tokenizer.apply_chat_template(
message, tokenize=False, add_generation_prompt=True
message, tokenize=False, add_generation_prompt=False
)
tokenized_message = self.tokenizer(
formatted_message,
Expand Down
Loading