|
23 | 23 | GRPOLossOutput, |
24 | 24 | MCAdvantage, |
25 | 25 | ) |
| 26 | +from torchrl._utils import logger |
26 | 27 | from torchrl.objectives.llm.sft import SFTLoss |
27 | 28 |
|
28 | 29 | _has_transformers = importlib.util.find_spec("transformers") is not None |
@@ -200,7 +201,7 @@ def test_grpo(self, mock_transformer_model, dapo): |
200 | 201 | ) |
201 | 202 |
|
202 | 203 | # Create loss module |
203 | | - loss_fn = GRPOLoss(actor_network, eps=eps) |
| 204 | + loss_fn = GRPOLoss(actor_network, clip_epsilon=eps) |
204 | 205 |
|
205 | 206 | # Create fake data |
206 | 207 | data = _mock_data_grpo(vocab_size=vocab_size, device=device) |
@@ -245,6 +246,124 @@ def test_grpo(self, mock_transformer_model, dapo): |
245 | 246 | 0 <= loss_vals.clip_fraction <= 1 |
246 | 247 | ), f"clip_fraction out of range: {loss_vals.clip_fraction}" |
247 | 248 |
|
| 249 | + def test_kl_mask_threshold(self, mock_transformer_model): |
| 250 | + """Test that kl_mask_threshold properly filters out high-KL tokens.""" |
| 251 | + torch.manual_seed(42) |
| 252 | + vocab_size = 1024 |
| 253 | + device = ( |
| 254 | + torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") |
| 255 | + ) |
| 256 | + |
| 257 | + # Create mock model and wrap it |
| 258 | + model = mock_transformer_model(vocab_size=vocab_size, device=device) |
| 259 | + actor_network = TransformersWrapper( |
| 260 | + model, |
| 261 | + generate=False, |
| 262 | + pad_output=True, |
| 263 | + input_mode="history", |
| 264 | + ) |
| 265 | + |
| 266 | + # Create fake data |
| 267 | + data = _mock_data_grpo(vocab_size=vocab_size, device=device) |
| 268 | + |
| 269 | + # First, test that the data works without any threshold |
| 270 | + loss_fn_baseline = GRPOLoss( |
| 271 | + actor_network, clip_epsilon=0.2, kl_mask_threshold=None |
| 272 | + ) |
| 273 | + |
| 274 | + data_baseline = data.clone() |
| 275 | + loss_baseline = loss_fn_baseline(data_baseline) |
| 276 | + logger.info(f"Baseline loss (no threshold): {loss_baseline.loss_objective}") |
| 277 | + logger.info(f"Baseline ESS: {loss_baseline.ESS}") |
| 278 | + |
| 279 | + # Check baseline is valid |
| 280 | + if not torch.isfinite(loss_baseline.loss_objective): |
| 281 | + raise ValueError( |
| 282 | + f"Baseline loss is not finite: {loss_baseline.loss_objective}, skipping test" |
| 283 | + ) |
| 284 | + |
| 285 | + # Now test with kl_mask_threshold enabled |
| 286 | + # Use a very high threshold that should not mask any tokens |
| 287 | + kl_threshold = 100.0 # Extremely high threshold to ensure no masking |
| 288 | + loss_fn_with_threshold = GRPOLoss( |
| 289 | + actor_network, clip_epsilon=0.2, kl_mask_threshold=kl_threshold |
| 290 | + ) |
| 291 | + |
| 292 | + data_with_threshold = data.clone() |
| 293 | + loss_with_threshold = loss_fn_with_threshold(data_with_threshold) |
| 294 | + |
| 295 | + # Should produce valid output |
| 296 | + assert isinstance(loss_with_threshold, GRPOLossOutput) |
| 297 | + |
| 298 | + # Check that the loss is finite (with such a high threshold, it should be) |
| 299 | + assert torch.isfinite( |
| 300 | + loss_with_threshold.loss_objective |
| 301 | + ), f"loss_with_threshold is not finite: {loss_with_threshold.loss_objective}" |
| 302 | + assert torch.isfinite( |
| 303 | + loss_with_threshold.ESS |
| 304 | + ), f"ESS with threshold is not finite: {loss_with_threshold.ESS}" |
| 305 | + |
| 306 | + logger.info( |
| 307 | + f"Loss with high threshold (100.0): {loss_with_threshold.loss_objective}" |
| 308 | + ) |
| 309 | + logger.info(f"ESS with high threshold: {loss_with_threshold.ESS}") |
| 310 | + |
| 311 | + # The losses should be identical or very similar since we're not masking anything |
| 312 | + # (the difference comes only from numerical precision) |
| 313 | + assert torch.isclose( |
| 314 | + loss_baseline.loss_objective, loss_with_threshold.loss_objective, rtol=1e-3 |
| 315 | + ), f"Losses differ too much with high threshold: {loss_baseline.loss_objective} vs {loss_with_threshold.loss_objective}" |
| 316 | + |
| 317 | + def test_failure_missing_entries(self, mock_transformer_model): |
| 318 | + """Test that GRPO fails when required keys are missing but works without optional keys.""" |
| 319 | + vocab_size = 1024 |
| 320 | + device = torch.device("cpu") |
| 321 | + |
| 322 | + # Create mock model and wrap it |
| 323 | + model = mock_transformer_model(vocab_size=vocab_size, device=device) |
| 324 | + actor_network = TransformersWrapper( |
| 325 | + model, |
| 326 | + generate=False, |
| 327 | + pad_output=True, |
| 328 | + input_mode="history", |
| 329 | + ) |
| 330 | + |
| 331 | + # Create loss module |
| 332 | + loss_fn = GRPOLoss(actor_network, clip_epsilon=0.2) |
| 333 | + |
| 334 | + # Create fake data |
| 335 | + data = _mock_data_grpo(vocab_size=vocab_size, device=device) |
| 336 | + |
| 337 | + # Test 1: Missing sample_log_prob (required) should fail |
| 338 | + data_missing_sample_log_prob = data.clone() |
| 339 | + data_missing_sample_log_prob.exclude(("log_probs", "full"), inplace=True) |
| 340 | + |
| 341 | + with pytest.raises(KeyError, match="Couldn't find the log-prob"): |
| 342 | + loss_fn(data_missing_sample_log_prob) |
| 343 | + |
| 344 | + # Test 2: Missing ref_log_probs (optional when kl_to_ref_coeff is None) should work |
| 345 | + data_missing_ref = data.clone() |
| 346 | + # Remove the ref_log_probs key if it exists |
| 347 | + if ("next", "ref_log_probs", "full") in data_missing_ref.keys(True): |
| 348 | + data_missing_ref.exclude(("next", "ref_log_probs", "full"), inplace=True) |
| 349 | + |
| 350 | + # Should work fine without ref_log_probs when kl_to_ref_coeff is None |
| 351 | + loss_vals = loss_fn(data_missing_ref) |
| 352 | + assert isinstance(loss_vals, GRPOLossOutput) |
| 353 | + assert torch.isfinite(loss_vals.loss_objective) |
| 354 | + |
| 355 | + # Test 3: Missing ref_log_probs when kl_to_ref_coeff is set should fail |
| 356 | + loss_fn_with_kl = GRPOLoss(actor_network, clip_epsilon=0.2, kl_to_ref_coeff=0.1) |
| 357 | + |
| 358 | + data_missing_ref_for_kl = data.clone() |
| 359 | + if ("next", "ref_log_probs", "full") in data_missing_ref_for_kl.keys(True): |
| 360 | + data_missing_ref_for_kl.exclude( |
| 361 | + ("next", "ref_log_probs", "full"), inplace=True |
| 362 | + ) |
| 363 | + |
| 364 | + with pytest.raises(KeyError, match="Couldn't find the ref log-prob"): |
| 365 | + loss_fn_with_kl(data_missing_ref_for_kl) |
| 366 | + |
248 | 367 | def test_cispo(self, mock_transformer_model): |
249 | 368 | """Test CISPO loss computation with mock models.""" |
250 | 369 | vocab_size = 1024 |
|
0 commit comments