Skip to content

Commit 7af77c7

Browse files
authored
[Feat] Activation offloading for distributed lora recipe (#1645)
1 parent 0a3762d commit 7af77c7

File tree

1 file changed

+49
-3
lines changed

1 file changed

+49
-3
lines changed

recipes/lora_finetune_distributed.py

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import contextlib
78
import sys
89
import time
910

@@ -34,7 +35,12 @@
3435
validate_missing_and_unexpected_for_lora,
3536
)
3637
from torchtune.recipe_interfaces import FTRecipeInterface
37-
from torchtune.training import DummyProfiler, PROFILER_KEY
38+
from torchtune.training import (
39+
DummyProfiler,
40+
NoOpManager,
41+
OffloadActivations,
42+
PROFILER_KEY,
43+
)
3844

3945
from tqdm import tqdm
4046

@@ -53,13 +59,25 @@ class LoRAFinetuneRecipeDistributed(FTRecipeInterface):
5359
``fsdp_reshard_after_forward`` to False (this corresponds to SHARD_GRAD_OP sharding strategy).
5460
DDP is currently not supported. Training on CPU is not supported.
5561
56-
- Activation Checkpointing. This can be controlled using the ``activation_checkpointing``
62+
- Activation Checkpointing. This can be controlled using the ``enable_activation_checkpointing``
5763
flag. Activation checkpointing helps reduce the memory footprint since we no longer keep
5864
activations in memory and instead recompute them during the backward pass. This is especially
5965
helpful for larger batch sizes when you're memory constrained. But these savings in memory
6066
come at the cost of training performance. In most cases training can slow-down quite a bit as
6167
a result of this activation recomputation.
6268
69+
- Activation Offloading. This can be controlled using the ``enable_activation_offloading``
70+
flag. Activation offloading is a technique similar to activations checkpointing that helps
71+
reduce the memory footprint to prevent OOMs on CUDA and enable bigger batches. Where activations
72+
checkpointing drops the activation in the forward to recompute it later in the backward,
73+
activations offloading will drop the activation in the forward to the CPU and bring it
74+
back during the backward pass. As always, there is a tradeoff--these savings in memory can
75+
come at the cost of training performance and CPU resources. To recover some runtime cost,
76+
we've added an option to enable offloading on a different stream to permit overlapping with
77+
the computation. This option is currently only available on PyTorch nightly 2.5.0.dev20240907
78+
or later and will be enabled by default if an acceptable torch version is found. Activation
79+
offloading can be used in conjunction with activation checkpointing.
80+
6381
- Precision. Full fp32 and bf16 training are supported. Precision is controlled using the ``dtype``
6482
flag. When ``dtype=bf16``, all activations, gradients and optimizer states are in bfloat16. In
6583
most cases this should halve the memory footprint of full precision (fp32) training, without
@@ -110,6 +128,7 @@ class LoRAFinetuneRecipeDistributed(FTRecipeInterface):
110128
ValueError: If world_size is 1
111129
RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16.
112130
RuntimeError: If ``left_pad_sequence`` is set as the data collator.
131+
RuntimeError: If ``enable_activation_offloading`` is True and device is not CUDA.
113132
"""
114133

115134
def __init__(self, cfg: DictConfig) -> None:
@@ -134,6 +153,13 @@ def __init__(self, cfg: DictConfig) -> None:
134153

135154
# training attributes
136155
self._enable_activation_checkpointing = cfg.enable_activation_checkpointing
156+
self._enable_activation_offloading = cfg.get(
157+
"enable_activation_offloading", False
158+
)
159+
if self._enable_activation_offloading and self._device.type != "cuda":
160+
raise RuntimeError(
161+
"enable_activation_offloading should only be enabled for training on CUDA"
162+
)
137163

138164
# These attributes constitute the recipe state and are updated by ``load_checkpoint``
139165
# when ``resume_from_checkpoint`` is ``True``
@@ -230,6 +256,7 @@ def setup(self, cfg: DictConfig) -> None:
230256
self._model = self._setup_model(
231257
cfg_model=cfg.model,
232258
enable_activation_checkpointing=cfg.enable_activation_checkpointing,
259+
enable_activation_offloading=self._enable_activation_offloading,
233260
fsdp_cpu_offload=cfg.get("fsdp_cpu_offload", False),
234261
reshard_after_forward=cfg.get("fsdp_reshard_after_forward", True),
235262
base_model_state_dict=checkpoint_dict[training.MODEL_KEY],
@@ -377,6 +404,7 @@ def _setup_model(
377404
self,
378405
cfg_model: DictConfig,
379406
enable_activation_checkpointing: bool,
407+
enable_activation_offloading: bool,
380408
fsdp_cpu_offload: bool,
381409
reshard_after_forward: bool,
382410
base_model_state_dict: Dict[str, Any],
@@ -496,6 +524,23 @@ def _is_layer_name(name: str, module: nn.Module) -> bool:
496524
# Ensure no params and buffers are on meta device
497525
training.validate_no_params_on_meta_device(model)
498526

527+
self.activations_handling_ctx = contextlib.nullcontext()
528+
if enable_activation_offloading:
529+
self.activations_handling_ctx = OffloadActivations()
530+
531+
# Below is our hack to disable offloading the last output Linear in every
532+
# step, as the cost for offloading the activation and then soon after bringing
533+
# it back is expensive. Moreover, due to heuristics in our streaming API,
534+
# we actually use more memory if we offload it as it interferes with chunkedCE.
535+
if hasattr(model, "output") and isinstance(model.output, nn.Module):
536+
noop_ctx = NoOpManager()
537+
model.output.register_forward_pre_hook(
538+
lambda *args: noop_ctx.__enter__()
539+
)
540+
model.output.register_forward_hook(
541+
lambda *args: noop_ctx.__exit__(), always_call=True
542+
)
543+
499544
if self._is_rank_zero:
500545
log.info(
501546
f"Instantiating model and loading checkpoint took {time.perf_counter() - init_start:.2f} secs"
@@ -733,7 +778,8 @@ def train(self) -> None:
733778
# Shape [b, s], needed for the loss not the model
734779
labels = batch.pop("labels")
735780

736-
logits = self._model(**batch)
781+
with self.activations_handling_ctx:
782+
logits = self._model(**batch)
737783

738784
# Shift labels to compute loss
739785
# equivalent to doing labels[..., 1:] and logits[..., :-1, :]

0 commit comments

Comments
 (0)