You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
An officially supported task in the examples folder
My own task or dataset (give details below)
Reproduction
fromtrl.trainer.utilsimportDataCollatorForChatMLfromtransformersimportAutoTokenizer# base model settingstokenizer=AutoTokenizer.from_pretrained("codellama/CodeLlama-7b-Instruct-hf")
tokenizer.pad_token=tokenizer.bos_tokeniftokenizer.pad_tokenisNoneelsetokenizer.pad_token# Initialize the data collator for processing the chat model inputsmax_length=1024ignore_index=-100messages_key="messages"examples= [
{
messages_key:
[
{"role": "user", "content": "Does the following code contain any security vulnerabilities? Return true or false.\nchar buffer[10];\nchar input[50];\nstrcpy(buffer, input);\n"},
{"role": "assistant", "content": "true"}
]
}
]
collator=DataCollatorForChatML(tokenizer=tokenizer, max_length=max_length, ignore_index=ignore_index, messages_key=messages_key)
# Process the datadata=collator(examples)
print(data["input_ids"])
print(data["labels"])
Preliminary: the expected output from assistant is a string "true" and its token id is 1565. The bos token id is 1 and the eos token id is 2.
Extra bos token and no eos token in data["input_ids"]. We can see that data["input_ids"] got an extra bos token at the beginning (i.e., [1, 1, ...]), and at the end of data["input_ids"] we can see another extra bos token (i.e., [..., 1, 1565]) and no eos token.
No target string in data["labels"]. As for the data["labels"], it does not include the expected output from assistant ("true" token) and only the bos and eos token is preserved (i.e., [..., -100, 1, 2]).
We can see the expected output of data["input_ids"] and data["labels"] should not have any extra bos tokens, and the data["labels"] preserved the target string correctly (i.e., [..., -100, 1565, 29871, 2]).
I believe this result is more compatible with the loss calculation like LlamaForCausalLM as demonstrated below:
loss=NoneiflabelsisnotNone:
# Upcast to float if we need to compute the loss to avoid potential precision issueslogits=logits.float()
# Shift so that tokens < n predict nshift_logits=logits[..., :-1, :].contiguous()
shift_labels=labels[..., 1:].contiguous()
# Flatten the tokensloss_fct=CrossEntropyLoss()
shift_logits=shift_logits.view(-1, self.config.vocab_size)
shift_labels=shift_labels.view(-1)
# Enable model parallelismshift_labels=shift_labels.to(shift_logits.device)
loss=loss_fct(shift_logits, shift_labels)
Location and Solution to the Bugs
The bugs is located at the beginning of call function of DataCollatorForChatML:
prompts= []
completions= []
forexampleinexamples:
messages=example[self.messages_key]
formatted_chat=self.tokenizer.apply_chat_template(messages, tokenize=False)
# Split the formatted chat into prompt and completionassistant_messages= [msgformsginmessagesifmsg["role"] =="assistant"]
last_assistant_message=assistant_messages[-1]["content"]
prompt=formatted_chat.rsplit(last_assistant_message, 1)[0]
completion=last_assistant_messageprompts.append(prompt)
completions.append(completion)
# Tokenize prompts and completionstokenized_prompts=self.tokenizer(
prompts, truncation=True, max_length=self.max_length, padding=False, return_tensors=None
)
tokenized_completions=self.tokenizer(
completions, truncation=True, max_length=self.max_length, padding=False, return_tensors=None
)
# Combine prompts and completionsinput_ids= []
attention_mask= []
labels= []
forprompt, completioninzip(tokenized_prompts["input_ids"], tokenized_completions["input_ids"]):
combined_input_ids=prompt+completioncombined_attention_mask= [1] *len(combined_input_ids)
# Create labels for one-token ahead task, masking the promptcombined_labels= [self.ignore_index] *len(prompt) +completion[:-1]
combined_labels.append(self.tokenizer.eos_token_id) # Add EOS token as final targetinput_ids.append(combined_input_ids)
attention_mask.append(combined_attention_mask)
labels.append(combined_labels)
System Info
Python 3.11.9
trl 0.11.0
transformers 4.45.1
Information
Tasks
examples
folderReproduction
Founded Bugs and Solutions
Output from Demo Script
Founded Bugs
Preliminary: the expected output from assistant is a string "true" and its token id is 1565. The bos token id is 1 and the eos token id is 2.
Expected Output of Demo Script
The expected output should look like this:
We can see the expected output of data["input_ids"] and data["labels"] should not have any extra bos tokens, and the data["labels"] preserved the target string correctly (i.e., [..., -100, 1565, 29871, 2]).
I believe this result is more compatible with the loss calculation like LlamaForCausalLM as demonstrated below:
Location and Solution to the Bugs
The bugs is located at the beginning of call function of DataCollatorForChatML:
Reason for Bug1
The code that raises bug1 is due to:
Because prompts are obtained from formatted_chat, and formatted_chat already adds bos token by default:
So, after using self.tokenizer to tokenize prompts and completions, it will add extra bos token to them.
Fix to Bug1
Set the default parameter of add_special_token from True to False in the following code:
Reason for Bug2
The location that raises bug2 is here:
Since apply_chat_template will add eos token to formatted_chat, above code does not include the eos token to the completion correctly. And
will adds extra bos token to the tokenized_completions, so in this code:
the expected output from assistant is ignored and only the bos token is retained.
Fix to Bug2
The correct implementation should be:
The text was updated successfully, but these errors were encountered: