-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Pass required token_type_ids #4148
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
Current state of the PR seems to fix the issue:
FAILED tests/test_rloo_trainer.py::RLOOTrainerTester::test_training_vlm_0_trl_internal_testing_tiny_Gemma3ForConditionalGeneration - TypeError: RLOOTrainer._get_per_token_logps_and_entropies() got an unexpected keyword argument 'token_type_ids'
= 1 failed, 920 passed, 49 skipped, 3 xfailed, 219 warnings, 5 rerun in 906.89s (0:15:06) =The only remaining issue is now the |
|
Everything is green now! 🚀 |
| token_type_ids = forward_kwargs["token_type_ids"] | ||
| forward_kwargs["token_type_ids"] = torch.cat( | ||
| [token_type_ids, token_type_ids.new_zeros(completion_ids.shape)], dim=1 | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you validate this approach, do you think this should be implemented in other trainers as well?
| # Concatenate prompt_mask with completion_mask for logit computation | ||
| prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) # (B, P+C) | ||
| attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C) | ||
| # If token_type_ids are used, extend them with zeros for the completion part |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1 is for image and 0 for text, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, completion tokens are text.
|
@qgallouedec could you please validate this PR so we can finally have the CI green? |
qgallouedec
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm, thanks!
|
Thank you for the fix 🙏 (I wouldn't be surprised if we see more models of this kind in the future -- |
Pass required
token_type_ids.Follow-up to
transformersPR:Fix #4142, fix #4150.
This PR extends support for the
token_type_idsinput across the GRPO and RLOO trainers, ensuring that models using token type information can correctly handle these inputs during training, evaluation, and loss computation.The changes are applied consistently to:
Changes
Token type IDs support:
token_type_idsas an optional argument to the_get_per_token_logps_and_entropiesmethod in bothgrpo_trainer.pyandrloo_trainer.py, allowing the trainers to process token type information.token_type_idswhen present, ensuring correct slicing and passing of token type IDs during batched forward passes.Integration with completions and output:
_generate_and_score_completionsmethod, extendedtoken_type_idswith zeros for the completion tokens and ensured they are passed through the forward arguments and included in the output dictionary.Loss computation:
_compute_lossmethod in both trainers to passtoken_type_idswhen available, ensuring that loss calculations take token type information into account.CC: @gante