@@ -141,6 +141,20 @@ def __init__(self, cfg: DictConfig) -> None:
141141 self ._resume_from_checkpoint = cfg .resume_from_checkpoint
142142 self ._gradient_accumulation_steps = cfg .gradient_accumulation_steps
143143 self ._optimizer_in_bwd = cfg .optimizer_in_bwd
144+ self ._clip_grad_norm = cfg .get ("clip_grad_norm" , None )
145+
146+ # Optimizer in backward is not compatible with gradient accumulation or gradient clipping
147+ if self ._optimizer_in_bwd :
148+ if self ._clip_grad_norm is not None :
149+ raise RuntimeError (
150+ "Gradient clipping is not supported with optimizer in bwd."
151+ "Please set clip_grad_norm=None, or optimizer_in_bwd=False."
152+ )
153+ if self ._gradient_accumulation_steps > 1 :
154+ raise RuntimeError (
155+ "Gradient accumulation is not supported with optimizer in bwd."
156+ "Please set gradient_accumulation_steps=1, or optimizer_in_bwd=False."
157+ )
144158
145159 # activation checkpointing/offloading
146160 self ._enable_activation_checkpointing = cfg .get (
@@ -164,22 +178,13 @@ def __init__(self, cfg: DictConfig) -> None:
164178 "Enabling activation offloading should reduce memory further."
165179 )
166180
167- # TODO: find a better place / way to perform validation of args that don't yet
168- # compose with each other.
169- if self ._gradient_accumulation_steps > 1 and self ._optimizer_in_bwd :
170- raise RuntimeError (
171- "Gradient accumulation is not supported with optimizer in bwd."
172- "Please set gradient_accumulation_steps=1, or optimizer_in_bwd=False."
173- )
174-
175181 # These are public properties which are updated by the checkpoint loader
176182 # when ``resume_from_checkpoint`` is `True` or validated in tests
177183 self .seed = training .set_seed (seed = cfg .seed )
178184 self .epochs_run = 0
179185 self .total_epochs = cfg .epochs
180186 self .max_steps_per_epoch = cfg .max_steps_per_epoch
181187 self .global_step = 0
182- self ._clip_grad_norm = cfg .get ("clip_grad_norm" , None )
183188
184189 def load_checkpoint (self , cfg_checkpointer : DictConfig ) -> Dict [str , Any ]:
185190 """
@@ -692,13 +697,13 @@ def train(self) -> None:
692697
693698 # Step with optimizer
694699 if (idx + 1 ) % self ._gradient_accumulation_steps == 0 :
695- training .scale_grads (self ._model , 1 / num_tokens )
696- if self ._clip_grad_norm is not None :
697- grad_norm = torch .nn .utils .clip_grad_norm_ (
698- self ._model .parameters (),
699- max_norm = float (self ._clip_grad_norm ),
700- )
701700 if not self ._optimizer_in_bwd :
701+ training .scale_grads (self ._model , 1 / num_tokens )
702+ if self ._clip_grad_norm is not None :
703+ grad_norm = torch .nn .utils .clip_grad_norm_ (
704+ self ._model .parameters (),
705+ max_norm = float (self ._clip_grad_norm ),
706+ )
702707 self ._optimizer .step ()
703708 self ._optimizer .zero_grad (set_to_none = True )
704709
0 commit comments