Skip to content

⏩ Train on completion only #3329

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Apr 23, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/dataset_formats.md
Original file line number Diff line number Diff line change
@@ -279,7 +279,7 @@ Choosing the right dataset type depends on the task you are working on and the s
| [`PPOTrainer`] | Tokenized language modeling |
| [`PRMTrainer`] | [Stepwise supervision](#stepwise-supervision) |
| [`RewardTrainer`] | [Preference (implicit prompt recommended)](#preference) |
| [`SFTTrainer`] | [Language modeling](#language-modeling) |
| [`SFTTrainer`] | [Language modeling](#language-modeling) or [Prompt-completion](#prompt-completion) |
| [`XPOTrainer`] | [Prompt-only](#prompt-only) |

<Tip>
100 changes: 2 additions & 98 deletions docs/source/sft_trainer.md
Original file line number Diff line number Diff line change
@@ -59,105 +59,9 @@ The above snippets will use the default training arguments from the [`SFTConfig`

### Train on completions only

You can use the `DataCollatorForCompletionOnlyLM` to train your model on the generated prompts only. Note that this works only in the case when `packing=False`.
To instantiate that collator for instruction data, pass a response template and the tokenizer. Here is an example of how it would work to fine-tune `opt-350m` on completions only on the CodeAlpaca dataset:
To train on completions only, simply use a [prompt-completion](#prompt-completion) dataset. In this mode, loss is computed solely on the completion part.

```python
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from trl import SFTConfig, SFTTrainer, DataCollatorForCompletionOnlyLM

dataset = load_dataset("lucasmccabe-lmi/CodeAlpaca-20k", split="train")

model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m")
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")

def formatting_prompt_func(example):
return f"### Question: {example['instruction']}\n ### Answer: {example['output']}"


response_template = " ### Answer:"
collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer)

trainer = SFTTrainer(
model,
train_dataset=dataset,
args=SFTConfig(output_dir="/tmp"),
formatting_func=formatting_prompt_func,
data_collator=collator,
)

trainer.train()
```

To instantiate that collator for assistant style conversation data, pass a response template, an instruction template and the tokenizer. Here is an example of how it would work to fine-tune `opt-350m` on assistant completions only on the Open Assistant Guanaco dataset:

```python
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from trl import SFTConfig, SFTTrainer, DataCollatorForCompletionOnlyLM

dataset = load_dataset("timdettmers/openassistant-guanaco", split="train")

model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m")
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")

instruction_template = "### Human:"
response_template = "### Assistant:"
collator = DataCollatorForCompletionOnlyLM(instruction_template=instruction_template, response_template=response_template, tokenizer=tokenizer, mlm=False)

trainer = SFTTrainer(
model,
args=SFTConfig(output_dir="/tmp"),
train_dataset=dataset,
data_collator=collator,
)

trainer.train()
```

Make sure to have a `pad_token_id` which is different from `eos_token_id` which can result in the model not properly predicting EOS (End of Sentence) tokens during generation.

#### Using token_ids directly for `response_template`

Some tokenizers like Llama 2 (`meta-llama/Llama-2-XXb-hf`) tokenize sequences differently depending on whether they have context or not. For example:

```python
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")

def print_tokens_with_ids(txt):
tokens = tokenizer.tokenize(txt, add_special_tokens=False)
token_ids = tokenizer.encode(txt, add_special_tokens=False)
print(list(zip(tokens, token_ids)))

prompt = """### User: Hello\n\n### Assistant: Hi, how can I help you?"""
print_tokens_with_ids(prompt) # [..., ('▁Hello', 15043), ('<0x0A>', 13), ('<0x0A>', 13), ('##', 2277), ('#', 29937), ('▁Ass', 4007), ('istant', 22137), (':', 29901), ...]

response_template = "### Assistant:"
print_tokens_with_ids(response_template) # [('▁###', 835), ('▁Ass', 4007), ('istant', 22137), (':', 29901)]
```

In this case, and due to lack of context in `response_template`, the same string ("### Assistant:") is tokenized differently:

- Text (with context): `[2277, 29937, 4007, 22137, 29901]`
- `response_template` (without context): `[835, 4007, 22137, 29901]`

This will lead to an error when the `DataCollatorForCompletionOnlyLM` does not find the `response_template` in the dataset example text:

```
RuntimeError: Could not find response key [835, 4007, 22137, 29901] in token IDs tensor([ 1, 835, ...])
```


To solve this, you can tokenize the `response_template` with the same context as in the dataset, truncate it as needed and pass the `token_ids` directly to the `response_template` argument of the `DataCollatorForCompletionOnlyLM` class. For example:

```python
response_template_with_context = "\n### Assistant:" # We added context here: "\n". This is enough for this tokenizer
response_template_ids = tokenizer.encode(response_template_with_context, add_special_tokens=False)[2:] # Now we have it like in the dataset texts: `[2277, 29937, 4007, 22137, 29901]`

data_collator = DataCollatorForCompletionOnlyLM(response_template_ids, tokenizer=tokenizer)
```
If you’d like to compute loss on both the prompt **and** the completion while still using a prompt-completion dataset, set `completion_only_loss=False` in the [`SFTConfig`]. This is equivalent to [converting the dataset to a language modeling](#from-prompt-completion-to-language-modeling-dataset) format.

### Add Special Tokens for Chat Format

40 changes: 1 addition & 39 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
@@ -960,27 +960,6 @@ def test_torch_dtype(self):
)
self.assertEqual(trainer.model.config.torch_dtype, torch.float16)

# Now test when `torch_dtype` is provided but is wrong
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = SFTConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
model_init_kwargs={"torch_dtype": -1},
report_to="none",
)
with self.assertRaises(ValueError) as context:
_ = SFTTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
args=training_args,
train_dataset=self.train_dataset,
)

self.assertIn(
"Invalid `torch_dtype` passed to `SFTConfig`. Expected either 'auto' or a string representing "
"a `torch.dtype` (e.g., 'float32'), but got -1.",
str(context.exception),
)


# This new tester aims to replace the first one at some point
class SFTTrainerTester2(unittest.TestCase):
@@ -1064,23 +1043,6 @@ def test_train_model_torch_dtype(self):
self.assertEqual(new_param.dtype, torch.float16)
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")

def test_train_model_wrong_torch_dtype(self):
# Get the dataset
dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train")

with tempfile.TemporaryDirectory() as tmp_dir:
# Initialize the trainer
training_args = SFTConfig(output_dir=tmp_dir, model_init_kwargs={"torch_dtype": -1}, report_to="none")
with self.assertRaises(ValueError) as context:
SFTTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=training_args, train_dataset=dataset
)
self.assertIn(
"Invalid `torch_dtype` passed to `SFTConfig`. Expected either 'auto' or a string representing "
"a `torch.dtype` (e.g., 'float32'), but got -1.",
str(context.exception),
)

Comment on lines -1067 to -1083
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not related to the core change of this PR.
With the new serialisation logic of TrainingArguments, passing a wrong dtype fails when you instantiate the TrainingArguments. There is no need for such test anymore

@require_peft
def test_train_peft_model(self):
# Get the base model
@@ -1211,7 +1173,7 @@ def test_train_with_iterable_dataset(self):
def test_train_with_data_collator_for_completion_only_and_padding_free(self):
# Get the dataset
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
dataset = load_dataset("trl-internal-testing/zen", "conversational_prompt_completion", split="train")
dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling", split="train")

tokenizer = AutoTokenizer.from_pretrained(model_id)
response_template = "<|im_start|>assistant\n"
18 changes: 18 additions & 0 deletions trl/trainer/sft_config.py
Original file line number Diff line number Diff line change
@@ -70,6 +70,12 @@ class SFTConfig(TrainingArguments):
learning_rate (`float`, *optional*, defaults to `2e-5`):
Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
[`~transformers.TrainingArguments`].
completion_only_loss (`bool` or `None`, *optional*, defaults to `None`):
Whether to compute loss only on the completion part of the sequence. If set to `True`, loss is computed
only on the completion, which is supported only for [prompt-completion](#prompt-completion) datasets. If
`False`, loss is computed on the entire sequence. If `None` (default), the behavior depends on the dataset:
loss is computed on the completion for [prompt-completion](#prompt-completion) datasets, and on
the full sequence for [language modeling](#language-modeling) datasets.
"""

# Parameters that control the model
@@ -147,6 +153,18 @@ class SFTConfig(TrainingArguments):
"`TrainingArguments`."
},
)
completion_only_loss: Optional[bool] = field(
default=None,
metadata={
"help": (
"Whether to compute loss only on the completion part of the sequence. If set to `True`, loss is "
"computed only on the completion, which is supported only for prompt-completion datasets. If `False`, "
"loss is computed on the entire sequence. If `None` (default), the behavior depends on the dataset: "
"loss is computed on the completion for prompt-completion datasets, and on the full sequence for "
"language modeling datasets."
)
},
)

# Deprecated parameters
dataset_batch_size: Optional[int] = field(
86 changes: 67 additions & 19 deletions trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
@@ -78,6 +78,9 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
Args:
pad_token_id (`int`):
Token ID to use for padding.
completion_only_loss (`bool`, *optional*, defaults to `True`):
When the input contains a completion mask (`completion_mask`), the labels are set to -100 for the tokens
that are not in the completion.
return_tensors (`str`, *optional*, defaults to `"pt"`):
Type of Tensor to return. Only `"pt"` is currently supported.

@@ -90,29 +93,47 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
... {"input_ids": [4, 5]}
... ]
>>> collator(examples)
{'input_ids': tensor([[ 1, 2, 3],
[ 4, 5, 0]]),
'attention_mask': tensor([[ 1, 1, 1],
[ 1, 1, 0]]),
{'input_ids': tensor([[ 1, 2, 3],
[ 4, 5, 0]]),
'attention_mask': tensor([[ 1, 1, 1],
[ 1, 1, 0]]),
'labels': tensor([[ 1, 2, 3],
[ 4, 5, -100]])}
>>> # With completion mask
>>> examples = [
... {"input_ids": [1, 2, 3], "completion_mask": [0, 1, 1]},
... {"input_ids": [4, 5], "completion_mask": [0, 1]}
... ]
>>> collator(examples)
{'input_ids': tensor([[ 1, 2, 3],
[ 4, 5, 0]]),
'attention_mask': tensor([[ 1, 1, 1],
[ 1, 1, 0]]),
'labels': tensor([[-100, 2, 3],
[-100, 5, -100]])}
```
"""

pad_token_id: int
completion_only_loss: bool = True
return_tensors: str = "pt"

def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> dict[str, Any]:
# Convert to tensor
input_ids = [torch.tensor(example["input_ids"]) for example in examples]
attention_mask = [torch.ones_like(input_ids) for input_ids in input_ids]
labels = [torch.tensor(example["input_ids"]) for example in examples]
if self.completion_only_loss and "completion_mask" in examples[0]:
completion_mask = [torch.tensor(example["completion_mask"]) for example in examples]

# Pad
output = {}
output["input_ids"] = pad(input_ids, padding_value=self.pad_token_id, padding_side="right")
output["attention_mask"] = pad(attention_mask, padding_value=0, padding_side="right")
output["labels"] = pad(labels, padding_value=-100, padding_side="right")
if self.completion_only_loss and "completion_mask" in examples[0]:
completion_mask = pad(completion_mask, padding_value=0, padding_side="right")
output["labels"][completion_mask == 0] = -100 # mask everything that is not in the completion

return output

@@ -278,6 +299,11 @@ def __init__(
)
data_collator = DataCollatorWithFlattening()

if args.completion_only_loss is None:
first_example = next(iter(train_dataset))
self.completion_only_loss = "prompt" in first_example
else:
self.completion_only_loss = args.completion_only_loss
if data_collator is None:
# Get the pad token: if not provided, use the one from the processing class or the eos token
# if the processing class does not have a pad token.
@@ -289,7 +315,7 @@ def __init__(
f"`processing_class` ({processing_class.__class__.__name__}). Ensure that the `pad_token` exists "
"in the vocabulary before using it as a padding token."
)
data_collator = DataCollatorForLanguageModeling(pad_token_id)
data_collator = DataCollatorForLanguageModeling(pad_token_id, self.completion_only_loss)

# Dataset
preprocess_dataset = args.dataset_kwargs is None or not args.dataset_kwargs.get("skip_prepare_dataset", False)
@@ -500,16 +526,6 @@ def _func(example):
)
dataset = dataset.map(_func, batched=True, **map_kwargs)

# If the dataset is prompt-completion, convert it to language modeling type
first_example = next(iter(dataset))
if "prompt" in first_example.keys() and "completion" in first_example.keys():
key = "messages" if is_conversational(first_example) else "text"

def concat_prompt_completion(example):
return {key: example["prompt"] + example["completion"]}

dataset = dataset.map(concat_prompt_completion, remove_columns=["prompt", "completion"])

Comment on lines -503 to -512
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This concatenation needs to be removed, as we loses the information about where the completion starts. This completion is now managed in tokenize.

if not is_processed:
# Convert the dataset to ChatML if needed
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
@@ -560,14 +576,38 @@ def add_eos(example, eos_token):
# See https://huggingface.co/blog/qgallouedec/gotchas-in-tokenizer-behavior#7-chat-template-and-tokenization-dont-compose-due-to-special-tokens
add_special_tokens = True

# Tokenize the dataset if needed
# Tokenize the dataset
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset"

def tokenize(example, processing_class, dataset_text_field, add_special_tokens):
processed = processing_class(
text=example[dataset_text_field], add_special_tokens=add_special_tokens
)
if "prompt" in example: # prompt-completion case
processed_prompt = processing_class(
text=example["prompt"],
add_special_tokens=add_special_tokens,
)
processed = processing_class(
text=example["prompt"] + example["completion"], add_special_tokens=add_special_tokens
)

# Check if the tokenized prompt starts with the tokenized prompt+completion
prompt_ids = processed_prompt["input_ids"]
prompt_completion_ids = processed["input_ids"]
if not prompt_completion_ids[: len(prompt_ids)] == prompt_ids:
warnings.warn(
"Mismatch between tokenized prompt and the start of tokenized prompt+completion. "
"This may be due to unexpected tokenizer behavior, whitespace issues, or special "
"token handling. Verify that the tokenizer is processing text consistently."
)

# Create a completion mask
completion_mask = [0] * len(prompt_ids) + [1] * (len(prompt_completion_ids) - len(prompt_ids))
processed = {**processed, "completion_mask": completion_mask}

else: # language modeling case
processed = processing_class(
text=example[dataset_text_field], add_special_tokens=add_special_tokens
)
return processed

dataset = dataset.map(
@@ -598,6 +638,14 @@ def tokenize(example, processing_class, dataset_text_field, add_special_tokens):

return dataset

def _set_signature_columns_if_needed(self):
# If `self.args.remove_unused_columns` is True, non-signature columns are removed.
# By default, this method sets `self._signature_columns` to the model's expected inputs (usually, "input_ids"
# and "attention_mask"). When using `train_on_completion_only` we add a "completion_mask" column to the
# dataset. So we need to override the default signature columns to include "completion_mask" as well.
if self._signature_columns is None:
self._signature_columns = ["input_ids", "attention_mask", "completion_mask"]

def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
"""
Compute training loss and additionally compute token accuracies
5 changes: 5 additions & 0 deletions trl/trainer/utils.py
Original file line number Diff line number Diff line change
@@ -97,6 +97,11 @@ def __init__(
**kwargs,
):
super().__init__(*args, mlm=mlm, **kwargs)
warnings.warn(
"This class is deprecated and will be removed in version 0.20.0. To train on completion only, please use "
"the parameter `completion_only_loss` of `SFTConfig` instead.",
DeprecationWarning,
)

self.instruction_template = instruction_template
if isinstance(instruction_template, str):