Skip to content

Commit 20e65b6

Browse files
committed
Update
[ghstack-poisoned]
1 parent 5bd198c commit 20e65b6

File tree

1 file changed

+122
-38
lines changed

1 file changed

+122
-38
lines changed

torchrl/objectives/llm/grpo.py

Lines changed: 122 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# LICENSE file in the root directory of this source tree.
55
from __future__ import annotations
66

7+
import contextlib
8+
79
from collections import defaultdict, deque
810
from dataclasses import dataclass
911
from typing import Literal
@@ -15,19 +17,19 @@
1517
TensorClass,
1618
TensorDict,
1719
TensorDictBase,
18-
TensorDictParams,
1920
)
2021
from tensordict.nn import (
22+
CompositeDistribution,
2123
ProbabilisticTensorDictModule,
2224
ProbabilisticTensorDictSequential,
23-
TensorDictModule,
25+
set_composite_lp_aggregate,
2426
)
2527
from tensordict.utils import expand_as_right
2628
from torch import distributions as d
27-
from torchrl._utils import logger as torchrl_logger
29+
from torchrl._utils import logger as torchrl_logger, VERBOSE
2830
from torchrl.envs.transforms.transforms import Transform
2931
from torchrl.modules.llm import LLMWrapperBase
30-
from torchrl.objectives.ppo import ClipPPOLoss
32+
from torchrl.objectives.common import LossModule
3133
from torchrl.objectives.utils import _reduce, _sum_td_features
3234

3335

@@ -46,7 +48,7 @@ class GRPOLossOutput(TensorClass["nocast"]):
4648
kl_to_inference: torch.Tensor | None = None
4749

4850

49-
class GRPOLoss(ClipPPOLoss):
51+
class GRPOLoss(LossModule):
5052
"""GRPO loss.
5153
5254
The clipped importance weighted loss is computed as follows:
@@ -116,20 +118,18 @@ class GRPOLoss(ClipPPOLoss):
116118
"""
117119

118120
actor_network: LLMWrapperBase
119-
critic_network: TensorDictModule
120-
actor_network_params: TensorDictParams
121-
critic_network_params: TensorDictParams
122-
target_actor_network_params: TensorDictParams
123-
target_critic_network_params: TensorDictParams
124121

125122
@dataclass
126-
class _AcceptedKeys(ClipPPOLoss._AcceptedKeys):
123+
class _AcceptedKeys(LossModule._AcceptedKeys):
127124
"""Maintains default values for all configurable tensordict keys.
128125
129126
This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their
130127
default values
131128
"""
132129

130+
advantage: NestedKey = "advantage"
131+
action: NestedKey = ("tokens", "full")
132+
sample_log_prob: NestedKey = ("log_probs", "full")
133133
ref_log_probs: NestedKey = ("next", "ref_log_probs", "full")
134134

135135
def __init__(
@@ -149,32 +149,85 @@ def __init__(
149149
masking_strategy: Literal["sft", "rlhf", "generic"] = "sft",
150150
**kwargs,
151151
):
152-
# Define clipping of the value loss
153-
if isinstance(clip_value, bool):
154-
clip_value = clip_epsilon if clip_value else None
155-
156-
super().__init__(
157-
actor_network,
158-
critic_network=None,
159-
entropy_bonus=entropy_bonus,
160-
samples_mc_entropy=samples_mc_entropy,
161-
entropy_coeff=entropy_coeff,
162-
gamma=gamma,
163-
separate_losses=False,
164-
reduction=reduction,
165-
clip_value=clip_value,
166-
functional=False,
167-
device=device,
168-
**kwargs,
169-
)
170-
# We don't want to use the string action but the tokens
171-
self._set_in_keys()
152+
super().__init__()
153+
# Core modules and hyper-parameters
154+
self.actor_network = actor_network
155+
self.entropy_bonus = entropy_bonus
156+
self.samples_mc_entropy = samples_mc_entropy
157+
self.entropy_coeff = entropy_coeff
158+
self.reduction = reduction
159+
160+
# Determine device and register clip epsilon as buffer
161+
if device is None:
162+
try:
163+
device = next(self.parameters()).device
164+
except (AttributeError, StopIteration):
165+
device = getattr(
166+
torch, "get_default_device", lambda: torch.device("cpu")
167+
)()
168+
self.register_buffer("clip_epsilon", torch.tensor(clip_epsilon, device=device))
169+
172170
self.masking_strategy = masking_strategy
173-
# Always use the full tokens for the action
171+
# Defaults for keys
174172
self.set_keys(sample_log_prob=("log_probs", "full"), action=("tokens", "full"))
175-
# TODO: make this a buffer
173+
# KL coefficients
176174
self.kl_to_ref_coeff = kl_to_ref_coeff
177175
self.kl_to_inference_coeff = kl_to_inference_coeff
176+
# Prepare IO keys
177+
self._set_in_keys()
178+
179+
@property
180+
def _clip_bounds(self):
181+
return ((-self.clip_epsilon).log1p(), self.clip_epsilon.log1p())
182+
183+
def _set_in_keys(self):
184+
keys = []
185+
if getattr(self, "actor_network", None) is not None and hasattr(
186+
self.actor_network, "in_keys"
187+
):
188+
in_keys = self.actor_network.in_keys
189+
if isinstance(in_keys, (list, tuple)):
190+
keys.extend(in_keys)
191+
keys.append(self.tensor_keys.action)
192+
keys.append(self.tensor_keys.sample_log_prob)
193+
keys.append(self.tensor_keys.advantage)
194+
keys.append(self.tensor_keys.ref_log_probs)
195+
self._in_keys = list(dict.fromkeys(keys))
196+
197+
@property
198+
def in_keys(self):
199+
if getattr(self, "_in_keys", None) is None:
200+
self._set_in_keys()
201+
return self._in_keys
202+
203+
@in_keys.setter
204+
def in_keys(self, values):
205+
self._in_keys = values
206+
207+
@property
208+
def out_keys(self):
209+
if getattr(self, "_out_keys", None) is None:
210+
keys = ["loss_objective", "clip_fraction", "ESS", "kl_approx"]
211+
if self.entropy_bonus:
212+
keys.extend(["entropy", "loss_entropy"])
213+
keys.extend(
214+
[
215+
"loss_kl_to_ref",
216+
"kl_to_ref",
217+
"loss_kl_to_inference",
218+
"kl_to_inference",
219+
]
220+
)
221+
self._out_keys = keys
222+
return self._out_keys
223+
224+
@out_keys.setter
225+
def out_keys(self, values):
226+
self._out_keys = values
227+
228+
def _forward_value_estimator_keys(self, **kwargs) -> None:
229+
# No value estimator in GRPO; simply refresh input keys
230+
self._set_in_keys()
178231

179232
def _get_cur_log_prob(self, tensordict):
180233
"""Override to use LLM-specific distribution with explicit masking strategy.
@@ -281,11 +334,6 @@ def forward(self, tensordict: TensorDictBase) -> GRPOLossOutput:
281334
entropy = _sum_td_features(entropy)
282335
td_out.set("entropy", entropy.detach().mean()) # for logging
283336
td_out.set("loss_entropy", -self.entropy_coeff * entropy)
284-
if self._has_critic:
285-
loss_critic, value_clip_fraction = self.loss_critic(tensordict)
286-
td_out.set("loss_critic", loss_critic)
287-
if value_clip_fraction is not None:
288-
td_out.set("value_clip_fraction", value_clip_fraction)
289337

290338
td_out.set("ESS", _reduce(ess / batch, self.reduction))
291339
td_out = td_out.named_apply(
@@ -323,6 +371,42 @@ def forward(self, tensordict: TensorDictBase) -> GRPOLossOutput:
323371
del tensordict["_cur_log_prob"]
324372
return GRPOLossOutput.from_tensordict(td_out)
325373

374+
def _get_entropy(
375+
self, dist: d.Distribution, adv_shape: torch.Size
376+
) -> torch.Tensor | TensorDict:
377+
try:
378+
entropy = dist.entropy()
379+
if not entropy.isfinite().all():
380+
del entropy
381+
if VERBOSE:
382+
torchrl_logger.info(
383+
"Entropy is not finite. Using Monte Carlo sampling."
384+
)
385+
raise NotImplementedError
386+
except NotImplementedError:
387+
if VERBOSE:
388+
torchrl_logger.warning(
389+
f"Entropy not implemented for {type(dist)} or is not finite. Using Monte Carlo sampling."
390+
)
391+
if getattr(dist, "has_rsample", False):
392+
x = dist.rsample((self.samples_mc_entropy,))
393+
else:
394+
x = dist.sample((self.samples_mc_entropy,))
395+
with set_composite_lp_aggregate(False) if isinstance(
396+
dist, CompositeDistribution
397+
) else contextlib.nullcontext():
398+
log_prob = dist.log_prob(x)
399+
if is_tensor_collection(log_prob):
400+
if isinstance(self.tensor_keys.sample_log_prob, NestedKey):
401+
log_prob = log_prob.get(self.tensor_keys.sample_log_prob)
402+
else:
403+
log_prob = log_prob.select(*self.tensor_keys.sample_log_prob)
404+
405+
entropy = -log_prob.mean(0)
406+
if is_tensor_collection(entropy) and entropy.batch_size != adv_shape:
407+
entropy.batch_size = adv_shape
408+
return entropy.unsqueeze(-1)
409+
326410
def _kl_to_ref(
327411
self,
328412
tensordict: TensorDictBase,

0 commit comments

Comments
 (0)