From c3f0e65e36131dcc03576a53079a7f919179767c Mon Sep 17 00:00:00 2001 From: vasiliy Date: Wed, 3 Sep 2025 12:15:07 -0700 Subject: [PATCH] better check for mxfp8 cuda kernel presence Summary: Short term fix for https://github.com/pytorch/ao/issues/2932. If torchao was build without CUDA 10.0 (such as in our CI), ensures that: a. only callsites which actually use the mxfp8 dim1 kernel see the error message. Using NVFP4 no longer hits this error. b. make the error message point to github issue for more info on the workaround (for now, build from souce). Test Plan: 1. hardcode mxfp8 kernel from being built: https://github.com/pytorch/ao/blob/85557135c93d3429320a4a360c0ee9cb49f84a00/setup.py#L641 2. build torchao from source, verify `torchao/prototype` does not have any `.so` files 3. run nvfp4 tests, verify they now pass: `pytest test/prototype/mx_formats/test_nvfp4_tensor.py -s -x` 4. run mxfp8 linear tests, verify the new error message is displayed for dim1 kernel tests: `pytest test/prototype/mx_formats/test_mx_linear.py -s -x -k test_linear_eager_vs_hp` 5. undo the change in (1), rebuild torchao, verify all mx tests pass: `pytest test/prototype/mx_formats/ -s -x` Reviewers: Subscribers: Tasks: Tags: --- torchao/prototype/mx_formats/kernels.py | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/torchao/prototype/mx_formats/kernels.py b/torchao/prototype/mx_formats/kernels.py index 5e054aaf35..5811dd9d21 100644 --- a/torchao/prototype/mx_formats/kernels.py +++ b/torchao/prototype/mx_formats/kernels.py @@ -4,6 +4,7 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +import logging from typing import Optional, Tuple import numpy as np @@ -35,6 +36,8 @@ F32_EXP_BIAS, ) +logger = logging.getLogger(__name__) + def get_bits(x: torch.Tensor) -> str: bits_per_byte = 8 @@ -1476,10 +1479,20 @@ def triton_quantize_nvfp4( raise AssertionError("needs torch version 2.8+ and triton") -# MXFP8 CUDA kernel is only built on SM100+ +mxfp8_cuda_extension_available = False if is_sm_at_least_100(): - from torchao.prototype import mxfp8_cuda - + try: + # MXFP8 CUDA kernel is only built on SM100+. Furthermore, + # currently our CI runners are not SM100+, so the user needs to build + # from source. + # TODO(#2932): improve this + from torchao.prototype import mxfp8_cuda + + mxfp8_cuda_extension_available = True + except ImportError: + logging.debug("Skipping import of torchao.prototype.mxfp8_cuda") + +if mxfp8_cuda_extension_available: # TODO: Make `scaling_mode` a choice (enum-like) rather than arbitrary string. # Currently we have to use an arbitrary string because custom ops don't support enum # params. @@ -1599,4 +1612,6 @@ def mxfp8_quantize_cuda( colwise: bool = True, scaling_mode: str = "floor", ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - raise NotImplementedError("needs torch version 2.8+ and sm100") + raise NotImplementedError( + "`mxfp8_quantize_cuda` needs (1) torch 2.8+ and (2) torchao built from source on a machine with CUDA capability 10.0+. Please see https://github.com/pytorch/ao/issues/2932 for more details." + )