44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
66
7+ import contextlib
78import sys
89import time
910
3435 validate_missing_and_unexpected_for_lora ,
3536)
3637from torchtune .recipe_interfaces import FTRecipeInterface
37- from torchtune .training import DummyProfiler , PROFILER_KEY
38+ from torchtune .training import (
39+ DummyProfiler ,
40+ NoOpManager ,
41+ OffloadActivations ,
42+ PROFILER_KEY ,
43+ )
3844
3945from tqdm import tqdm
4046
@@ -53,13 +59,25 @@ class LoRAFinetuneRecipeDistributed(FTRecipeInterface):
5359 ``fsdp_reshard_after_forward`` to False (this corresponds to SHARD_GRAD_OP sharding strategy).
5460 DDP is currently not supported. Training on CPU is not supported.
5561
56- - Activation Checkpointing. This can be controlled using the ``activation_checkpointing ``
62+ - Activation Checkpointing. This can be controlled using the ``enable_activation_checkpointing ``
5763 flag. Activation checkpointing helps reduce the memory footprint since we no longer keep
5864 activations in memory and instead recompute them during the backward pass. This is especially
5965 helpful for larger batch sizes when you're memory constrained. But these savings in memory
6066 come at the cost of training performance. In most cases training can slow-down quite a bit as
6167 a result of this activation recomputation.
6268
69+ - Activation Offloading. This can be controlled using the ``enable_activation_offloading``
70+ flag. Activation offloading is a technique similar to activations checkpointing that helps
71+ reduce the memory footprint to prevent OOMs on CUDA and enable bigger batches. Where activations
72+ checkpointing drops the activation in the forward to recompute it later in the backward,
73+ activations offloading will drop the activation in the forward to the CPU and bring it
74+ back during the backward pass. As always, there is a tradeoff--these savings in memory can
75+ come at the cost of training performance and CPU resources. To recover some runtime cost,
76+ we've added an option to enable offloading on a different stream to permit overlapping with
77+ the computation. This option is currently only available on PyTorch nightly 2.5.0.dev20240907
78+ or later and will be enabled by default if an acceptable torch version is found. Activation
79+ offloading can be used in conjunction with activation checkpointing.
80+
6381 - Precision. Full fp32 and bf16 training are supported. Precision is controlled using the ``dtype``
6482 flag. When ``dtype=bf16``, all activations, gradients and optimizer states are in bfloat16. In
6583 most cases this should halve the memory footprint of full precision (fp32) training, without
@@ -110,6 +128,7 @@ class LoRAFinetuneRecipeDistributed(FTRecipeInterface):
110128 ValueError: If world_size is 1
111129 RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16.
112130 RuntimeError: If ``left_pad_sequence`` is set as the data collator.
131+ RuntimeError: If ``enable_activation_offloading`` is True and device is not CUDA.
113132 """
114133
115134 def __init__ (self , cfg : DictConfig ) -> None :
@@ -134,6 +153,13 @@ def __init__(self, cfg: DictConfig) -> None:
134153
135154 # training attributes
136155 self ._enable_activation_checkpointing = cfg .enable_activation_checkpointing
156+ self ._enable_activation_offloading = cfg .get (
157+ "enable_activation_offloading" , False
158+ )
159+ if self ._enable_activation_offloading and self ._device .type != "cuda" :
160+ raise RuntimeError (
161+ "enable_activation_offloading should only be enabled for training on CUDA"
162+ )
137163
138164 # These attributes constitute the recipe state and are updated by ``load_checkpoint``
139165 # when ``resume_from_checkpoint`` is ``True``
@@ -230,6 +256,7 @@ def setup(self, cfg: DictConfig) -> None:
230256 self ._model = self ._setup_model (
231257 cfg_model = cfg .model ,
232258 enable_activation_checkpointing = cfg .enable_activation_checkpointing ,
259+ enable_activation_offloading = self ._enable_activation_offloading ,
233260 fsdp_cpu_offload = cfg .get ("fsdp_cpu_offload" , False ),
234261 reshard_after_forward = cfg .get ("fsdp_reshard_after_forward" , True ),
235262 base_model_state_dict = checkpoint_dict [training .MODEL_KEY ],
@@ -377,6 +404,7 @@ def _setup_model(
377404 self ,
378405 cfg_model : DictConfig ,
379406 enable_activation_checkpointing : bool ,
407+ enable_activation_offloading : bool ,
380408 fsdp_cpu_offload : bool ,
381409 reshard_after_forward : bool ,
382410 base_model_state_dict : Dict [str , Any ],
@@ -496,6 +524,23 @@ def _is_layer_name(name: str, module: nn.Module) -> bool:
496524 # Ensure no params and buffers are on meta device
497525 training .validate_no_params_on_meta_device (model )
498526
527+ self .activations_handling_ctx = contextlib .nullcontext ()
528+ if enable_activation_offloading :
529+ self .activations_handling_ctx = OffloadActivations ()
530+
531+ # Below is our hack to disable offloading the last output Linear in every
532+ # step, as the cost for offloading the activation and then soon after bringing
533+ # it back is expensive. Moreover, due to heuristics in our streaming API,
534+ # we actually use more memory if we offload it as it interferes with chunkedCE.
535+ if hasattr (model , "output" ) and isinstance (model .output , nn .Module ):
536+ noop_ctx = NoOpManager ()
537+ model .output .register_forward_pre_hook (
538+ lambda * args : noop_ctx .__enter__ ()
539+ )
540+ model .output .register_forward_hook (
541+ lambda * args : noop_ctx .__exit__ (), always_call = True
542+ )
543+
499544 if self ._is_rank_zero :
500545 log .info (
501546 f"Instantiating model and loading checkpoint took { time .perf_counter () - init_start :.2f} secs"
@@ -733,7 +778,8 @@ def train(self) -> None:
733778 # Shape [b, s], needed for the loss not the model
734779 labels = batch .pop ("labels" )
735780
736- logits = self ._model (** batch )
781+ with self .activations_handling_ctx :
782+ logits = self ._model (** batch )
737783
738784 # Shift labels to compute loss
739785 # equivalent to doing labels[..., 1:] and logits[..., :-1, :]
0 commit comments