diff --git a/docs/source/tutorials/memory_optimizations.rst b/docs/source/tutorials/memory_optimizations.rst index aab23d2a0e..69295e7264 100644 --- a/docs/source/tutorials/memory_optimizations.rst +++ b/docs/source/tutorials/memory_optimizations.rst @@ -83,6 +83,35 @@ and in most cases training can slow-down quite a bit as a result of this activat To enable activation checkpointing, use the ``enable_activation_checkpointing`` config entry or flag in any of our recipes, e.g. ``enable_activation_checkpointing=True``. +.. _glossary_act_off: + +Activation Offloading +--------------------- + +*What's going on here?* + +You may have just read about activation checkpointing! Similar to checkpointing, offloading is a memory +efficiency technique that allows saving GPU VRAM by temporarily moving activations to CPU and bringing +them back when needed in the backward pass. + +See `PyTorch autograd hook tutorial `_ +for more details about how this is implemented through saved_tensors_hooks. + +This setting is especially helpful for larger batch sizes, or longer context lengths when you're memory constrained. +However, these savings in memory can come at the cost of training speed (i.e. tokens per-second), as it takes runtime +and resources to move Tensors from GPU to CPU and back. The implementation in torchtune has the ``offload_with_streams`` +option to use multiple CUDA streams in order to overlap the extra communication with the computation to hide the extra +runtime. As the communication workload is variable depending on the number and size of tensors being offloaded, it is +common to not offload every single activation. In fact, once can use offloading in conjunction with activations +checkpointing, where all activations will either be recomputed later in the backward or brought back from the CPU. + +*Sounds great! How do I use it?* + +To enable activation offloading, use the ``enable_activation_offloading`` config entry or flag +in our lora finetuning single device recipe, e.g. ``enable_activation_offloading=True``. To allow +usage of streams, make sure you are on a torch version later than PyTorch 2.5.0.dev20240907 and +specify ``offload_with_streams=True``. + .. _glossary_grad_accm: Gradient Accumulation diff --git a/pyproject.toml b/pyproject.toml index 322f4238c9..58a0e4ce33 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,7 @@ dependencies = [ "numpy<=1.26.4", # Pin here until https://github.com/tensorflow/tensorboard/issues/6869 is addressed "tqdm", "omegaconf", + "psutil", ] dynamic = ["version"] diff --git a/recipes/configs/code_llama2/7B_lora_single_device.yaml b/recipes/configs/code_llama2/7B_lora_single_device.yaml index ba5d43113d..9bd4ed5ff8 100644 --- a/recipes/configs/code_llama2/7B_lora_single_device.yaml +++ b/recipes/configs/code_llama2/7B_lora_single_device.yaml @@ -74,6 +74,7 @@ device: cuda # Memory management enable_activation_checkpointing: True +enable_activation_offloading: False dtype: bf16 # Logging diff --git a/recipes/configs/code_llama2/7B_qlora_single_device.yaml b/recipes/configs/code_llama2/7B_qlora_single_device.yaml index c2e990bf7b..8af0599e72 100644 --- a/recipes/configs/code_llama2/7B_qlora_single_device.yaml +++ b/recipes/configs/code_llama2/7B_qlora_single_device.yaml @@ -74,6 +74,7 @@ device: cuda # Memory management enable_activation_checkpointing: True +enable_activation_offloading: False dtype: bf16 # Logging diff --git a/recipes/configs/gemma/2B_lora_single_device.yaml b/recipes/configs/gemma/2B_lora_single_device.yaml index 8c322495ce..dafb77de0d 100644 --- a/recipes/configs/gemma/2B_lora_single_device.yaml +++ b/recipes/configs/gemma/2B_lora_single_device.yaml @@ -71,6 +71,7 @@ device: cuda # Memory management enable_activation_checkpointing: True +enable_activation_offloading: False # Reduced precision dtype: bf16 diff --git a/recipes/configs/gemma/2B_qlora_single_device.yaml b/recipes/configs/gemma/2B_qlora_single_device.yaml index 7ed60ce180..d534f22961 100644 --- a/recipes/configs/gemma/2B_qlora_single_device.yaml +++ b/recipes/configs/gemma/2B_qlora_single_device.yaml @@ -71,6 +71,7 @@ device: cuda # Memory management enable_activation_checkpointing: True +enable_activation_offloading: False # Reduced precision dtype: bf16 diff --git a/recipes/configs/gemma/7B_lora_single_device.yaml b/recipes/configs/gemma/7B_lora_single_device.yaml index aa69fa50f8..4b9333239e 100644 --- a/recipes/configs/gemma/7B_lora_single_device.yaml +++ b/recipes/configs/gemma/7B_lora_single_device.yaml @@ -73,6 +73,7 @@ device: cuda # Memory management enable_activation_checkpointing: True +enable_activation_offloading: False # Reduced precision dtype: bf16 diff --git a/recipes/configs/gemma/7B_qlora_single_device.yaml b/recipes/configs/gemma/7B_qlora_single_device.yaml index 8a08c49b5c..6202bef679 100644 --- a/recipes/configs/gemma/7B_qlora_single_device.yaml +++ b/recipes/configs/gemma/7B_qlora_single_device.yaml @@ -73,6 +73,7 @@ device: cuda # Memory management enable_activation_checkpointing: True +enable_activation_offloading: False # Reduced precision dtype: bf16 diff --git a/recipes/configs/llama2/13B_qlora_single_device.yaml b/recipes/configs/llama2/13B_qlora_single_device.yaml index 74c51fd6bc..8fd988f869 100644 --- a/recipes/configs/llama2/13B_qlora_single_device.yaml +++ b/recipes/configs/llama2/13B_qlora_single_device.yaml @@ -81,7 +81,9 @@ log_peak_memory_stats: False # Environment device: cuda dtype: bf16 + enable_activation_checkpointing: True +enable_activation_offloading: False # Show case the usage of pytorch profiler # Set enabled to False as it's only needed for debugging training diff --git a/recipes/configs/llama2/7B_lora_single_device.yaml b/recipes/configs/llama2/7B_lora_single_device.yaml index dba911a0d7..3c67df18a1 100644 --- a/recipes/configs/llama2/7B_lora_single_device.yaml +++ b/recipes/configs/llama2/7B_lora_single_device.yaml @@ -81,7 +81,10 @@ log_peak_memory_stats: False # Environment device: cuda dtype: bf16 + +# Activations Memory enable_activation_checkpointing: True +enable_activation_offloading: False # Show case the usage of pytorch profiler # Set enabled to False as it's only needed for debugging training diff --git a/recipes/configs/llama2/7B_qlora_single_device.yaml b/recipes/configs/llama2/7B_qlora_single_device.yaml index 427cdc50be..de7b26866f 100644 --- a/recipes/configs/llama2/7B_qlora_single_device.yaml +++ b/recipes/configs/llama2/7B_qlora_single_device.yaml @@ -80,7 +80,10 @@ log_peak_memory_stats: False # Environment device: cuda dtype: bf16 + +# Activations Memory enable_activation_checkpointing: True +enable_activation_offloading: False # Show case the usage of pytorch profiler # Set enabled to False as it's only needed for debugging training diff --git a/recipes/configs/llama3/8B_lora_single_device.yaml b/recipes/configs/llama3/8B_lora_single_device.yaml index c35d93b9b4..d4cddd1ffa 100644 --- a/recipes/configs/llama3/8B_lora_single_device.yaml +++ b/recipes/configs/llama3/8B_lora_single_device.yaml @@ -80,7 +80,10 @@ log_peak_memory_stats: False # Environment device: cuda dtype: bf16 + +# Activations Memory enable_activation_checkpointing: True +enable_activation_offloading: False # Profiler (disabled) profiler: diff --git a/recipes/configs/llama3/8B_qlora_single_device.yaml b/recipes/configs/llama3/8B_qlora_single_device.yaml index b06523b4a0..8da48f04b2 100644 --- a/recipes/configs/llama3/8B_qlora_single_device.yaml +++ b/recipes/configs/llama3/8B_qlora_single_device.yaml @@ -79,7 +79,10 @@ log_peak_memory_stats: False # Environment device: cuda dtype: bf16 + +# Activations Memory enable_activation_checkpointing: True +enable_activation_offloading: True # Profiler (disabled) profiler: diff --git a/recipes/configs/llama3_1/8B_lora_single_device.yaml b/recipes/configs/llama3_1/8B_lora_single_device.yaml index a445051a98..88e9766df3 100644 --- a/recipes/configs/llama3_1/8B_lora_single_device.yaml +++ b/recipes/configs/llama3_1/8B_lora_single_device.yaml @@ -83,7 +83,10 @@ log_peak_memory_stats: False # Environment device: cuda dtype: bf16 + +# Activations Memory enable_activation_checkpointing: True +enable_activation_offloading: False # Profiler (disabled) profiler: diff --git a/recipes/configs/llama3_1/8B_qlora_single_device.yaml b/recipes/configs/llama3_1/8B_qlora_single_device.yaml index 6b8b3497c2..1fba4e371d 100644 --- a/recipes/configs/llama3_1/8B_qlora_single_device.yaml +++ b/recipes/configs/llama3_1/8B_qlora_single_device.yaml @@ -82,7 +82,10 @@ log_peak_memory_stats: False # Environment device: cuda dtype: bf16 + +# Activations Offloading enable_activation_checkpointing: True +enable_activation_offloading: False # Profiler (disabled) profiler: diff --git a/recipes/configs/mistral/7B_lora_single_device.yaml b/recipes/configs/mistral/7B_lora_single_device.yaml index 1c461f9f46..c0d72b78f9 100644 --- a/recipes/configs/mistral/7B_lora_single_device.yaml +++ b/recipes/configs/mistral/7B_lora_single_device.yaml @@ -77,6 +77,7 @@ device: cuda # Memory management enable_activation_checkpointing: True +enable_activation_offloading: False # Reduced precision dtype: bf16 diff --git a/recipes/configs/mistral/7B_qlora_single_device.yaml b/recipes/configs/mistral/7B_qlora_single_device.yaml index 54d0906150..aad29ffb12 100644 --- a/recipes/configs/mistral/7B_qlora_single_device.yaml +++ b/recipes/configs/mistral/7B_qlora_single_device.yaml @@ -78,6 +78,7 @@ device: cuda # Memory management enable_activation_checkpointing: True +enable_activation_offloading: False # Reduced precision dtype: bf16 diff --git a/recipes/configs/phi3/mini_lora_single_device.yaml b/recipes/configs/phi3/mini_lora_single_device.yaml index 15e2fc02a5..70f09a1cab 100644 --- a/recipes/configs/phi3/mini_lora_single_device.yaml +++ b/recipes/configs/phi3/mini_lora_single_device.yaml @@ -72,6 +72,9 @@ device: cuda # Memory management enable_activation_checkpointing: True +enable_activation_offloading: False + +# Reduced precision dtype: bf16 # Logging diff --git a/recipes/configs/phi3/mini_qlora_single_device.yaml b/recipes/configs/phi3/mini_qlora_single_device.yaml index 5b3ae9e6a1..8361bc37fd 100644 --- a/recipes/configs/phi3/mini_qlora_single_device.yaml +++ b/recipes/configs/phi3/mini_qlora_single_device.yaml @@ -72,6 +72,9 @@ device: cuda # Memory management enable_activation_checkpointing: True +enable_activation_offloading: False + +# Reduced precision dtype: bf16 # Logging diff --git a/recipes/configs/qwen2/0.5B_lora_single_device.yaml b/recipes/configs/qwen2/0.5B_lora_single_device.yaml index 23bcb1742c..2958f7ff8a 100644 --- a/recipes/configs/qwen2/0.5B_lora_single_device.yaml +++ b/recipes/configs/qwen2/0.5B_lora_single_device.yaml @@ -80,7 +80,10 @@ log_peak_memory_stats: False # Environment device: cuda dtype: bf16 + +# Activations Memory enable_activation_checkpointing: True +enable_activation_offloading: False # Show case the usage of pytorch profiler # Set enabled to False as it's only needed for debugging training diff --git a/recipes/configs/qwen2/1.5B_lora_single_device.yaml b/recipes/configs/qwen2/1.5B_lora_single_device.yaml index 6c77f84d16..1a1e19aa72 100644 --- a/recipes/configs/qwen2/1.5B_lora_single_device.yaml +++ b/recipes/configs/qwen2/1.5B_lora_single_device.yaml @@ -78,7 +78,10 @@ log_peak_memory_stats: False # Environment device: cuda dtype: bf16 + +# Activations Memory enable_activation_checkpointing: True +enable_activation_offloading: False # Show case the usage of pytorch profiler # Set enabled to False as it's only needed for debugging training diff --git a/recipes/configs/qwen2/7B_lora_single_device.yaml b/recipes/configs/qwen2/7B_lora_single_device.yaml index 188066e1e2..dacdf5508a 100644 --- a/recipes/configs/qwen2/7B_lora_single_device.yaml +++ b/recipes/configs/qwen2/7B_lora_single_device.yaml @@ -82,7 +82,10 @@ log_peak_memory_stats: False # Environment device: cuda dtype: bf16 + +# Activations Offloading enable_activation_checkpointing: True +enable_activation_offloading: False # Show case the usage of pytorch profiler # Set enabled to False as it's only needed for debugging training diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index 2112cd886d..674c07e169 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import contextlib import sys import time @@ -23,7 +24,13 @@ from torchtune.data import padded_collate_packed, padded_collate_sft from torchtune.datasets import ConcatDataset from torchtune.recipe_interfaces import FTRecipeInterface -from torchtune.training import DummyProfiler, PROFILER_KEY + +from torchtune.training import ( + DummyProfiler, + NoOpManager, + OffloadActivations, + PROFILER_KEY, +) from torchtune.training.activations import apply_selective_activation_checkpointing from tqdm import tqdm @@ -43,13 +50,25 @@ class FullFinetuneRecipeDistributed(FTRecipeInterface): ``fsdp_reshard_after_forward`` to False (this corresponds to SHARD_GRAD_OP sharding strategy). DDP is currently not supported. Training on CPU is not supported. - - Activation Checkpointing. This can be controlled using the ``activation_checkpointing`` + - Activation Checkpointing. This can be controlled using the ``enable_activation_checkpointing`` flag. Activation checkpointing helps reduce the memory footprint since we no longer keep activations in memory and instead recompute them during the backward pass. This is especially helpful for larger batch sizes when you're memory constrained. But these savings in memory come at the cost of training performance. In most cases training can slow-down quite a bit as a result of this activation recomputation. + - Activation Offloading. This can be controlled using the ``enable_activation_offloading`` + flag. Activation offloading is a technique similar to activations checkpointing that helps + reduce the memory footprint to prevent OOMs and enable bigger batches. Where activations + checkpointing drops the activation in the forward to recompute it later in the backward, + activations offloading will drop the activation in the forward to the CPU and bring it + back during the backward pass. As always, there is a tradeoff--these savings in memory can + come at the cost of training performance and CPU resources. To recover some runtime cost, + specify ``offload_with_streams: True`` to enable offloading on a different stream to permit + overlapping with the computation. This option is currently only available on PyTorch nightly + version 2.5.0.dev20240907 or later. Activation offloading can be used in conjunction with + activation checkpointing. + - Precision. Full fp32 and bf16 training are supported. Precision is controlled using the ``dtype`` flag. When ``dtype=bf16``, all activations, gradients and optimizer states are in bfloat16. In most cases this should halve the memory footprint of full precision (fp32) training, without @@ -204,6 +223,8 @@ def setup(self, cfg: DictConfig) -> None: self._model = self._setup_model( cfg_model=cfg.model, enable_activation_checkpointing=cfg.enable_activation_checkpointing, + enable_activation_offloading=cfg.get("enable_activation_offloading", False), + offload_with_streams=cfg.get("offload_with_streams", False), custom_sharded_layers=cfg.get("custom_sharded_layers", None), fsdp_cpu_offload=cfg.get("fsdp_cpu_offload", False), reshard_after_forward=cfg.get("fsdp_reshard_after_forward", True), @@ -338,6 +359,8 @@ def _setup_model( self, cfg_model: DictConfig, enable_activation_checkpointing: bool, + enable_activation_offloading: bool, + offload_with_streams: bool, custom_sharded_layers: Optional[List[str]], fsdp_cpu_offload: bool, reshard_after_forward: bool, @@ -434,6 +457,25 @@ def _is_layer_fqn(s: str) -> bool: # Ensure no params and buffers are on meta device training.validate_no_params_on_meta_device(model) + self.activations_handling_ctx = contextlib.nullcontext() + if enable_activation_offloading: + self.activations_handling_ctx = OffloadActivations( + use_streams=offload_with_streams + ) + + # Below is our hack to disable offloading the last output Linear in every + # step, as the cost for offloading the activation and then soon after bringing + # it back is expensive. Moreover, due to heuristics in our streaming API, + # we actually use more memory if we offload it as it interferes with chunkedCE. + if hasattr(model, "output") and isinstance(model.output, nn.Module): + noop_ctx = NoOpManager() + model.output.register_forward_pre_hook( + lambda *args: noop_ctx.__enter__() + ) + model.output.register_forward_hook( + lambda *args: noop_ctx.__exit__(), always_call=True + ) + if self._is_rank_zero: log.info( f"Instantiating model and loading checkpoint took {time.perf_counter() - init_start:.2f} secs" @@ -626,7 +668,8 @@ def train(self) -> None: input_pos.to(self._device) if input_pos is not None else None ) - logits = self._model(tokens, mask=mask, input_pos=input_pos) + with self.activations_handling_ctx: + logits = self._model(tokens, mask=mask, input_pos=input_pos) # Shift labels to compute loss # equivalent to doing labels[..., 1:] and logits[..., :-1, :] diff --git a/recipes/full_finetune_single_device.py b/recipes/full_finetune_single_device.py index 94f804ef85..1e35a60c4b 100644 --- a/recipes/full_finetune_single_device.py +++ b/recipes/full_finetune_single_device.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import contextlib import sys import time from functools import partial @@ -21,7 +22,13 @@ from torchtune.data import padded_collate_packed, padded_collate_sft from torchtune.datasets import ConcatDataset from torchtune.recipe_interfaces import FTRecipeInterface -from torchtune.training import DummyProfiler, PROFILER_KEY + +from torchtune.training import ( + DummyProfiler, + NoOpManager, + OffloadActivations, + PROFILER_KEY, +) from tqdm import tqdm @@ -35,13 +42,25 @@ class FullFinetuneRecipeSingleDevice(FTRecipeInterface): for single GPU training. Training on CPU is not supported. Features: - - Activation Checkpointing. This can be controlled using the ``activation_checkpointing`` + - Activation Checkpointing. This can be controlled using the ``enable_activation_checkpointing`` flag. Activation checkpointing helps reduce the memory footprint since we no longer keep activations in memory and instead recompute them during the backward pass. This is especially helpful for larger batch sizes when you're memory constrained. But these savings in memory come at the cost of training performance. In most cases training can slow-down quite a bit as a result of this activation recomputation. + - Activation Offloading. This can be controlled using the ``enable_activation_offloading`` + flag. Activation offloading is a technique similar to activations checkpointing that helps + reduce the memory footprint to prevent OOMs and enable bigger batches. Where activations + checkpointing drops the activation in the forward to recompute it later in the backward, + activations offloading will drop the activation in the forward to the CPU and bring it + back during the backward pass. As always, there is a tradeoff--these savings in memory can + come at the cost of training performance and CPU resources. To recover some runtime cost, + specify ``offload_with_streams: True`` to enable offloading on a different stream to permit + overlapping with the computation. This option is currently only available on PyTorch nightly + version 2.5.0.dev20240907 or later. Activation offloading can be used in conjunction with + activation checkpointing. + - Precision. Full fp32 and bf16 training are supported. Precision is controlled using the ``dtype`` flag. When ``dtype=bf16``, all activations, gradients and optimizer states are in bfloat16. In most cases this should halve the memory footprint of full precision (fp32) training, without @@ -210,6 +229,8 @@ def setup(self, cfg: DictConfig) -> None: self._model = self._setup_model( cfg_model=cfg.model, enable_activation_checkpointing=cfg.enable_activation_checkpointing, + enable_activation_offloading=cfg.get("enable_activation_offloading", False), + offload_with_streams=cfg.get("offload_with_streams", False), compile_model=self._compile, model_state_dict=ckpt_dict[training.MODEL_KEY], ) @@ -343,6 +364,8 @@ def _setup_model( self, cfg_model: DictConfig, enable_activation_checkpointing: bool, + enable_activation_offloading: bool, + offload_with_streams: bool, compile_model: bool, model_state_dict: Dict[str, Any], ) -> nn.Module: @@ -368,6 +391,25 @@ def _setup_model( ) log.info(f"Model is initialized with precision {self._dtype}.") + self.activations_handling_ctx = contextlib.nullcontext() + if enable_activation_offloading: + self.activations_handling_ctx = OffloadActivations( + use_streams=offload_with_streams + ) + + # Below is our hack to disable offloading the last output Linear in every + # step, as the cost for offloading the activation and then soon after bringing + # it back is expensive. Moreover, due to heuristics in our streaming API, + # we actually use more memory if we offload it as it interferes with chunkedCE. + if hasattr(model, "output") and isinstance(model.output, nn.Module): + noop_ctx = NoOpManager() + model.output.register_forward_pre_hook( + lambda *args: noop_ctx.__enter__() + ) + model.output.register_forward_hook( + lambda *args: noop_ctx.__exit__(), always_call=True + ) + if self._device.type == "cuda": memory_stats = training.get_memory_stats(device=self._device) training.log_memory_stats(memory_stats) @@ -500,7 +542,8 @@ def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: mask = batch.get("mask", None) # shape [b, s, s] input_pos = batch.get("input_pos", None) # shape [b, s] - logits = self._model(tokens, mask=mask, input_pos=input_pos) + with self.activations_handling_ctx: + logits = self._model(tokens, mask=mask, input_pos=input_pos) # Shift labels to compute loss # equivalent to doing labels[..., 1:] and logits[..., :-1, :] diff --git a/recipes/lora_finetune_distributed.py b/recipes/lora_finetune_distributed.py index 56d70bb2b2..4e4182f81d 100644 --- a/recipes/lora_finetune_distributed.py +++ b/recipes/lora_finetune_distributed.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import contextlib import sys import time @@ -33,7 +34,13 @@ validate_missing_and_unexpected_for_lora, ) from torchtune.recipe_interfaces import FTRecipeInterface -from torchtune.training import DummyProfiler, PROFILER_KEY + +from torchtune.training import ( + DummyProfiler, + NoOpManager, + OffloadActivations, + PROFILER_KEY, +) from tqdm import tqdm @@ -52,13 +59,25 @@ class LoRAFinetuneRecipeDistributed(FTRecipeInterface): ``fsdp_reshard_after_forward`` to False (this corresponds to SHARD_GRAD_OP sharding strategy). DDP is currently not supported. Training on CPU is not supported. - - Activation Checkpointing. This can be controlled using the ``activation_checkpointing`` + - Activation Checkpointing. This can be controlled using the ``enable_activation_checkpointing`` flag. Activation checkpointing helps reduce the memory footprint since we no longer keep activations in memory and instead recompute them during the backward pass. This is especially helpful for larger batch sizes when you're memory constrained. But these savings in memory come at the cost of training performance. In most cases training can slow-down quite a bit as a result of this activation recomputation. + - Activation Offloading. This can be controlled using the ``enable_activation_offloading`` + flag. Activation offloading is a technique similar to activations checkpointing that helps + reduce the memory footprint to prevent OOMs and enable bigger batches. Where activations + checkpointing drops the activation in the forward to recompute it later in the backward, + activations offloading will drop the activation in the forward to the CPU and bring it + back during the backward pass. As always, there is a tradeoff--these savings in memory can + come at the cost of training performance and CPU resources. To recover some runtime cost, + specify ``offload_with_streams: True`` to enable offloading on a different stream to permit + overlapping with the computation. This option is currently only available on PyTorch nightly + version 2.5.0.dev20240907 or later. Activation offloading can be used in conjunction with + activation checkpointing. + - Precision. Full fp32 and bf16 training are supported. Precision is controlled using the ``dtype`` flag. When ``dtype=bf16``, all activations, gradients and optimizer states are in bfloat16. In most cases this should halve the memory footprint of full precision (fp32) training, without @@ -223,6 +242,8 @@ def setup(self, cfg: DictConfig) -> None: self._model = self._setup_model( cfg_model=cfg.model, enable_activation_checkpointing=cfg.enable_activation_checkpointing, + enable_activation_offloading=cfg.get("enable_activation_offloading", False), + offload_with_streams=cfg.get("offload_with_streams", False), fsdp_cpu_offload=cfg.get("fsdp_cpu_offload", False), reshard_after_forward=cfg.get("fsdp_reshard_after_forward", True), base_model_state_dict=checkpoint_dict[training.MODEL_KEY], @@ -368,6 +389,8 @@ def _setup_model( self, cfg_model: DictConfig, enable_activation_checkpointing: bool, + enable_activation_offloading: bool, + offload_with_streams: bool, fsdp_cpu_offload: bool, reshard_after_forward: bool, base_model_state_dict: Dict[str, Any], @@ -487,6 +510,25 @@ def _is_layer_name(name: str, module: nn.Module) -> bool: # Ensure no params and buffers are on meta device training.validate_no_params_on_meta_device(model) + self.activations_handling_ctx = contextlib.nullcontext() + if enable_activation_offloading: + self.activations_handling_ctx = OffloadActivations( + use_streams=offload_with_streams + ) + + # Below is our hack to disable offloading the last output Linear in every + # step, as the cost for offloading the activation and then soon after bringing + # it back is expensive. Moreover, due to heuristics in our streaming API, + # we actually use more memory if we offload it as it interferes with chunkedCE. + if hasattr(model, "output") and isinstance(model.output, nn.Module): + noop_ctx = NoOpManager() + model.output.register_forward_pre_hook( + lambda *args: noop_ctx.__enter__() + ) + model.output.register_forward_hook( + lambda *args: noop_ctx.__exit__(), always_call=True + ) + if self._is_rank_zero: log.info( f"Instantiating model and loading checkpoint took {time.perf_counter() - init_start:.2f} secs" @@ -723,7 +765,8 @@ def train(self) -> None: input_pos = ( input_pos.to(self._device) if input_pos is not None else None ) - logits = self._model(tokens, mask=mask, input_pos=input_pos) + with self.activations_handling_ctx: + logits = self._model(tokens, mask=mask, input_pos=input_pos) # Shift labels to compute loss # equivalent to doing labels[..., 1:] and logits[..., :-1, :] diff --git a/recipes/lora_finetune_single_device.py b/recipes/lora_finetune_single_device.py index 0862675a77..a8322a05f9 100644 --- a/recipes/lora_finetune_single_device.py +++ b/recipes/lora_finetune_single_device.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import contextlib import sys import time @@ -30,8 +31,12 @@ validate_missing_and_unexpected_for_lora, ) from torchtune.recipe_interfaces import FTRecipeInterface -from torchtune.training import DummyProfiler, PROFILER_KEY - +from torchtune.training import ( + DummyProfiler, + NoOpManager, + OffloadActivations, + PROFILER_KEY, +) from tqdm import tqdm log = utils.get_logger("DEBUG") @@ -43,13 +48,25 @@ class LoRAFinetuneRecipeSingleDevice(FTRecipeInterface): for single GPU training. Training on CPU is not supported. Features: - - Activation Checkpointing. This can be controlled using the ``activation_checkpointing`` + - Activation Checkpointing. This can be controlled using the ``enable_activation_checkpointing`` flag. Activation checkpointing helps reduce the memory footprint since we no longer keep activations in memory and instead recompute them during the backward pass. This is especially helpful for larger batch sizes when you're memory constrained. But these savings in memory come at the cost of training performance. In most cases training can slow-down quite a bit as a result of this activation recomputation. + - Activation Offloading. This can be controlled using the ``enable_activation_offloading`` + flag. Activation offloading is a technique similar to activations checkpointing that helps + reduce the memory footprint to prevent OOMs and enable bigger batches. Where activations + checkpointing drops the activation in the forward to recompute it later in the backward, + activations offloading will drop the activation in the forward to the CPU and bring it + back during the backward pass. As always, there is a tradeoff--these savings in memory can + come at the cost of training performance and CPU resources. To recover some runtime cost, + specify ``offload_with_streams: True`` to enable offloading on a different stream to permit + overlapping with the computation. This option is currently only available on PyTorch nightly + version 2.5.0.dev20240907 or later. Activation offloading can be used in conjunction with + activation checkpointing. + - Precision. Full fp32 and bf16 training are supported. Precision is controlled using the ``dtype`` flag. When ``dtype=bf16``, all activations, gradients and optimizer states are in bfloat16. In most cases this should halve the memory footprint of full precision (fp32) training, without @@ -222,6 +239,8 @@ def setup(self, cfg: DictConfig) -> None: self._model = self._setup_model( cfg_model=cfg.model, enable_activation_checkpointing=cfg.enable_activation_checkpointing, + enable_activation_offloading=cfg.get("enable_activation_offloading", False), + offload_with_streams=cfg.get("offload_with_streams", False), compile_model=cfg.compile, base_model_state_dict=checkpoint_dict[training.MODEL_KEY], lora_weights_state_dict=( @@ -367,6 +386,8 @@ def _setup_model( self, cfg_model: DictConfig, enable_activation_checkpointing: bool, + enable_activation_offloading: bool, + offload_with_streams: bool, compile_model: bool, base_model_state_dict: Dict[str, Any], lora_weights_state_dict: Optional[Dict[str, Any]] = None, @@ -420,6 +441,25 @@ def _setup_model( self.adapter_params.items(), dtype=self._dtype ) + self.activations_handling_ctx = contextlib.nullcontext() + if enable_activation_offloading: + self.activations_handling_ctx = OffloadActivations( + use_streams=offload_with_streams + ) + + # Below is our hack to disable offloading the last output Linear in every + # step, as the cost for offloading the activation and then soon after bringing + # it back is expensive. Moreover, due to heuristics in our streaming API, + # we actually use more memory if we offload it as it interferes with chunkedCE. + if hasattr(model, "output") and isinstance(model.output, nn.Module): + noop_ctx = NoOpManager() + model.output.register_forward_pre_hook( + lambda *args: noop_ctx.__enter__() + ) + model.output.register_forward_hook( + lambda *args: noop_ctx.__exit__(), always_call=True + ) + log.info(f"Model is initialized with precision {self._dtype}.") if self._device.type == "cuda": @@ -576,7 +616,8 @@ def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: input_pos = batch.get("input_pos", None) # shape [b, s] # run model - logits = self._model(tokens, mask=mask, input_pos=input_pos) + with self.activations_handling_ctx: + logits = self._model(tokens, mask=mask, input_pos=input_pos) # Shift labels to compute loss # equivalent to doing labels[..., 1:] and logits[..., :-1, :] diff --git a/torchtune/training/__init__.py b/torchtune/training/__init__.py index 241326cfce..7a1c4651cc 100644 --- a/torchtune/training/__init__.py +++ b/torchtune/training/__init__.py @@ -3,6 +3,7 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from torchtune.training._activation_offloading import NoOpManager, OffloadActivations from torchtune.training._compile import compile_loss, compile_model from torchtune.training._distributed import ( contains_fsdp, @@ -122,4 +123,6 @@ "setup_torch_profiler", "compile_loss", "compile_model", + "NoOpManager", + "OffloadActivations", ] diff --git a/torchtune/training/_activation_offloading.py b/torchtune/training/_activation_offloading.py new file mode 100644 index 0000000000..3ebecc96f0 --- /dev/null +++ b/torchtune/training/_activation_offloading.py @@ -0,0 +1,313 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from warnings import warn + +import psutil +import torch +from torch.autograd.graph import saved_tensors_hooks + + +class OffloadActivations(saved_tensors_hooks): + """Context manager under which activation tensors created in the forward pass will be offloaded. + + Enable the memory efficiency technique of activation offloading, where activations bigger than + min_offload_size bytes will be offloaded to CPU in the forward and brought back in the backward. + This is in contrast to maintaining the activation on GPU VRAM throughout the program. + + This manager contains the option of using one additional CUDA stream to handle the communication + between CUDA and CPU, which is intended to overlap with the default computation stream to improve + runtime. We designed synchronization with a few heuristics for optimizing the tradeoff between + runtime vs memory usage. + + Args: + use_pin_memory (bool): Whether or not the offloaded Tensor will be placed in pinned + memory on the CPU. Pinned memory allows the Tensor to be moved back onto GPU more quickly + but is a limited resource. Default: True. + + use_streams (bool): Whether or not to use streams for performance optimization where + the communications get overlapped with the computation. Requires a torch build + after torch-2.5.0.dev20240907. Default: False. + + max_fwd_stash_size (int): The maximum size of the forward stash, or the maximum number of + consecutive activations to keep alive during the forward pass. This number must be at + least 1. Keeping alive more activations will potentially allow more overlap between the + communication and compute streams at the cost of increasing memory usage. Keeping alive + fewer activations will conserve memory, but may cause poor overlap between the streams, + increasing runtime. Default: 5. + + min_offload_size (int): The minimum number of bytes a Tensor must be in order to qualify + for offloading. If the tensor is too small, we do not want to waste bandwidth and resources + moving it to CPU and back. Default: 1024 bytes. + + Raises: + ValueError: if max_fwd_stash_size is not at least 1. + + Example: + >>> with OffloadActivations(): + >>> logits = model(inputs) + >>> loss = ... + >>> loss.backward() + """ + + def __init__( + self, + use_pin_memory: bool = True, + use_streams: bool = False, + max_fwd_stash_size: int = 5, + min_offload_size: int = 1024, + ) -> None: + self.min_tensor_size_bytes = ( + min_offload_size # we don't want to bother with small tensors + ) + self.tracker = ( + {} + ) # tensor_id => (new_tensor, if_modified) ---> track what saved/offloaded tensors are where + self.tensor_id: int = 0 + self.is_first_forward_call = True + self.is_first_backward_call = True + self.is_first_forward_pass = True + + # managing cpu memory + self.use_pin_memory: bool = use_pin_memory + self.virtual_memory_safe_pct = ( + 60 # we should not exceed this percentage of memory + ) + + self.s0 = torch.cuda.default_stream() # comp stream + + # for streaming + if use_streams: + if torch.__version__ < "2.5.0.dev20240907": + raise RuntimeError( + "OffloadActivations with use_streams=True requires PyTorch 2.5.0.dev20240907 or later." + ) + self.use_streams = use_streams + self.s1 = torch.cuda.Stream() # comms stream + self.fwd_stash = {} # tensor_id => (activation, ev1) + if max_fwd_stash_size < 1: + raise ValueError( + f"max_fwd_stash_size should be at least 1 but is {max_fwd_stash_size}" + ) + self.max_fwd_stash_size = max_fwd_stash_size + self.bwd_tensor_stash = {} # tensor_id => activation + self.bwd_ev_stash = {} # tensor_id => ev0 + self.curr_graph_id = None + self.curr_autograd_node = None + + # -------- platform util functions -------- # + def verify_sufficient_virtual_memory(): + curr_pct = get_cpu_ram_pct() + if curr_pct > self.virtual_memory_safe_pct: + warn( + f"***** WARNING: {curr_pct=}% > {self.virtual_memory_safe_pct=}% of virtual memory used" + ) + + def get_cpu_ram_pct() -> float: + # get the percentage of memory used by the system + return psutil.virtual_memory().percent + + def get_tensor_id() -> int: + # create a unique id for each tensor we are managing + self.tensor_id += 1 + return self.tensor_id + + def get_num_bytes_tensor(x: torch.Tensor) -> int: + # get the number of bytes in a tensor, for memory management purposes + return ( + x.element_size() * x.nelement() + ) # x.element_size() * x._base_storage().nbytes() + + # -------- core pack / unpack work -------- # + def pack_tensor(activation: torch.Tensor) -> int: + # activations are passed in during forward pass - from here we take over and return a unique id + if self.is_first_forward_call: + assert ( + len(self.tracker) == 0 + ), "backward pass should have cleared tracker of all tensors" + + # set training phase trackers + self.is_first_forward_call = False + self.is_first_backward_call = True + + # query for basic tensor info + num_bytes = get_num_bytes_tensor(activation) + tensor_id = get_tensor_id() + + # only offload hefty bois + if num_bytes >= self.min_tensor_size_bytes: + if use_streams: + # First, sync back and dereference previously offloaded tensors + # as the offloading should be done sufficiently long ago. + for id in [k for k in self.fwd_stash.keys()]: + if id <= tensor_id - self.max_fwd_stash_size: + _, ev = self.fwd_stash[id] + self.s0.wait_event(ev) + del self.fwd_stash[id] + else: + break + + # Sync in, offload, and add an event to sync back later + self.s1.wait_stream(self.s0) + + stream = self.s1 if use_streams else self.s0 + with torch.cuda.stream(stream): + cpu_tensor = torch.empty_like( + activation, + pin_memory=self.use_pin_memory, + device=torch.device("cpu"), + ) + + cpu_tensor.copy_(activation, non_blocking=True) + self.tracker[tensor_id] = ( + cpu_tensor, + True, + ) # True = (in future) modified + + if use_streams: + event = self.s1.record_event() + + # Stash to keep activation alive til s1 is done + self.fwd_stash[tensor_id] = (activation, event) + else: + self.tracker[tensor_id] = ( + activation, + False, + ) # False = not modified, tensor is as is + + return tensor_id + + def unpack_tensor_single_stream(unpack_tensor_id: int) -> torch.Tensor: + # backward pass - we are called with the tensor_id, which + # we will use to retrieve the saved/offloaded tensor + if self.is_first_backward_call: + if self.is_first_forward_pass: + self.is_first_forward_pass = False + if self.use_pin_memory: + verify_sufficient_virtual_memory() + + self.is_first_backward_call = False + self.is_first_forward_call = True + + assert ( + unpack_tensor_id in self.tracker + ), f"untracked tensor with id {unpack_tensor_id}" + + maybe_gpu_tensor, modified = self.tracker[unpack_tensor_id] + if modified: + gpu_tensor = maybe_gpu_tensor.to(device="cuda", non_blocking=True) + maybe_gpu_tensor = gpu_tensor + + # clear tensor from tracking + del self.tracker[unpack_tensor_id] + return maybe_gpu_tensor + + def unpack_tensor_with_streams(unpack_tensor_id: int) -> torch.Tensor: + # backward pass - we are called with the tensor_id, which + # we will use to retrieve the saved/offloaded tensor + if self.is_first_backward_call: + self.curr_graph_id = torch._C._current_graph_task_id() + + def wait_and_del_remaining_references() -> None: + for id in [k for k in self.bwd_tensor_stash.keys()]: + event = self.bwd_ev_stash[id] + self.s1.wait_event(event) + del self.bwd_tensor_stash[id] + + # Register a callback to the end of autograd to clean everything up + torch.autograd.variable.Variable._execution_engine.queue_callback( + wait_and_del_remaining_references + ) + + if self.is_first_forward_pass: + self.is_first_forward_pass = False + if self.use_pin_memory: + verify_sufficient_virtual_memory() + + self.is_first_backward_call = False + self.is_first_forward_call = True + + assert ( + unpack_tensor_id in self.tracker + ), f"untracked tensor with id {unpack_tensor_id}" + + maybe_gpu_tensor, modified = self.tracker[unpack_tensor_id] + if modified: + # Get data on the current autograd node + graph_id = torch._C._current_graph_task_id() + node = torch._C._current_autograd_node() + prev_node_ids = [] + + # If we're on a new node, mark prev node's tensors to be freed later + if graph_id == self.curr_graph_id and self.curr_autograd_node != node: + self.curr_autograd_node = node + prev_node_ids = [id for id in self.bwd_tensor_stash.keys()] + + brought_back_from_cpu = True + if unpack_tensor_id in self.fwd_stash: + maybe_gpu_tensor = self.fwd_stash[unpack_tensor_id][0] + brought_back_from_cpu = False + else: + # Kick off the process to bring tensors back + with torch.cuda.stream(self.s1): + gpu_tensor = maybe_gpu_tensor.to( + device="cuda", non_blocking=True + ) + maybe_gpu_tensor = gpu_tensor + + # Tell comp stream to wait for the info to be loaded before executing + self.s0.wait_stream(self.s1) + + # Stash the tensor to keep memory alive until compute stream is complete + self.bwd_tensor_stash[unpack_tensor_id] = maybe_gpu_tensor + + def hook(outputs, inputs): + # create events for the current node inputs/outputs if they were streamed in + if brought_back_from_cpu: + event = self.s0.record_event() + self.bwd_ev_stash[unpack_tensor_id] = event + + # if there are still things in the fwd_stash, get rid of them as we're in bwd now + for id in [k for k in self.fwd_stash.keys()]: + _, ev = self.fwd_stash[id] + self.s0.wait_event(ev) + del self.fwd_stash[id] + + # wait on prev node's events and del those + for id in prev_node_ids: + event = self.bwd_ev_stash[id] + self.s1.wait_event(event) + del self.bwd_tensor_stash[id] + + return outputs + + node.register_hook(hook) + + # clear tensor from tracking + del self.tracker[unpack_tensor_id] + return maybe_gpu_tensor + + unpack_tensor = ( + unpack_tensor_with_streams if use_streams else unpack_tensor_single_stream + ) + super().__init__(pack_tensor, unpack_tensor) + + +class NoOpManager(saved_tensors_hooks): + """ + A saved_tensors_hook manager used to disable any other saved_tensors_hook manager + applied before. This relies on the behavior that only the most recently registered + saved_tensors_hook will run. + + One example usage is to opt a local region of code out of activations offloading, + which is usually applied globally to best track state. + """ + + def __init__(self) -> None: + def noop(tensor): + return tensor + + super().__init__(noop, noop)