From 40584c8bdba8dac43fa67746869293f34354fabd Mon Sep 17 00:00:00 2001 From: Shuqi Yang Date: Wed, 2 Oct 2024 00:07:21 -0700 Subject: [PATCH] Use checkpoint to enforece the recomputation of fp8 weight (#936) Summary: Pull Request resolved: https://github.com/pytorch/ao/pull/936 The issue: When using float8 training with FSDP, we have these tensors in the forward_backward graph: - Without fp8-all-gather: original_weight (all-gather output, sharded) - fp8_weight - fp8_weight_transpose (needed in backward) - With fp8-all-gather: original_weight (sharded) - fp8_weight (all-gather output, sharded) - fp8_weight_transpose (needed in backward) `torch.compile` decides how to partition the graph and which tensors to save for backward. In both the case of with and without fp8-all-gather, it decides to save "fp8_weight_transpose" for backward. It's good in single GPU case, and compute both fp8_weight and fp_weight_transpose in forawrd can be fused into one kernel. However, if we use FSDP to shard the weights, although the weight itself is sharded, the "fp8_weight_transpose" tensors are not. Saving it for backward costs a high memory utilization. ---- To fix it, we have different options: - In the user code, enforce which tensors to save for backward - The `save_for_backward` in custom autograd.Function is one way to specify which tensors to save. However, torch.compile will ignore what are manually saved for backward in a custom autograd.Function, and just run the partitioner. - **[This pr]** Using "torch.utils.checkpoint", which is the API that compile does promise to respect today. It would instruct compile to only save its inputs for backward (the weight and activation), and not the intermediate values from the float8 cast. - Rely on torch.compile to find the best partition that optimizes both computation and memory. It may be a very longer-term solution to fix in compile. Reviewed By: vkuzo Differential Revision: D63345959 --- torchao/float8/README.md | 2 + torchao/float8/config.py | 21 ++++- torchao/float8/float8_linear.py | 131 +++++++++++++++++++------------- 3 files changed, 100 insertions(+), 54 deletions(-) diff --git a/torchao/float8/README.md b/torchao/float8/README.md index 57bb7c77f..b9b40d7e4 100644 --- a/torchao/float8/README.md +++ b/torchao/float8/README.md @@ -116,6 +116,8 @@ We compose with the `DTensor` based [distributed APIs](https://pytorch.org/docs/ such as FSDP, TP and SP. Please see the [torchtitan](https://github.com/pytorch/torchtitan) repository for e2e examples on using `torchao.float8` in a distributed setting. +:warning: When using FSDP, it's recommended to enable `config.force_recompute_fp8_weight_in_bwd`, which prevents the un-sharded fp8 weights to be saved for backward. If you are using customized activation checkpoiting, you may ignore this config and handle the recomputation of fp8 weights in the customized AC code. + # Performance A common question about float8 training is "when is float8 linear faster vs bfloat16?". Given the M, K, N of the forward pass through your linear, you can reference the table below for a microbenchmark based speedup estimate on NVIDIA H100: diff --git a/torchao/float8/config.py b/torchao/float8/config.py index eb28dcbd8..0fa25b9bb 100644 --- a/torchao/float8/config.py +++ b/torchao/float8/config.py @@ -37,8 +37,10 @@ class CastConfig: def __post_init__(self): if self.scaling_type is ScalingType.STATIC: - assert self.static_scale is not None, \ - "static_scale must be specified for static scaling" + assert ( + self.static_scale is not None + ), "static_scale must be specified for static scaling" + @dataclass(frozen=True) class DelayedScalingConfig: @@ -132,6 +134,21 @@ class Float8LinearConfig: # configuration, this field may move to per-tensor configs. delayed_scaling_config: DelayedScalingConfig = DelayedScalingConfig() + # If the option is enabled, fp8_weight will always be re-computed in backward. + # It's recommended to enable this flag when using FSDP. + # Otherwise, the entire fp8_weight, instead of the sharded weight may be saved. + # If using outer activation checkpointing context or SAC, you may disable this option + # and handle the recomputation of fp8 weight in your customized AC context. + # + # Details: + # When using float8 training with FSDP, the original weight is sharded; fp8_weight (in forward) and fp8_weight_transpose (in backward) are used by the model. + # However, when partitioning the forward_backward graph, torch.compile may decide to + # save the fp8_weight_transpose for backward, which is an un-sahrded weight and costs a high memory utilization. + # The longer-term solution is to let compile decide how to partition the graph with optimal computation and memory savings. + # For now, we use the checkpointing api to force the recomputation of fp8 weight in backward. + + force_recompute_fp8_weight_in_bwd: bool = False + # If True, use 'fnuz' float8 types for calculations. # Currently, ROCm only supports fnuz variants. diff --git a/torchao/float8/float8_linear.py b/torchao/float8/float8_linear.py index cb0ff7afb..dd9255625 100644 --- a/torchao/float8/float8_linear.py +++ b/torchao/float8/float8_linear.py @@ -9,11 +9,14 @@ import dataclasses import enum +import logging from typing import Optional import torch +import torch.utils.checkpoint as checkpoint + from torchao.float8.config import Float8LinearConfig, ScalingType from torchao.float8.float8_scaling_utils import ( @@ -29,11 +32,17 @@ from torchao.float8.float8_tensor import ( Float8Tensor, GemmInputRole, + hp_tensor_and_scale_to_float8, LinearMMConfig, ScaledMMConfig, ) -from torchao.float8.float8_utils import e4m3_dtype, e5m2_dtype, tensor_to_amax +from torchao.float8.float8_utils import ( + e4m3_dtype, + e5m2_dtype, + tensor_to_amax, + tensor_to_scale, +) from torchao.float8.fsdp_utils import ( WeightWithDelayedFloat8CastTensor, @@ -41,6 +50,8 @@ WeightWithStaticFloat8CastTensor, ) +logger = logging.getLogger(__name__) + # this code was resurrected from https://github.com/pytorch-labs/torchao.float8/pull/128/files @torch._dynamo.allow_in_graph @@ -180,6 +191,15 @@ def __init__(self, *args, **kwargs): # would be initialized in every iteration. self.enable_pre_and_post_forward = self.config.enable_pre_and_post_forward + # See the comments in config.py for more details of this option. + if ( + self.config.enable_pre_and_post_forward + and not self.config.force_recompute_fp8_weight_in_bwd + ): + logger.warning( + "When using FSDP, it's recommended to enable config.force_recompute_fp8_weight_in_bwd." + ) + def create_buffers(self): # Default values for history buffers, see above TODO history_len = self.config.delayed_scaling_config.history_len @@ -226,17 +246,17 @@ def create_buffers(self): if self.config.cast_config_input.static_scale is not None: self.register_always_float32_buffer( - "fp8_static_scale_input", + "fp8_static_scale_input", self.config.cast_config_input.static_scale.to(device), ) if self.config.cast_config_weight.static_scale is not None: self.register_always_float32_buffer( - "fp8_static_scale_weight", + "fp8_static_scale_weight", self.config.cast_config_weight.static_scale.to(device), ) if self.config.cast_config_grad_output.static_scale is not None: self.register_always_float32_buffer( - "fp8_static_scale_grad_output", + "fp8_static_scale_grad_output", self.config.cast_config_grad_output.static_scale.to(device), ) @@ -296,56 +316,48 @@ def cast_input_to_float8( input_fp8 = hp_tensor_to_float8_static( input, self.fp8_static_scale_input, e4m3_dtype, self.linear_mm_config ) - + return input_fp8 - def cast_weight_to_float8( - self, weight: torch.Tensor, is_amax_initialized: bool - ) -> torch.Tensor: + def get_weight_scale(self, weight: torch.Tensor) -> Optional[torch.Tensor]: + if isinstance(weight, Float8Tensor): + return None if self.scaling_type_weight is ScalingType.DELAYED: - if isinstance(self.weight, Float8Tensor): # cast by FSDP - weight_fp8 = self.weight - else: - scale_fn_name = self.config.delayed_scaling_config.scale_fn_name - _maybe_initialize_amaxes_scales_for_float8_cast( - weight, - self.fp8_amax_weight, - self.fp8_amax_history_weight, - self.fp8_scale_weight, - scale_fn_name, - e4m3_dtype, - is_amax_initialized, - reduce_amax=False, - ) - - weight_fp8 = hp_tensor_to_float8_delayed( - weight, - self.fp8_scale_weight, - e4m3_dtype, - self.fp8_amax_weight, - linear_mm_config=self.linear_mm_config, - gemm_input_role=GemmInputRole.WEIGHT, - ) + scale_fn_name = self.config.delayed_scaling_config.scale_fn_name + _maybe_initialize_amaxes_scales_for_float8_cast( + weight, + self.fp8_amax_weight, + self.fp8_amax_history_weight, + self.fp8_scale_weight, + scale_fn_name, + e4m3_dtype, + self.is_amax_initialized, + reduce_amax=True, + ) + self.fp8_amax_weight.fill_(tensor_to_amax(weight)) + return self.fp8_scale_weight elif self.scaling_type_weight is ScalingType.DYNAMIC: - if isinstance(self.weight, Float8Tensor): # cast by FSDP - weight_fp8 = self.weight - else: - weight_fp8 = hp_tensor_to_float8_dynamic( - self.weight, - e4m3_dtype, - self.linear_mm_config, - gemm_input_role=GemmInputRole.WEIGHT, - ) + return tensor_to_scale(weight, e4m3_dtype) else: assert self.scaling_type_weight is ScalingType.STATIC - weight_fp8 = hp_tensor_to_float8_static( - self.weight, - self.fp8_static_scale_weight, - e4m3_dtype, - self.linear_mm_config, - gemm_input_role=GemmInputRole.WEIGHT, - ) - return weight_fp8 + return self.fp8_static_scale_weight + + def cast_weight_to_float8_t( + self, + weight: torch.Tensor, + is_amax_initialized: bool, + weight_scale: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if isinstance(weight, Float8Tensor): + return weight.t() + weight_fp8 = hp_tensor_and_scale_to_float8( + weight, + weight_scale, + e4m3_dtype, + self.linear_mm_config, + gemm_input_role=GemmInputRole.WEIGHT, + ) + return weight_fp8.t() def cast_output_to_float8_in_bw(self, output: torch.Tensor) -> torch.Tensor: if self.scaling_type_grad_output is ScalingType.DELAYED: @@ -364,8 +376,8 @@ def cast_output_to_float8_in_bw(self, output: torch.Tensor) -> torch.Tensor: else: assert self.scaling_type_grad_output is ScalingType.STATIC output = NoopFwToFloat8E5M2BwStatic.apply( - output, - self.fp8_static_scale_grad_output, + output, + self.fp8_static_scale_grad_output, self.linear_mm_config, ) return output @@ -396,9 +408,24 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: self.float8_pre_forward(input) input_fp8 = self.cast_input_to_float8(input, self.is_amax_initialized) - weight_fp8 = self.cast_weight_to_float8(self.weight, self.is_amax_initialized) - output = manual_float8_matmul.apply(input_fp8, weight_fp8.t()) + # If force_recompute_fp8_weight_in_bwd, we only recompute the fp8 weight, + # weight_scale should be saved. + weight_scale = self.get_weight_scale(self.weight) + + if self.config.force_recompute_fp8_weight_in_bwd: + weight_fp8_t = checkpoint.checkpoint( + self.cast_weight_to_float8_t, + self.weight, + self.is_amax_initialized, + weight_scale, + ) + else: + weight_fp8_t = self.cast_weight_to_float8_t( + self.weight, self.is_amax_initialized, weight_scale + ) + + output = manual_float8_matmul.apply(input_fp8, weight_fp8_t) # Cast grad_output to float8_e5m2 during backward output = self.cast_output_to_float8_in_bw(output)