From 65481f31c5f9e864f7da4675160d4c84a8819500 Mon Sep 17 00:00:00 2001 From: Venkata Satyanarayana Chivatam Date: Mon, 13 Oct 2025 18:22:58 -0400 Subject: [PATCH] Fix distributed GRPO trajectory concatenation and reward handling This commit addresses several issues in the distributed GRPO implementation: 1. **Fix trajectory concatenation**: Update generate_trajectory_batched to properly handle the 'answers' field which is a list of strings, not a tensor. 2. **Update reward function signature**: Modify batched_rewards call to include device parameter and handle multiple reward functions properly. 3. **Add missing fields to GRPOTrajectory**: Add rewards and successes fields to the GRPOTrajectory NamedTuple and update docstring. 4. **Fix GRPOStats concatenation**: Update the concatenation logic in the training loop to properly handle the metadata field. 5. **Fix reward indexing**: Correct the answer indexing in batched_rewards function from answers[b][g] to answers[b]. These changes ensure that distributed GRPO training works correctly with proper trajectory handling and reward computation. --- recipes/dev/grpo_full_finetune_distributed.py | 45 ++++++++++++++++--- torchtune/dev/rl/rewards.py | 4 +- torchtune/dev/rl/types.py | 4 ++ 3 files changed, 46 insertions(+), 7 deletions(-) diff --git a/recipes/dev/grpo_full_finetune_distributed.py b/recipes/dev/grpo_full_finetune_distributed.py index 145f5cd661..591ddd5c36 100644 --- a/recipes/dev/grpo_full_finetune_distributed.py +++ b/recipes/dev/grpo_full_finetune_distributed.py @@ -646,9 +646,15 @@ def generate_trajectory( # Do some reward modelingggggggg # responses :: [B x G, L] responses = responses.reshape(batch_size, grpo_size, -1) # [B, G, L] - rewards, successes = batched_rewards(self._tokenizer, responses, answers) - rewards = rewards.to(self._device) # [B, G] - successes = successes.to(self._device) # [B, G] + rewards, successes, metadata = batched_rewards( + self._tokenizer, responses, answers, device=self._device + ) + rewards = rewards.to(self._device) # [B, G, num_reward_funcs] + successes = successes.to(self._device) # [B, G, num_reward_funcs] + + # Aggregate rewards and successes across reward functions + rewards = rewards.sum(dim=-1) # [B, G] + successes = successes.sum(dim=-1) # [B, G] advantages = (rewards - rewards.mean(1, keepdim=True)) / ( rewards.std(1, keepdim=True) + 1e-4 @@ -672,6 +678,7 @@ def generate_trajectory( position_ids=position_ids, response_padding_masks=response_padding_masks, seq_lens=training.get_unmasked_sequence_lengths(response_padding_masks), + answers=answers, ) def generate_trajectory_batched( @@ -703,7 +710,22 @@ def generate_trajectory_batched( self.generate_trajectory(batch_input_ids, batch_answers) ) torch.cuda.empty_cache() - return GRPOTrajectory(*map(torch.cat, zip(*trajectories))) + + # Concatenate all trajectory fields except answers (which is a list of strings) + concatenated_fields = {} + for field_name in trajectories[0]._fields: + if field_name == "answers": + # Concatenate lists of answers + concatenated_fields[field_name] = [] + for traj in trajectories: + concatenated_fields[field_name].extend(traj.answers) + else: + # Concatenate tensors + concatenated_fields[field_name] = torch.cat( + [getattr(traj, field_name) for traj in trajectories] + ) + + return GRPOTrajectory(**concatenated_fields) def grpo_step( self, @@ -771,6 +793,7 @@ def grpo_step( ratios, clipfrac, approx_policy_kls, + None, # metadata ) def train(self) -> None: @@ -853,9 +876,21 @@ def train(self) -> None: if grad_norm is not None: extra_metrics["grad_norm"] = grad_norm + # Concatenate GRPOStats fields properly + concatenated_stats = {} + for field_name in grpo_stats[0]._fields: + if field_name == "metadata": + # Handle metadata separately (it's None, so just use None) + concatenated_stats[field_name] = None + else: + # Stack tensors + concatenated_stats[field_name] = torch.stack( + [getattr(stat, field_name) for stat in grpo_stats] + ) + self.log_metrics( trajectory, - GRPOStats(*map(torch.stack, zip(*grpo_stats))), + GRPOStats(**concatenated_stats), **extra_metrics, ) diff --git a/torchtune/dev/rl/rewards.py b/torchtune/dev/rl/rewards.py index 95c45ee9b0..42dc40a74f 100644 --- a/torchtune/dev/rl/rewards.py +++ b/torchtune/dev/rl/rewards.py @@ -298,8 +298,8 @@ def batched_rewards( for b in range(batch_size): for g in range(grpo_size): - - answer = answers[b][g] + # print(answers) + answer = answers[b] text_completion = tokenizer.decode(completions[b, g].tolist()) diff --git a/torchtune/dev/rl/types.py b/torchtune/dev/rl/types.py index b0aae365ff..3120d0b826 100644 --- a/torchtune/dev/rl/types.py +++ b/torchtune/dev/rl/types.py @@ -19,6 +19,8 @@ class GRPOTrajectory(NamedTuple): logprobs (torch.Tensor): Log probabilities of the generated responses with shape [B x G, L]. ref_logprobs (torch.Tensor): Log probabilities of the generated responses using the reference policy with shape [B x G, L]. advantages (torch.Tensor): Advantage estimates for the generated responses with shape [B x G]. + rewards (torch.Tensor): Reward values for the generated responses with shape [B x G]. + successes (torch.Tensor): Success indicators for the generated responses with shape [B x G]. masks (torch.Tensor): Attention masks for input ids-generated responses pairs with shape [B x G, P+L, P+L]. position_ids (torch.Tensor): Position IDs for input ids-generated responses pairs with shape [B x G, P+L]. response_padding_masks (torch.Tensor): Padding masks for the truncated and padded generated responses with shape [B x G, L]. @@ -30,6 +32,8 @@ class GRPOTrajectory(NamedTuple): logprobs: torch.Tensor = None # [B x G, L] ref_logprobs: torch.Tensor = None # [B x G, L] advantages: torch.Tensor = None # [B x G] + rewards: torch.Tensor = None + successes: torch.Tensor = None masks: torch.Tensor = None # [B x G, P+L, P+L] position_ids: torch.Tensor = None # [B x G, P+L] response_padding_masks: torch.Tensor = None # [B x G, L]