-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Fix GRPO with replay buffer by inserting images in the prompt #4391
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix GRPO with replay buffer by inserting images in the prompt #4391
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
| if std_rewards is None: | ||
| std_rewards = rewards.view(-1, self.num_generations).std(dim=1) | ||
| std_rewards = std_rewards.repeat_interleave(self.num_generations, dim=0) | ||
| std_rewards = std_rewards[process_slice] if std_rewards is not None else None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This last line is not in GRPOTrainer: is it necessary? If so, shouldn't we implement it in GRPOTrainer as well?
std_rewards = std_rewards[process_slice]There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@pramodith if you've some time to check
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will take a look a bit later this evening!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need the sliced std_rewards in this trainer because we decide if a specific example should be added to the replay buffer or sampled from the buffer based on std_reward of that specific rollout. Since each gpu sees a unique batch of data we need to only perform the buffer lookup and update based on the slice residing in the gpu.
GRPOTrainer doesn't need the std after advantage scores are computed so it can be discarded in GRPOTrainer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This shows how std_rewards is used for updating the replay buffer
trl/trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_trainer.py
Lines 407 to 411 in 1c2322e
| if groups_with_variance.any(): | |
| # Calculate replay buffer scores for groups with variance | |
| replay_buffer_scores = (group_advantages.abs() * group_std_rewards).sum(dim=-1)[groups_with_variance] | |
| # Add all groups to replay buffer at once (batch operation) | |
| self.replay_buffer.add(replay_buffer_scores.tolist(), buffered_outputs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This entire block of code removed should remain in grpo_with_replay_buffer_trainer.py we always need the group level std to determine what goes into the replay. buffer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the explanation about std_rewards = std_rewards[process_slice], @pramodith. 🤗
Respect to the block lines just before, I removed the condition if std_rewards is None because I think this is always False. Just some lines above, we have this code: https://github.com/albertvillanova/trl/blob/d0324230761e7860646f4d15d7ff8beb433103ac/trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_trainer.py#L250-L260
if self.scale_rewards in ["group", "none"]:
# If self.scale_rewards = "none", we'll still log group level std
std_rewards = rewards.view(-1, self.num_generations).std(dim=1)
std_rewards = std_rewards.repeat_interleave(self.num_generations, dim=0)
elif self.scale_rewards == "batch":
# Compute global std
std_rewards = rewards.std().expand_as(rewards)
else:
raise ValueError(
f"Invalid value for scale_rewards: {self.scale_rewards}. Must be one of 'batch', 'group', or 'none'."
)Therefore, std_rewards can't be None if I understand correctly. It should always be a torch.Tensor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmmm, yeah you're right but there's a bug here. The replay buffer requires the std to be computed over the group. I'll fix that in a subsequent PR, getting rid of that block in this PR is fine.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I re-added just that line: 4266550
😉
|
The failing test will be fixed after the merge of: |
qgallouedec
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok lgtm!
|
Mmh wait we've |
|
That is fixed in my previous PR. See my comment above, @qgallouedec. |
|
@qgallouedec, there were 2 different bugs:
|
|
oh my bad, thanks for fixing it |
Fix GRPO with replay buffer by inserting images in the prompt. Additionally, fix the CI test
test_training_with_replay_buffer.Follow-up to:
_generatein GRPO/RLOO: Insert images in the prompt #4155Currently, GRPO with Replay Buffer raises an error: https://github.com/huggingface/trl/actions/runs/18940392458/job/54077463859
Stacktrace: