Skip to content

Conversation

@albertvillanova
Copy link
Member

@albertvillanova albertvillanova commented Sep 26, 2025

Pass required token_type_ids.

Follow-up to transformers PR:

🚨 BC-breaking: paligemma processor now returns token_type_ids by default. This is required to disambiguate forward passes, due to the bidirectional attention mask in the prompt. Advanced generation methods may run forward passes with prompt + generated tokens, so they will fail without token_type_ids.

Fix #4142, fix #4150.

This PR extends support for the token_type_ids input across the GRPO and RLOO trainers, ensuring that models using token type information can correctly handle these inputs during training, evaluation, and loss computation.

The changes are applied consistently to:

  • GRPO
  • RLOO

Changes

Token type IDs support:

  • Added token_type_ids as an optional argument to the _get_per_token_logps_and_entropies method in both grpo_trainer.py and rloo_trainer.py, allowing the trainers to process token type information.
  • Updated batching logic to include token_type_ids when present, ensuring correct slicing and passing of token type IDs during batched forward passes.

Integration with completions and output:

  • In the _generate_and_score_completions method, extended token_type_ids with zeros for the completion tokens and ensured they are passed through the forward arguments and included in the output dictionary.

Loss computation:

  • Updated the _compute_loss method in both trainers to pass token_type_ids when available, ensuring that loss calculations take token type information into account.

CC: @gante

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@albertvillanova
Copy link
Member Author

Current state of the PR seems to fix the issue:

FAILED tests/test_rloo_trainer.py::RLOOTrainerTester::test_training_vlm_0_trl_internal_testing_tiny_Gemma3ForConditionalGeneration - TypeError: RLOOTrainer._get_per_token_logps_and_entropies() got an unexpected keyword argument 'token_type_ids'
= 1 failed, 920 passed, 49 skipped, 3 xfailed, 219 warnings, 5 rerun in 906.89s (0:15:06) =

The only remaining issue is now the TypeError:

@albertvillanova
Copy link
Member Author

Everything is green now! 🚀

token_type_ids = forward_kwargs["token_type_ids"]
forward_kwargs["token_type_ids"] = torch.cat(
[token_type_ids, token_type_ids.new_zeros(completion_ids.shape)], dim=1
)
Copy link
Member Author

@albertvillanova albertvillanova Sep 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you validate this approach, do you think this should be implemented in other trainers as well?

# Concatenate prompt_mask with completion_mask for logit computation
prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) # (B, P+C)
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C)
# If token_type_ids are used, extend them with zeros for the completion part
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 is for image and 0 for text, right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, completion tokens are text.

@albertvillanova albertvillanova changed the title WIP: Pass required token_type_ids Pass required token_type_ids Sep 29, 2025
@albertvillanova
Copy link
Member Author

@qgallouedec could you please validate this PR so we can finally have the CI green?

Copy link
Member

@qgallouedec qgallouedec left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm, thanks!

@albertvillanova albertvillanova merged commit 910aeeb into huggingface:main Sep 29, 2025
10 checks passed
kashif pushed a commit that referenced this pull request Sep 30, 2025
@gante
Copy link

gante commented Sep 30, 2025

Thank you for the fix 🙏

(I wouldn't be surprised if we see more models of this kind in the future -- token_type_ids is used essentially to tag blocks of inputs that need bidirectional attention)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

4 participants