Skip to content

Commit

Permalink
[bugfix] Fix DataCollatorForChatML unexpected generation prompt (#2450)
Browse files Browse the repository at this point in the history
* [bugfix] Fix DataCollatorForChatML unexpected generation prompt

* Update utils.py

* Update test_utils.py

* Update tests/test_utils.py

* Update tests/test_utils.py

* Update tests/test_utils.py

* Update tests/test_utils.py

* Update test_utils.py

* Update tests/test_utils.py

* Update tests/test_utils.py

---------

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
  • Loading branch information
NIL-zhuang and kashif authored Dec 11, 2024
1 parent 460e780 commit c9c4f18
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 31 deletions.
80 changes: 50 additions & 30 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,54 +205,74 @@ def setUp(self):
ignore_index=self.ignore_index,
)

# See https://github.com/huggingface/trl/pull/2287#discussion_r1856594421
@unittest.skip("This test must be updated.")
def test_data_collator_for_chatml(self):
# Process the data
data = self.collator(self.examples)

# Verify basic shapes and types
self.assertIn("input_ids", data)
self.assertIn("attention_mask", data)
self.assertIn("labels", data)
self.assertIn("prompts", data)
self.assertIn("prompt_attention_mask", data)

# Decode input_ids and labels for verification
input_ids = data["input_ids"][0].tolist()
labels = data["labels"][0].tolist()
prompt_only = data["prompts"][0].tolist()

# Verify that input_ids start with optional padding tokens and a single BOS token and there are no extra ones
first_non_pad = next(token for token in input_ids if token != self.tokenizer.pad_token_id)
self.assertEqual(
first_non_pad, self.bos_token_id, "The first non-padding token of input_ids should be BOS token."
)
self.assertEqual(input_ids.count(self.bos_token_id), 1, "There should be exactly one BOS token in input_ids.")

# Verify that the assistant's response token is present in input_ids and not in the prompt_only
last_assistant_response = self.examples[0][self.messages_key][-1]["content"]
last_assistant_response_tokens = self.tokenizer.encode(last_assistant_response, add_special_tokens=False)
response_in_input_ids = all(token in input_ids for token in last_assistant_response_tokens)
self.assertTrue(response_in_input_ids, "The assistant's response should be present in input_ids.")
# Get the last assistant's response for comparison
last_message = self.examples[0][self.messages_key][-1]
self.assertEqual(last_message["role"], "assistant", "Last message should be from assistant")
last_assistant_response = last_message["content"]

# Check if the last assistant's response tokens are not in prompt_only
response_in_prompt = all(token in prompt_only for token in last_assistant_response_tokens)
self.assertFalse(response_in_prompt, "The assistant's response should not be present in prompt_only.")
# Verify that input_ids contain both prompt and response
decoded_input = self.tokenizer.decode(input_ids)
self.assertIn(last_assistant_response, decoded_input, "Input should contain assistant's response")

# Verify that EOS token is at the end of input_ids
self.assertEqual(input_ids[-1], self.eos_token_id, "The last token of input_ids should be EOS token.")
# Verify that prompts only contain the conversation up to the last response
decoded_prompt = self.tokenizer.decode(prompt_only)
self.assertNotIn(last_assistant_response, decoded_prompt, "Prompt should not contain assistant's response")

# Verify that the labels preserved the target string (last_assistant_response)
last_assistant_response = self.examples[0][self.messages_key][-1]["content"]
last_assistant_response_tokens = self.tokenizer.encode(last_assistant_response, add_special_tokens=False)
# Verify labels are -100 for non-assistant parts
prompt_length = len(prompt_only)
self.assertTrue(
all(label == self.ignore_index for label in labels[:prompt_length]),
"Labels should be ignore_index for prompt tokens",
)

# Find the start and end of the last assistant's response in the labels
response_start = next(i for i, label in enumerate(labels) if label != self.ignore_index)
response_end = next(i for i in range(len(labels) - 1, -1, -1) if labels[i] != self.ignore_index)
# Verify labels match assistant response after prompt
# Add a filter to remove any trailing tokens after the first <|im_end|>
last_assistant_response_with_end = last_assistant_response + self.tokenizer.eos_token
last_assistant_response_tokens = self.tokenizer.encode(
last_assistant_response_with_end, add_special_tokens=False
)

actual_response = labels[response_start : response_end - 1]
response_labels = []
for label in labels[prompt_length:]:
if label == self.ignore_index:
continue
response_labels.append(label)
if label == self.tokenizer.convert_tokens_to_ids("<|im_end|>"):
break
self.assertEqual(
actual_response,
response_labels,
last_assistant_response_tokens,
"The labels should preserve the last assistant's response tokens.",
"Labels should match assistant response tokens",
)

# Verify that EOS token is at the end of labels
self.assertEqual(labels[-1], self.eos_token_id, "The last token of labels should be EOS token.")
# Verify there isn't a generation prompt at the end
generation_prompt = "<|im_start|>assistant"
self.assertFalse(
decoded_input.strip().endswith(generation_prompt),
f"Input should not end with generation prompt '{generation_prompt}'",
)

self.assertEqual(
response_labels,
last_assistant_response_tokens,
"Labels should match assistant response tokens",
)


class TestBatchGeneration(unittest.TestCase):
Expand Down
2 changes: 1 addition & 1 deletion trl/trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def __call__(self, examples: list[dict[str, Any]]) -> dict[str, torch.Tensor]:
if "input_ids" not in example:
message = example[self.messages_key]
formatted_message = self.tokenizer.apply_chat_template(
message, tokenize=False, add_generation_prompt=True
message, tokenize=False, add_generation_prompt=False
)
tokenized_message = self.tokenizer(
formatted_message,
Expand Down

0 comments on commit c9c4f18

Please sign in to comment.