Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Incorrect data processing in DataCollatorForChatML #2169

Closed
3 of 4 tasks
ruijunfeng opened this issue Oct 4, 2024 · 2 comments · Fixed by ruijunfeng/trl#1 or #2172
Closed
3 of 4 tasks

Incorrect data processing in DataCollatorForChatML #2169

ruijunfeng opened this issue Oct 4, 2024 · 2 comments · Fixed by ruijunfeng/trl#1 or #2172
Labels
🐛 bug Something isn't working 🏋 GKD Related to GKD

Comments

@ruijunfeng
Copy link
Contributor

ruijunfeng commented Oct 4, 2024

System Info

Python 3.11.9
trl 0.11.0
transformers 4.45.1

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder
  • My own task or dataset (give details below)

Reproduction

from trl.trainer.utils import DataCollatorForChatML
from transformers import AutoTokenizer
# base model settings
tokenizer = AutoTokenizer.from_pretrained("codellama/CodeLlama-7b-Instruct-hf")
tokenizer.pad_token = tokenizer.bos_token if tokenizer.pad_token is None else tokenizer.pad_token

# Initialize the data collator for processing the chat model inputs
max_length = 1024
ignore_index = -100
messages_key = "messages"
examples = [
    {
        messages_key:
            [
                {"role": "user", "content": "Does the following code contain any security vulnerabilities? Return true or false.\nchar buffer[10];\nchar input[50];\nstrcpy(buffer, input);\n"}, 
                {"role": "assistant", "content": "true"}
            ]
    }
]
collator = DataCollatorForChatML(tokenizer=tokenizer, max_length=max_length, ignore_index=ignore_index, messages_key=messages_key)

# Process the data
data = collator(examples)
print(data["input_ids"])
print(data["labels"])

Founded Bugs and Solutions

Output from Demo Script

>>> print(data["input_ids"])
tensor([[    1,     1,   518, 25580, 29962,  5538,   278,  1494,   775,  1712,
           738,  6993, 23180, 11614, 29973,  7106,  1565,   470,  2089, 29889,
            13,  3090,  6835, 29961, 29896, 29900,  1385,    13,  3090,  1881,
         29961, 29945, 29900,  1385,    13,   710, 23141, 29898,  9040, 29892,
          1881,   416,   518, 29914, 25580, 29962, 29871,     1,  1565]])
>>> print(data["labels"])
tensor([[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
         -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
         -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
         -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,    1,
            2]])

Founded Bugs

Preliminary: the expected output from assistant is a string "true" and its token id is 1565. The bos token id is 1 and the eos token id is 2.

  1. Extra bos token and no eos token in data["input_ids"]. We can see that data["input_ids"] got an extra bos token at the beginning (i.e., [1, 1, ...]), and at the end of data["input_ids"] we can see another extra bos token (i.e., [..., 1, 1565]) and no eos token.
  2. No target string in data["labels"]. As for the data["labels"], it does not include the expected output from assistant ("true" token) and only the bos and eos token is preserved (i.e., [..., -100, 1, 2]).

Expected Output of Demo Script

The expected output should look like this:

>>> print(data["input_ids"])
tensor([[    1,   518, 25580, 29962,  5538,   278,  1494,   775,  1712,   738,
          6993, 23180, 11614, 29973,  7106,  1565,   470,  2089, 29889,    13,
          3090,  6835, 29961, 29896, 29900,  1385,    13,  3090,  1881, 29961,
         29945, 29900,  1385,    13,   710, 23141, 29898,  9040, 29892,  1881,
           416,   518, 29914, 25580, 29962, 29871,  1565, 29871,     2]])
>>> print(data["labels"])
tensor([[ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  1565, 29871,     2]])

We can see the expected output of data["input_ids"] and data["labels"] should not have any extra bos tokens, and the data["labels"] preserved the target string correctly (i.e., [..., -100, 1565, 29871, 2]).
I believe this result is more compatible with the loss calculation like LlamaForCausalLM as demonstrated below:

loss = None
if labels is not None:
    # Upcast to float if we need to compute the loss to avoid potential precision issues
    logits = logits.float()
    # Shift so that tokens < n predict n
    shift_logits = logits[..., :-1, :].contiguous()
    shift_labels = labels[..., 1:].contiguous()
    # Flatten the tokens
    loss_fct = CrossEntropyLoss()
    shift_logits = shift_logits.view(-1, self.config.vocab_size)
    shift_labels = shift_labels.view(-1)
    # Enable model parallelism
    shift_labels = shift_labels.to(shift_logits.device)
    loss = loss_fct(shift_logits, shift_labels)

Location and Solution to the Bugs

The bugs is located at the beginning of call function of DataCollatorForChatML:

prompts = []
completions = []

for example in examples:
    messages = example[self.messages_key]
    formatted_chat = self.tokenizer.apply_chat_template(messages, tokenize=False)

    # Split the formatted chat into prompt and completion
    assistant_messages = [msg for msg in messages if msg["role"] == "assistant"]
    last_assistant_message = assistant_messages[-1]["content"]
    prompt = formatted_chat.rsplit(last_assistant_message, 1)[0]
    completion = last_assistant_message

    prompts.append(prompt)
    completions.append(completion)

# Tokenize prompts and completions
tokenized_prompts = self.tokenizer(
    prompts, truncation=True, max_length=self.max_length, padding=False, return_tensors=None
)
tokenized_completions = self.tokenizer(
    completions, truncation=True, max_length=self.max_length, padding=False, return_tensors=None
)

# Combine prompts and completions
input_ids = []
attention_mask = []
labels = []

for prompt, completion in zip(tokenized_prompts["input_ids"], tokenized_completions["input_ids"]):
    combined_input_ids = prompt + completion
    combined_attention_mask = [1] * len(combined_input_ids)

    # Create labels for one-token ahead task, masking the prompt
    combined_labels = [self.ignore_index] * len(prompt) + completion[:-1]
    combined_labels.append(self.tokenizer.eos_token_id)  # Add EOS token as final target

    input_ids.append(combined_input_ids)
    attention_mask.append(combined_attention_mask)
    labels.append(combined_labels)

Reason for Bug1

The code that raises bug1 is due to:

# Tokenize prompts and completions
tokenized_prompts = self.tokenizer(
    prompts, truncation=True, max_length=self.max_length, padding=False, return_tensors=None
)
tokenized_completions = self.tokenizer(
    completions, truncation=True, max_length=self.max_length, padding=False, return_tensors=None
)

Because prompts are obtained from formatted_chat, and formatted_chat already adds bos token by default:

formatted_chat = self.tokenizer.apply_chat_template(messages, tokenize=False)

So, after using self.tokenizer to tokenize prompts and completions, it will add extra bos token to them.

Fix to Bug1

Set the default parameter of add_special_token from True to False in the following code:

# Tokenize prompts and completions
tokenized_prompts = self.tokenizer(
    prompts, truncation=True, max_length=self.max_length, padding=False, return_tensors=None, add_special_token=False
)
tokenized_completions = self.tokenizer(
    completions, truncation=True, max_length=self.max_length, padding=False, return_tensors=None, add_special_token=False
)

Reason for Bug2

The location that raises bug2 is here:

completion = last_assistant_message

Since apply_chat_template will add eos token to formatted_chat, above code does not include the eos token to the completion correctly. And

tokenized_completions = self.tokenizer(
    completions, truncation=True, max_length=self.max_length, padding=False, return_tensors=None
)

will adds extra bos token to the tokenized_completions, so in this code:

combined_labels = [self.ignore_index] * len(prompt) + completion[:-1]

the expected output from assistant is ignored and only the bos token is retained.

Fix to Bug2

The correct implementation should be:

completion = last_assistant_message + formatted_chat.rsplit(last_assistant_message, 1)[1].
@ruijunfeng ruijunfeng added the 🐛 bug Something isn't working label Oct 4, 2024
@qgallouedec
Copy link
Member

Thanks a lot for this detailed report @ruijunfeng
This is indeed a critical issue. Are you willing to submit a PR to solve it?

@ruijunfeng
Copy link
Contributor Author

Thanks a lot for this detailed report @ruijunfeng This is indeed a critical issue. Are you willing to submit a PR to solve it?

Hi there, I have submitted a PR to fix this, hope this will help😊

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
🐛 bug Something isn't working 🏋 GKD Related to GKD
Projects
None yet
2 participants