Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use checkpoint to enforece the recomputation of fp8 weight #936

Merged
merged 1 commit into from
Oct 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions torchao/float8/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ We compose with the `DTensor` based [distributed APIs](https://pytorch.org/docs/
such as FSDP, TP and SP. Please see the [torchtitan](https://github.com/pytorch/torchtitan) repository for e2e examples
on using `torchao.float8` in a distributed setting.

:warning: <em>When using FSDP, it's recommended to enable `config.force_recompute_fp8_weight_in_bwd`, which prevents the un-sharded fp8 weights to be saved for backward. If you are using customized activation checkpoiting, you may ignore this config and handle the recomputation of fp8 weights in the customized AC code. </em>

# Performance

A common question about float8 training is "when is float8 linear faster vs bfloat16?". Given the M, K, N of the forward pass through your linear, you can reference the table below for a microbenchmark based speedup estimate on NVIDIA H100:
Expand Down
21 changes: 19 additions & 2 deletions torchao/float8/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,10 @@ class CastConfig:

def __post_init__(self):
if self.scaling_type is ScalingType.STATIC:
assert self.static_scale is not None, \
"static_scale must be specified for static scaling"
assert (
self.static_scale is not None
), "static_scale must be specified for static scaling"


@dataclass(frozen=True)
class DelayedScalingConfig:
Expand Down Expand Up @@ -132,6 +134,21 @@ class Float8LinearConfig:
# configuration, this field may move to per-tensor configs.
delayed_scaling_config: DelayedScalingConfig = DelayedScalingConfig()

# If the option is enabled, fp8_weight will always be re-computed in backward.
# It's recommended to enable this flag when using FSDP.
# Otherwise, the entire fp8_weight, instead of the sharded weight may be saved.
# If using outer activation checkpointing context or SAC, you may disable this option
# and handle the recomputation of fp8 weight in your customized AC context.
#
# Details:
# When using float8 training with FSDP, the original weight is sharded; fp8_weight (in forward) and fp8_weight_transpose (in backward) are used by the model.
# However, when partitioning the forward_backward graph, torch.compile may decide to
# save the fp8_weight_transpose for backward, which is an un-sahrded weight and costs a high memory utilization.
# The longer-term solution is to let compile decide how to partition the graph with optimal computation and memory savings.
# For now, we use the checkpointing api to force the recomputation of fp8 weight in backward.

force_recompute_fp8_weight_in_bwd: bool = False


# If True, use 'fnuz' float8 types for calculations.
# Currently, ROCm only supports fnuz variants.
Expand Down
131 changes: 79 additions & 52 deletions torchao/float8/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,14 @@

import dataclasses
import enum
import logging

from typing import Optional

import torch

import torch.utils.checkpoint as checkpoint

from torchao.float8.config import Float8LinearConfig, ScalingType

from torchao.float8.float8_scaling_utils import (
Expand All @@ -29,18 +32,26 @@
from torchao.float8.float8_tensor import (
Float8Tensor,
GemmInputRole,
hp_tensor_and_scale_to_float8,
LinearMMConfig,
ScaledMMConfig,
)

from torchao.float8.float8_utils import e4m3_dtype, e5m2_dtype, tensor_to_amax
from torchao.float8.float8_utils import (
e4m3_dtype,
e5m2_dtype,
tensor_to_amax,
tensor_to_scale,
)

from torchao.float8.fsdp_utils import (
WeightWithDelayedFloat8CastTensor,
WeightWithDynamicFloat8CastTensor,
WeightWithStaticFloat8CastTensor,
)

logger = logging.getLogger(__name__)


# this code was resurrected from https://github.com/pytorch-labs/torchao.float8/pull/128/files
@torch._dynamo.allow_in_graph
Expand Down Expand Up @@ -180,6 +191,15 @@ def __init__(self, *args, **kwargs):
# would be initialized in every iteration.
self.enable_pre_and_post_forward = self.config.enable_pre_and_post_forward

# See the comments in config.py for more details of this option.
if (
self.config.enable_pre_and_post_forward
and not self.config.force_recompute_fp8_weight_in_bwd
):
logger.warning(
"When using FSDP, it's recommended to enable config.force_recompute_fp8_weight_in_bwd."
)

def create_buffers(self):
# Default values for history buffers, see above TODO
history_len = self.config.delayed_scaling_config.history_len
Expand Down Expand Up @@ -226,17 +246,17 @@ def create_buffers(self):

if self.config.cast_config_input.static_scale is not None:
self.register_always_float32_buffer(
"fp8_static_scale_input",
"fp8_static_scale_input",
self.config.cast_config_input.static_scale.to(device),
)
if self.config.cast_config_weight.static_scale is not None:
self.register_always_float32_buffer(
"fp8_static_scale_weight",
"fp8_static_scale_weight",
self.config.cast_config_weight.static_scale.to(device),
)
if self.config.cast_config_grad_output.static_scale is not None:
self.register_always_float32_buffer(
"fp8_static_scale_grad_output",
"fp8_static_scale_grad_output",
self.config.cast_config_grad_output.static_scale.to(device),
)

Expand Down Expand Up @@ -296,56 +316,48 @@ def cast_input_to_float8(
input_fp8 = hp_tensor_to_float8_static(
input, self.fp8_static_scale_input, e4m3_dtype, self.linear_mm_config
)

return input_fp8

def cast_weight_to_float8(
self, weight: torch.Tensor, is_amax_initialized: bool
) -> torch.Tensor:
def get_weight_scale(self, weight: torch.Tensor) -> Optional[torch.Tensor]:
if isinstance(weight, Float8Tensor):
return None
if self.scaling_type_weight is ScalingType.DELAYED:
if isinstance(self.weight, Float8Tensor): # cast by FSDP
weight_fp8 = self.weight
else:
scale_fn_name = self.config.delayed_scaling_config.scale_fn_name
_maybe_initialize_amaxes_scales_for_float8_cast(
weight,
self.fp8_amax_weight,
self.fp8_amax_history_weight,
self.fp8_scale_weight,
scale_fn_name,
e4m3_dtype,
is_amax_initialized,
reduce_amax=False,
)

weight_fp8 = hp_tensor_to_float8_delayed(
weight,
self.fp8_scale_weight,
e4m3_dtype,
self.fp8_amax_weight,
linear_mm_config=self.linear_mm_config,
gemm_input_role=GemmInputRole.WEIGHT,
)
scale_fn_name = self.config.delayed_scaling_config.scale_fn_name
_maybe_initialize_amaxes_scales_for_float8_cast(
weight,
self.fp8_amax_weight,
self.fp8_amax_history_weight,
self.fp8_scale_weight,
scale_fn_name,
e4m3_dtype,
self.is_amax_initialized,
reduce_amax=True,
)
self.fp8_amax_weight.fill_(tensor_to_amax(weight))
return self.fp8_scale_weight
elif self.scaling_type_weight is ScalingType.DYNAMIC:
if isinstance(self.weight, Float8Tensor): # cast by FSDP
weight_fp8 = self.weight
else:
weight_fp8 = hp_tensor_to_float8_dynamic(
self.weight,
e4m3_dtype,
self.linear_mm_config,
gemm_input_role=GemmInputRole.WEIGHT,
)
return tensor_to_scale(weight, e4m3_dtype)
else:
assert self.scaling_type_weight is ScalingType.STATIC
weight_fp8 = hp_tensor_to_float8_static(
self.weight,
self.fp8_static_scale_weight,
e4m3_dtype,
self.linear_mm_config,
gemm_input_role=GemmInputRole.WEIGHT,
)
return weight_fp8
return self.fp8_static_scale_weight

def cast_weight_to_float8_t(
self,
weight: torch.Tensor,
is_amax_initialized: bool,
weight_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if isinstance(weight, Float8Tensor):
return weight.t()
weight_fp8 = hp_tensor_and_scale_to_float8(
weight,
weight_scale,
e4m3_dtype,
self.linear_mm_config,
gemm_input_role=GemmInputRole.WEIGHT,
)
return weight_fp8.t()

def cast_output_to_float8_in_bw(self, output: torch.Tensor) -> torch.Tensor:
if self.scaling_type_grad_output is ScalingType.DELAYED:
Expand All @@ -364,8 +376,8 @@ def cast_output_to_float8_in_bw(self, output: torch.Tensor) -> torch.Tensor:
else:
assert self.scaling_type_grad_output is ScalingType.STATIC
output = NoopFwToFloat8E5M2BwStatic.apply(
output,
self.fp8_static_scale_grad_output,
output,
self.fp8_static_scale_grad_output,
self.linear_mm_config,
)
return output
Expand Down Expand Up @@ -396,9 +408,24 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
self.float8_pre_forward(input)

input_fp8 = self.cast_input_to_float8(input, self.is_amax_initialized)
weight_fp8 = self.cast_weight_to_float8(self.weight, self.is_amax_initialized)

output = manual_float8_matmul.apply(input_fp8, weight_fp8.t())
# If force_recompute_fp8_weight_in_bwd, we only recompute the fp8 weight,
# weight_scale should be saved.
weight_scale = self.get_weight_scale(self.weight)

if self.config.force_recompute_fp8_weight_in_bwd:
weight_fp8_t = checkpoint.checkpoint(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this force_recompute_fp8_weight_in_bwd flag help ensure that the checkpoint is only done in compile?

In eager if someone has an outer AC context, the semantics are to do recursive checkpointing, i.e., a single tensor is computed multiple times. Nesting AC within SAC also hasn't been tested.

In the case of compile, the behavior is not really defined, but it might be possible that it would behave as you want - e.g. nested SAC policies where the inner policy overrode the outer policy, but I have not tested it.

But also we should add a note here that even though this might work today, in the future we'd want to replace it with some nicer API.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the comment.
The most use case of float8 is with compile, because the eager performance is very bad.
Maybe I can add a warning that if "force_recompute_fp8_weight_in_bwd" is enabled, it's recommended to rely on torch.compile for activation checkpointing. Otherwise, if the users want to use customized AC, they should be sure to handle the checkpointing of weights themselves?
And I'll also add the note of "that even though this might work today, in the future we'd want to replace it with some nicer API." Shall we also open an issue to track the longer-term solution?

self.cast_weight_to_float8_t,
self.weight,
self.is_amax_initialized,
weight_scale,
)
else:
weight_fp8_t = self.cast_weight_to_float8_t(
self.weight, self.is_amax_initialized, weight_scale
)

output = manual_float8_matmul.apply(input_fp8, weight_fp8_t)

# Cast grad_output to float8_e5m2 during backward
output = self.cast_output_to_float8_in_bw(output)
Expand Down
Loading