Skip to content

Commit

Permalink
Revert "[ROCm] use dataclass for fnuz type setting" (pytorch#1148)
Browse files Browse the repository at this point in the history
Revert "[ROCm] use dataclass for fnuz type setting (pytorch#1142)"

This reverts commit eb1fb3a.
  • Loading branch information
msaroufim authored Oct 23, 2024
1 parent eb1fb3a commit d252612
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 55 deletions.
28 changes: 14 additions & 14 deletions test/float8/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@


from torchao.float8.config import (
CastConfig,
Float8LinearConfig,
CastConfig,
Float8LinearConfig,
ScalingGranularity,
ScalingType,
Float8LinearRecipeName,
Expand Down Expand Up @@ -109,15 +109,15 @@ def test_split_cat(self):

def test_index_put(self):
a = torch.rand(16, dtype=torch.bfloat16)
scale_a = tensor_to_scale(a, e4m3_dtype)
fp8_a = hp_tensor_and_scale_to_float8(a, scale_a, e4m3_dtype)
scale_a = tensor_to_scale(a, torch.float8_e4m3fn)
fp8_a = hp_tensor_and_scale_to_float8(a, scale_a, torch.float8_e4m3fn)

index = torch.randint(0, 15, (16,), dtype=torch.long)

b = torch.rand(16, 16, dtype=torch.bfloat16)
scale_b = tensor_to_scale(b, e4m3_dtype)
fp8_b = hp_tensor_and_scale_to_float8(b, scale_a, e4m3_dtype)
fp8_b_bad = hp_tensor_and_scale_to_float8(b, scale_b, e4m3_dtype)
scale_b = tensor_to_scale(b, torch.float8_e4m3fn)
fp8_b = hp_tensor_and_scale_to_float8(b, scale_a, torch.float8_e4m3fn)
fp8_b_bad = hp_tensor_and_scale_to_float8(b, scale_b, torch.float8_e4m3fn)

with pytest.raises(AssertionError):
b[index] = fp8_a
Expand All @@ -127,8 +127,8 @@ def test_index_put(self):

def test_copy_(self):
a = torch.rand(16, dtype=torch.bfloat16)
scale_a = tensor_to_scale(a, e4m3_dtype)
fp8_a = hp_tensor_and_scale_to_float8(a, scale_a, e4m3_dtype)
scale_a = tensor_to_scale(a, torch.float8_e4m3fn)
fp8_a = hp_tensor_and_scale_to_float8(a, scale_a, torch.float8_e4m3fn)

b = torch.empty(16, dtype=torch.bfloat16)
b.copy_(fp8_a) # Should work
Expand All @@ -137,7 +137,7 @@ def test_copy_(self):
fp8_a.copy_(b) # Should fail

fp8_b = Float8Tensor(
torch.empty(16, dtype=e4m3_dtype),
torch.empty(16, dtype=torch.float8_e4m3fn),
scale_a,
torch.bfloat16,
fp8_a._linear_mm_config,
Expand Down Expand Up @@ -332,11 +332,11 @@ def _test_linear_impl(
@pytest.mark.parametrize("emulate", [True, False] if is_cuda_8_9 else [True])
@pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)])
@pytest.mark.parametrize(
"scaling_type_input",
"scaling_type_input",
[ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC]
)
@pytest.mark.parametrize(
"scaling_type_weight",
"scaling_type_weight",
[ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC]
)
@pytest.mark.parametrize(
Expand Down Expand Up @@ -377,7 +377,7 @@ def test_linear_from_config_params(
# to combine with the main testing function.
# TODO(future PR): make this cleaner.
@pytest.mark.parametrize(
"recipe_name",
"recipe_name",
[Float8LinearRecipeName.ALL_AXISWISE, Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP],
)
@pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)])
Expand Down Expand Up @@ -610,7 +610,7 @@ def test_different_configs_error(self):
@pytest.mark.parametrize("use_fast_accum", [True, False])
def test_pad_inner_dim(self, base_dtype, use_fast_accum):
torch.manual_seed(42)
input_dtype = e4m3_dtype
input_dtype = torch.float8_e4m3fn
compare_type = torch.float32

a = torch.randn(16, 41, device="cuda", dtype=base_dtype)
Expand Down
14 changes: 7 additions & 7 deletions test/float8/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
import torch
import torch.nn as nn
from torchao.float8.config import (
CastConfig,
Float8LinearConfig,
ScalingType,
CastConfig,
Float8LinearConfig,
ScalingType,
Float8LinearRecipeName,
recipe_name_to_linear_config,
)
Expand Down Expand Up @@ -77,7 +77,7 @@ def _test_compile_base(
y_fp8.sum().backward()
y_ref = m_ref(x_ref)
y_ref.sum().backward()
# TODO(future PR): can also test fp8 eager vs compile here with a tigher
# TODO(future PR): can also test fp8 eager vs compile here with a tigher
# tolerance
torch.testing.assert_close(y_fp8, y_ref, atol=9.5e-2, rtol=9.5e-2)
torch.testing.assert_close(
Expand Down Expand Up @@ -199,7 +199,7 @@ def test_inductor_from_config_params(
# to combine with the main testing function.
# TODO(future PR): make this cleaner.
@pytest.mark.parametrize(
"recipe_name",
"recipe_name",
[Float8LinearRecipeName.ALL_AXISWISE, Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP],
)
@unittest.skipIf(not is_H100, "CUDA with capability 9.0 or greater not available")
Expand Down Expand Up @@ -412,14 +412,14 @@ def test_dynamic_scale_numeric_parity(dtype: torch.dtype):
)
float8_eager = hp_tensor_to_float8_dynamic(
hp_tensor1,
e4m3_dtype,
torch.float8_e4m3fn,
linear_mm_config,
gemm_input_role=GemmInputRole.WEIGHT,
)
torch._dynamo.reset()
float8_compile = torch.compile(hp_tensor_to_float8_dynamic)(
hp_tensor2,
e4m3_dtype,
torch.float8_e4m3fn,
linear_mm_config,
gemm_input_role=GemmInputRole.WEIGHT,
)
Expand Down
43 changes: 13 additions & 30 deletions torchao/float8/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,29 +96,6 @@ def __post_init__(self):
), f"{self.scale_fn_name} is not implemented yet. Only max is supported for now."


@dataclass
class Float8TypeConfig:
"""
Configuration for selecting the preferred float8 type pair, either e4m3fn/e5m2 or e4m3fnuz/e5m2fnuz.
Currently, ROCm only supports fnuz variants.
"""

# The preferred e4m3 type.
e4m3_dtype = torch.float8_e4m3fn

# The preferred e5m2 type.
e5m2_dtype = torch.float8_e5m2

def __post_init__(self):
if torch.version.hip:
prop = torch.cuda.get_device_properties(0)
MI300_ARCH = ("gfx940", "gfx941", "gfx942")
if prop.gcnArchName.split(":")[0] in MI300_ARCH:
self.e4m3_dtype = torch.float8_e4m3fnuz
self.e5m2_dtype = torch.float8_e5m2fnuz


@dataclass(frozen=True)
class Float8GemmConfig:
"""
Expand All @@ -141,11 +118,11 @@ class Float8LinearConfig:
# Per-tensor configuration for casting of `input`, `weight`, `grad_output`
# for the operands of gemms calculating `output`, `grad_weight`, and `grad_input`.
#
# Note:
# 1. if `cast_config_input_for_grad_weight` is None, then
# Note:
# 1. if `cast_config_input_for_grad_weight` is None, then
# `cast_config_input` is used for scaling `input` for both gemms that
# use `input.
# 2. if `cast_config_input_for_grad_weight` is specified, then
# use `input.
# 2. if `cast_config_input_for_grad_weight` is specified, then
# a. `cast_config_input` is used for scaling `input` for the gemm that calculates
# `output`
# b. `cast_config_input_for_grad_weight` is used for scaling `input` for
Expand Down Expand Up @@ -263,6 +240,12 @@ def __post_init__(self):
f"incompatible operand precision for {gemm_name}"


# If True, use 'fnuz' float8 types for calculations.
# Currently, ROCm only supports fnuz variants.
# TODO(future PR): move this to Float8LinearConfig
use_fnuz_dtype = False


# Pre-made recipes for common configurations
# TODO(future PR): go through a round of design on this, and eventually expose
# as a top level public API.
Expand All @@ -289,7 +272,7 @@ def recipe_name_to_linear_config(
cc_i = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)
cc_w = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)
cc_go = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)

# The current rowwise CUTLASS kernels in `torch._scaled_mm` are only
# fast with `use_fast_accum=True`. Note that rowwise scaling is more
# accurate than tensorwise scaling, so the overall impact on accuracy
Expand Down Expand Up @@ -317,8 +300,8 @@ def recipe_name_to_linear_config(
#
# key characteristics:
# * increased accuracy for grad_weight
# * `input`, `weight` and `grad_output` now only need to be scaled
# axiswise across a single dim compared to vanilla all-axiswise,
# * `input`, `weight` and `grad_output` now only need to be scaled
# axiswise across a single dim compared to vanilla all-axiswise,
# which is more amenable to fast kernels

# output_hp = input_fp8_axiswise_dim0 @ weight_t_axiswise_dim1
Expand Down
8 changes: 4 additions & 4 deletions torchao/float8/float8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
import torch
import torch.distributed as dist

from torchao.float8.config import Float8TypeConfig, ScalingGranularity
import torchao.float8.config as config
from torchao.float8.config import ScalingGranularity

# Helpful visualizer for debugging (only supports fp32):
# https://www.h-schmidt.net/FloatConverter/IEEE754.html
Expand All @@ -28,9 +29,8 @@


# User defined type for using the individual F8 type based on config
type_config = Float8TypeConfig()
e4m3_dtype = type_config.e4m3_dtype
e5m2_dtype = type_config.e5m2_dtype
e4m3_dtype = torch.float8_e4m3fn if not config.use_fnuz_dtype else torch.float8_e4m3fnuz
e5m2_dtype = torch.float8_e5m2 if not config.use_fnuz_dtype else torch.float8_e5m2fnuz


@torch.no_grad()
Expand Down

0 comments on commit d252612

Please sign in to comment.