Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[float8] improve eager numerics for dynamic scales and gets on par with torch.compile #904

Merged
merged 43 commits into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
6bf0f5c
[float8] improve eager numerics for dynamic scales
weifengpy Sep 19, 2024
553687f
leave torch.linalg.vector_norm for another PR
weifengpy Sep 19, 2024
19a592d
cuda
weifengpy Sep 19, 2024
218290e
remove _data and investigate
weifengpy Sep 19, 2024
24ec914
remove _data comment
weifengpy Sep 19, 2024
c099486
upcast to float32 is enough
weifengpy Sep 21, 2024
b93ffc8
explain why float32
weifengpy Sep 21, 2024
ebff416
_data parity
weifengpy Sep 21, 2024
8978ab2
handle sm8.9
weifengpy Sep 21, 2024
f17dc12
fix transformer unit test
weifengpy Sep 22, 2024
511c751
print if error
weifengpy Sep 26, 2024
9becda1
Add tutorial for trainable tensor subclass (#908)
andrewor14 Sep 20, 2024
e4fdca9
Introducing 1-bit quantization for Llama in torchchat (#910)
vaishnavi17 Sep 20, 2024
0cd4d37
Rename Floating point to fp8 (#909)
jainapurva Sep 20, 2024
014558d
[float8] fix typo in bitwise_identical unit test (#918)
weifengpy Sep 23, 2024
3267402
Adding example for quantized tensor + tensor parallelism (#785)
jerryzh168 Sep 23, 2024
1e07eff
rename cuda mode -> gpu mode (#925)
msaroufim Sep 24, 2024
ebdeed0
Add workaround to recover the perf for quantized vit in torch.compile…
jerryzh168 Sep 24, 2024
09ffa22
clean up device checks in float8 unit test files (#923)
vkuzo Sep 24, 2024
0b8dd85
[low-bit optim] Change 8-bit and FP8 optim block size from 2048 to 25…
gau-nernst Sep 24, 2024
87faf04
Float8 autoquant weight only (#866)
jainapurva Sep 24, 2024
3a9fdb0
Fix failing FP6 benchmark (#931)
tobiasvanderwerff Sep 25, 2024
fc6c393
Remove two if statements in fp8 padding (#935)
y-sq Sep 25, 2024
0043ace
[Distributed] Improve sharding example (#937)
kwen2501 Sep 25, 2024
ab3435c
Add composable QAT quantizer (#938)
andrewor14 Sep 25, 2024
a05a40f
resolve conflict with latest main
weifengpy Sep 26, 2024
334891b
Add torchchat quantizer
metascroy Sep 25, 2024
c706139
Add compile tests to test suite (#906)
jerryzh168 Sep 26, 2024
93554c0
Fix up CMakeLists and reorganize some code locations
metascroy Sep 26, 2024
efd9bb9
[float8] all-reduce amax on dp mesh instead of global pg (#933)
weifengpy Sep 26, 2024
85126cc
int8 dynamic quant + bsr support (#821)
jcaip Sep 26, 2024
a5a426e
fixing some issues with our support for 70/405B models (#941)
HDCharles Sep 26, 2024
e7270f1
Update INT8 mixed-precision training test to be less flaky (#950)
gau-nernst Sep 26, 2024
352685c
Add executorch parallel
metascroy Sep 26, 2024
168cfe9
Merge branch 'weifengpy-dynamic_scale_numerics' into dynamic_scale_nu…
weifengpy Sep 26, 2024
5900c3e
Merge branch 'main' into dynamic_scale_numerics
weifengpy Sep 26, 2024
37e1479
test CI
weifengpy Sep 26, 2024
2efde49
better comment on why upcasting
weifengpy Sep 26, 2024
8c04f4f
control seed
weifengpy Sep 26, 2024
04b229b
move unit test to test_compile
weifengpy Sep 26, 2024
8b7c2ef
fix typo
weifengpy Sep 26, 2024
9346afd
float64 upcasting after allreduce
weifengpy Sep 27, 2024
3d0da20
use LinearMMConfig
weifengpy Sep 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion test/float8/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@

import torch
import torch.nn as nn
from torchao.float8.float8_scaling_utils import (
hp_tensor_to_float8_dynamic,
)

from torchao.utils import TORCH_VERSION_AT_LEAST_2_5

Expand Down Expand Up @@ -53,7 +56,7 @@
is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)

def bitwise_identical(a: Float8Tensor, b: Float8Tensor) -> bool:
assert torch.all(a._data == b._data).item(), "scales are not identical"
weifengpy marked this conversation as resolved.
Show resolved Hide resolved
assert torch.all(a._scale == b._scale).item(), "scales are not identical"
assert torch.all(a._data == b._data).item(), "data is not identical"
return True

Expand Down Expand Up @@ -604,6 +607,35 @@ def test_small_amax_float16(self, float8_dtype):
x = torch.tensor([target_amax], dtype=torch.float16, device="cuda")
scale = tensor_to_scale(x, float8_dtype)
assert not torch.any(torch.isinf(scale))

@pytest.mark.parametrize(
"dtype",
[
torch.float32,
torch.bfloat16,
torch.float16,
],
)
def test_dynamic_scale_parity(self, dtype: torch.dtype):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: move to test_compile.py since this is testing compile vs eager?

scaling_type_weight = ScalingType.DYNAMIC
torch.manual_seed(42)
hp_tensor = torch.randn(768, 32, device="cuda", dtype=dtype)
float8_config = Float8LinearConfig(
cast_config_weight=CastConfig(scaling_type=scaling_type_weight),
)
float8_eager = hp_tensor_to_float8_dynamic(
hp_tensor,
torch.float8_e4m3fn,
float8_config,
gemm_input_role=GemmInputRole.WEIGHT,
)
float8_compile = torch.compile(hp_tensor_to_float8_dynamic)(
hp_tensor,
torch.float8_e4m3fn,
float8_config,
gemm_input_role=GemmInputRole.WEIGHT,
)
assert bitwise_identical(float8_eager, float8_compile)
weifengpy marked this conversation as resolved.
Show resolved Hide resolved


class TestFloat8LinearUtils(unittest.TestCase):
Expand Down
5 changes: 4 additions & 1 deletion torchao/float8/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,10 @@ def forward(

DTensor Invariant: DTensor must always be the outer most tensor subclass
"""
tensor_scaled = tensor * scale
# Required by scaled_mm, scale is always float32.
# Cast tensor to float32 to improve numerics and
# get on-par with torch.compile.
tensor_scaled = tensor.to(torch.float32) * scale
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

without upcasting, the eager numeric is like -157.00000000000000000000, compile is like -157.06507873535156250000

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.compile upcast tensor ahead, see tmp0 = tl.load(in_ptr0 + (x0), None).to(tl.float32) in following output code

@triton.jit
def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 24576
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (x0), None).to(tl.float32)

bits_fp8 = to_fp8_saturated(tensor_scaled, float8_dtype)

if isinstance(bits_fp8, DTensor):
Expand Down
3 changes: 3 additions & 0 deletions torchao/float8/float8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ def amax_to_scale(
float8_dtype: The float8 dtype.
orig_dtype: The original dtype of the tensor.
"""
# Preserve precision in amax-to-scale conversion
# and ensure on-par numerics with torch.compile
amax = amax.to(torch.float64)
Copy link
Contributor Author

@weifengpy weifengpy Sep 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

upcast amax in amax_to_scale instead of tensor_to_amax for 2 reasons

  • we can still do bfloat16 all-reduce for amax
  • safer to delayed scaling as it won't change dtype for amax_buffer

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you share why the upcasting happens?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can look into inductor more on how it achieved fp64

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.compile actually upcasts to float32 with tl.load(in_ptr0 + (x0), None).to(tl.float32). Upcasting to float64 further help because torch.compile and eager shows different numerics for 1.0 / float32 (but same numeric for float64)

The float32 numeric difference can be verified with

import torch
def upcast_reciprocal(inp: torch.Tensor):
    return inp.reciprocal()
inp = torch.full([], 0.00817871093750000000, device="cuda", dtype=torch.float32)
eager_scale = upcast_reciprocal(inp)
compile_scale = torch.compile(upcast_reciprocal)(inp)
fp64_ground_truth = inp.to(torch.float64).reciprocal()
assert torch.equal(eager_scale, compile_scale), f"{eager_scale=} vs {compile_scale=}, {fp64_ground_truth=}"

if float8_dtype in FP8_TYPES:
res = torch.finfo(float8_dtype).max / torch.clamp(amax, min=EPS)
else:
Expand Down
Loading