Skip to content

Commit 6181945

Browse files
committed
Update
[ghstack-poisoned]
2 parents d747f8e + 82511a1 commit 6181945

File tree

3 files changed

+134
-13
lines changed

3 files changed

+134
-13
lines changed

docs/source/reference/llms.rst

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1158,20 +1158,22 @@ Objectives
11581158

11591159
LLM post-training requires specialized loss functions that are adapted to the unique characteristics of language models.
11601160

1161-
GRPO
1162-
~~~~
1163-
1164-
The :class:`~torchrl.objectives.llm.GRPOLoss` class is a thin wrapper around the :class:`~torchrl.objectives.PPOLoss` class
1165-
that codes the LLM-specific functionalities.
1161+
GRPO, DAPO, CISPO
1162+
^^^^^^^^^^^^^^^^^
11661163

11671164
.. currentmodule:: torchrl.objectives.llm
11681165

11691166
.. autosummary::
11701167
:toctree: generated/
11711168
:template: rl_template.rst
11721169

1170+
LLMLossOutput
11731171
GRPOLoss
11741172
GRPOLossOutput
1173+
CISPOLoss
1174+
CISPOLossOutput
1175+
DAPO
1176+
DAPOLossOutput
11751177
MCAdvantage
11761178

11771179
SFT

test/llm/test_objectives.py

Lines changed: 120 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
GRPOLossOutput,
2424
MCAdvantage,
2525
)
26+
from torchrl._utils import logger
2627
from torchrl.objectives.llm.sft import SFTLoss
2728

2829
_has_transformers = importlib.util.find_spec("transformers") is not None
@@ -200,7 +201,7 @@ def test_grpo(self, mock_transformer_model, dapo):
200201
)
201202

202203
# Create loss module
203-
loss_fn = GRPOLoss(actor_network, eps=eps)
204+
loss_fn = GRPOLoss(actor_network, clip_epsilon=eps)
204205

205206
# Create fake data
206207
data = _mock_data_grpo(vocab_size=vocab_size, device=device)
@@ -245,6 +246,124 @@ def test_grpo(self, mock_transformer_model, dapo):
245246
0 <= loss_vals.clip_fraction <= 1
246247
), f"clip_fraction out of range: {loss_vals.clip_fraction}"
247248

249+
def test_kl_mask_threshold(self, mock_transformer_model):
250+
"""Test that kl_mask_threshold properly filters out high-KL tokens."""
251+
torch.manual_seed(42)
252+
vocab_size = 1024
253+
device = (
254+
torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
255+
)
256+
257+
# Create mock model and wrap it
258+
model = mock_transformer_model(vocab_size=vocab_size, device=device)
259+
actor_network = TransformersWrapper(
260+
model,
261+
generate=False,
262+
pad_output=True,
263+
input_mode="history",
264+
)
265+
266+
# Create fake data
267+
data = _mock_data_grpo(vocab_size=vocab_size, device=device)
268+
269+
# First, test that the data works without any threshold
270+
loss_fn_baseline = GRPOLoss(
271+
actor_network, clip_epsilon=0.2, kl_mask_threshold=None
272+
)
273+
274+
data_baseline = data.clone()
275+
loss_baseline = loss_fn_baseline(data_baseline)
276+
logger.info(f"Baseline loss (no threshold): {loss_baseline.loss_objective}")
277+
logger.info(f"Baseline ESS: {loss_baseline.ESS}")
278+
279+
# Check baseline is valid
280+
if not torch.isfinite(loss_baseline.loss_objective):
281+
raise ValueError(
282+
f"Baseline loss is not finite: {loss_baseline.loss_objective}, skipping test"
283+
)
284+
285+
# Now test with kl_mask_threshold enabled
286+
# Use a very high threshold that should not mask any tokens
287+
kl_threshold = 100.0 # Extremely high threshold to ensure no masking
288+
loss_fn_with_threshold = GRPOLoss(
289+
actor_network, clip_epsilon=0.2, kl_mask_threshold=kl_threshold
290+
)
291+
292+
data_with_threshold = data.clone()
293+
loss_with_threshold = loss_fn_with_threshold(data_with_threshold)
294+
295+
# Should produce valid output
296+
assert isinstance(loss_with_threshold, GRPOLossOutput)
297+
298+
# Check that the loss is finite (with such a high threshold, it should be)
299+
assert torch.isfinite(
300+
loss_with_threshold.loss_objective
301+
), f"loss_with_threshold is not finite: {loss_with_threshold.loss_objective}"
302+
assert torch.isfinite(
303+
loss_with_threshold.ESS
304+
), f"ESS with threshold is not finite: {loss_with_threshold.ESS}"
305+
306+
logger.info(
307+
f"Loss with high threshold (100.0): {loss_with_threshold.loss_objective}"
308+
)
309+
logger.info(f"ESS with high threshold: {loss_with_threshold.ESS}")
310+
311+
# The losses should be identical or very similar since we're not masking anything
312+
# (the difference comes only from numerical precision)
313+
assert torch.isclose(
314+
loss_baseline.loss_objective, loss_with_threshold.loss_objective, rtol=1e-3
315+
), f"Losses differ too much with high threshold: {loss_baseline.loss_objective} vs {loss_with_threshold.loss_objective}"
316+
317+
def test_failure_missing_entries(self, mock_transformer_model):
318+
"""Test that GRPO fails when required keys are missing but works without optional keys."""
319+
vocab_size = 1024
320+
device = torch.device("cpu")
321+
322+
# Create mock model and wrap it
323+
model = mock_transformer_model(vocab_size=vocab_size, device=device)
324+
actor_network = TransformersWrapper(
325+
model,
326+
generate=False,
327+
pad_output=True,
328+
input_mode="history",
329+
)
330+
331+
# Create loss module
332+
loss_fn = GRPOLoss(actor_network, clip_epsilon=0.2)
333+
334+
# Create fake data
335+
data = _mock_data_grpo(vocab_size=vocab_size, device=device)
336+
337+
# Test 1: Missing sample_log_prob (required) should fail
338+
data_missing_sample_log_prob = data.clone()
339+
data_missing_sample_log_prob.exclude(("log_probs", "full"), inplace=True)
340+
341+
with pytest.raises(KeyError, match="Couldn't find the log-prob"):
342+
loss_fn(data_missing_sample_log_prob)
343+
344+
# Test 2: Missing ref_log_probs (optional when kl_to_ref_coeff is None) should work
345+
data_missing_ref = data.clone()
346+
# Remove the ref_log_probs key if it exists
347+
if ("next", "ref_log_probs", "full") in data_missing_ref.keys(True):
348+
data_missing_ref.exclude(("next", "ref_log_probs", "full"), inplace=True)
349+
350+
# Should work fine without ref_log_probs when kl_to_ref_coeff is None
351+
loss_vals = loss_fn(data_missing_ref)
352+
assert isinstance(loss_vals, GRPOLossOutput)
353+
assert torch.isfinite(loss_vals.loss_objective)
354+
355+
# Test 3: Missing ref_log_probs when kl_to_ref_coeff is set should fail
356+
loss_fn_with_kl = GRPOLoss(actor_network, clip_epsilon=0.2, kl_to_ref_coeff=0.1)
357+
358+
data_missing_ref_for_kl = data.clone()
359+
if ("next", "ref_log_probs", "full") in data_missing_ref_for_kl.keys(True):
360+
data_missing_ref_for_kl.exclude(
361+
("next", "ref_log_probs", "full"), inplace=True
362+
)
363+
364+
with pytest.raises(KeyError, match="Couldn't find the ref log-prob"):
365+
loss_fn_with_kl(data_missing_ref_for_kl)
366+
248367
def test_cispo(self, mock_transformer_model):
249368
"""Test CISPO loss computation with mock models."""
250369
vocab_size = 1024

torchrl/objectives/llm/grpo.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -398,24 +398,24 @@ def forward(self, tensordict: TensorDictBase) -> LLMOutputType:
398398
# Optional per-token trust-region filtering (KL-Mask) vs reference policy
399399
if self.kl_mask_threshold is not None and self.kl_mask_threshold > 0:
400400
try:
401-
ref_log_prob = tensordict.get(
402-
self.tensor_keys.ref_log_probs,
401+
inference_log_prob = tensordict.get(
402+
self.tensor_keys.sample_log_prob,
403403
as_padded_tensor=True,
404404
padding_side="left",
405405
padding_value=0.0,
406406
)
407407
except KeyError:
408-
ref_log_prob = None
408+
inference_log_prob = None
409409
cur_log_prob = tensordict.get("_cur_log_prob", None)
410-
if (ref_log_prob is not None) and (cur_log_prob is not None):
410+
if (inference_log_prob is not None) and (cur_log_prob is not None):
411411
# Align to valid tokens only (safety)
412412
cur_log_prob_masked = torch.where(
413413
expand_as_right(mask, cur_log_prob), cur_log_prob, 0.0
414414
)
415-
ref_log_prob_masked = torch.where(
416-
expand_as_right(mask, ref_log_prob), ref_log_prob, 0.0
415+
inference_log_prob_masked = torch.where(
416+
expand_as_right(mask, inference_log_prob), inference_log_prob, 0.0
417417
)
418-
log_is_ref = cur_log_prob_masked - ref_log_prob_masked
418+
log_is_ref = cur_log_prob_masked - inference_log_prob_masked
419419
kl_token = 0.5 * (log_is_ref**2)
420420
tr_mask = kl_token <= self.kl_mask_threshold
421421
# Combine with attention mask

0 commit comments

Comments
 (0)