From 7886accc8406ff3a44b4ef7e2e98265e286b3cbf Mon Sep 17 00:00:00 2001 From: Jackmin801 Date: Sat, 21 Sep 2024 22:15:31 +0000 Subject: [PATCH] port changes from single device case --- recipes/lora_finetune_distributed.py | 50 ++++++++++++++++++++++++++-- 1 file changed, 47 insertions(+), 3 deletions(-) diff --git a/recipes/lora_finetune_distributed.py b/recipes/lora_finetune_distributed.py index 3d5b82f7d..9e4d413ac 100644 --- a/recipes/lora_finetune_distributed.py +++ b/recipes/lora_finetune_distributed.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import contextlib import sys import time @@ -33,7 +34,12 @@ validate_missing_and_unexpected_for_lora, ) from torchtune.recipe_interfaces import FTRecipeInterface -from torchtune.training import DummyProfiler, PROFILER_KEY +from torchtune.training import ( + DummyProfiler, + NoOpManager, + OffloadActivations, + PROFILER_KEY, +) from tqdm import tqdm @@ -52,12 +58,24 @@ class LoRAFinetuneRecipeDistributed(FTRecipeInterface): ``fsdp_reshard_after_forward`` to False (this corresponds to SHARD_GRAD_OP sharding strategy). DDP is currently not supported. Training on CPU is not supported. - - Activation Checkpointing. This can be controlled using the ``activation_checkpointing`` + - Activation Checkpointing. This can be controlled using the ``enable_activation_checkpointing`` flag. Activation checkpointing helps reduce the memory footprint since we no longer keep activations in memory and instead recompute them during the backward pass. This is especially helpful for larger batch sizes when you're memory constrained. But these savings in memory come at the cost of training performance. In most cases training can slow-down quite a bit as a result of this activation recomputation. + + - Activation Offloading. This can be controlled using the ``enable_activation_offloading`` + flag. Activation offloading is a technique similar to activations checkpointing that helps + reduce the memory footprint to prevent OOMs on CUDA and enable bigger batches. Where activations + checkpointing drops the activation in the forward to recompute it later in the backward, + activations offloading will drop the activation in the forward to the CPU and bring it + back during the backward pass. As always, there is a tradeoff--these savings in memory can + come at the cost of training performance and CPU resources. To recover some runtime cost, + we've added an option to enable offloading on a different stream to permit overlapping with + the computation. This option is currently only available on PyTorch nightly 2.5.0.dev20240907 + or later and will be enabled by default if an acceptable torch version is found. Activation + offloading can be used in conjunction with activation checkpointing. - Precision. Full fp32 and bf16 training are supported. Precision is controlled using the ``dtype`` flag. When ``dtype=bf16``, all activations, gradients and optimizer states are in bfloat16. In @@ -104,6 +122,7 @@ class LoRAFinetuneRecipeDistributed(FTRecipeInterface): ValueError: If ``dtype`` is set to fp16. ValueError: If world_size is 1 RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16. + RuntimeError: If ``enable_activation_offloading`` is True and device is not CUDA. """ def __init__(self, cfg: DictConfig) -> None: @@ -128,6 +147,11 @@ def __init__(self, cfg: DictConfig) -> None: # training attributes self._enable_activation_checkpointing = cfg.enable_activation_checkpointing + self._enable_activation_offloading = cfg.get("enable_activation_offloading", False) + if self._enable_activation_offloading and self._device.type != "cuda": + raise RuntimeError( + "enable_activation_offloading should only be enabled for training on CUDA" + ) # These attributes constitute the recipe state and are updated by ``load_checkpoint`` # when ``resume_from_checkpoint`` is ``True`` @@ -223,6 +247,7 @@ def setup(self, cfg: DictConfig) -> None: self._model = self._setup_model( cfg_model=cfg.model, enable_activation_checkpointing=cfg.enable_activation_checkpointing, + enable_activation_offloading=self._enable_activation_offloading, fsdp_cpu_offload=cfg.get("fsdp_cpu_offload", False), reshard_after_forward=cfg.get("fsdp_reshard_after_forward", True), base_model_state_dict=checkpoint_dict[training.MODEL_KEY], @@ -368,6 +393,7 @@ def _setup_model( self, cfg_model: DictConfig, enable_activation_checkpointing: bool, + enable_activation_offloading: bool, fsdp_cpu_offload: bool, reshard_after_forward: bool, base_model_state_dict: Dict[str, Any], @@ -487,6 +513,23 @@ def _is_layer_name(name: str, module: nn.Module) -> bool: # Ensure no params and buffers are on meta device training.validate_no_params_on_meta_device(model) + self.activations_handling_ctx = contextlib.nullcontext() + if enable_activation_offloading: + self.activations_handling_ctx = OffloadActivations() + + # Below is our hack to disable offloading the last output Linear in every + # step, as the cost for offloading the activation and then soon after bringing + # it back is expensive. Moreover, due to heuristics in our streaming API, + # we actually use more memory if we offload it as it interferes with chunkedCE. + if hasattr(model, "output") and isinstance(model.output, nn.Module): + noop_ctx = NoOpManager() + model.output.register_forward_pre_hook( + lambda *args: noop_ctx.__enter__() + ) + model.output.register_forward_hook( + lambda *args: noop_ctx.__exit__(), always_call=True + ) + if self._is_rank_zero: log.info( f"Instantiating model and loading checkpoint took {time.perf_counter() - init_start:.2f} secs" @@ -726,7 +769,8 @@ def train(self) -> None: input_pos = ( input_pos.to(self._device) if input_pos is not None else None ) - logits = self._model(tokens, mask=mask, input_pos=input_pos) + with self.activations_handling_ctx: + logits = self._model(tokens, mask=mask, input_pos=input_pos) # Shift labels to compute loss # equivalent to doing labels[..., 1:] and logits[..., :-1, :]