generated from fastai/nbdev_template
-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Closed
Labels
Description
Reproduction
Similar to #4150 and #4189
token_type_ids, which is required by Gemma-3 models (for newer versions of huggingface/transformers, is not forwarded in the DPOTrainer.
This can easily be reproduced by adding trl-internal-testing/tiny-Gemma3ForConditionalGeneration to the models in the test_vdpo_trainer in tests/test_dpo_trainer.py
Example:
import tempfile
import numpy as np
from datasets import Dataset, features
from PIL import Image
from transformers import (
AutoModelForImageTextToText,
AutoProcessor,
)
from trl import DPOConfig, DPOTrainer
if __name__ == "__main__":
model_id = "trl-internal-testing/tiny-Gemma3ForConditionalGeneration"
# fmt: off
dataset_dict = {
"prompt": [
[{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "Describe the image in great detail."}]}],
[{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "Is this bus in the USA?"}]}],
[{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "Give a thorough description of the image."}]}],
[{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "Who are the people in the image?"}]}],
[{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What is written?"}]}],
],
"chosen": [
[{"role": "assistant", "content": [{"type": "text", "text": "The image features a modern, multi-colored train."}]}],
[{"role": "assistant", "content": [{"type": "text", "text": "Yes, it can be assumed that this bus is in the USA."}]}],
[{"role": "assistant", "content": [{"type": "text", "text": "The image features a forest path."}]}],
[{"role": "assistant", "content": [{"type": "text", "text": "There are two individuals, possibly girls or women."}]}],
[{"role": "assistant", "content": [{"type": "text", "text": '"ccpb".'}]}],
],
"rejected": [
[{"role": "assistant", "content": [{"type": "text", "text": "The image features a modern, colorful train."}]}],
[{"role": "assistant", "content": [{"type": "text", "text": "No, it's not in the USA."}]}],
[{"role": "assistant", "content": [{"type": "text", "text": "The image features a forest path surrounded by trees."}]}],
[{"role": "assistant", "content": [{"type": "text", "text": "In the image, there are two individuals."}]}],
[{"role": "assistant", "content": [{"type": "text", "text": '"ccpb".'}]}],
],
"images": [
[Image.fromarray(np.random.randint(0, 255, (92, 33, 3), dtype=np.uint8))],
[Image.fromarray(np.random.randint(0, 255, (64, 48, 3), dtype=np.uint8))],
[Image.fromarray(np.random.randint(0, 255, (80, 152, 3), dtype=np.uint8))],
[Image.fromarray(np.random.randint(0, 255, (57, 24, 3), dtype=np.uint8))],
[Image.fromarray(np.random.randint(0, 255, (102, 48, 3), dtype=np.uint8))],
],
}
# fmt: on
dataset = Dataset.from_dict(dataset_dict)
dataset = dataset.cast_column("images", features.Sequence(features.Image()))
# Instantiate the model and processor
model = AutoModelForImageTextToText.from_pretrained(model_id)
ref_model = AutoModelForImageTextToText.from_pretrained(model_id)
processor = AutoProcessor.from_pretrained(model_id)
with tempfile.TemporaryDirectory() as tmpdirname:
training_args = DPOConfig(
output_dir=tmpdirname,
per_device_train_batch_size=2,
remove_unused_columns=False,
learning_rate=0.01, # increase learning rate to speed up test
max_prompt_length=None, # don't truncate to avoid issues with patch tokens
max_length=None,
report_to="none",
)
trainer = DPOTrainer(
model=model,
ref_model=ref_model,
args=training_args,
processing_class=processor,
train_dataset=dataset,
eval_dataset=dataset,
)
trainer.train()outputs:
0%| | 0/9 [00:00<?, ?it/s]Traceback (most recent call last):
File "/mnt/m2/git/trl/tests/dpo_fail.py", line 78, in <module>
trainer.train()
~~~~~~~~~~~~~^^
File "/mnt/m2/git/trl/venv/lib/python3.13/site-packages/transformers/trainer.py", line 2152, in train
return inner_training_loop(
args=args,
...<2 lines>...
ignore_keys_for_eval=ignore_keys_for_eval,
)
File "/mnt/m2/git/trl/venv/lib/python3.13/site-packages/transformers/trainer.py", line 2483, in _inner_training_loop
tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
File "/mnt/m2/git/trl/venv/lib/python3.13/site-packages/transformers/trainer.py", line 3762, in training_step
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
File "/mnt/m2/git/trl/trl/trainer/dpo_trainer.py", line 1787, in compute_loss
loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")
~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/m2/git/trl/trl/trainer/dpo_trainer.py", line 1703, in get_batch_loss_metrics
model_output = self.concatenated_forward(model, batch)
File "/mnt/m2/git/trl/trl/trainer/dpo_trainer.py", line 1577, in concatenated_forward
outputs = model(input_ids, **model_kwargs)
File "/mnt/m2/git/trl/venv/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
File "/mnt/m2/git/trl/venv/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
File "/mnt/m2/git/trl/venv/lib/python3.13/site-packages/accelerate/utils/operations.py", line 818, in forward
return model_forward(*args, **kwargs)
File "/mnt/m2/git/trl/venv/lib/python3.13/site-packages/accelerate/utils/operations.py", line 806, in __call__
return convert_to_fp32(self.model_forward(*args, **kwargs))
~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
File "/mnt/m2/git/trl/venv/lib/python3.13/site-packages/torch/amp/autocast_mode.py", line 44, in decorate_autocast
return func(*args, **kwargs)
File "/mnt/m2/git/trl/venv/lib/python3.13/site-packages/transformers/models/gemma3/modeling_gemma3.py", line 1119, in forward
outputs = self.model(
input_ids=input_ids,
...<12 lines>...
**lm_kwargs,
)
File "/mnt/m2/git/trl/venv/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
File "/mnt/m2/git/trl/venv/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
File "/mnt/m2/git/trl/venv/lib/python3.13/site-packages/transformers/utils/generic.py", line 757, in wrapper
output = func(self, *args, **kwargs)
File "/mnt/m2/git/trl/venv/lib/python3.13/site-packages/transformers/models/gemma3/modeling_gemma3.py", line 964, in forward
causal_mask_mapping = create_causal_mask_mapping(
self.config,
...<7 lines>...
is_training=self.training,
)
File "/mnt/m2/git/trl/venv/lib/python3.13/site-packages/transformers/models/gemma3/modeling_gemma3.py", line 777, in create_causal_mask_mapping
raise ValueError("`token_type_ids` is required as a model input when training")
ValueError: `token_type_ids` is required as a model input when training
It's my first issue, so please feel free to point out incomplete things.
I will also open a pull request that fixes the problem.
System Info
- Platform: Linux-5.15.0-140-generic-x86_64-with-glibc2.35
- Python version: 3.13.2
- TRL version: 0.24.0.dev0+36e8e0a
- PyTorch version: 2.8.0
- accelerator(s): NVIDIA GeForce RTX 3080
- Transformers version: 5.0.0.dev0
- Accelerate version: 1.10.1
- Accelerate config: not found
- Datasets version: 4.2.0
- HF Hub version: 1.0.0rc5
- bitsandbytes version: 0.48.1
- DeepSpeed version: 0.18.0
- Liger-Kernel version: 0.6.2
- LLM-Blender version: 0.0.2
- OpenAI version: 2.3.0
- PEFT version: 0.17.1
- vLLM version: not installed
Checklist
- I have checked that my issue isn't already filed (see open issues)
- I have included my system information
- Any code provided is minimal, complete, and reproducible (more on MREs)
- Any code provided is properly formatted in code blocks, (no screenshot, more on code blocks)
- Any traceback provided is complete