Skip to content

Commit

Permalink
Use checkpoint to enforece the recomputation of fp8 weight (#936)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #936

The issue:
When using float8 training with FSDP, we have these tensors in the forward_backward graph:
- Without fp8-all-gather:
original_weight (all-gather output, sharded) - fp8_weight - fp8_weight_transpose (needed in backward)
- With fp8-all-gather:
original_weight (sharded) - fp8_weight (all-gather output, sharded) - fp8_weight_transpose (needed in backward)

`torch.compile` decides how to partition the graph and which tensors to save for backward. In both the case of with and without fp8-all-gather, it decides to save "fp8_weight_transpose" for backward. It's good in single GPU case, and compute both fp8_weight and fp_weight_transpose in forawrd can be fused into one kernel. However, if we use FSDP to shard the weights, although the weight itself is sharded, the "fp8_weight_transpose" tensors are not. Saving it for backward costs a high memory utilization.

----
To fix it, we have different options:
- In the user code, enforce which tensors to save for backward
  - The `save_for_backward` in custom autograd.Function is one way to specify which tensors to save. However, torch.compile will ignore what are manually saved for backward in a custom autograd.Function, and just run the partitioner.
  - **[This pr]** Using "torch.utils.checkpoint", which is the API that compile does promise to respect today.  It would instruct compile to only save its inputs for backward (the weight and activation), and not the intermediate values from the float8 cast.
- Rely on torch.compile to find the best partition that optimizes both computation and memory. It may be a very longer-term solution to fix in compile.

Differential Revision: D63345959
  • Loading branch information
y-sq authored and facebook-github-bot committed Sep 25, 2024
1 parent fbe97a0 commit 0c14069
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 52 deletions.
6 changes: 6 additions & 0 deletions torchao/float8/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,12 @@ class Float8LinearConfig:
# configuration, this field may move to per-tensor configs.
delayed_scaling_config: DelayedScalingConfig = DelayedScalingConfig()

# If True, fp8_weight will always be re-computed in backward.
# If False, fp8_weight from forward may be saved for backward.
# It's recommended to enable this flag when using FSDP.
# Otherwise, the entire fp8_weight, instead of the sharded weight may be saved.
force_recompute_fp8_weight_in_bwd: bool = False


# If True, use 'fnuz' float8 types for calculations.
# Currently, ROCm only supports fnuz variants.
Expand Down
123 changes: 71 additions & 52 deletions torchao/float8/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

import torch

import torch.utils.checkpoint as checkpoint

from torchao.float8.config import Float8LinearConfig, ScalingType

from torchao.float8.float8_scaling_utils import (
Expand All @@ -29,11 +31,17 @@
from torchao.float8.float8_tensor import (
Float8Tensor,
GemmInputRole,
hp_tensor_and_scale_to_float8,
LinearMMConfig,
ScaledMMConfig,
)

from torchao.float8.float8_utils import e4m3_dtype, e5m2_dtype, tensor_to_amax
from torchao.float8.float8_utils import (
e4m3_dtype,
e5m2_dtype,
tensor_to_amax,
tensor_to_scale,
)

from torchao.float8.fsdp_utils import (
WeightWithDelayedFloat8CastTensor,
Expand Down Expand Up @@ -180,6 +188,10 @@ def __init__(self, *args, **kwargs):
# would be initialized in every iteration.
self.enable_pre_and_post_forward = self.config.enable_pre_and_post_forward

self.force_recompute_fp8_weight_in_bwd = (
self.config.force_recompute_fp8_weight_in_bwd
)

def create_buffers(self):
# Default values for history buffers, see above TODO
history_len = self.config.delayed_scaling_config.history_len
Expand Down Expand Up @@ -226,17 +238,17 @@ def create_buffers(self):

if self.config.cast_config_input.static_scale is not None:
self.register_always_float32_buffer(
"fp8_static_scale_input",
"fp8_static_scale_input",
self.config.cast_config_input.static_scale.to(device),
)
if self.config.cast_config_weight.static_scale is not None:
self.register_always_float32_buffer(
"fp8_static_scale_weight",
"fp8_static_scale_weight",
self.config.cast_config_weight.static_scale.to(device),
)
if self.config.cast_config_grad_output.static_scale is not None:
self.register_always_float32_buffer(
"fp8_static_scale_grad_output",
"fp8_static_scale_grad_output",
self.config.cast_config_grad_output.static_scale.to(device),
)

Expand Down Expand Up @@ -296,56 +308,48 @@ def cast_input_to_float8(
input_fp8 = hp_tensor_to_float8_static(
input, self.fp8_static_scale_input, e4m3_dtype, self.linear_mm_config
)

return input_fp8

def cast_weight_to_float8(
self, weight: torch.Tensor, is_amax_initialized: bool
) -> torch.Tensor:
def get_weight_scale(self, weight: torch.Tensor) -> Optional[torch.Tensor]:
if isinstance(weight, Float8Tensor):
return None
if self.scaling_type_weight is ScalingType.DELAYED:
if isinstance(self.weight, Float8Tensor): # cast by FSDP
weight_fp8 = self.weight
else:
scale_fn_name = self.config.delayed_scaling_config.scale_fn_name
_maybe_initialize_amaxes_scales_for_float8_cast(
weight,
self.fp8_amax_weight,
self.fp8_amax_history_weight,
self.fp8_scale_weight,
scale_fn_name,
e4m3_dtype,
is_amax_initialized,
reduce_amax=False,
)

weight_fp8 = hp_tensor_to_float8_delayed(
weight,
self.fp8_scale_weight,
e4m3_dtype,
self.fp8_amax_weight,
linear_mm_config=self.linear_mm_config,
gemm_input_role=GemmInputRole.WEIGHT,
)
scale_fn_name = self.config.delayed_scaling_config.scale_fn_name
_maybe_initialize_amaxes_scales_for_float8_cast(
weight,
self.fp8_amax_weight,
self.fp8_amax_history_weight,
self.fp8_scale_weight,
scale_fn_name,
e4m3_dtype,
self.is_amax_initialized,
reduce_amax=True,
)
self.fp8_amax_weight.fill_(tensor_to_amax(weight))
return self.fp8_scale_weight
elif self.scaling_type_weight is ScalingType.DYNAMIC:
if isinstance(self.weight, Float8Tensor): # cast by FSDP
weight_fp8 = self.weight
else:
weight_fp8 = hp_tensor_to_float8_dynamic(
self.weight,
e4m3_dtype,
self.linear_mm_config,
gemm_input_role=GemmInputRole.WEIGHT,
)
return tensor_to_scale(weight, e4m3_dtype)
else:
assert self.scaling_type_weight is ScalingType.STATIC
weight_fp8 = hp_tensor_to_float8_static(
self.weight,
self.fp8_static_scale_weight,
e4m3_dtype,
self.linear_mm_config,
gemm_input_role=GemmInputRole.WEIGHT,
)
return weight_fp8
return self.fp8_static_scale_weight

def cast_weight_to_float8_t(
self,
weight: torch.Tensor,
is_amax_initialized: bool,
weight_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if isinstance(weight, Float8Tensor):
return weight.t()
weight_fp8 = hp_tensor_and_scale_to_float8(
weight,
weight_scale,
e4m3_dtype,
self.linear_mm_config,
gemm_input_role=GemmInputRole.WEIGHT,
)
return weight_fp8.t()

def cast_output_to_float8_in_bw(self, output: torch.Tensor) -> torch.Tensor:
if self.scaling_type_grad_output is ScalingType.DELAYED:
Expand All @@ -364,8 +368,8 @@ def cast_output_to_float8_in_bw(self, output: torch.Tensor) -> torch.Tensor:
else:
assert self.scaling_type_grad_output is ScalingType.STATIC
output = NoopFwToFloat8E5M2BwStatic.apply(
output,
self.fp8_static_scale_grad_output,
output,
self.fp8_static_scale_grad_output,
self.linear_mm_config,
)
return output
Expand Down Expand Up @@ -396,9 +400,24 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
self.float8_pre_forward(input)

input_fp8 = self.cast_input_to_float8(input, self.is_amax_initialized)
weight_fp8 = self.cast_weight_to_float8(self.weight, self.is_amax_initialized)

output = manual_float8_matmul.apply(input_fp8, weight_fp8.t())
# If force_recompute_fp8_weight_in_bwd, we only recompute the fp8 weight,
# weight_scale should be saved.
weight_scale = self.get_weight_scale(self.weight)

if self.force_recompute_fp8_weight_in_bwd:
weight_fp8_t = checkpoint.checkpoint(
self.cast_weight_to_float8_t,
self.weight,
self.is_amax_initialized,
weight_scale,
)
else:
weight_fp8_t = self.cast_weight_to_float8_t(
self.weight, self.is_amax_initialized, weight_scale
)

output = manual_float8_matmul.apply(input_fp8, weight_fp8_t)

# Cast grad_output to float8_e5m2 during backward
output = self.cast_output_to_float8_in_bw(output)
Expand Down

0 comments on commit 0c14069

Please sign in to comment.