Skip to content

Commit c4f43de

Browse files
committed
expose emulation in mxfp4 inference workflow
Summary: Enables emulation mode in the mxfp4 inference workflow. This is useful for debugging on an H100. Test Plan: ```bash pytest test/prototype/mx_formats/test_inference_workflow.py -s -k test_inference_workflow_mx ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: db1bc94 ghstack-comment-id: 3335736147 Pull Request resolved: #3066
1 parent 8c5c33e commit c4f43de

File tree

2 files changed

+13
-13
lines changed

2 files changed

+13
-13
lines changed

test/prototype/mx_formats/test_inference_workflow.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -50,12 +50,12 @@ def run_around_tests():
5050
@pytest.mark.parametrize("elem_dtype", [torch.float8_e4m3fn, torch.float4_e2m1fn_x2])
5151
@pytest.mark.parametrize("bias", [True, False])
5252
@pytest.mark.parametrize("compile", [True, False])
53+
@pytest.mark.parametrize("emulate", [True, False])
5354
@torch.no_grad()
5455
@skip_if_rocm(
5556
"ROCm float4 gemm require gfx950"
5657
) # TODO(future): deploy gfx950 in ROCM CI
57-
@pytest.mark.skipif(not is_sm_at_least_100(), reason="CUDA capability >= 10.0 required")
58-
def test_inference_workflow_mx(elem_dtype, bias: bool, compile: bool):
58+
def test_inference_workflow_mx(elem_dtype, bias: bool, compile: bool, emulate: bool):
5959
"""
6060
Smoke test for inference compile
6161
"""
@@ -64,17 +64,21 @@ def test_inference_workflow_mx(elem_dtype, bias: bool, compile: bool):
6464
if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
6565
if not is_sm_at_least_89():
6666
pytest.skip("CUDA capability >= 8.9 required for float8 in triton")
67+
elif not is_sm_at_least_100() and not emulate:
68+
pytest.skip("CUDA capability >= 10.0 required for mxfp8 gemm")
6769
elif elem_dtype == torch.float4_e2m1fn_x2:
68-
if not is_sm_at_least_100():
69-
pytest.skip("CUDA capability >= 10.0 required for float4 gemm")
70+
if not is_sm_at_least_100() and not emulate:
71+
pytest.skip("CUDA capability >= 10.0 required for mxfp4 gemm")
7072

7173
m = nn.Linear(32, 128, bias=bias, dtype=torch.bfloat16, device="cuda")
7274
m_mx = copy.deepcopy(m)
73-
kernel_choice = (
74-
MXGemmKernelChoice.CUTLASS
75-
if elem_dtype == torch.float4_e2m1fn_x2
76-
else MXGemmKernelChoice.CUBLAS
77-
)
75+
76+
if emulate:
77+
kernel_choice = MXGemmKernelChoice.EMULATED
78+
elif elem_dtype == torch.float4_e2m1fn_x2:
79+
kernel_choice = MXGemmKernelChoice.CUTLASS
80+
else:
81+
kernel_choice = MXGemmKernelChoice.CUBLAS
7882
config = MXFPInferenceConfig(
7983
activation_dtype=elem_dtype,
8084
weight_dtype=elem_dtype,

torchao/prototype/mx_formats/inference_workflow.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,10 +96,6 @@ def _linear_extra_repr(self):
9696
def _mx_inference_linear_transform(
9797
module: torch.nn.Module, config: MXFPInferenceConfig
9898
):
99-
# TODO Sm120 has slightly more restrictive reqs
100-
# TODO handle AMD
101-
assert is_sm_at_least_100(), "MXFP is only supported on sm100 machiens for now"
102-
10399
weight = module.weight
104100

105101
assert weight.dtype == torch.bfloat16, (

0 commit comments

Comments
 (0)