diff --git a/torchao/prototype/mx_formats/kernels.py b/torchao/prototype/mx_formats/kernels.py index 5e054aaf35..5811dd9d21 100644 --- a/torchao/prototype/mx_formats/kernels.py +++ b/torchao/prototype/mx_formats/kernels.py @@ -4,6 +4,7 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +import logging from typing import Optional, Tuple import numpy as np @@ -35,6 +36,8 @@ F32_EXP_BIAS, ) +logger = logging.getLogger(__name__) + def get_bits(x: torch.Tensor) -> str: bits_per_byte = 8 @@ -1476,10 +1479,20 @@ def triton_quantize_nvfp4( raise AssertionError("needs torch version 2.8+ and triton") -# MXFP8 CUDA kernel is only built on SM100+ +mxfp8_cuda_extension_available = False if is_sm_at_least_100(): - from torchao.prototype import mxfp8_cuda - + try: + # MXFP8 CUDA kernel is only built on SM100+. Furthermore, + # currently our CI runners are not SM100+, so the user needs to build + # from source. + # TODO(#2932): improve this + from torchao.prototype import mxfp8_cuda + + mxfp8_cuda_extension_available = True + except ImportError: + logging.debug("Skipping import of torchao.prototype.mxfp8_cuda") + +if mxfp8_cuda_extension_available: # TODO: Make `scaling_mode` a choice (enum-like) rather than arbitrary string. # Currently we have to use an arbitrary string because custom ops don't support enum # params. @@ -1599,4 +1612,6 @@ def mxfp8_quantize_cuda( colwise: bool = True, scaling_mode: str = "floor", ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - raise NotImplementedError("needs torch version 2.8+ and sm100") + raise NotImplementedError( + "`mxfp8_quantize_cuda` needs (1) torch 2.8+ and (2) torchao built from source on a machine with CUDA capability 10.0+. Please see https://github.com/pytorch/ao/issues/2932 for more details." + )