Skip to content

CI fails: ValueError: token_type_ids is required as a model input when training #4142

@albertvillanova

Description

@albertvillanova

CI fails for Tests with dev dependencies using tiny-Gemma3ForConditionalGeneration: https://github.com/huggingface/trl/actions/runs/18003157615/job/51217046062?pr=4128

ValueError: `token_type_ids` is required as a model input when training
FAILED tests/test_grpo_trainer.py::GRPOTrainerTester::test_training_vlm_0_trl_internal_testing_tiny_Gemma3ForConditionalGeneration - ValueError: `token_type_ids` is required as a model input when training
FAILED tests/test_rloo_trainer.py::RLOOTrainerTester::test_training_vlm_0_trl_internal_testing_tiny_Gemma3ForConditionalGeneration - ValueError: `token_type_ids` is required as a model input when training

Traceback:

a = (<tests.test_rloo_trainer.RLOOTrainerTester testMethod=test_training_vlm_0_trl_internal_testing_tiny_Gemma3ForConditionalGeneration>,)
kw = {}

    @wraps(func)
    def standalone_func(*a, **kw):
>       return func(*(a + p.args), **p.kwargs, **kw)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.venv/lib/python3.12/site-packages/parameterized/parameterized.py:620: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
tests/test_rloo_trainer.py:1114: in test_training_vlm
    trainer.train()
.venv/lib/python3.12/site-packages/transformers/trainer.py:2325: in train
    return inner_training_loop(
.venv/lib/python3.12/site-packages/transformers/trainer.py:2674: in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/transformers/trainer.py:4013: in training_step
    inputs = self._prepare_inputs(inputs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
trl/extras/profiling.py:98: in wrapper
    return func(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
trl/trainer/rloo_trainer.py:997: in _prepare_inputs
    generation_batch = self._generate_and_score_completions(generation_batch)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
trl/trainer/rloo_trainer.py:1354: in _generate_and_score_completions
    old_per_token_logps, _ = self._get_per_token_logps_and_entropies(
trl/extras/profiling.py:98: in wrapper
    return func(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
trl/trainer/rloo_trainer.py:830: in _get_per_token_logps_and_entropies
    logits = model(**model_inputs).logits
             ^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1773: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1784: in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/accelerate/utils/operations.py:819: in forward
    return model_forward(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/accelerate/utils/operations.py:807: in __call__
    return convert_to_fp32(self.model_forward(*args, **kwargs))
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/torch/amp/autocast_mode.py:44: in decorate_autocast
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/transformers/models/gemma3/modeling_gemma3.py:1122: in forward
    outputs = self.model(
.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1773: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1784: in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/transformers/utils/generic.py:783: in wrapper
    output = func(self, *args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.12/site-packages/transformers/models/gemma3/modeling_gemma3.py:967: in forward
    causal_mask_mapping = create_causal_mask_mapping(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

config = Gemma3Config {
  "architectures": [
    "Gemma3ForConditionalGeneration"
  ],
  "boi_token_index": 255999,
  "bos_toke... 3,
    "num_hidden_layers": 2,
    "num_key_value_heads": 2,
    "patch_size": 14,
    "vision_use_head": false
  }
}

input_embeds = tensor([[[-0.0613, -0.1484,  0.1104,  ...,  0.0640, -0.0977, -0.1719],
         [-0.0369,  0.0674, -0.0698,  ..., -0.0...69,  0.0703,  0.1133],
         [ 0.0388, -0.0200, -0.0781,  ..., -0.1240, -0.0474,  0.1187]]],
       device='cuda:0')
attention_mask = tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1... 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]],
       device='cuda:0')
cache_position = tensor([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
         14,  15,  16,  17,  18,  19,  2...68, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279,
        280, 281, 282, 283, 284, 285, 286], device='cuda:0')
past_key_values = None, position_ids = None, token_type_ids = None
pixel_values = tensor([[[[-0.7412, -0.7412, -0.7412,  ...,  0.7020,  0.7020,  0.7020],
          [-0.7412, -0.7412, -0.7412,  ...,  0..., -0.5294, -0.5294],
          [ 0.2392,  0.2392,  0.2392,  ..., -0.5294, -0.5294, -0.5294]]]],
       device='cuda:0')
is_training = True, kwargs = {}

    def create_causal_mask_mapping(
        config: PretrainedConfig,
        input_embeds: torch.Tensor,
        attention_mask: Optional[torch.Tensor],
        cache_position: torch.Tensor,
        past_key_values: Optional[Cache],
        position_ids: Optional[torch.Tensor],
        token_type_ids: Optional[torch.Tensor] = None,
        pixel_values: Optional[torch.FloatTensor] = None,
        is_training: bool = False,
        **kwargs,
    ) -> dict:
        """
        Overwrites the base `create_masks_for_generate` with `token_type_ids` masking to create the causal mask mapping
        for all kinds of forward passes. Gemma3 uses a bidirectional mask for images.
    
        Uses `pixel_values` as an optional input to disambiguate edge cases.
        """
        if is_training and token_type_ids is None:
>           raise ValueError("`token_type_ids` is required as a model input when training")
E           ValueError: `token_type_ids` is required as a model input when training

.venv/lib/python3.12/site-packages/transformers/models/gemma3/modeling_gemma3.py:780: ValueError

It seems the error is raised by the transformers function create_causal_mask_mapping, which was created in:

@gante, any hint about how we should handle this on our side?

Metadata

Metadata

Labels

🏋 GRPORelated to GRPO🏋 RLOORelated to RLOO🐛 bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions