Skip to content

token_type_ids missing in DPOTrainer #4284

@aweers

Description

@aweers

Reproduction

Similar to #4150 and #4189
token_type_ids, which is required by Gemma-3 models (for newer versions of huggingface/transformers, is not forwarded in the DPOTrainer.
This can easily be reproduced by adding trl-internal-testing/tiny-Gemma3ForConditionalGeneration to the models in the test_vdpo_trainer in tests/test_dpo_trainer.py

Example:

import tempfile

import numpy as np
from datasets import Dataset, features
from PIL import Image
from transformers import (
    AutoModelForImageTextToText,
    AutoProcessor,
)

from trl import DPOConfig, DPOTrainer


if __name__ == "__main__":
    model_id = "trl-internal-testing/tiny-Gemma3ForConditionalGeneration"
    # fmt: off
    dataset_dict = {
        "prompt": [
            [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "Describe the image in great detail."}]}],
            [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "Is this bus in the USA?"}]}],
            [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "Give a thorough description of the image."}]}],
            [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "Who are the people in the image?"}]}],
            [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What is written?"}]}],
        ],
        "chosen": [
            [{"role": "assistant", "content": [{"type": "text", "text": "The image features a modern, multi-colored train."}]}],
            [{"role": "assistant", "content": [{"type": "text", "text": "Yes, it can be assumed that this bus is in the USA."}]}],
            [{"role": "assistant", "content": [{"type": "text", "text": "The image features a forest path."}]}],
            [{"role": "assistant", "content": [{"type": "text", "text": "There are two individuals, possibly girls or women."}]}],
            [{"role": "assistant", "content": [{"type": "text", "text": '"ccpb".'}]}],
        ],
        "rejected": [
            [{"role": "assistant", "content": [{"type": "text", "text": "The image features a modern, colorful train."}]}],
            [{"role": "assistant", "content": [{"type": "text", "text": "No, it's not in the USA."}]}],
            [{"role": "assistant", "content": [{"type": "text", "text": "The image features a forest path surrounded by trees."}]}],
            [{"role": "assistant", "content": [{"type": "text", "text": "In the image, there are two individuals."}]}],
            [{"role": "assistant", "content": [{"type": "text", "text": '"ccpb".'}]}],
        ],
        "images": [
            [Image.fromarray(np.random.randint(0, 255, (92, 33, 3), dtype=np.uint8))],
            [Image.fromarray(np.random.randint(0, 255, (64, 48, 3), dtype=np.uint8))],
            [Image.fromarray(np.random.randint(0, 255, (80, 152, 3), dtype=np.uint8))],
            [Image.fromarray(np.random.randint(0, 255, (57, 24, 3), dtype=np.uint8))],
            [Image.fromarray(np.random.randint(0, 255, (102, 48, 3), dtype=np.uint8))],
        ],
    }
    # fmt: on
    dataset = Dataset.from_dict(dataset_dict)
    dataset = dataset.cast_column("images", features.Sequence(features.Image()))

    # Instantiate the model and processor
    model = AutoModelForImageTextToText.from_pretrained(model_id)
    ref_model = AutoModelForImageTextToText.from_pretrained(model_id)
    processor = AutoProcessor.from_pretrained(model_id)

    with tempfile.TemporaryDirectory() as tmpdirname:
        training_args = DPOConfig(
            output_dir=tmpdirname,
            per_device_train_batch_size=2,
            remove_unused_columns=False,
            learning_rate=0.01,  # increase learning rate to speed up test
            max_prompt_length=None,  # don't truncate to avoid issues with patch tokens
            max_length=None,
            report_to="none",
        )
        trainer = DPOTrainer(
            model=model,
            ref_model=ref_model,
            args=training_args,
            processing_class=processor,
            train_dataset=dataset,
            eval_dataset=dataset,
        )

        trainer.train()

outputs:

0%|          | 0/9 [00:00<?, ?it/s]Traceback (most recent call last):
  File "/mnt/m2/git/trl/tests/dpo_fail.py", line 78, in <module>
    trainer.train()
    ~~~~~~~~~~~~~^^
  File "/mnt/m2/git/trl/venv/lib/python3.13/site-packages/transformers/trainer.py", line 2152, in train
    return inner_training_loop(
        args=args,
    ...<2 lines>...
        ignore_keys_for_eval=ignore_keys_for_eval,
    )
  File "/mnt/m2/git/trl/venv/lib/python3.13/site-packages/transformers/trainer.py", line 2483, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
  File "/mnt/m2/git/trl/venv/lib/python3.13/site-packages/transformers/trainer.py", line 3762, in training_step
    loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
  File "/mnt/m2/git/trl/trl/trainer/dpo_trainer.py", line 1787, in compute_loss
    loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")
                    ~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/m2/git/trl/trl/trainer/dpo_trainer.py", line 1703, in get_batch_loss_metrics
    model_output = self.concatenated_forward(model, batch)
  File "/mnt/m2/git/trl/trl/trainer/dpo_trainer.py", line 1577, in concatenated_forward
    outputs = model(input_ids, **model_kwargs)
  File "/mnt/m2/git/trl/venv/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/mnt/m2/git/trl/venv/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/m2/git/trl/venv/lib/python3.13/site-packages/accelerate/utils/operations.py", line 818, in forward
    return model_forward(*args, **kwargs)
  File "/mnt/m2/git/trl/venv/lib/python3.13/site-packages/accelerate/utils/operations.py", line 806, in __call__
    return convert_to_fp32(self.model_forward(*args, **kwargs))
                           ~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/mnt/m2/git/trl/venv/lib/python3.13/site-packages/torch/amp/autocast_mode.py", line 44, in decorate_autocast
    return func(*args, **kwargs)
  File "/mnt/m2/git/trl/venv/lib/python3.13/site-packages/transformers/models/gemma3/modeling_gemma3.py", line 1119, in forward
    outputs = self.model(
        input_ids=input_ids,
    ...<12 lines>...
        **lm_kwargs,
    )
  File "/mnt/m2/git/trl/venv/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/mnt/m2/git/trl/venv/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/m2/git/trl/venv/lib/python3.13/site-packages/transformers/utils/generic.py", line 757, in wrapper
    output = func(self, *args, **kwargs)
  File "/mnt/m2/git/trl/venv/lib/python3.13/site-packages/transformers/models/gemma3/modeling_gemma3.py", line 964, in forward
    causal_mask_mapping = create_causal_mask_mapping(
        self.config,
    ...<7 lines>...
        is_training=self.training,
    )
  File "/mnt/m2/git/trl/venv/lib/python3.13/site-packages/transformers/models/gemma3/modeling_gemma3.py", line 777, in create_causal_mask_mapping
    raise ValueError("`token_type_ids` is required as a model input when training")
ValueError: `token_type_ids` is required as a model input when training

It's my first issue, so please feel free to point out incomplete things.
I will also open a pull request that fixes the problem.

System Info

  • Platform: Linux-5.15.0-140-generic-x86_64-with-glibc2.35
  • Python version: 3.13.2
  • TRL version: 0.24.0.dev0+36e8e0a
  • PyTorch version: 2.8.0
  • accelerator(s): NVIDIA GeForce RTX 3080
  • Transformers version: 5.0.0.dev0
  • Accelerate version: 1.10.1
  • Accelerate config: not found
  • Datasets version: 4.2.0
  • HF Hub version: 1.0.0rc5
  • bitsandbytes version: 0.48.1
  • DeepSpeed version: 0.18.0
  • Liger-Kernel version: 0.6.2
  • LLM-Blender version: 0.0.2
  • OpenAI version: 2.3.0
  • PEFT version: 0.17.1
  • vLLM version: not installed

Checklist

  • I have checked that my issue isn't already filed (see open issues)
  • I have included my system information
  • Any code provided is minimal, complete, and reproducible (more on MREs)
  • Any code provided is properly formatted in code blocks, (no screenshot, more on code blocks)
  • Any traceback provided is complete

Metadata

Metadata

Assignees

No one assigned

    Labels

    🏋 DPORelated to DPO🐛 bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions