Skip to content

Commit 5cefb39

Browse files
Fix GRPO with replay buffer by inserting images in the prompt (#4391)
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
1 parent 50b96e2 commit 5cefb39

File tree

2 files changed

+66
-21
lines changed

2 files changed

+66
-21
lines changed

tests/experimental/test_grpo_with_replay_buffer_trainer.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
import torch
1717
from datasets import load_dataset
1818

19-
from trl import GRPOTrainer
2019
from trl.experimental.grpo_with_replay_buffer import (
2120
GRPOWithReplayBufferConfig,
2221
GRPOWithReplayBufferTrainer,
@@ -271,7 +270,7 @@ def custom_reward_func(completions, **kwargs):
271270
replay_buffer_size=8,
272271
report_to="none",
273272
)
274-
trainer = GRPOTrainer(
273+
trainer = GRPOWithReplayBufferTrainer(
275274
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
276275
reward_funcs=[custom_reward_func],
277276
args=training_args,

trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_trainer.py

Lines changed: 65 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,11 @@
1818
import torch
1919
from accelerate.utils import gather_object
2020

21-
from trl.data_utils import is_conversational
21+
from trl.data_utils import (
22+
apply_chat_template,
23+
is_conversational,
24+
prepare_multimodal_messages,
25+
)
2226
from trl.trainer.grpo_trainer import GRPOTrainer
2327
from trl.trainer.utils import nanmax, nanmin, nanstd, pad
2428

@@ -80,19 +84,36 @@ def _generate_and_score_completions(
8084
if images is not None and all(img_list == [] for img_list in images):
8185
images = None
8286

83-
(
84-
prompt_ids,
85-
completion_ids,
86-
prompt_mask,
87-
completion_mask,
88-
num_items_in_batch,
89-
sampling_per_token_logps,
90-
forward_kwargs,
91-
) = self._generate(prompts, images)
87+
# If the prompts are conversational and the inputs contain images, we need to convert the prompts from
88+
# [{"role": "user", "content": "What color is the sky?"}] to
89+
# [{"role": "user", "content": [{"type": "image", "image": <Image>}, {"type": "text", "text": "What color is the sky?"}]}]
90+
if images is not None:
91+
prompts = [prepare_multimodal_messages(prompt, image_list) for prompt, image_list in zip(prompts, images)]
92+
93+
prompt_ids_list, completion_ids_list, num_items_in_batch, sampling_per_token_logps_list, extra_fields = (
94+
self._generate(prompts)
95+
)
96+
97+
# Convert lists of token IDs to padded tensors
98+
prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list]
99+
prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids]
100+
prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left")
101+
prompt_mask = pad(prompt_mask, padding_value=0, padding_side="left")
102+
completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids_list]
103+
completion_mask = [torch.ones_like(ids, dtype=torch.long) for ids in completion_ids]
104+
completion_ids = pad(completion_ids, padding_value=self.pad_token_id, padding_side="right")
105+
completion_mask = pad(completion_mask, padding_value=0, padding_side="right")
106+
if sampling_per_token_logps_list is not None:
107+
sampling_per_token_logps = [torch.tensor(logps, device=device) for logps in sampling_per_token_logps_list]
108+
sampling_per_token_logps = pad(sampling_per_token_logps, padding_value=0.0, padding_side="right")
109+
else:
110+
sampling_per_token_logps = None
92111

93-
# Convert tensor to a list of lists of token IDs. This will be passed to the reward function, avoiding the need
94-
# to re-tokenize completions if the reward is computed from tokens.
95-
completion_ids_list = [row[mask_row].tolist() for row, mask_row in zip(completion_ids, completion_mask.bool())]
112+
# If mask_truncated_completions is enabled, zero out truncated completions in completion_mask
113+
if self.mask_truncated_completions:
114+
eos_and_pad = [self.eos_token_id, self.pad_token_id]
115+
is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids_list], device=device)
116+
completion_mask = completion_mask * (~is_truncated).unsqueeze(1).int()
96117

97118
# Concatenate prompt_mask with completion_mask for logit computation
98119
prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) # (B, P+C)
@@ -103,6 +124,25 @@ def _generate_and_score_completions(
103124

104125
num_images = [len(img_list) for img_list in images] if images is not None else None
105126

127+
# Get forward_kwargs for models with multimodal inputs
128+
if images is not None:
129+
prompts_text = [
130+
apply_chat_template({"prompt": prompt}, self.processing_class, **self.chat_template_kwargs)["prompt"]
131+
for prompt in prompts
132+
]
133+
prompt_inputs = self.processing_class(images=images, text=prompts_text, padding=True, return_tensors="pt")
134+
prompt_inputs = super()._prepare_inputs(prompt_inputs)
135+
forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]}
136+
else:
137+
forward_kwargs = {}
138+
139+
# If token_type_ids are used, extend them with zeros for the completion part
140+
if "token_type_ids" in forward_kwargs:
141+
token_type_ids = forward_kwargs["token_type_ids"]
142+
forward_kwargs["token_type_ids"] = torch.cat(
143+
[token_type_ids, token_type_ids.new_zeros(completion_ids.shape)], dim=1
144+
)
145+
106146
with torch.no_grad():
107147
# If the generation and optimization steps are misaligned—i.e., if generation does not occur at the end of
108148
# a full optimizer step (when gradient_accumulation_steps is not a multiple of generate_every)—then the
@@ -171,6 +211,15 @@ def _generate_and_score_completions(
171211
else:
172212
completions = completions_text
173213

214+
# Merge extra_fields from rollout_func into inputs for reward functions
215+
if extra_fields:
216+
for i, inp in enumerate(inputs):
217+
for key, values in extra_fields.items():
218+
if isinstance(values, list) and i < len(values):
219+
inp[key] = values[i]
220+
elif not isinstance(values, list):
221+
inp[key] = values
222+
174223
# Calculate rewards for each reward function. rewards_per_func aggregates rewards across all processes. This is
175224
# important because rewards will be normalized per group, and completions are distributed. We will later slice
176225
# rewards_per_func to extract each process's subset.
@@ -185,7 +234,7 @@ def _generate_and_score_completions(
185234
# Normalize the rewards to compute the advantages
186235
mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
187236
advantages = rewards - mean_grouped_rewards
188-
std_rewards = None
237+
189238
if self.scale_rewards in ["group", "none"]:
190239
# If self.scale_rewards = "none", we'll still log group level std
191240
std_rewards = rewards.view(-1, self.num_generations).std(dim=1)
@@ -209,10 +258,7 @@ def _generate_and_score_completions(
209258
)
210259
all_process_advantages = advantages.clone() # keep the aggregated advantages for logging
211260
advantages = advantages[process_slice]
212-
if std_rewards is None:
213-
std_rewards = rewards.view(-1, self.num_generations).std(dim=1)
214-
std_rewards = std_rewards.repeat_interleave(self.num_generations, dim=0)
215-
std_rewards = std_rewards[process_slice] if std_rewards is not None else None
261+
std_rewards = std_rewards[process_slice]
216262

217263
# Calculate mean reward per function, but only for samples where the function was applied (non-NaN values)
218264
for i, reward_func_name in enumerate(self.reward_func_names):
@@ -306,7 +352,7 @@ def _generate_and_score_completions(
306352
if "token_type_ids" in forward_kwargs:
307353
output["token_type_ids"] = forward_kwargs["token_type_ids"]
308354
if images is not None:
309-
output["images"] = images
355+
output["num_images"] = num_images
310356
return output
311357

312358
def slice_group_data(

0 commit comments

Comments
 (0)