Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[float8] improve eager numerics for dynamic scales and gets on par with torch.compile #904

Merged
merged 43 commits into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
6bf0f5c
[float8] improve eager numerics for dynamic scales
weifengpy Sep 19, 2024
553687f
leave torch.linalg.vector_norm for another PR
weifengpy Sep 19, 2024
19a592d
cuda
weifengpy Sep 19, 2024
218290e
remove _data and investigate
weifengpy Sep 19, 2024
24ec914
remove _data comment
weifengpy Sep 19, 2024
c099486
upcast to float32 is enough
weifengpy Sep 21, 2024
b93ffc8
explain why float32
weifengpy Sep 21, 2024
ebff416
_data parity
weifengpy Sep 21, 2024
8978ab2
handle sm8.9
weifengpy Sep 21, 2024
f17dc12
fix transformer unit test
weifengpy Sep 22, 2024
511c751
print if error
weifengpy Sep 26, 2024
9becda1
Add tutorial for trainable tensor subclass (#908)
andrewor14 Sep 20, 2024
e4fdca9
Introducing 1-bit quantization for Llama in torchchat (#910)
vaishnavi17 Sep 20, 2024
0cd4d37
Rename Floating point to fp8 (#909)
jainapurva Sep 20, 2024
014558d
[float8] fix typo in bitwise_identical unit test (#918)
weifengpy Sep 23, 2024
3267402
Adding example for quantized tensor + tensor parallelism (#785)
jerryzh168 Sep 23, 2024
1e07eff
rename cuda mode -> gpu mode (#925)
msaroufim Sep 24, 2024
ebdeed0
Add workaround to recover the perf for quantized vit in torch.compile…
jerryzh168 Sep 24, 2024
09ffa22
clean up device checks in float8 unit test files (#923)
vkuzo Sep 24, 2024
0b8dd85
[low-bit optim] Change 8-bit and FP8 optim block size from 2048 to 25…
gau-nernst Sep 24, 2024
87faf04
Float8 autoquant weight only (#866)
jainapurva Sep 24, 2024
3a9fdb0
Fix failing FP6 benchmark (#931)
tobiasvanderwerff Sep 25, 2024
fc6c393
Remove two if statements in fp8 padding (#935)
y-sq Sep 25, 2024
0043ace
[Distributed] Improve sharding example (#937)
kwen2501 Sep 25, 2024
ab3435c
Add composable QAT quantizer (#938)
andrewor14 Sep 25, 2024
a05a40f
resolve conflict with latest main
weifengpy Sep 26, 2024
334891b
Add torchchat quantizer
metascroy Sep 25, 2024
c706139
Add compile tests to test suite (#906)
jerryzh168 Sep 26, 2024
93554c0
Fix up CMakeLists and reorganize some code locations
metascroy Sep 26, 2024
efd9bb9
[float8] all-reduce amax on dp mesh instead of global pg (#933)
weifengpy Sep 26, 2024
85126cc
int8 dynamic quant + bsr support (#821)
jcaip Sep 26, 2024
a5a426e
fixing some issues with our support for 70/405B models (#941)
HDCharles Sep 26, 2024
e7270f1
Update INT8 mixed-precision training test to be less flaky (#950)
gau-nernst Sep 26, 2024
352685c
Add executorch parallel
metascroy Sep 26, 2024
168cfe9
Merge branch 'weifengpy-dynamic_scale_numerics' into dynamic_scale_nu…
weifengpy Sep 26, 2024
5900c3e
Merge branch 'main' into dynamic_scale_numerics
weifengpy Sep 26, 2024
37e1479
test CI
weifengpy Sep 26, 2024
2efde49
better comment on why upcasting
weifengpy Sep 26, 2024
8c04f4f
control seed
weifengpy Sep 26, 2024
04b229b
move unit test to test_compile
weifengpy Sep 26, 2024
8b7c2ef
fix typo
weifengpy Sep 26, 2024
9346afd
float64 upcasting after allreduce
weifengpy Sep 27, 2024
3d0da20
use LinearMMConfig
weifengpy Sep 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 69 additions & 2 deletions test/float8/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,15 @@
get_float8_layers,
sync_float8_amax_and_scale_history,
)
from torchao.float8.float8_scaling_utils import hp_tensor_to_float8_delayed
from torchao.float8.float8_tensor import LinearMMConfig
from torchao.float8.float8_scaling_utils import (
hp_tensor_to_float8_delayed,
hp_tensor_to_float8_dynamic,
)
from torchao.float8.float8_tensor import (
LinearMMConfig,
GemmInputRole,
ScaledMMConfig,
)
from torchao.float8.float8_utils import e4m3_dtype

from torch._dynamo.test_case import TestCase as DynamoTestCase
Expand Down Expand Up @@ -353,5 +360,65 @@ def test_sync_amax_func_cuda_graph_success():
assert "skipping cudagraphs due to mutaton on input" not in stderr[0]


@unittest.skipIf(
not is_cuda_8_9,
"CUDA not available",
)
@pytest.mark.parametrize(
"dtype",
[
torch.float32,
torch.bfloat16,
torch.float16,
],
)
def test_dynamic_scale_numeric_parity(dtype: torch.dtype):
scaling_type_weight = ScalingType.DYNAMIC
torch.manual_seed(42)
hp_tensor1 = torch.randn(16, 16, device="cuda", dtype=dtype)
hp_tensor2 = hp_tensor1.detach().clone()
float8_config = Float8LinearConfig(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be an object of type LinearMMConfig, I'm actually kind of surprised passing in Float8LinearConfig works :(

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch! I switched to LinearMMConfig. Float8LinearConfig was working because I did not call matmul that requires access to self._linear_mm_config.

cast_config_weight=CastConfig(scaling_type=scaling_type_weight),
)
linear_mm_config = LinearMMConfig(
# output
ScaledMMConfig(
False,
float8_config.gemm_config_output.use_fast_accum,
False,
float8_config.pad_inner_dim,
),
# grad_input
ScaledMMConfig(
False,
float8_config.gemm_config_grad_input.use_fast_accum,
False,
float8_config.pad_inner_dim,
),
# grad_weight
ScaledMMConfig(
False,
float8_config.gemm_config_grad_weight.use_fast_accum,
False,
float8_config.pad_inner_dim,
),
)
float8_eager = hp_tensor_to_float8_dynamic(
hp_tensor1,
torch.float8_e4m3fn,
linear_mm_config,
gemm_input_role=GemmInputRole.WEIGHT,
)
torch._dynamo.reset()
float8_compile = torch.compile(hp_tensor_to_float8_dynamic)(
hp_tensor2,
torch.float8_e4m3fn,
linear_mm_config,
gemm_input_role=GemmInputRole.WEIGHT,
)
assert torch.equal(float8_eager._scale, float8_compile._scale)
assert torch.equal(float8_eager._data, float8_compile._data)


if __name__ == "__main__":
pytest.main([__file__])
5 changes: 4 additions & 1 deletion torchao/float8/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,10 @@ def forward(

DTensor Invariant: DTensor must always be the outer most tensor subclass
"""
tensor_scaled = tensor * scale
# Note: when the line below is compiled with `torch.compile`, `tensor` is automatically
# upcasted to `float32` to multiply with the scale
# In order to match numerics between eager and compile, we upcast manually here.
tensor_scaled = tensor.to(torch.float32) * scale
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

without upcasting, the eager numeric is like -157.00000000000000000000, compile is like -157.06507873535156250000

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.compile upcast tensor ahead, see tmp0 = tl.load(in_ptr0 + (x0), None).to(tl.float32) in following output code

@triton.jit
def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 24576
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (x0), None).to(tl.float32)

bits_fp8 = to_fp8_saturated(tensor_scaled, float8_dtype)

if isinstance(bits_fp8, DTensor):
Expand Down
3 changes: 3 additions & 0 deletions torchao/float8/float8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ def amax_to_scale(
float8_dtype: The float8 dtype.
orig_dtype: The original dtype of the tensor.
"""
# torch.compile and eager show different numerics for 1.0 / float32,
# upcast to float64 to ensure same numeric between compile and eager
amax = amax.to(torch.float64)
Copy link
Contributor Author

@weifengpy weifengpy Sep 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

upcast amax in amax_to_scale instead of tensor_to_amax for 2 reasons

  • we can still do bfloat16 all-reduce for amax
  • safer to delayed scaling as it won't change dtype for amax_buffer

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you share why the upcasting happens?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can look into inductor more on how it achieved fp64

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.compile actually upcasts to float32 with tl.load(in_ptr0 + (x0), None).to(tl.float32). Upcasting to float64 further help because torch.compile and eager shows different numerics for 1.0 / float32 (but same numeric for float64)

The float32 numeric difference can be verified with

import torch
def upcast_reciprocal(inp: torch.Tensor):
    return inp.reciprocal()
inp = torch.full([], 0.00817871093750000000, device="cuda", dtype=torch.float32)
eager_scale = upcast_reciprocal(inp)
compile_scale = torch.compile(upcast_reciprocal)(inp)
fp64_ground_truth = inp.to(torch.float64).reciprocal()
assert torch.equal(eager_scale, compile_scale), f"{eager_scale=} vs {compile_scale=}, {fp64_ground_truth=}"

if float8_dtype in FP8_TYPES:
res = torch.finfo(float8_dtype).max / torch.clamp(amax, min=EPS)
else:
Expand Down
11 changes: 8 additions & 3 deletions torchao/float8/fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,17 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None:
# clamp is dispatched through DTensor
# it will issue a single all-reduce
amax_tensor = torch.clamp(amax_tensor, EPS) # Replicate
# keep consistent with float8_utils.amax_to_scale
# torch.compile and eager show different numerics for 1.0 / float32,
# upcast to float64 to ensure same numeric between compile and eager
origin_dtype = amax_tensor.dtype
amax_tensor = amax_tensor.to(torch.float64)
scale_tensor = torch.finfo(torch.float8_e4m3fn).max / amax_tensor # Replicate
if amax_tensor.dtype is torch.float16:
if origin_dtype is torch.float16:
scale_tensor = torch.clamp(scale_tensor, max=torch.finfo(torch.float16).max)
local_scale_tensor = scale_tensor.to_local()
local_scale_tensor = scale_tensor.to_local().to(torch.float32)
for i, float8_linear in enumerate(float8_linears):
float8_linear.weight._local_tensor._precomputed_scale = local_scale_tensor[i].to(torch.float32)
float8_linear.weight._local_tensor._precomputed_scale = local_scale_tensor[i]


# FSDP pads its local tensor on dim-0. The subclass should be preserved such
Expand Down
5 changes: 1 addition & 4 deletions torchao/testing/float8/fsdp2_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,7 @@ def check_parity_no_mp(
):
precompute_float8_dynamic_scale_for_fsdp(model)

if compile_transformer_block:
test_cls.assertEqual(losses[0], losses[1], atol=1e-4, rtol=1e-4, msg = f"iter: {iter_idx}, loss-ref: {losses[0]}, loss-fp8: {losses[1]}")
else:
test_cls.assertEqual(losses[0], losses[1], msg = f"iter: {iter_idx}, loss-ref: {losses[0]}, loss-fp8: {losses[1]}")
test_cls.assertEqual(losses[0], losses[1], msg = f"iter: {iter_idx}, loss-ref: {losses[0]}, loss-fp8: {losses[1]}")


def check_parity_bf16_mp(
Expand Down
Loading