Skip to content

Conversation

@albertvillanova
Copy link
Member

@albertvillanova albertvillanova commented Oct 30, 2025

Fix GRPO with replay buffer by inserting images in the prompt. Additionally, fix the CI test test_training_with_replay_buffer.

Follow-up to:

Currently, GRPO with Replay Buffer raises an error: https://github.com/huggingface/trl/actions/runs/18940392458/job/54077463859

TypeError: GRPOTrainer._generate() takes 2 positional arguments but 3 were given

Stacktrace:

>       trainer.train()

tests/experimental/test_grpo_with_replay_buffer_trainer.py:282: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
.venv/lib/python3.13/site-packages/transformers/trainer.py:2325: in train
    return inner_training_loop(
.venv/lib/python3.13/site-packages/transformers/trainer.py:2674: in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.13/site-packages/transformers/trainer.py:4014: in training_step
    inputs = self._prepare_inputs(inputs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
trl/extras/profiling.py:98: in wrapper
    return func(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
trl/trainer/grpo_trainer.py:1037: in _prepare_inputs
    generation_batch = self._generate_and_score_completions(generation_batch)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <trl.experimental.grpo_with_replay_buffer.grpo_with_replay_buffer_trainer.GRPOWithReplayBufferTrainer object at 0x7fb2382151d0>
inputs = [{'prompt': "Although that way may not be obvious at first unless you're"}, {'prompt': "Although that way may not be o...may not be obvious at first unless you're"}, {'prompt': "Although that way may not be obvious at first unless you're"}]

    def _generate_and_score_completions(
        self, inputs: list[dict[str, Union[torch.Tensor, Any]]]
    ) -> dict[str, Union[torch.Tensor, Any]]:
        device = self.accelerator.device
        mode = "train" if self.model.training else "eval"
    
        prompts = [x["prompt"] for x in inputs]
    
        if "images" in inputs[0]:
            images = [example.get("images") for example in inputs]
        elif "image" in inputs[0]:
            images = [[example.get("image")] if example.get("image") is not None else None for example in inputs]
        else:
            images = None
        # Transformers requires at least one image in the batch, otherwise it throws an error
        if images is not None and all(img_list == [] for img_list in images):
            images = None
    
        (
            prompt_ids,
            completion_ids,
            prompt_mask,
            completion_mask,
            num_items_in_batch,
            sampling_per_token_logps,
            forward_kwargs,
>       ) = self._generate(prompts, images)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E       TypeError: GRPOTrainer._generate() takes 2 positional arguments but 3 were given

trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_trainer.py:91: TypeError

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

if std_rewards is None:
std_rewards = rewards.view(-1, self.num_generations).std(dim=1)
std_rewards = std_rewards.repeat_interleave(self.num_generations, dim=0)
std_rewards = std_rewards[process_slice] if std_rewards is not None else None
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This last line is not in GRPOTrainer: is it necessary? If so, shouldn't we implement it in GRPOTrainer as well?

std_rewards = std_rewards[process_slice]

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pramodith if you've some time to check

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will take a look a bit later this evening!

Copy link
Collaborator

@pramodith pramodith Oct 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need the sliced std_rewards in this trainer because we decide if a specific example should be added to the replay buffer or sampled from the buffer based on std_reward of that specific rollout. Since each gpu sees a unique batch of data we need to only perform the buffer lookup and update based on the slice residing in the gpu.

GRPOTrainer doesn't need the std after advantage scores are computed so it can be discarded in GRPOTrainer.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shows how std_rewards is used for updating the replay buffer

if groups_with_variance.any():
# Calculate replay buffer scores for groups with variance
replay_buffer_scores = (group_advantages.abs() * group_std_rewards).sum(dim=-1)[groups_with_variance]
# Add all groups to replay buffer at once (batch operation)
self.replay_buffer.add(replay_buffer_scores.tolist(), buffered_outputs)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This entire block of code removed should remain in grpo_with_replay_buffer_trainer.py we always need the group level std to determine what goes into the replay. buffer.

Copy link
Member Author

@albertvillanova albertvillanova Oct 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation about std_rewards = std_rewards[process_slice], @pramodith. 🤗

Respect to the block lines just before, I removed the condition if std_rewards is None because I think this is always False. Just some lines above, we have this code: https://github.com/albertvillanova/trl/blob/d0324230761e7860646f4d15d7ff8beb433103ac/trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_trainer.py#L250-L260

        if self.scale_rewards in ["group", "none"]:
            # If self.scale_rewards = "none", we'll still log group level std
            std_rewards = rewards.view(-1, self.num_generations).std(dim=1)
            std_rewards = std_rewards.repeat_interleave(self.num_generations, dim=0)
        elif self.scale_rewards == "batch":
            # Compute global std
            std_rewards = rewards.std().expand_as(rewards)
        else:
            raise ValueError(
                f"Invalid value for scale_rewards: {self.scale_rewards}. Must be one of 'batch', 'group', or 'none'."
            )

Therefore, std_rewards can't be None if I understand correctly. It should always be a torch.Tensor.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm, yeah you're right but there's a bug here. The replay buffer requires the std to be computed over the group. I'll fix that in a subsequent PR, getting rid of that block in this PR is fine.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I re-added just that line: 4266550
😉

@albertvillanova
Copy link
Member Author

The failing test will be fixed after the merge of:

Copy link
Member

@qgallouedec qgallouedec left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok lgtm!

@qgallouedec
Copy link
Member

Mmh wait we've

tests/experimental/test_grpo_with_replay_buffer_trainer.py::TestUpdateWithReplayBuffer::test_update_with_inputs_different_seq_len - TypeError: GRPOWithReplayBufferTrainer.update_with_replay_buffer() got an unexpected keyword argument 'prompt_inputs'. Did you mean 'prompt_ids'?

@albertvillanova
Copy link
Member Author

albertvillanova commented Oct 31, 2025

That is fixed in my previous PR. See my comment above, @qgallouedec.

@albertvillanova
Copy link
Member Author

@qgallouedec, there were 2 different bugs:

@qgallouedec
Copy link
Member

oh my bad, thanks for fixing it

@qgallouedec
Copy link
Member

qgallouedec commented Oct 31, 2025

Let's merge #4366 to keep things clean

EDIT Let's merge #4366 first to keep things clean

@albertvillanova albertvillanova merged commit 5cefb39 into huggingface:main Oct 31, 2025
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants