Skip to content

Commit 84fb834

Browse files
authored
for now, delete the float8-all-gather-only functionality from float8 … (#1451)
for now, delete the float8-all-gather-only functionality from float8 training Summary: In #1093 we added a config option, off by default, to use only float8 all-gather for training and do the matrix multiply in high precision. This seems generally useful for communication bound workloads, but we can probably think of a cleaner way to add this functionality (such as a weight wrapper tensor subclass). The current implementation adds non-trivial complexity and doesn't jive well with where we want to take this codebase. Since no one is using this internally or externally yet and we haven't talked about it in the release notes, I think we should do a BC-breaking delete as a one-off. However, if people have concerns - let me know and we can talk about less aggressive options. Test Plan: ``` ./test/float8/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags:
1 parent e474839 commit 84fb834

File tree

5 files changed

+5
-265
lines changed

5 files changed

+5
-265
lines changed

test/float8/test_fsdp2/test_fsdp2_fp8_comm_only.py

Lines changed: 0 additions & 180 deletions
This file was deleted.

test/quantization/test_quant_api.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,25 +15,25 @@
1515
import torch
1616
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
1717
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
18-
get_symmetric_quantization_config,
1918
XNNPACKQuantizer,
19+
get_symmetric_quantization_config,
2020
)
2121
from torch.testing._internal import common_utils
2222
from torch.testing._internal.common_utils import TestCase
2323

2424
from torchao import quantize_
25-
from torchao._models.llama.model import prepare_inputs_for_model, Transformer
25+
from torchao._models.llama.model import Transformer, prepare_inputs_for_model
2626
from torchao._models.llama.tokenizer import get_tokenizer
2727
from torchao.dtypes import AffineQuantizedTensor
2828
from torchao.quantization import LinearActivationQuantizedTensor
2929
from torchao.quantization.quant_api import (
30+
Quantizer,
31+
TwoStepQuantizer,
3032
_replace_with_custom_fn_if_matches_filter,
3133
int4_weight_only,
3234
int8_dynamic_activation_int4_weight,
3335
int8_dynamic_activation_int8_weight,
3436
int8_weight_only,
35-
Quantizer,
36-
TwoStepQuantizer,
3737
)
3838
from torchao.quantization.quant_primitives import MappingType
3939
from torchao.quantization.subclass import (

torchao/float8/config.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -234,12 +234,6 @@ class Float8LinearConfig:
234234
# tests so that the warning does not spam the CI stdout.
235235
force_recompute_fp8_weight_in_bwd: bool = False
236236

237-
# If True, we only use fp8-all-gather to reduce the communication cost.
238-
# The gemm computation is still done in the original precision.
239-
# `cast_config_weight` is used to decide how to cast the weight to fp8,
240-
# other casting configs will be ignored.
241-
use_fp8_all_gather_only: bool = False
242-
243237
def __post_init__(self):
244238
# Populate the additional cast overrides, if the user did not specify them
245239
# Note: this hacks around the frozen-ness of this dataclass
@@ -301,9 +295,6 @@ def __post_init__(self):
301295
cc1.target_dtype == cc2.target_dtype
302296
), f"{operand_name} must be cast to the same dtype in both matmuls it's used in"
303297

304-
if self.use_fp8_all_gather_only:
305-
assert self.enable_fsdp_float8_all_gather, "use_fp8_all_gather_only requires enable_fsdp_float8_all_gather to be True"
306-
307298
# See the comments around `force_recompute_fp8_weight_in_bwd` for more details of this warning.
308299
if (
309300
self.enable_fsdp_float8_all_gather

torchao/float8/float8_linear.py

Lines changed: 1 addition & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
hp_tensor_to_float8_dynamic,
2121
)
2222
from torchao.float8.float8_tensor import (
23-
Float8Tensor,
2423
GemmInputRole,
2524
LinearMMConfig,
2625
ScaledMMConfig,
@@ -344,12 +343,6 @@ def cast_weight_to_float8_t(
344343
)
345344
return weight_fp8.t()
346345

347-
def cast_weight_to_original_t(self, weight: torch.Tensor):
348-
if isinstance(weight, Float8Tensor):
349-
return weight.to_original_precision().t()
350-
else:
351-
return weight.t()
352-
353346
def cast_output_to_float8_in_bw(self, output: torch.Tensor) -> torch.Tensor:
354347
assert self.scaling_type_grad_output is ScalingType.DYNAMIC
355348
output = NoopFwToFloat8BwDynamic.apply(
@@ -359,7 +352,7 @@ def cast_output_to_float8_in_bw(self, output: torch.Tensor) -> torch.Tensor:
359352
)
360353
return output
361354

362-
def forward_fp8_matmul(self, input: torch.Tensor) -> torch.Tensor:
355+
def forward(self, input: torch.Tensor) -> torch.Tensor:
363356
has_any_axiswise_scaling = any(
364357
cc.scaling_granularity is ScalingGranularity.AXISWISE
365358
for cc in [
@@ -403,24 +396,6 @@ def forward_fp8_matmul(self, input: torch.Tensor) -> torch.Tensor:
403396
self.linear_mm_config,
404397
self.config,
405398
)
406-
return output
407-
408-
def forward_original_precision_matmul(self, input: torch.Tensor) -> torch.Tensor:
409-
if self.config.force_recompute_fp8_weight_in_bwd:
410-
orig_weight_t = checkpoint.checkpoint(
411-
self.cast_weight_to_original_t, self.weight
412-
)
413-
else:
414-
orig_weight_t = self.cast_weight_to_original_t(self.weight)
415-
416-
output = torch.matmul(input, orig_weight_t)
417-
return output
418-
419-
def forward(self, input: torch.Tensor) -> torch.Tensor:
420-
if self.config.use_fp8_all_gather_only:
421-
output = self.forward_original_precision_matmul(input)
422-
else:
423-
output = self.forward_fp8_matmul(input)
424399

425400
if self.bias is not None:
426401
output = output + self.bias.to(output.dtype)

torchao/testing/float8/fsdp2_utils.py

Lines changed: 0 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -93,49 +93,3 @@ def check_parity_bf16_mp(
9393
losses[1],
9494
msg=f"iter: {iter_idx}, loss-ref: {losses[0]}, loss-fp8: {losses[1]}",
9595
)
96-
97-
98-
def check_parity_fp8_comm_only(
99-
test_cls,
100-
ref_model: nn.Module,
101-
ref_optim: torch.optim.Optimizer,
102-
fsdp_model: nn.Module,
103-
fsdp_optim: torch.optim.Optimizer,
104-
local_inp: torch.Tensor,
105-
config: Float8LinearConfig,
106-
precompute: bool = False,
107-
compile: bool = False,
108-
):
109-
for iter_idx in range(10):
110-
losses: List[torch.Tensor] = []
111-
for model, optim in ((ref_model, ref_optim), (fsdp_model, fsdp_optim)):
112-
optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
113-
losses.append(model(local_inp).sum())
114-
losses[-1].backward()
115-
if model is ref_model:
116-
for name, param in model.named_parameters():
117-
dist.all_reduce(param.grad)
118-
param.grad.div_(dist.get_world_size())
119-
120-
if linear_requires_sync(config):
121-
sync_float8_amax_and_scale_history(model)
122-
123-
optim.step()
124-
if (
125-
model is fsdp_model
126-
and precompute
127-
and config.cast_config_weight.scaling_type is ScalingType.DYNAMIC
128-
):
129-
precompute_float8_dynamic_scale_for_fsdp(model)
130-
131-
if compile:
132-
# When compile, the ref loss and fsdp loss are not exactly the same, only check the loss values are valid for now.
133-
assert (
134-
torch.isfinite(losses[0]).any() and torch.isfinite(losses[1]).any()
135-
), f"iter: {iter_idx}, loss-ref: {losses[0]}, loss-fp8: {losses[1]}"
136-
else:
137-
test_cls.assertEqual(
138-
losses[0],
139-
losses[1],
140-
f"iter: {iter_idx}, loss-ref: {losses[0]}, loss-fp8: {losses[1]}",
141-
)

0 commit comments

Comments
 (0)