44# LICENSE file in the root directory of this source tree.
55from __future__ import annotations
66
7+ import contextlib
8+
79from collections import defaultdict , deque
810from dataclasses import dataclass
911from typing import Literal
1517 TensorClass ,
1618 TensorDict ,
1719 TensorDictBase ,
18- TensorDictParams ,
1920)
2021from tensordict .nn import (
22+ CompositeDistribution ,
2123 ProbabilisticTensorDictModule ,
2224 ProbabilisticTensorDictSequential ,
23- TensorDictModule ,
25+ set_composite_lp_aggregate ,
2426)
2527from tensordict .utils import expand_as_right
2628from torch import distributions as d
27- from torchrl ._utils import logger as torchrl_logger
29+ from torchrl ._utils import logger as torchrl_logger , VERBOSE
2830from torchrl .envs .transforms .transforms import Transform
2931from torchrl .modules .llm import LLMWrapperBase
30- from torchrl .objectives .ppo import ClipPPOLoss
32+ from torchrl .objectives .common import LossModule
3133from torchrl .objectives .utils import _reduce , _sum_td_features
3234
3335
@@ -46,7 +48,7 @@ class GRPOLossOutput(TensorClass["nocast"]):
4648 kl_to_inference : torch .Tensor | None = None
4749
4850
49- class GRPOLoss (ClipPPOLoss ):
51+ class GRPOLoss (LossModule ):
5052 """GRPO loss.
5153
5254 The clipped importance weighted loss is computed as follows:
@@ -116,20 +118,18 @@ class GRPOLoss(ClipPPOLoss):
116118 """
117119
118120 actor_network : LLMWrapperBase
119- critic_network : TensorDictModule
120- actor_network_params : TensorDictParams
121- critic_network_params : TensorDictParams
122- target_actor_network_params : TensorDictParams
123- target_critic_network_params : TensorDictParams
124121
125122 @dataclass
126- class _AcceptedKeys (ClipPPOLoss ._AcceptedKeys ):
123+ class _AcceptedKeys (LossModule ._AcceptedKeys ):
127124 """Maintains default values for all configurable tensordict keys.
128125
129126 This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their
130127 default values
131128 """
132129
130+ advantage : NestedKey = "advantage"
131+ action : NestedKey = ("tokens" , "full" )
132+ sample_log_prob : NestedKey = ("log_probs" , "full" )
133133 ref_log_probs : NestedKey = ("next" , "ref_log_probs" , "full" )
134134
135135 def __init__ (
@@ -149,32 +149,85 @@ def __init__(
149149 masking_strategy : Literal ["sft" , "rlhf" , "generic" ] = "sft" ,
150150 ** kwargs ,
151151 ):
152- # Define clipping of the value loss
153- if isinstance (clip_value , bool ):
154- clip_value = clip_epsilon if clip_value else None
155-
156- super ().__init__ (
157- actor_network ,
158- critic_network = None ,
159- entropy_bonus = entropy_bonus ,
160- samples_mc_entropy = samples_mc_entropy ,
161- entropy_coeff = entropy_coeff ,
162- gamma = gamma ,
163- separate_losses = False ,
164- reduction = reduction ,
165- clip_value = clip_value ,
166- functional = False ,
167- device = device ,
168- ** kwargs ,
169- )
170- # We don't want to use the string action but the tokens
171- self ._set_in_keys ()
152+ super ().__init__ ()
153+ # Core modules and hyper-parameters
154+ self .actor_network = actor_network
155+ self .entropy_bonus = entropy_bonus
156+ self .samples_mc_entropy = samples_mc_entropy
157+ self .entropy_coeff = entropy_coeff
158+ self .reduction = reduction
159+
160+ # Determine device and register clip epsilon as buffer
161+ if device is None :
162+ try :
163+ device = next (self .parameters ()).device
164+ except (AttributeError , StopIteration ):
165+ device = getattr (
166+ torch , "get_default_device" , lambda : torch .device ("cpu" )
167+ )()
168+ self .register_buffer ("clip_epsilon" , torch .tensor (clip_epsilon , device = device ))
169+
172170 self .masking_strategy = masking_strategy
173- # Always use the full tokens for the action
171+ # Defaults for keys
174172 self .set_keys (sample_log_prob = ("log_probs" , "full" ), action = ("tokens" , "full" ))
175- # TODO: make this a buffer
173+ # KL coefficients
176174 self .kl_to_ref_coeff = kl_to_ref_coeff
177175 self .kl_to_inference_coeff = kl_to_inference_coeff
176+ # Prepare IO keys
177+ self ._set_in_keys ()
178+
179+ @property
180+ def _clip_bounds (self ):
181+ return ((- self .clip_epsilon ).log1p (), self .clip_epsilon .log1p ())
182+
183+ def _set_in_keys (self ):
184+ keys = []
185+ if getattr (self , "actor_network" , None ) is not None and hasattr (
186+ self .actor_network , "in_keys"
187+ ):
188+ in_keys = self .actor_network .in_keys
189+ if isinstance (in_keys , (list , tuple )):
190+ keys .extend (in_keys )
191+ keys .append (self .tensor_keys .action )
192+ keys .append (self .tensor_keys .sample_log_prob )
193+ keys .append (self .tensor_keys .advantage )
194+ keys .append (self .tensor_keys .ref_log_probs )
195+ self ._in_keys = list (dict .fromkeys (keys ))
196+
197+ @property
198+ def in_keys (self ):
199+ if getattr (self , "_in_keys" , None ) is None :
200+ self ._set_in_keys ()
201+ return self ._in_keys
202+
203+ @in_keys .setter
204+ def in_keys (self , values ):
205+ self ._in_keys = values
206+
207+ @property
208+ def out_keys (self ):
209+ if getattr (self , "_out_keys" , None ) is None :
210+ keys = ["loss_objective" , "clip_fraction" , "ESS" , "kl_approx" ]
211+ if self .entropy_bonus :
212+ keys .extend (["entropy" , "loss_entropy" ])
213+ keys .extend (
214+ [
215+ "loss_kl_to_ref" ,
216+ "kl_to_ref" ,
217+ "loss_kl_to_inference" ,
218+ "kl_to_inference" ,
219+ ]
220+ )
221+ self ._out_keys = keys
222+ return self ._out_keys
223+
224+ @out_keys .setter
225+ def out_keys (self , values ):
226+ self ._out_keys = values
227+
228+ def _forward_value_estimator_keys (self , ** kwargs ) -> None :
229+ # No value estimator in GRPO; simply refresh input keys
230+ self ._set_in_keys ()
178231
179232 def _get_cur_log_prob (self , tensordict ):
180233 """Override to use LLM-specific distribution with explicit masking strategy.
@@ -281,11 +334,6 @@ def forward(self, tensordict: TensorDictBase) -> GRPOLossOutput:
281334 entropy = _sum_td_features (entropy )
282335 td_out .set ("entropy" , entropy .detach ().mean ()) # for logging
283336 td_out .set ("loss_entropy" , - self .entropy_coeff * entropy )
284- if self ._has_critic :
285- loss_critic , value_clip_fraction = self .loss_critic (tensordict )
286- td_out .set ("loss_critic" , loss_critic )
287- if value_clip_fraction is not None :
288- td_out .set ("value_clip_fraction" , value_clip_fraction )
289337
290338 td_out .set ("ESS" , _reduce (ess / batch , self .reduction ))
291339 td_out = td_out .named_apply (
@@ -323,6 +371,42 @@ def forward(self, tensordict: TensorDictBase) -> GRPOLossOutput:
323371 del tensordict ["_cur_log_prob" ]
324372 return GRPOLossOutput .from_tensordict (td_out )
325373
374+ def _get_entropy (
375+ self , dist : d .Distribution , adv_shape : torch .Size
376+ ) -> torch .Tensor | TensorDict :
377+ try :
378+ entropy = dist .entropy ()
379+ if not entropy .isfinite ().all ():
380+ del entropy
381+ if VERBOSE :
382+ torchrl_logger .info (
383+ "Entropy is not finite. Using Monte Carlo sampling."
384+ )
385+ raise NotImplementedError
386+ except NotImplementedError :
387+ if VERBOSE :
388+ torchrl_logger .warning (
389+ f"Entropy not implemented for { type (dist )} or is not finite. Using Monte Carlo sampling."
390+ )
391+ if getattr (dist , "has_rsample" , False ):
392+ x = dist .rsample ((self .samples_mc_entropy ,))
393+ else :
394+ x = dist .sample ((self .samples_mc_entropy ,))
395+ with set_composite_lp_aggregate (False ) if isinstance (
396+ dist , CompositeDistribution
397+ ) else contextlib .nullcontext ():
398+ log_prob = dist .log_prob (x )
399+ if is_tensor_collection (log_prob ):
400+ if isinstance (self .tensor_keys .sample_log_prob , NestedKey ):
401+ log_prob = log_prob .get (self .tensor_keys .sample_log_prob )
402+ else :
403+ log_prob = log_prob .select (* self .tensor_keys .sample_log_prob )
404+
405+ entropy = - log_prob .mean (0 )
406+ if is_tensor_collection (entropy ) and entropy .batch_size != adv_shape :
407+ entropy .batch_size = adv_shape
408+ return entropy .unsqueeze (- 1 )
409+
326410 def _kl_to_ref (
327411 self ,
328412 tensordict : TensorDictBase ,
0 commit comments