Skip to content

Commit 97ca1a2

Browse files
authored
Fix bugs in CISPO conditions (#4499)
1 parent ffb3dd5 commit 97ca1a2

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

trl/trainer/grpo_trainer.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1824,11 +1824,10 @@ def _compute_loss(self, model, inputs):
18241824

18251825
# From here, log_importance_weights (and all subsequent tensors, coef_1, coef_2, etc.) shape depends on
18261826
# importance_sampling_level: "token" level: (B, T); "sequence" level: (B, 1)
1827-
if self.loss_type in ["grpo", "bnpo", "dr_grpo", "dapo"]:
1827+
if self.loss_type == "cispo":
18281828
clamped_ratios = torch.clamp(coef_1, max=self.epsilon_high).detach()
18291829
per_token_loss = -clamped_ratios * advantages.unsqueeze(1) * per_token_logps
1830-
1831-
else:
1830+
elif self.loss_type in ["grpo", "bnpo", "dr_grpo", "dapo"]:
18321831
coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high)
18331832
# Two-sided clipping
18341833
if self.args.delta is not None:
@@ -1837,6 +1836,8 @@ def _compute_loss(self, model, inputs):
18371836
per_token_loss1 = coef_1 * advantages.unsqueeze(1)
18381837
per_token_loss2 = coef_2 * advantages.unsqueeze(1)
18391838
per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
1839+
else:
1840+
raise ValueError(f"Unknown loss type: {self.loss_type}")
18401841

18411842
if entropy_mask is not None:
18421843
per_token_loss = per_token_loss * entropy_mask
@@ -1880,7 +1881,7 @@ def masked_batch_mean(x):
18801881
mean_entropy = masked_batch_mean(entropies)
18811882
self._metrics[mode]["entropy"].append(self.accelerator.gather(mean_entropy).nanmean().item())
18821883

1883-
if self.loss_type != "cispo":
1884+
if self.loss_type in ["grpo", "bnpo", "dr_grpo", "dapo"]:
18841885
# Compute the clipped probability ratios
18851886
is_low_clipped = (coef_1 < 1 - self.epsilon_low) & (advantages.unsqueeze(1) < 0)
18861887
is_high_clipped = (coef_1 > 1 + self.epsilon_high) & (advantages.unsqueeze(1) > 0)

0 commit comments

Comments
 (0)