Skip to content

Commit 408e521

Browse files
committed
address comments, some optim in bwd cleanup
1 parent 83cba27 commit 408e521

File tree

3 files changed

+45
-33
lines changed

3 files changed

+45
-33
lines changed

recipes/full_finetune_distributed.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -151,12 +151,20 @@ def __init__(self, cfg: DictConfig) -> None:
151151
self._resume_from_checkpoint = cfg.resume_from_checkpoint
152152
self._gradient_accumulation_steps = cfg.gradient_accumulation_steps
153153
self._optimizer_in_bwd = cfg.get("optimizer_in_bwd", False)
154+
self._clip_grad_norm = cfg.get("clip_grad_norm", None)
154155

155-
if self._gradient_accumulation_steps > 1 and self._optimizer_in_bwd:
156-
raise RuntimeError(
157-
"Gradient accumulation is not supported with optimizer in bwd."
158-
"Please set gradient_accumulation_steps=1, or optimizer_in_bwd=False."
159-
)
156+
# Optimizer in backward is not compatible with gradient accumulation or gradient clipping
157+
if self._optimizer_in_bwd:
158+
if self._clip_grad_norm is not None:
159+
raise RuntimeError(
160+
"Gradient clipping is not supported with optimizer in bwd."
161+
"Please set clip_grad_norm=None, or optimizer_in_bwd=False."
162+
)
163+
if self._gradient_accumulation_steps > 1:
164+
raise RuntimeError(
165+
"Gradient accumulation is not supported with optimizer in bwd."
166+
"Please set gradient_accumulation_steps=1, or optimizer_in_bwd=False."
167+
)
160168

161169
# activation checkpointing/offloading
162170
self._enable_activation_checkpointing = cfg.get(
@@ -187,7 +195,6 @@ def __init__(self, cfg: DictConfig) -> None:
187195
self.total_epochs = cfg.epochs
188196
self.max_steps_per_epoch = cfg.max_steps_per_epoch
189197
self.global_step = 0
190-
self._clip_grad_norm = cfg.get("clip_grad_norm", None)
191198

192199
def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]:
193200
"""
@@ -796,16 +803,11 @@ def train(self) -> None:
796803
torch.distributed.all_reduce(running_loss)
797804
# Manually scale the gradients from unnormalized loss by total # of tokens
798805
training.scale_grads(self._model, 1 / num_tokens)
799-
if self._clip_grad_norm is not None:
800-
if self._optimizer_in_bwd:
801-
raise NotImplementedError(
802-
"Gradient clipping is not supported after optimizer-in-the-backward."
806+
if self._clip_grad_norm is not None:
807+
grad_norm = torch.nn.utils.clip_grad_norm_(
808+
self._model.parameters(),
809+
max_norm=float(self._clip_grad_norm),
803810
)
804-
grad_norm = torch.nn.utils.clip_grad_norm_(
805-
self._model.parameters(),
806-
max_norm=float(self._clip_grad_norm),
807-
)
808-
if not self._optimizer_in_bwd:
809811
self._optimizer.step()
810812
self._optimizer.zero_grad(set_to_none=True)
811813

recipes/full_finetune_single_device.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,20 @@ def __init__(self, cfg: DictConfig) -> None:
141141
self._resume_from_checkpoint = cfg.resume_from_checkpoint
142142
self._gradient_accumulation_steps = cfg.gradient_accumulation_steps
143143
self._optimizer_in_bwd = cfg.optimizer_in_bwd
144+
self._clip_grad_norm = cfg.get("clip_grad_norm", None)
145+
146+
# Optimizer in backward is not compatible with gradient accumulation or gradient clipping
147+
if self._optimizer_in_bwd:
148+
if self._clip_grad_norm is not None:
149+
raise RuntimeError(
150+
"Gradient clipping is not supported with optimizer in bwd."
151+
"Please set clip_grad_norm=None, or optimizer_in_bwd=False."
152+
)
153+
if self._gradient_accumulation_steps > 1:
154+
raise RuntimeError(
155+
"Gradient accumulation is not supported with optimizer in bwd."
156+
"Please set gradient_accumulation_steps=1, or optimizer_in_bwd=False."
157+
)
144158

145159
# activation checkpointing/offloading
146160
self._enable_activation_checkpointing = cfg.get(
@@ -164,22 +178,13 @@ def __init__(self, cfg: DictConfig) -> None:
164178
"Enabling activation offloading should reduce memory further."
165179
)
166180

167-
# TODO: find a better place / way to perform validation of args that don't yet
168-
# compose with each other.
169-
if self._gradient_accumulation_steps > 1 and self._optimizer_in_bwd:
170-
raise RuntimeError(
171-
"Gradient accumulation is not supported with optimizer in bwd."
172-
"Please set gradient_accumulation_steps=1, or optimizer_in_bwd=False."
173-
)
174-
175181
# These are public properties which are updated by the checkpoint loader
176182
# when ``resume_from_checkpoint`` is `True` or validated in tests
177183
self.seed = training.set_seed(seed=cfg.seed)
178184
self.epochs_run = 0
179185
self.total_epochs = cfg.epochs
180186
self.max_steps_per_epoch = cfg.max_steps_per_epoch
181187
self.global_step = 0
182-
self._clip_grad_norm = cfg.get("clip_grad_norm", None)
183188

184189
def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]:
185190
"""
@@ -692,13 +697,13 @@ def train(self) -> None:
692697

693698
# Step with optimizer
694699
if (idx + 1) % self._gradient_accumulation_steps == 0:
695-
training.scale_grads(self._model, 1 / num_tokens)
696-
if self._clip_grad_norm is not None:
697-
grad_norm = torch.nn.utils.clip_grad_norm_(
698-
self._model.parameters(),
699-
max_norm=float(self._clip_grad_norm),
700-
)
701700
if not self._optimizer_in_bwd:
701+
training.scale_grads(self._model, 1 / num_tokens)
702+
if self._clip_grad_norm is not None:
703+
grad_norm = torch.nn.utils.clip_grad_norm_(
704+
self._model.parameters(),
705+
max_norm=float(self._clip_grad_norm),
706+
)
702707
self._optimizer.step()
703708
self._optimizer.zero_grad(set_to_none=True)
704709

tests/recipes/test_full_finetune_single_device.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@ def _get_test_config_overrides(self):
4646
"lr_scheduler.num_warmup_steps=0",
4747
"lr_scheduler.num_cycles=0",
4848
"log_every_n_steps=1",
49-
"clip_grad_norm=100",
5049
] + dummy_alpaca_dataset_config()
5150

5251
def _fetch_expected_loss_values(self, model_type):
@@ -94,7 +93,6 @@ def test_loss(
9493
--config {config} \
9594
batch_size={micro_batch_size} \
9695
gradient_accumulation_steps={gradient_accumulation_steps} \
97-
optimizer_in_bwd={optimizer_in_bwd} \
9896
output_dir={tmpdir} \
9997
checkpointer._component_={ckpt_component} \
10098
checkpointer.checkpoint_dir='{ckpt_dir}' \
@@ -109,7 +107,14 @@ def test_loss(
109107

110108
model_config = MODEL_TEST_CONFIGS[model_type]
111109
cmd = cmd + self._get_test_config_overrides() + model_config
112-
110+
# "optimizer_in_bwd=True" would free gradient info before clip_grad, causing
111+
# wrong grad_norm, so we only test one of them each time. But loss values
112+
# should be the same.
113+
if not optimizer_in_bwd:
114+
cmd.append("clip_grad_norm=100")
115+
cmd.append("optimizer_in_bwd=False")
116+
else:
117+
cmd.append("optimizer_in_bwd=True")
113118
monkeypatch.setattr(sys, "argv", cmd)
114119
with pytest.raises(SystemExit, match=""):
115120
runpy.run_path(TUNE_PATH, run_name="__main__")

0 commit comments

Comments
 (0)