-
Notifications
You must be signed in to change notification settings - Fork 2.3k
🧺 [1/N] Refactor _generate in GRPO/RLOO: list of ints instead of tensors
#4146
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…_thw` in GRPO and RLOO trainers; update `split_pixel_values_by_grid` to use `image_grid_thw`
trl/trainer/grpo_trainer.py
Outdated
| **kwargs, | ||
| ) | ||
| prompt_inputs = super()._prepare_inputs(prompt_inputs) | ||
| prompt_inputs = self.processing_class(text=prompts_text, add_special_tokens=False, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the function must now return a list of ints, so we must remove padding
| prompt_mask, | ||
| completion_mask, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
prompt and completion masks are later inferred from the sequence lengths
_generate in GRPO/RLOO_generate in GRPO/RLOO
_generate in GRPO/RLOO_generate in GRPO/RLOO: list of ints instead of tensors
albertvillanova
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks.
trl/trainer/utils.py
Outdated
| sequences (`list[int]`): | ||
| Input sequence of token IDs. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The sequences name in the docstring is not aligned with the ids name in the signature.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additionally, before it accepted batch_size sequences (within the tensor) and now it accepts a single sequence (list[int]). Isn't this breaking something? Some tests should be failing because of the new behavior?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some tests should be failing because of the new behavior?
yes, tests have been updated as well, see TruncateWithProtectedTokensTester
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
sequencesname in the docstring is not aligned with theidsname in the signature.
thanks! fixed in c570fb0
Co-authored-by: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
Co-authored-by: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
This PR belongs to a sequence of PR that aims to refactor the generation part of GRPO/RLOO to allow for easier customization and ultimately tool calling
Previous:
image_split_sizesin favour ofimage_grid_thw#4111_generate#4114Next:
_generatein GRPO/RLOO: Useprompt_idsfrom generation #4152_generatein GRPO/RLOO: Rely on generator for prompt truncation #4153_generatein GRPO/RLOO: Moveforward_kwargsoutside generation method #4154_generatein GRPO/RLOO: Insert images in the prompt #4155The idea with this PR is to make
_generatereturn list of ints instead of tensors. This will help a lots when implementing tool calling.Several modifications:
truncate_with_protected_tokens: instead of operating on 2D tensors (ids and mask), it will operate on sequence ids directly:before
after
_generatenow returns list of ids, instead of tensor + maskconversion to tensor is handle in
_generate_and_score_completionsThe generation part is moved to a function
_generate_single_turn, which is called by_generate