Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 40 additions & 5 deletions recipes/dev/grpo_full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,9 +646,15 @@ def generate_trajectory(
# Do some reward modelingggggggg
# responses :: [B x G, L]
responses = responses.reshape(batch_size, grpo_size, -1) # [B, G, L]
rewards, successes = batched_rewards(self._tokenizer, responses, answers)
rewards = rewards.to(self._device) # [B, G]
successes = successes.to(self._device) # [B, G]
rewards, successes, metadata = batched_rewards(
self._tokenizer, responses, answers, device=self._device
)
rewards = rewards.to(self._device) # [B, G, num_reward_funcs]
successes = successes.to(self._device) # [B, G, num_reward_funcs]

# Aggregate rewards and successes across reward functions
rewards = rewards.sum(dim=-1) # [B, G]
successes = successes.sum(dim=-1) # [B, G]

advantages = (rewards - rewards.mean(1, keepdim=True)) / (
rewards.std(1, keepdim=True) + 1e-4
Expand All @@ -672,6 +678,7 @@ def generate_trajectory(
position_ids=position_ids,
response_padding_masks=response_padding_masks,
seq_lens=training.get_unmasked_sequence_lengths(response_padding_masks),
answers=answers,
)

def generate_trajectory_batched(
Expand Down Expand Up @@ -703,7 +710,22 @@ def generate_trajectory_batched(
self.generate_trajectory(batch_input_ids, batch_answers)
)
torch.cuda.empty_cache()
return GRPOTrajectory(*map(torch.cat, zip(*trajectories)))

# Concatenate all trajectory fields except answers (which is a list of strings)
concatenated_fields = {}
for field_name in trajectories[0]._fields:
if field_name == "answers":
# Concatenate lists of answers
concatenated_fields[field_name] = []
for traj in trajectories:
concatenated_fields[field_name].extend(traj.answers)
else:
# Concatenate tensors
concatenated_fields[field_name] = torch.cat(
[getattr(traj, field_name) for traj in trajectories]
)

return GRPOTrajectory(**concatenated_fields)

def grpo_step(
self,
Expand Down Expand Up @@ -771,6 +793,7 @@ def grpo_step(
ratios,
clipfrac,
approx_policy_kls,
None, # metadata
)

def train(self) -> None:
Expand Down Expand Up @@ -853,9 +876,21 @@ def train(self) -> None:
if grad_norm is not None:
extra_metrics["grad_norm"] = grad_norm

# Concatenate GRPOStats fields properly
concatenated_stats = {}
for field_name in grpo_stats[0]._fields:
if field_name == "metadata":
# Handle metadata separately (it's None, so just use None)
concatenated_stats[field_name] = None
else:
# Stack tensors
concatenated_stats[field_name] = torch.stack(
[getattr(stat, field_name) for stat in grpo_stats]
)

self.log_metrics(
trajectory,
GRPOStats(*map(torch.stack, zip(*grpo_stats))),
GRPOStats(**concatenated_stats),
**extra_metrics,
)

Expand Down
4 changes: 2 additions & 2 deletions torchtune/dev/rl/rewards.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,8 +298,8 @@ def batched_rewards(
for b in range(batch_size):

for g in range(grpo_size):

answer = answers[b][g]
# print(answers)
answer = answers[b]

text_completion = tokenizer.decode(completions[b, g].tolist())

Expand Down
4 changes: 4 additions & 0 deletions torchtune/dev/rl/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ class GRPOTrajectory(NamedTuple):
logprobs (torch.Tensor): Log probabilities of the generated responses with shape [B x G, L].
ref_logprobs (torch.Tensor): Log probabilities of the generated responses using the reference policy with shape [B x G, L].
advantages (torch.Tensor): Advantage estimates for the generated responses with shape [B x G].
rewards (torch.Tensor): Reward values for the generated responses with shape [B x G].
successes (torch.Tensor): Success indicators for the generated responses with shape [B x G].
masks (torch.Tensor): Attention masks for input ids-generated responses pairs with shape [B x G, P+L, P+L].
position_ids (torch.Tensor): Position IDs for input ids-generated responses pairs with shape [B x G, P+L].
response_padding_masks (torch.Tensor): Padding masks for the truncated and padded generated responses with shape [B x G, L].
Expand All @@ -30,6 +32,8 @@ class GRPOTrajectory(NamedTuple):
logprobs: torch.Tensor = None # [B x G, L]
ref_logprobs: torch.Tensor = None # [B x G, L]
advantages: torch.Tensor = None # [B x G]
rewards: torch.Tensor = None
successes: torch.Tensor = None
masks: torch.Tensor = None # [B x G, P+L, P+L]
position_ids: torch.Tensor = None # [B x G, P+L]
response_padding_masks: torch.Tensor = None # [B x G, L]
Expand Down