From 6ed80fbd561cae16434d9fc5dd95b576ef512256 Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Wed, 29 Oct 2025 17:30:53 +0100 Subject: [PATCH 1/5] Fix trainer in GRPO with replay buffer test --- tests/experimental/test_grpo_with_replay_buffer_trainer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/experimental/test_grpo_with_replay_buffer_trainer.py b/tests/experimental/test_grpo_with_replay_buffer_trainer.py index cad66f8034c..c73f41a0d4a 100644 --- a/tests/experimental/test_grpo_with_replay_buffer_trainer.py +++ b/tests/experimental/test_grpo_with_replay_buffer_trainer.py @@ -2,7 +2,6 @@ import torch from datasets import load_dataset -from trl import GRPOTrainer from trl.experimental.grpo_with_replay_buffer import ( GRPOWithReplayBufferConfig, GRPOWithReplayBufferTrainer, @@ -257,7 +256,7 @@ def custom_reward_func(completions, **kwargs): replay_buffer_size=8, report_to="none", ) - trainer = GRPOTrainer( + trainer = GRPOWithReplayBufferTrainer( model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", reward_funcs=[custom_reward_func], args=training_args, From d0324230761e7860646f4d15d7ff8beb433103ac Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Thu, 30 Oct 2025 13:25:21 +0100 Subject: [PATCH 2/5] Insert images in the prompt for GRPOWithReplayBufferTrainer --- .../grpo_with_replay_buffer_trainer.py | 93 +++++++++++++++---- 1 file changed, 75 insertions(+), 18 deletions(-) diff --git a/trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_trainer.py b/trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_trainer.py index 10d5948cae7..867a66733b2 100644 --- a/trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_trainer.py +++ b/trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_trainer.py @@ -18,7 +18,11 @@ import torch from accelerate.utils import gather_object -from trl.data_utils import is_conversational +from trl.data_utils import ( + apply_chat_template, + is_conversational, + prepare_multimodal_messages, +) from trl.trainer.grpo_trainer import GRPOTrainer from trl.trainer.utils import nanmax, nanmin, nanstd, pad @@ -80,19 +84,48 @@ def _generate_and_score_completions( if images is not None and all(img_list == [] for img_list in images): images = None - ( - prompt_ids, - completion_ids, - prompt_mask, - completion_mask, - num_items_in_batch, - sampling_per_token_logps, - forward_kwargs, - ) = self._generate(prompts, images) + # If the prompts are conversational and the inputs contain images, we need to convert the prompts from + # [{"role": "user", "content": "What color is the sky?"}] to + # [{"role": "user", "content": [{"type": "image", "image": }, {"type": "text", "text": "What color is the sky?"}]}] + if images is not None: + prompts = [prepare_multimodal_messages(prompt, image_list) for prompt, image_list in zip(prompts, images)] + + # ( + # prompt_ids, + # completion_ids, + # prompt_mask, + # completion_mask, + # num_items_in_batch, + # sampling_per_token_logps, + # forward_kwargs, + # ) = self._generate(prompts, images) + prompt_ids_list, completion_ids_list, num_items_in_batch, sampling_per_token_logps_list, extra_fields = ( + self._generate(prompts) + ) + + # # Convert tensor to a list of lists of token IDs. This will be passed to the reward function, avoiding the need + # # to re-tokenize completions if the reward is computed from tokens. + # completion_ids_list = [row[mask_row].tolist() for row, mask_row in zip(completion_ids, completion_mask.bool())] + # Convert lists of token IDs to padded tensors + prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list] + prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids] + prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left") + prompt_mask = pad(prompt_mask, padding_value=0, padding_side="left") + completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids_list] + completion_mask = [torch.ones_like(ids, dtype=torch.long) for ids in completion_ids] + completion_ids = pad(completion_ids, padding_value=self.pad_token_id, padding_side="right") + completion_mask = pad(completion_mask, padding_value=0, padding_side="right") + if sampling_per_token_logps_list is not None: + sampling_per_token_logps = [torch.tensor(logps, device=device) for logps in sampling_per_token_logps_list] + sampling_per_token_logps = pad(sampling_per_token_logps, padding_value=0.0, padding_side="right") + else: + sampling_per_token_logps = None - # Convert tensor to a list of lists of token IDs. This will be passed to the reward function, avoiding the need - # to re-tokenize completions if the reward is computed from tokens. - completion_ids_list = [row[mask_row].tolist() for row, mask_row in zip(completion_ids, completion_mask.bool())] + # If mask_truncated_completions is enabled, zero out truncated completions in completion_mask + if self.mask_truncated_completions: + eos_and_pad = [self.eos_token_id, self.pad_token_id] + is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids_list], device=device) + completion_mask = completion_mask * (~is_truncated).unsqueeze(1).int() # Concatenate prompt_mask with completion_mask for logit computation prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) # (B, P+C) @@ -103,6 +136,25 @@ def _generate_and_score_completions( num_images = [len(img_list) for img_list in images] if images is not None else None + # Get forward_kwargs for models with multimodal inputs + if images is not None: + prompts_text = [ + apply_chat_template({"prompt": prompt}, self.processing_class, **self.chat_template_kwargs)["prompt"] + for prompt in prompts + ] + prompt_inputs = self.processing_class(images=images, text=prompts_text, padding=True, return_tensors="pt") + prompt_inputs = super()._prepare_inputs(prompt_inputs) + forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]} + else: + forward_kwargs = {} + + # If token_type_ids are used, extend them with zeros for the completion part + if "token_type_ids" in forward_kwargs: + token_type_ids = forward_kwargs["token_type_ids"] + forward_kwargs["token_type_ids"] = torch.cat( + [token_type_ids, token_type_ids.new_zeros(completion_ids.shape)], dim=1 + ) + with torch.no_grad(): # If the generation and optimization steps are misaligned—i.e., if generation does not occur at the end of # a full optimizer step (when gradient_accumulation_steps is not a multiple of generate_every)—then the @@ -171,6 +223,15 @@ def _generate_and_score_completions( else: completions = completions_text + # Merge extra_fields from rollout_func into inputs for reward functions + if extra_fields: + for i, inp in enumerate(inputs): + for key, values in extra_fields.items(): + if isinstance(values, list) and i < len(values): + inp[key] = values[i] + elif not isinstance(values, list): + inp[key] = values + # Calculate rewards for each reward function. rewards_per_func aggregates rewards across all processes. This is # important because rewards will be normalized per group, and completions are distributed. We will later slice # rewards_per_func to extract each process's subset. @@ -185,7 +246,7 @@ def _generate_and_score_completions( # Normalize the rewards to compute the advantages mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0) advantages = rewards - mean_grouped_rewards - std_rewards = None + if self.scale_rewards in ["group", "none"]: # If self.scale_rewards = "none", we'll still log group level std std_rewards = rewards.view(-1, self.num_generations).std(dim=1) @@ -209,10 +270,6 @@ def _generate_and_score_completions( ) all_process_advantages = advantages.clone() # keep the aggregated advantages for logging advantages = advantages[process_slice] - if std_rewards is None: - std_rewards = rewards.view(-1, self.num_generations).std(dim=1) - std_rewards = std_rewards.repeat_interleave(self.num_generations, dim=0) - std_rewards = std_rewards[process_slice] if std_rewards is not None else None # Calculate mean reward per function, but only for samples where the function was applied (non-NaN values) for i, reward_func_name in enumerate(self.reward_func_names): From 4266550534ac69feb0729ea0c0f3bb0ba4666ab2 Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Thu, 30 Oct 2025 19:50:51 +0100 Subject: [PATCH 3/5] Slice std_rewards --- .../grpo_with_replay_buffer/grpo_with_replay_buffer_trainer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_trainer.py b/trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_trainer.py index 867a66733b2..ae34ae4badc 100644 --- a/trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_trainer.py +++ b/trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_trainer.py @@ -270,6 +270,7 @@ def _generate_and_score_completions( ) all_process_advantages = advantages.clone() # keep the aggregated advantages for logging advantages = advantages[process_slice] + std_rewards = std_rewards[process_slice] # Calculate mean reward per function, but only for samples where the function was applied (non-NaN values) for i, reward_func_name in enumerate(self.reward_func_names): From d71921b701b335b4ea7ecbafaa34293c75ac6c5b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 31 Oct 2025 16:02:03 +0000 Subject: [PATCH 4/5] remove commented code + fix num_images output --- .../grpo_with_replay_buffer_trainer.py | 14 +------------- 1 file changed, 1 insertion(+), 13 deletions(-) diff --git a/trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_trainer.py b/trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_trainer.py index ae34ae4badc..47037a104a2 100644 --- a/trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_trainer.py +++ b/trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_trainer.py @@ -90,22 +90,10 @@ def _generate_and_score_completions( if images is not None: prompts = [prepare_multimodal_messages(prompt, image_list) for prompt, image_list in zip(prompts, images)] - # ( - # prompt_ids, - # completion_ids, - # prompt_mask, - # completion_mask, - # num_items_in_batch, - # sampling_per_token_logps, - # forward_kwargs, - # ) = self._generate(prompts, images) prompt_ids_list, completion_ids_list, num_items_in_batch, sampling_per_token_logps_list, extra_fields = ( self._generate(prompts) ) - # # Convert tensor to a list of lists of token IDs. This will be passed to the reward function, avoiding the need - # # to re-tokenize completions if the reward is computed from tokens. - # completion_ids_list = [row[mask_row].tolist() for row, mask_row in zip(completion_ids, completion_mask.bool())] # Convert lists of token IDs to padded tensors prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list] prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids] @@ -364,7 +352,7 @@ def _generate_and_score_completions( if "token_type_ids" in forward_kwargs: output["token_type_ids"] = forward_kwargs["token_type_ids"] if images is not None: - output["images"] = images + output["num_images"] = num_images return output def slice_group_data( From bd2ca7f9e998e221e2531630212e3ee485664434 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 31 Oct 2025 16:15:01 +0000 Subject: [PATCH 5/5] fix keyword argument --- tests/experimental/test_grpo_with_replay_buffer_trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/experimental/test_grpo_with_replay_buffer_trainer.py b/tests/experimental/test_grpo_with_replay_buffer_trainer.py index 0587c3589bf..ff2e0835406 100644 --- a/tests/experimental/test_grpo_with_replay_buffer_trainer.py +++ b/tests/experimental/test_grpo_with_replay_buffer_trainer.py @@ -139,7 +139,7 @@ def _make_inputs(self, group_advantages, with_pixels=False, with_logprobs=False) "prompt_mask": torch.ones(4, 2, dtype=torch.long), "completion_ids": torch.tensor([[9, 10], [11, 12], [13, 14], [15, 16]]), "completion_mask": torch.ones(4, 2, dtype=torch.long), - "prompt_inputs": {"pixel_values": torch.randn(4, 3, 224, 224)} if with_pixels else {}, + "forward_kwargs": {"pixel_values": torch.randn(4, 3, 224, 224)} if with_pixels else {}, "old_per_token_logps": torch.randn(4, 2) if with_logprobs else None, } inputs["group_std_rewards"] = group_advantages.std(dim=1).expand_as(group_advantages) @@ -216,7 +216,7 @@ def test_update_with_inputs_different_seq_len(self): ] ), "completion_mask": torch.tensor([[1, 1, 0], [1, 1, 1], [1, 1, 0], [1, 1, 1]], dtype=torch.long), - "prompt_inputs": {}, + "forward_kwargs": {}, } inputs["group_std_rewards"] = group_advantages.std(dim=1).expand_as(group_advantages)