Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feat] Activation offloading for distributed lora recipe #1645

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 49 additions & 3 deletions recipes/lora_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import contextlib
import sys
import time

Expand Down Expand Up @@ -34,7 +35,12 @@
validate_missing_and_unexpected_for_lora,
)
from torchtune.recipe_interfaces import FTRecipeInterface
from torchtune.training import DummyProfiler, PROFILER_KEY
from torchtune.training import (
DummyProfiler,
NoOpManager,
OffloadActivations,
PROFILER_KEY,
)

from tqdm import tqdm

Expand All @@ -53,13 +59,25 @@ class LoRAFinetuneRecipeDistributed(FTRecipeInterface):
``fsdp_reshard_after_forward`` to False (this corresponds to SHARD_GRAD_OP sharding strategy).
DDP is currently not supported. Training on CPU is not supported.

- Activation Checkpointing. This can be controlled using the ``activation_checkpointing``
- Activation Checkpointing. This can be controlled using the ``enable_activation_checkpointing``
flag. Activation checkpointing helps reduce the memory footprint since we no longer keep
activations in memory and instead recompute them during the backward pass. This is especially
helpful for larger batch sizes when you're memory constrained. But these savings in memory
come at the cost of training performance. In most cases training can slow-down quite a bit as
a result of this activation recomputation.

- Activation Offloading. This can be controlled using the ``enable_activation_offloading``
flag. Activation offloading is a technique similar to activations checkpointing that helps
reduce the memory footprint to prevent OOMs on CUDA and enable bigger batches. Where activations
checkpointing drops the activation in the forward to recompute it later in the backward,
activations offloading will drop the activation in the forward to the CPU and bring it
back during the backward pass. As always, there is a tradeoff--these savings in memory can
come at the cost of training performance and CPU resources. To recover some runtime cost,
we've added an option to enable offloading on a different stream to permit overlapping with
the computation. This option is currently only available on PyTorch nightly 2.5.0.dev20240907
or later and will be enabled by default if an acceptable torch version is found. Activation
offloading can be used in conjunction with activation checkpointing.

- Precision. Full fp32 and bf16 training are supported. Precision is controlled using the ``dtype``
flag. When ``dtype=bf16``, all activations, gradients and optimizer states are in bfloat16. In
most cases this should halve the memory footprint of full precision (fp32) training, without
Expand Down Expand Up @@ -110,6 +128,7 @@ class LoRAFinetuneRecipeDistributed(FTRecipeInterface):
ValueError: If world_size is 1
RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16.
RuntimeError: If ``left_pad_sequence`` is set as the data collator.
RuntimeError: If ``enable_activation_offloading`` is True and device is not CUDA.
"""

def __init__(self, cfg: DictConfig) -> None:
Expand All @@ -134,6 +153,13 @@ def __init__(self, cfg: DictConfig) -> None:

# training attributes
self._enable_activation_checkpointing = cfg.enable_activation_checkpointing
self._enable_activation_offloading = cfg.get(
"enable_activation_offloading", False
)
if self._enable_activation_offloading and self._device.type != "cuda":
raise RuntimeError(
"enable_activation_offloading should only be enabled for training on CUDA"
)

# These attributes constitute the recipe state and are updated by ``load_checkpoint``
# when ``resume_from_checkpoint`` is ``True``
Expand Down Expand Up @@ -230,6 +256,7 @@ def setup(self, cfg: DictConfig) -> None:
self._model = self._setup_model(
cfg_model=cfg.model,
enable_activation_checkpointing=cfg.enable_activation_checkpointing,
enable_activation_offloading=self._enable_activation_offloading,
fsdp_cpu_offload=cfg.get("fsdp_cpu_offload", False),
reshard_after_forward=cfg.get("fsdp_reshard_after_forward", True),
base_model_state_dict=checkpoint_dict[training.MODEL_KEY],
Expand Down Expand Up @@ -377,6 +404,7 @@ def _setup_model(
self,
cfg_model: DictConfig,
enable_activation_checkpointing: bool,
enable_activation_offloading: bool,
fsdp_cpu_offload: bool,
reshard_after_forward: bool,
base_model_state_dict: Dict[str, Any],
Expand Down Expand Up @@ -496,6 +524,23 @@ def _is_layer_name(name: str, module: nn.Module) -> bool:
# Ensure no params and buffers are on meta device
training.validate_no_params_on_meta_device(model)

self.activations_handling_ctx = contextlib.nullcontext()
if enable_activation_offloading:
self.activations_handling_ctx = OffloadActivations()

# Below is our hack to disable offloading the last output Linear in every
# step, as the cost for offloading the activation and then soon after bringing
# it back is expensive. Moreover, due to heuristics in our streaming API,
# we actually use more memory if we offload it as it interferes with chunkedCE.
if hasattr(model, "output") and isinstance(model.output, nn.Module):
noop_ctx = NoOpManager()
model.output.register_forward_pre_hook(
lambda *args: noop_ctx.__enter__()
)
model.output.register_forward_hook(
lambda *args: noop_ctx.__exit__(), always_call=True
)

if self._is_rank_zero:
log.info(
f"Instantiating model and loading checkpoint took {time.perf_counter() - init_start:.2f} secs"
Expand Down Expand Up @@ -733,7 +778,8 @@ def train(self) -> None:
# Shape [b, s], needed for the loss not the model
labels = batch.pop("labels")

logits = self._model(**batch)
with self.activations_handling_ctx:
logits = self._model(**batch)

# Shift labels to compute loss
# equivalent to doing labels[..., 1:] and logits[..., :-1, :]
Expand Down
Loading