@@ -1824,11 +1824,10 @@ def _compute_loss(self, model, inputs):
18241824
18251825 # From here, log_importance_weights (and all subsequent tensors, coef_1, coef_2, etc.) shape depends on
18261826 # importance_sampling_level: "token" level: (B, T); "sequence" level: (B, 1)
1827- if self .loss_type in [ "grpo" , "bnpo" , "dr_grpo" , "dapo" ] :
1827+ if self .loss_type == "cispo" :
18281828 clamped_ratios = torch .clamp (coef_1 , max = self .epsilon_high ).detach ()
18291829 per_token_loss = - clamped_ratios * advantages .unsqueeze (1 ) * per_token_logps
1830-
1831- else :
1830+ elif self .loss_type in ["grpo" , "bnpo" , "dr_grpo" , "dapo" ]:
18321831 coef_2 = torch .clamp (coef_1 , 1 - self .epsilon_low , 1 + self .epsilon_high )
18331832 # Two-sided clipping
18341833 if self .args .delta is not None :
@@ -1837,6 +1836,8 @@ def _compute_loss(self, model, inputs):
18371836 per_token_loss1 = coef_1 * advantages .unsqueeze (1 )
18381837 per_token_loss2 = coef_2 * advantages .unsqueeze (1 )
18391838 per_token_loss = - torch .min (per_token_loss1 , per_token_loss2 )
1839+ else :
1840+ raise ValueError (f"Unknown loss type: { self .loss_type } " )
18401841
18411842 if entropy_mask is not None :
18421843 per_token_loss = per_token_loss * entropy_mask
@@ -1880,7 +1881,7 @@ def masked_batch_mean(x):
18801881 mean_entropy = masked_batch_mean (entropies )
18811882 self ._metrics [mode ]["entropy" ].append (self .accelerator .gather (mean_entropy ).nanmean ().item ())
18821883
1883- if self .loss_type != "cispo" :
1884+ if self .loss_type in [ "grpo" , "bnpo" , "dr_grpo" , "dapo" ] :
18841885 # Compute the clipped probability ratios
18851886 is_low_clipped = (coef_1 < 1 - self .epsilon_low ) & (advantages .unsqueeze (1 ) < 0 )
18861887 is_high_clipped = (coef_1 > 1 + self .epsilon_high ) & (advantages .unsqueeze (1 ) > 0 )
0 commit comments