Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions tests/experimental/test_grpo_with_replay_buffer_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import torch
from datasets import load_dataset

from trl import GRPOTrainer
from trl.experimental.grpo_with_replay_buffer import (
GRPOWithReplayBufferConfig,
GRPOWithReplayBufferTrainer,
Expand Down Expand Up @@ -271,7 +270,7 @@ def custom_reward_func(completions, **kwargs):
replay_buffer_size=8,
report_to="none",
)
trainer = GRPOTrainer(
trainer = GRPOWithReplayBufferTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
reward_funcs=[custom_reward_func],
args=training_args,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@
import torch
from accelerate.utils import gather_object

from trl.data_utils import is_conversational
from trl.data_utils import (
apply_chat_template,
is_conversational,
prepare_multimodal_messages,
)
from trl.trainer.grpo_trainer import GRPOTrainer
from trl.trainer.utils import nanmax, nanmin, nanstd, pad

Expand Down Expand Up @@ -80,19 +84,36 @@ def _generate_and_score_completions(
if images is not None and all(img_list == [] for img_list in images):
images = None

(
prompt_ids,
completion_ids,
prompt_mask,
completion_mask,
num_items_in_batch,
sampling_per_token_logps,
forward_kwargs,
) = self._generate(prompts, images)
# If the prompts are conversational and the inputs contain images, we need to convert the prompts from
# [{"role": "user", "content": "What color is the sky?"}] to
# [{"role": "user", "content": [{"type": "image", "image": <Image>}, {"type": "text", "text": "What color is the sky?"}]}]
if images is not None:
prompts = [prepare_multimodal_messages(prompt, image_list) for prompt, image_list in zip(prompts, images)]

prompt_ids_list, completion_ids_list, num_items_in_batch, sampling_per_token_logps_list, extra_fields = (
self._generate(prompts)
)

# Convert lists of token IDs to padded tensors
prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list]
prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids]
prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left")
prompt_mask = pad(prompt_mask, padding_value=0, padding_side="left")
completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids_list]
completion_mask = [torch.ones_like(ids, dtype=torch.long) for ids in completion_ids]
completion_ids = pad(completion_ids, padding_value=self.pad_token_id, padding_side="right")
completion_mask = pad(completion_mask, padding_value=0, padding_side="right")
if sampling_per_token_logps_list is not None:
sampling_per_token_logps = [torch.tensor(logps, device=device) for logps in sampling_per_token_logps_list]
sampling_per_token_logps = pad(sampling_per_token_logps, padding_value=0.0, padding_side="right")
else:
sampling_per_token_logps = None

# Convert tensor to a list of lists of token IDs. This will be passed to the reward function, avoiding the need
# to re-tokenize completions if the reward is computed from tokens.
completion_ids_list = [row[mask_row].tolist() for row, mask_row in zip(completion_ids, completion_mask.bool())]
# If mask_truncated_completions is enabled, zero out truncated completions in completion_mask
if self.mask_truncated_completions:
eos_and_pad = [self.eos_token_id, self.pad_token_id]
is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids_list], device=device)
completion_mask = completion_mask * (~is_truncated).unsqueeze(1).int()

# Concatenate prompt_mask with completion_mask for logit computation
prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) # (B, P+C)
Expand All @@ -103,6 +124,25 @@ def _generate_and_score_completions(

num_images = [len(img_list) for img_list in images] if images is not None else None

# Get forward_kwargs for models with multimodal inputs
if images is not None:
prompts_text = [
apply_chat_template({"prompt": prompt}, self.processing_class, **self.chat_template_kwargs)["prompt"]
for prompt in prompts
]
prompt_inputs = self.processing_class(images=images, text=prompts_text, padding=True, return_tensors="pt")
prompt_inputs = super()._prepare_inputs(prompt_inputs)
forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]}
else:
forward_kwargs = {}

# If token_type_ids are used, extend them with zeros for the completion part
if "token_type_ids" in forward_kwargs:
token_type_ids = forward_kwargs["token_type_ids"]
forward_kwargs["token_type_ids"] = torch.cat(
[token_type_ids, token_type_ids.new_zeros(completion_ids.shape)], dim=1
)

with torch.no_grad():
# If the generation and optimization steps are misaligned—i.e., if generation does not occur at the end of
# a full optimizer step (when gradient_accumulation_steps is not a multiple of generate_every)—then the
Expand Down Expand Up @@ -171,6 +211,15 @@ def _generate_and_score_completions(
else:
completions = completions_text

# Merge extra_fields from rollout_func into inputs for reward functions
if extra_fields:
for i, inp in enumerate(inputs):
for key, values in extra_fields.items():
if isinstance(values, list) and i < len(values):
inp[key] = values[i]
elif not isinstance(values, list):
inp[key] = values

# Calculate rewards for each reward function. rewards_per_func aggregates rewards across all processes. This is
# important because rewards will be normalized per group, and completions are distributed. We will later slice
# rewards_per_func to extract each process's subset.
Expand All @@ -185,7 +234,7 @@ def _generate_and_score_completions(
# Normalize the rewards to compute the advantages
mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
advantages = rewards - mean_grouped_rewards
std_rewards = None

if self.scale_rewards in ["group", "none"]:
# If self.scale_rewards = "none", we'll still log group level std
std_rewards = rewards.view(-1, self.num_generations).std(dim=1)
Expand All @@ -209,10 +258,7 @@ def _generate_and_score_completions(
)
all_process_advantages = advantages.clone() # keep the aggregated advantages for logging
advantages = advantages[process_slice]
if std_rewards is None:
std_rewards = rewards.view(-1, self.num_generations).std(dim=1)
std_rewards = std_rewards.repeat_interleave(self.num_generations, dim=0)
std_rewards = std_rewards[process_slice] if std_rewards is not None else None
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This last line is not in GRPOTrainer: is it necessary? If so, shouldn't we implement it in GRPOTrainer as well?

std_rewards = std_rewards[process_slice]

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pramodith if you've some time to check

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will take a look a bit later this evening!

Copy link
Collaborator

@pramodith pramodith Oct 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need the sliced std_rewards in this trainer because we decide if a specific example should be added to the replay buffer or sampled from the buffer based on std_reward of that specific rollout. Since each gpu sees a unique batch of data we need to only perform the buffer lookup and update based on the slice residing in the gpu.

GRPOTrainer doesn't need the std after advantage scores are computed so it can be discarded in GRPOTrainer.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shows how std_rewards is used for updating the replay buffer

if groups_with_variance.any():
# Calculate replay buffer scores for groups with variance
replay_buffer_scores = (group_advantages.abs() * group_std_rewards).sum(dim=-1)[groups_with_variance]
# Add all groups to replay buffer at once (batch operation)
self.replay_buffer.add(replay_buffer_scores.tolist(), buffered_outputs)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This entire block of code removed should remain in grpo_with_replay_buffer_trainer.py we always need the group level std to determine what goes into the replay. buffer.

Copy link
Member Author

@albertvillanova albertvillanova Oct 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation about std_rewards = std_rewards[process_slice], @pramodith. 🤗

Respect to the block lines just before, I removed the condition if std_rewards is None because I think this is always False. Just some lines above, we have this code: https://github.com/albertvillanova/trl/blob/d0324230761e7860646f4d15d7ff8beb433103ac/trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_trainer.py#L250-L260

        if self.scale_rewards in ["group", "none"]:
            # If self.scale_rewards = "none", we'll still log group level std
            std_rewards = rewards.view(-1, self.num_generations).std(dim=1)
            std_rewards = std_rewards.repeat_interleave(self.num_generations, dim=0)
        elif self.scale_rewards == "batch":
            # Compute global std
            std_rewards = rewards.std().expand_as(rewards)
        else:
            raise ValueError(
                f"Invalid value for scale_rewards: {self.scale_rewards}. Must be one of 'batch', 'group', or 'none'."
            )

Therefore, std_rewards can't be None if I understand correctly. It should always be a torch.Tensor.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm, yeah you're right but there's a bug here. The replay buffer requires the std to be computed over the group. I'll fix that in a subsequent PR, getting rid of that block in this PR is fine.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I re-added just that line: 4266550
😉

std_rewards = std_rewards[process_slice]

# Calculate mean reward per function, but only for samples where the function was applied (non-NaN values)
for i, reward_func_name in enumerate(self.reward_func_names):
Expand Down Expand Up @@ -306,7 +352,7 @@ def _generate_and_score_completions(
if "token_type_ids" in forward_kwargs:
output["token_type_ids"] = forward_kwargs["token_type_ids"]
if images is not None:
output["images"] = images
output["num_images"] = num_images
return output

def slice_group_data(
Expand Down
Loading