Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

do NOT review - offloading for other recipes #1578

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions docs/source/tutorials/memory_optimizations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,35 @@ and in most cases training can slow-down quite a bit as a result of this activat
To enable activation checkpointing, use the ``enable_activation_checkpointing`` config entry or flag
in any of our recipes, e.g. ``enable_activation_checkpointing=True``.

.. _glossary_act_off:

Activation Offloading
---------------------

*What's going on here?*

You may have just read about activation checkpointing! Similar to checkpointing, offloading is a memory
efficiency technique that allows saving GPU VRAM by temporarily moving activations to CPU and bringing
them back when needed in the backward pass.

See `PyTorch autograd hook tutorial <https://pytorch.org/tutorials/intermediate/autograd_saved_tensors_hooks_tutorial.html#saving-tensors-to-cpu>`_
for more details about how this is implemented through saved_tensors_hooks.

This setting is especially helpful for larger batch sizes, or longer context lengths when you're memory constrained.
However, these savings in memory can come at the cost of training speed (i.e. tokens per-second), as it takes runtime
and resources to move Tensors from GPU to CPU and back. The implementation in torchtune has the ``offload_with_streams``
option to use multiple CUDA streams in order to overlap the extra communication with the computation to hide the extra
runtime. As the communication workload is variable depending on the number and size of tensors being offloaded, it is
common to not offload every single activation. In fact, once can use offloading in conjunction with activations
checkpointing, where all activations will either be recomputed later in the backward or brought back from the CPU.

*Sounds great! How do I use it?*

To enable activation offloading, use the ``enable_activation_offloading`` config entry or flag
in our lora finetuning single device recipe, e.g. ``enable_activation_offloading=True``. To allow
usage of streams, make sure you are on a torch version later than PyTorch 2.5.0.dev20240907 and
specify ``offload_with_streams=True``.

.. _glossary_grad_accm:

Gradient Accumulation
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ dependencies = [
"numpy<=1.26.4", # Pin here until https://github.com/tensorflow/tensorboard/issues/6869 is addressed
"tqdm",
"omegaconf",
"psutil",

]
dynamic = ["version"]
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/code_llama2/7B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ device: cuda

# Memory management
enable_activation_checkpointing: True
enable_activation_offloading: False
dtype: bf16

# Logging
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/code_llama2/7B_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ device: cuda

# Memory management
enable_activation_checkpointing: True
enable_activation_offloading: False
dtype: bf16

# Logging
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/gemma/2B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ device: cuda

# Memory management
enable_activation_checkpointing: True
enable_activation_offloading: False

# Reduced precision
dtype: bf16
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/gemma/2B_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ device: cuda

# Memory management
enable_activation_checkpointing: True
enable_activation_offloading: False

# Reduced precision
dtype: bf16
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/gemma/7B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ device: cuda

# Memory management
enable_activation_checkpointing: True
enable_activation_offloading: False

# Reduced precision
dtype: bf16
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/gemma/7B_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ device: cuda

# Memory management
enable_activation_checkpointing: True
enable_activation_offloading: False

# Reduced precision
dtype: bf16
Expand Down
2 changes: 2 additions & 0 deletions recipes/configs/llama2/13B_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,9 @@ log_peak_memory_stats: False
# Environment
device: cuda
dtype: bf16

enable_activation_checkpointing: True
enable_activation_offloading: False

# Show case the usage of pytorch profiler
# Set enabled to False as it's only needed for debugging training
Expand Down
3 changes: 3 additions & 0 deletions recipes/configs/llama2/7B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,10 @@ log_peak_memory_stats: False
# Environment
device: cuda
dtype: bf16

# Activations Memory
enable_activation_checkpointing: True
enable_activation_offloading: False

# Show case the usage of pytorch profiler
# Set enabled to False as it's only needed for debugging training
Expand Down
3 changes: 3 additions & 0 deletions recipes/configs/llama2/7B_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,10 @@ log_peak_memory_stats: False
# Environment
device: cuda
dtype: bf16

# Activations Memory
enable_activation_checkpointing: True
enable_activation_offloading: False

# Show case the usage of pytorch profiler
# Set enabled to False as it's only needed for debugging training
Expand Down
3 changes: 3 additions & 0 deletions recipes/configs/llama3/8B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,10 @@ log_peak_memory_stats: False
# Environment
device: cuda
dtype: bf16

# Activations Memory
enable_activation_checkpointing: True
enable_activation_offloading: False

# Profiler (disabled)
profiler:
Expand Down
3 changes: 3 additions & 0 deletions recipes/configs/llama3/8B_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,10 @@ log_peak_memory_stats: False
# Environment
device: cuda
dtype: bf16

# Activations Memory
enable_activation_checkpointing: True
enable_activation_offloading: True

# Profiler (disabled)
profiler:
Expand Down
3 changes: 3 additions & 0 deletions recipes/configs/llama3_1/8B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,10 @@ log_peak_memory_stats: False
# Environment
device: cuda
dtype: bf16

# Activations Memory
enable_activation_checkpointing: True
enable_activation_offloading: False

# Profiler (disabled)
profiler:
Expand Down
3 changes: 3 additions & 0 deletions recipes/configs/llama3_1/8B_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,10 @@ log_peak_memory_stats: False
# Environment
device: cuda
dtype: bf16

# Activations Offloading
enable_activation_checkpointing: True
enable_activation_offloading: False

# Profiler (disabled)
profiler:
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/mistral/7B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ device: cuda

# Memory management
enable_activation_checkpointing: True
enable_activation_offloading: False

# Reduced precision
dtype: bf16
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/mistral/7B_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ device: cuda

# Memory management
enable_activation_checkpointing: True
enable_activation_offloading: False

# Reduced precision
dtype: bf16
Expand Down
3 changes: 3 additions & 0 deletions recipes/configs/phi3/mini_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ device: cuda

# Memory management
enable_activation_checkpointing: True
enable_activation_offloading: False

# Reduced precision
dtype: bf16

# Logging
Expand Down
3 changes: 3 additions & 0 deletions recipes/configs/phi3/mini_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ device: cuda

# Memory management
enable_activation_checkpointing: True
enable_activation_offloading: False

# Reduced precision
dtype: bf16

# Logging
Expand Down
3 changes: 3 additions & 0 deletions recipes/configs/qwen2/0.5B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,10 @@ log_peak_memory_stats: False
# Environment
device: cuda
dtype: bf16

# Activations Memory
enable_activation_checkpointing: True
enable_activation_offloading: False

# Show case the usage of pytorch profiler
# Set enabled to False as it's only needed for debugging training
Expand Down
3 changes: 3 additions & 0 deletions recipes/configs/qwen2/1.5B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,10 @@ log_peak_memory_stats: False
# Environment
device: cuda
dtype: bf16

# Activations Memory
enable_activation_checkpointing: True
enable_activation_offloading: False

# Show case the usage of pytorch profiler
# Set enabled to False as it's only needed for debugging training
Expand Down
3 changes: 3 additions & 0 deletions recipes/configs/qwen2/7B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,10 @@ log_peak_memory_stats: False
# Environment
device: cuda
dtype: bf16

# Activations Offloading
enable_activation_checkpointing: True
enable_activation_offloading: False

# Show case the usage of pytorch profiler
# Set enabled to False as it's only needed for debugging training
Expand Down
49 changes: 46 additions & 3 deletions recipes/full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import contextlib
import sys
import time

Expand All @@ -23,7 +24,13 @@
from torchtune.data import padded_collate_packed, padded_collate_sft
from torchtune.datasets import ConcatDataset
from torchtune.recipe_interfaces import FTRecipeInterface
from torchtune.training import DummyProfiler, PROFILER_KEY

from torchtune.training import (
DummyProfiler,
NoOpManager,
OffloadActivations,
PROFILER_KEY,
)
from torchtune.training.activations import apply_selective_activation_checkpointing

from tqdm import tqdm
Expand All @@ -43,13 +50,25 @@ class FullFinetuneRecipeDistributed(FTRecipeInterface):
``fsdp_reshard_after_forward`` to False (this corresponds to SHARD_GRAD_OP sharding strategy).
DDP is currently not supported. Training on CPU is not supported.

- Activation Checkpointing. This can be controlled using the ``activation_checkpointing``
- Activation Checkpointing. This can be controlled using the ``enable_activation_checkpointing``
flag. Activation checkpointing helps reduce the memory footprint since we no longer keep
activations in memory and instead recompute them during the backward pass. This is especially
helpful for larger batch sizes when you're memory constrained. But these savings in memory
come at the cost of training performance. In most cases training can slow-down quite a bit as
a result of this activation recomputation.

- Activation Offloading. This can be controlled using the ``enable_activation_offloading``
flag. Activation offloading is a technique similar to activations checkpointing that helps
reduce the memory footprint to prevent OOMs and enable bigger batches. Where activations
checkpointing drops the activation in the forward to recompute it later in the backward,
activations offloading will drop the activation in the forward to the CPU and bring it
back during the backward pass. As always, there is a tradeoff--these savings in memory can
come at the cost of training performance and CPU resources. To recover some runtime cost,
specify ``offload_with_streams: True`` to enable offloading on a different stream to permit
overlapping with the computation. This option is currently only available on PyTorch nightly
version 2.5.0.dev20240907 or later. Activation offloading can be used in conjunction with
activation checkpointing.

- Precision. Full fp32 and bf16 training are supported. Precision is controlled using the ``dtype``
flag. When ``dtype=bf16``, all activations, gradients and optimizer states are in bfloat16. In
most cases this should halve the memory footprint of full precision (fp32) training, without
Expand Down Expand Up @@ -204,6 +223,8 @@ def setup(self, cfg: DictConfig) -> None:
self._model = self._setup_model(
cfg_model=cfg.model,
enable_activation_checkpointing=cfg.enable_activation_checkpointing,
enable_activation_offloading=cfg.get("enable_activation_offloading", False),
offload_with_streams=cfg.get("offload_with_streams", False),
custom_sharded_layers=cfg.get("custom_sharded_layers", None),
fsdp_cpu_offload=cfg.get("fsdp_cpu_offload", False),
reshard_after_forward=cfg.get("fsdp_reshard_after_forward", True),
Expand Down Expand Up @@ -338,6 +359,8 @@ def _setup_model(
self,
cfg_model: DictConfig,
enable_activation_checkpointing: bool,
enable_activation_offloading: bool,
offload_with_streams: bool,
custom_sharded_layers: Optional[List[str]],
fsdp_cpu_offload: bool,
reshard_after_forward: bool,
Expand Down Expand Up @@ -434,6 +457,25 @@ def _is_layer_fqn(s: str) -> bool:
# Ensure no params and buffers are on meta device
training.validate_no_params_on_meta_device(model)

self.activations_handling_ctx = contextlib.nullcontext()
if enable_activation_offloading:
self.activations_handling_ctx = OffloadActivations(
use_streams=offload_with_streams
)

# Below is our hack to disable offloading the last output Linear in every
# step, as the cost for offloading the activation and then soon after bringing
# it back is expensive. Moreover, due to heuristics in our streaming API,
# we actually use more memory if we offload it as it interferes with chunkedCE.
if hasattr(model, "output") and isinstance(model.output, nn.Module):
noop_ctx = NoOpManager()
model.output.register_forward_pre_hook(
lambda *args: noop_ctx.__enter__()
)
model.output.register_forward_hook(
lambda *args: noop_ctx.__exit__(), always_call=True
)

if self._is_rank_zero:
log.info(
f"Instantiating model and loading checkpoint took {time.perf_counter() - init_start:.2f} secs"
Expand Down Expand Up @@ -626,7 +668,8 @@ def train(self) -> None:
input_pos.to(self._device) if input_pos is not None else None
)

logits = self._model(tokens, mask=mask, input_pos=input_pos)
with self.activations_handling_ctx:
logits = self._model(tokens, mask=mask, input_pos=input_pos)

# Shift labels to compute loss
# equivalent to doing labels[..., 1:] and logits[..., :-1, :]
Expand Down
Loading
Loading