Skip to content

Commit 135e81b

Browse files
y-sqfacebook-github-bot
authored andcommitted
Use checkpoint to enforece the recomputation of fp8 weight (#936)
Summary: Pull Request resolved: #936 The issue: When using float8 training with FSDP, we have these tensors in the forward_backward graph: - Without fp8-all-gather: original_weight (all-gather output, sharded) - fp8_weight - fp8_weight_transpose (needed in backward) - With fp8-all-gather: original_weight (sharded) - fp8_weight (all-gather output, sharded) - fp8_weight_transpose (needed in backward) `torch.compile` decides how to partition the graph and which tensors to save for backward. In both the case of with and without fp8-all-gather, it decides to save "fp8_weight_transpose" for backward. It's good in single GPU case, and compute both fp8_weight and fp_weight_transpose in forawrd can be fused into one kernel. However, if we use FSDP to shard the weights, although the weight itself is sharded, the "fp8_weight_transpose" tensors are not. Saving it for backward costs a high memory utilization. ---- To fix it, we have different options: - In the user code, enforce which tensors to save for backward - The `save_for_backward` in custom autograd.Function is one way to specify which tensors to save. However, torch.compile will ignore what are manually saved for backward in a custom autograd.Function, and just run the partitioner. - **[This pr]** Using "torch.utils.checkpoint", which is the API that compile does promise to respect today. It would instruct compile to only save its inputs for backward (the weight and activation), and not the intermediate values from the float8 cast. - Rely on torch.compile to find the best partition that optimizes both computation and memory. It may be a very longer-term solution to fix in compile. Differential Revision: D63345959
1 parent 2dea315 commit 135e81b

File tree

2 files changed

+67
-47
lines changed

2 files changed

+67
-47
lines changed

torchao/float8/config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,12 @@ class Float8LinearConfig:
132132
# configuration, this field may move to per-tensor configs.
133133
delayed_scaling_config: DelayedScalingConfig = DelayedScalingConfig()
134134

135+
# If True, fp8_weight will always be re-computed in backward.
136+
# If False, fp8_weight from forward may be saved for backward.
137+
# It's recommended to enable this flag when using FSDP.
138+
# Otherwise, the entire fp8_weight, instead of the sharded weight may be saved.
139+
force_recompute_fp8_weight_in_bwd: bool = False
140+
135141

136142
# If True, use 'fnuz' float8 types for calculations.
137143
# Currently, ROCm only supports fnuz variants.

torchao/float8/float8_linear.py

Lines changed: 61 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
import torch
1616

17+
import torch.utils.checkpoint as checkpoint
18+
1719
from torchao.float8.config import Float8LinearConfig, ScalingType
1820

1921
from torchao.float8.float8_scaling_utils import (
@@ -29,11 +31,17 @@
2931
from torchao.float8.float8_tensor import (
3032
Float8Tensor,
3133
GemmInputRole,
34+
hp_tensor_and_scale_to_float8,
3235
LinearMMConfig,
3336
ScaledMMConfig,
3437
)
3538

36-
from torchao.float8.float8_utils import e4m3_dtype, e5m2_dtype, tensor_to_amax
39+
from torchao.float8.float8_utils import (
40+
e4m3_dtype,
41+
e5m2_dtype,
42+
tensor_to_amax,
43+
tensor_to_scale,
44+
)
3745

3846
from torchao.float8.fsdp_utils import (
3947
WeightWithDelayedFloat8CastTensor,
@@ -53,8 +61,9 @@ class manual_float8_matmul(torch.autograd.Function):
5361
def forward(
5462
ctx,
5563
input_fp8,
56-
weight_fp8_t,
64+
weight_fp8,
5765
):
66+
weight_fp8_t = weight_fp8.t()
5867
ctx.save_for_backward(input_fp8, weight_fp8_t)
5968
# the reshapes are needed in order to make the shapes compatible with
6069
# torch.mm
@@ -95,7 +104,7 @@ def backward(ctx, grad_output_fp8):
95104
input_fp8_reshaped,
96105
)
97106

98-
return grad_input, grad_weight.t()
107+
return grad_input, grad_weight
99108

100109

101110
class Float8Linear(torch.nn.Linear):
@@ -180,6 +189,8 @@ def __init__(self, *args, **kwargs):
180189
# would be initialized in every iteration.
181190
self.enable_pre_and_post_forward = self.config.enable_pre_and_post_forward
182191

192+
self.force_recompute_fp8_weight_in_bwd = self.config.force_recompute_fp8_weight_in_bwd
193+
183194
def create_buffers(self):
184195
# Default values for history buffers, see above TODO
185196
history_len = self.config.delayed_scaling_config.history_len
@@ -299,52 +310,38 @@ def cast_input_to_float8(
299310

300311
return input_fp8
301312

302-
def cast_weight_to_float8(
303-
self, weight: torch.Tensor, is_amax_initialized: bool
304-
) -> torch.Tensor:
313+
314+
def get_weight_scale(self, weight: torch.Tensor) -> torch.Tensor:
305315
if self.scaling_type_weight is ScalingType.DELAYED:
306-
if isinstance(self.weight, Float8Tensor): # cast by FSDP
307-
weight_fp8 = self.weight
308-
else:
309-
scale_fn_name = self.config.delayed_scaling_config.scale_fn_name
310-
_maybe_initialize_amaxes_scales_for_float8_cast(
311-
weight,
312-
self.fp8_amax_weight,
313-
self.fp8_amax_history_weight,
314-
self.fp8_scale_weight,
315-
scale_fn_name,
316-
e4m3_dtype,
317-
is_amax_initialized,
318-
reduce_amax=False,
319-
)
320-
321-
weight_fp8 = hp_tensor_to_float8_delayed(
322-
weight,
323-
self.fp8_scale_weight,
324-
e4m3_dtype,
325-
self.fp8_amax_weight,
326-
linear_mm_config=self.linear_mm_config,
327-
gemm_input_role=GemmInputRole.WEIGHT,
328-
)
316+
scale_fn_name = self.config.delayed_scaling_config.scale_fn_name
317+
_maybe_initialize_amaxes_scales_for_float8_cast(
318+
weight,
319+
self.fp8_amax_weight,
320+
self.fp8_amax_history_weight,
321+
self.fp8_scale_weight,
322+
scale_fn_name,
323+
e4m3_dtype,
324+
self.is_amax_initialized,
325+
reduce_amax=True,
326+
)
327+
self.fp8_amax_weight.fill_(tensor_to_amax(weight))
328+
return self.fp8_scale_weight
329329
elif self.scaling_type_weight is ScalingType.DYNAMIC:
330-
if isinstance(self.weight, Float8Tensor): # cast by FSDP
331-
weight_fp8 = self.weight
332-
else:
333-
weight_fp8 = hp_tensor_to_float8_dynamic(
334-
self.weight,
335-
e4m3_dtype,
336-
self.linear_mm_config,
337-
gemm_input_role=GemmInputRole.WEIGHT,
338-
)
330+
return tensor_to_scale(weight, e4m3_dtype)
339331
else:
340332
assert self.scaling_type_weight is ScalingType.STATIC
341-
weight_fp8 = hp_tensor_to_float8_static(
342-
self.weight,
343-
self.fp8_static_scale_weight,
344-
e4m3_dtype,
345-
self.linear_mm_config,
346-
gemm_input_role=GemmInputRole.WEIGHT,
347-
)
333+
return self.fp8_static_scale_weight
334+
335+
def cast_weight_to_float8(
336+
self, weight: torch.Tensor, weight_scale: torch.Tensor, is_amax_initialized: bool,
337+
) -> torch.Tensor:
338+
weight_fp8 = hp_tensor_and_scale_to_float8(
339+
weight,
340+
weight_scale,
341+
e4m3_dtype,
342+
self.linear_mm_config,
343+
gemm_input_role=GemmInputRole.WEIGHT,
344+
)
348345
return weight_fp8
349346

350347
def cast_output_to_float8_in_bw(self, output: torch.Tensor) -> torch.Tensor:
@@ -390,15 +387,32 @@ def float8_post_forward(self):
390387
# amaxes and scales
391388
self.is_amax_initialized = True
392389
self.amax_and_scale_synced = False
390+
391+
def cast_weight_and_matmul(self, input_fp8, weight_scale):
392+
weight_fp8 = self.cast_weight_to_float8(self.weight, weight_scale, self.is_amax_initialized)
393+
output = manual_float8_matmul.apply(input_fp8, weight_fp8)
394+
return output
393395

394396
def forward(self, input: torch.Tensor) -> torch.Tensor:
395397
if self.has_any_delayed_scaling:
396398
self.float8_pre_forward(input)
397399

398400
input_fp8 = self.cast_input_to_float8(input, self.is_amax_initialized)
399-
weight_fp8 = self.cast_weight_to_float8(self.weight, self.is_amax_initialized)
400401

401-
output = manual_float8_matmul.apply(input_fp8, weight_fp8.t())
402+
if isinstance(self.weight, Float8Tensor): # cast by FSDP
403+
weight_fp8 = self.weight
404+
if self.force_recompute_fp8_weight_in_bwd:
405+
output = checkpoint.checkpoint(manual_float8_matmul.apply, input_fp8, weight_fp8)
406+
else:
407+
output = manual_float8_matmul.apply(input_fp8, weight_fp8)
408+
else:
409+
weight_scale = self.get_weight_scale(self.weight)
410+
# We save weight_scale and only recompute weight_fp8
411+
if self.force_recompute_fp8_weight_in_bwd:
412+
output = checkpoint.checkpoint(self.cast_weight_and_matmul, input_fp8, weight_scale)
413+
else:
414+
output = self.cast_weight_and_matmul(input_fp8, weight_scale)
415+
402416

403417
# Cast grad_output to float8_e5m2 during backward
404418
output = self.cast_output_to_float8_in_bw(output)

0 commit comments

Comments
 (0)