Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DataCollatorForChatML unexpected generation prompt #2449

Closed
7 of 9 tasks
NIL-zhuang opened this issue Dec 7, 2024 · 0 comments · Fixed by #2450
Closed
7 of 9 tasks

DataCollatorForChatML unexpected generation prompt #2449

NIL-zhuang opened this issue Dec 7, 2024 · 0 comments · Fixed by #2450
Assignees

Comments

@NIL-zhuang
Copy link
Contributor

System Info

  • Platform: macOS-15.1.1-arm64-arm-64bit
  • Python version: 3.10.15
  • PyTorch version: 2.4.1
  • CUDA device(s): not available
  • Transformers version: 4.45.2
  • Accelerate version: 1.0.1
  • Accelerate config: not found
  • Datasets version: 3.0.1
  • HF Hub version: 0.25.2
  • TRL version: 0.12.2
  • bitsandbytes version: not installed
  • DeepSpeed version: 0.15.2
  • Diffusers version: 0.30.3
  • Liger-Kernel version: not installed
  • LLM-Blender version: not installed
  • OpenAI version: 1.51.2
  • PEFT version: 0.13.2

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder
  • My own task or dataset (give details below)

Reproduction

from transformers import AutoTokenizer
from trl.trainer.utils import DataCollatorForChatML

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct")
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.pad_token = tokenizer.eos_token

data_collator = DataCollatorForChatML(tokenizer)
examples = [
    {
        "messages": [
            {"role": "system", "content": "You are a professional translator."},
            {"role": "user", "content": "Hello!"},
            {"role": "assistant", "content": "Hi there! How can I help you today?"},
        ],
    },
]
batch = data_collator(examples)

print(tokenizer.decode(batch["input_ids"][0]))

label = batch["labels"][0]
label[label == -100] = tokenizer.eos_token_id
print(tokenizer.decode(label))

outputs:

<|begin_of_text|><|start_header_id|>system<|end_header_id|>

Cutting Knowledge Date: December 2023
Today Date: 07 Dec 2024

You are a professional translator.<|eot_id|><|start_header_id|>user<|end_header_id|>

Hello!<|eot_id|><|start_header_id|>assistant<|end_header_id|>

Hi there! How can I help you today?<|eot_id|><|start_header_id|>assistant<|end_header_id|>


!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!Hi there! How can I help you today?<|eot_id|><|start_header_id|>assistant<|end_header_id|>

Expected behavior

When processing instruction tuning, the model is not expected to generate <|start_header_id|>assistant<|end_header_id|> after <|eot_id|>. The correct model response should be Hi there! How can I help you today?<|eot_id|>.

We should change trl/trainer/utils.py L276:277

formatted_message = self.tokenizer.apply_chat_template(
    # message, tokenize=False, add_generation_prompt=True,
    message, tokenize=False, add_generation_prompt=False
)

Checklist

  • I have checked that my issue isn't already filed (see open issues)
  • I have included my system information
  • Any code provided is minimal, complete, and reproducible (more on MREs)
  • Any code provided is properly formatted in code blocks, (no screenshot, more on code blocks)
  • Any traceback provided is complete
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants