diff --git a/test/float8/test_compile.py b/test/float8/test_compile.py index bae62bf77..5106bd778 100644 --- a/test/float8/test_compile.py +++ b/test/float8/test_compile.py @@ -25,8 +25,15 @@ get_float8_layers, sync_float8_amax_and_scale_history, ) -from torchao.float8.float8_scaling_utils import hp_tensor_to_float8_delayed -from torchao.float8.float8_tensor import LinearMMConfig +from torchao.float8.float8_scaling_utils import ( + hp_tensor_to_float8_delayed, + hp_tensor_to_float8_dynamic, +) +from torchao.float8.float8_tensor import ( + LinearMMConfig, + GemmInputRole, + ScaledMMConfig, +) from torchao.float8.float8_utils import e4m3_dtype from torch._dynamo.test_case import TestCase as DynamoTestCase @@ -353,5 +360,65 @@ def test_sync_amax_func_cuda_graph_success(): assert "skipping cudagraphs due to mutaton on input" not in stderr[0] +@unittest.skipIf( + not is_cuda_8_9, + "CUDA not available", + ) +@pytest.mark.parametrize( + "dtype", + [ + torch.float32, + torch.bfloat16, + torch.float16, + ], +) +def test_dynamic_scale_numeric_parity(dtype: torch.dtype): + scaling_type_weight = ScalingType.DYNAMIC + torch.manual_seed(42) + hp_tensor1 = torch.randn(16, 16, device="cuda", dtype=dtype) + hp_tensor2 = hp_tensor1.detach().clone() + float8_config = Float8LinearConfig( + cast_config_weight=CastConfig(scaling_type=scaling_type_weight), + ) + linear_mm_config = LinearMMConfig( + # output + ScaledMMConfig( + False, + float8_config.gemm_config_output.use_fast_accum, + False, + float8_config.pad_inner_dim, + ), + # grad_input + ScaledMMConfig( + False, + float8_config.gemm_config_grad_input.use_fast_accum, + False, + float8_config.pad_inner_dim, + ), + # grad_weight + ScaledMMConfig( + False, + float8_config.gemm_config_grad_weight.use_fast_accum, + False, + float8_config.pad_inner_dim, + ), + ) + float8_eager = hp_tensor_to_float8_dynamic( + hp_tensor1, + torch.float8_e4m3fn, + linear_mm_config, + gemm_input_role=GemmInputRole.WEIGHT, + ) + torch._dynamo.reset() + float8_compile = torch.compile(hp_tensor_to_float8_dynamic)( + hp_tensor2, + torch.float8_e4m3fn, + linear_mm_config, + gemm_input_role=GemmInputRole.WEIGHT, + ) + assert torch.equal(float8_eager._scale, float8_compile._scale) + assert torch.equal(float8_eager._data, float8_compile._data) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/torchao/float8/float8_tensor.py b/torchao/float8/float8_tensor.py index afeaf6462..8927cf4e7 100644 --- a/torchao/float8/float8_tensor.py +++ b/torchao/float8/float8_tensor.py @@ -163,7 +163,10 @@ def forward( DTensor Invariant: DTensor must always be the outer most tensor subclass """ - tensor_scaled = tensor * scale + # Note: when the line below is compiled with `torch.compile`, `tensor` is automatically + # upcasted to `float32` to multiply with the scale + # In order to match numerics between eager and compile, we upcast manually here. + tensor_scaled = tensor.to(torch.float32) * scale bits_fp8 = to_fp8_saturated(tensor_scaled, float8_dtype) if isinstance(bits_fp8, DTensor): diff --git a/torchao/float8/float8_utils.py b/torchao/float8/float8_utils.py index d0c1d2c01..49a2a1152 100644 --- a/torchao/float8/float8_utils.py +++ b/torchao/float8/float8_utils.py @@ -42,6 +42,9 @@ def amax_to_scale( float8_dtype: The float8 dtype. orig_dtype: The original dtype of the tensor. """ + # torch.compile and eager show different numerics for 1.0 / float32, + # upcast to float64 to ensure same numeric between compile and eager + amax = amax.to(torch.float64) if float8_dtype in FP8_TYPES: res = torch.finfo(float8_dtype).max / torch.clamp(amax, min=EPS) else: diff --git a/torchao/float8/fsdp_utils.py b/torchao/float8/fsdp_utils.py index 5939f721f..201e9fdfe 100644 --- a/torchao/float8/fsdp_utils.py +++ b/torchao/float8/fsdp_utils.py @@ -64,12 +64,17 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None: # clamp is dispatched through DTensor # it will issue a single all-reduce amax_tensor = torch.clamp(amax_tensor, EPS) # Replicate + # keep consistent with float8_utils.amax_to_scale + # torch.compile and eager show different numerics for 1.0 / float32, + # upcast to float64 to ensure same numeric between compile and eager + origin_dtype = amax_tensor.dtype + amax_tensor = amax_tensor.to(torch.float64) scale_tensor = torch.finfo(torch.float8_e4m3fn).max / amax_tensor # Replicate - if amax_tensor.dtype is torch.float16: + if origin_dtype is torch.float16: scale_tensor = torch.clamp(scale_tensor, max=torch.finfo(torch.float16).max) - local_scale_tensor = scale_tensor.to_local() + local_scale_tensor = scale_tensor.to_local().to(torch.float32) for i, float8_linear in enumerate(float8_linears): - float8_linear.weight._local_tensor._precomputed_scale = local_scale_tensor[i].to(torch.float32) + float8_linear.weight._local_tensor._precomputed_scale = local_scale_tensor[i] # FSDP pads its local tensor on dim-0. The subclass should be preserved such diff --git a/torchao/testing/float8/fsdp2_utils.py b/torchao/testing/float8/fsdp2_utils.py index f558bb11f..62a571e15 100644 --- a/torchao/testing/float8/fsdp2_utils.py +++ b/torchao/testing/float8/fsdp2_utils.py @@ -48,10 +48,7 @@ def check_parity_no_mp( ): precompute_float8_dynamic_scale_for_fsdp(model) - if compile_transformer_block: - test_cls.assertEqual(losses[0], losses[1], atol=1e-4, rtol=1e-4, msg = f"iter: {iter_idx}, loss-ref: {losses[0]}, loss-fp8: {losses[1]}") - else: - test_cls.assertEqual(losses[0], losses[1], msg = f"iter: {iter_idx}, loss-ref: {losses[0]}, loss-fp8: {losses[1]}") + test_cls.assertEqual(losses[0], losses[1], msg = f"iter: {iter_idx}, loss-ref: {losses[0]}, loss-fp8: {losses[1]}") def check_parity_bf16_mp(