@@ -1816,19 +1816,25 @@ def _compute_loss(self, model, inputs):
18161816 f"Unknown importance sampling level: { self .importance_sampling_level } . Possible values are 'token' "
18171817 "and 'sequence'."
18181818 )
1819+
1820+ coef_1 = torch .exp (log_importance_weights )
1821+
18191822 # From here, log_importance_weights (and all subsequent tensors, coef_1, coef_2, etc.) shape depends on
18201823 # importance_sampling_level: "token" level: (B, T); "sequence" level: (B, 1)
1824+ if self .loss_type in ["grpo" , "bnpo" , "dr_grpo" , "dapo" ]:
1825+ clamped_ratios = torch .clamp (coef_1 , max = self .epsilon_high ).detach ()
1826+ per_token_loss = - clamped_ratios * advantages .unsqueeze (1 ) * per_token_logps
18211827
1822- coef_1 = torch .exp (log_importance_weights )
1823- coef_2 = torch .clamp (coef_1 , 1 - self .epsilon_low , 1 + self .epsilon_high )
1828+ else :
1829+ coef_2 = torch .clamp (coef_1 , 1 - self .epsilon_low , 1 + self .epsilon_high )
1830+ # Two-sided clipping
1831+ if self .args .delta is not None :
1832+ coef_1 = torch .clamp (coef_1 , max = self .args .delta )
18241833
1825- # Two-sided clipping
1826- if self . args . delta is not None :
1827- coef_1 = torch .clamp ( coef_1 , max = self . args . delta )
1834+ per_token_loss1 = coef_1 * advantages . unsqueeze ( 1 )
1835+ per_token_loss2 = coef_2 * advantages . unsqueeze ( 1 )
1836+ per_token_loss = - torch .min ( per_token_loss1 , per_token_loss2 )
18281837
1829- per_token_loss1 = coef_1 * advantages .unsqueeze (1 )
1830- per_token_loss2 = coef_2 * advantages .unsqueeze (1 )
1831- per_token_loss = - torch .min (per_token_loss1 , per_token_loss2 )
18321838 if entropy_mask is not None :
18331839 per_token_loss = per_token_loss * entropy_mask
18341840
@@ -1847,7 +1853,7 @@ def _compute_loss(self, model, inputs):
18471853 elif self .loss_type == "dr_grpo" :
18481854 loss = (per_token_loss * completion_mask ).sum () / (per_token_loss .size (0 ) * self .max_completion_length )
18491855 loss = loss / self .current_gradient_accumulation_steps
1850- elif self .loss_type == " dapo" :
1856+ elif self .loss_type in [ "cispo" , " dapo"] :
18511857 normalizer = inputs ["num_items_in_batch" ] / self .accelerator .num_processes
18521858 loss = (per_token_loss * completion_mask ).sum () / normalizer
18531859 else :
@@ -1871,23 +1877,30 @@ def masked_batch_mean(x):
18711877 mean_entropy = masked_batch_mean (entropies )
18721878 self ._metrics [mode ]["entropy" ].append (self .accelerator .gather (mean_entropy ).nanmean ().item ())
18731879
1874- # Compute the clipped probability ratios
1875- is_low_clipped = (coef_1 < 1 - self .epsilon_low ) & (advantages .unsqueeze (1 ) < 0 )
1876- is_high_clipped = (coef_1 > 1 + self .epsilon_high ) & (advantages .unsqueeze (1 ) > 0 )
1877- is_region_clipped = is_low_clipped | is_high_clipped
1878-
1879- low_clip = masked_batch_mean (is_low_clipped .float ())
1880- high_clip = masked_batch_mean (is_high_clipped .float ())
1881- clip_ratio = masked_batch_mean (is_region_clipped .float ())
1882-
1883- gathered_low_clip = self .accelerator .gather (low_clip )
1884- self ._metrics [mode ]["clip_ratio/low_mean" ].append (gathered_low_clip .nanmean ().item ())
1885- self ._metrics [mode ]["clip_ratio/low_min" ].append (nanmin (gathered_low_clip ).item ())
1886- gathered_high_clip = self .accelerator .gather (high_clip )
1887- self ._metrics [mode ]["clip_ratio/high_mean" ].append (gathered_high_clip .nanmean ().item ())
1888- self ._metrics [mode ]["clip_ratio/high_max" ].append (nanmax (gathered_high_clip ).item ())
1889- gathered_clip_ratio = self .accelerator .gather (clip_ratio )
1890- self ._metrics [mode ]["clip_ratio/region_mean" ].append (gathered_clip_ratio .nanmean ().item ())
1880+ if self .loss_type != "cispo" :
1881+ # Compute the clipped probability ratios
1882+ is_low_clipped = (coef_1 < 1 - self .epsilon_low ) & (advantages .unsqueeze (1 ) < 0 )
1883+ is_high_clipped = (coef_1 > 1 + self .epsilon_high ) & (advantages .unsqueeze (1 ) > 0 )
1884+ is_region_clipped = is_low_clipped | is_high_clipped
1885+
1886+ low_clip = masked_batch_mean (is_low_clipped .float ())
1887+ high_clip = masked_batch_mean (is_high_clipped .float ())
1888+ clip_ratio = masked_batch_mean (is_region_clipped .float ())
1889+
1890+ gathered_low_clip = self .accelerator .gather (low_clip )
1891+ self ._metrics [mode ]["clip_ratio/low_mean" ].append (gathered_low_clip .nanmean ().item ())
1892+ self ._metrics [mode ]["clip_ratio/low_min" ].append (nanmin (gathered_low_clip ).item ())
1893+ gathered_high_clip = self .accelerator .gather (high_clip )
1894+ self ._metrics [mode ]["clip_ratio/high_mean" ].append (gathered_high_clip .nanmean ().item ())
1895+ self ._metrics [mode ]["clip_ratio/high_max" ].append (nanmax (gathered_high_clip ).item ())
1896+ gathered_clip_ratio = self .accelerator .gather (clip_ratio )
1897+ self ._metrics [mode ]["clip_ratio/region_mean" ].append (gathered_clip_ratio .nanmean ().item ())
1898+ elif self .loss_type == "cispo" :
1899+ is_cispo_clipped = (coef_1 > self .epsilon_high ) & (advantages .unsqueeze (1 ) > 0 )
1900+ cispo_clip_ratio = masked_batch_mean (is_cispo_clipped .float ())
1901+ gathered_cispo_clip_ratio = self .accelerator .gather (cispo_clip_ratio )
1902+ self ._metrics [mode ]["cispo_clip_ratio" ].append (gathered_cispo_clip_ratio .nanmean ().item ())
1903+
18911904 return loss
18921905
18931906 def prediction_step (self , model , inputs , prediction_loss_only , ignore_keys : list [str ] | None = None ):
0 commit comments