From 0a18f345012cf9af2236f0e67442fb46f3474a95 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Mon, 24 Nov 2025 09:34:47 -0800 Subject: [PATCH 1/4] Update [ghstack-poisoned] --- test/prototype/mx_formats/test_mx_tensor.py | 4 +-- .../prototype/mx_formats/test_nvfp4_tensor.py | 4 +-- torchao/prototype/mx_formats/mx_tensor.py | 28 +++++++-------- torchao/prototype/mx_formats/nvfp4_tensor.py | 34 +++++++++---------- torchao/prototype/mx_formats/utils.py | 22 ++++++------ 5 files changed, 46 insertions(+), 46 deletions(-) diff --git a/test/prototype/mx_formats/test_mx_tensor.py b/test/prototype/mx_formats/test_mx_tensor.py index fc38aac8d8..1d46bff0d3 100644 --- a/test/prototype/mx_formats/test_mx_tensor.py +++ b/test/prototype/mx_formats/test_mx_tensor.py @@ -519,7 +519,7 @@ def test_to_mx_from_mx_compile_numerics(elem_dtype, hp_dtype, all_zeros): x_mx.qdata, x_mx.scale, x_mx._elem_dtype, - x_mx._block_size, + x_mx.block_size, hp_dtype, # noqa: E501 pack_fp6, ) @@ -527,7 +527,7 @@ def test_to_mx_from_mx_compile_numerics(elem_dtype, hp_dtype, all_zeros): x_mx_c.qdata, x_mx_c.scale, x_mx_c._elem_dtype, - x_mx_c._block_size, + x_mx_c.block_size, hp_dtype, pack_fp6, ) diff --git a/test/prototype/mx_formats/test_nvfp4_tensor.py b/test/prototype/mx_formats/test_nvfp4_tensor.py index 5889019af3..e098edb745 100644 --- a/test/prototype/mx_formats/test_nvfp4_tensor.py +++ b/test/prototype/mx_formats/test_nvfp4_tensor.py @@ -71,7 +71,7 @@ def assert_sqnr_gt_threshold(orig, new, threshold): reconstructed_amax = x_nvfp4.get_hp_scales().view(shape[0], -1, 1) * F4_E2M1_MAX max_abs = torch.amax( - torch.abs(x.reshape(shape[0], -1, x_nvfp4._block_size)), dim=-1 + torch.abs(x.reshape(shape[0], -1, x_nvfp4.block_size)), dim=-1 ).unsqueeze(-1) assert_sqnr_gt_threshold(max_abs, reconstructed_amax, 30.0) @@ -526,7 +526,7 @@ def test_nvfp4_to_copy(): assert y.per_tensor_scale is None assert x.act_per_tensor_scale is None assert y.act_per_tensor_scale is None - assert x._block_size == y._block_size + assert x.block_size == y.block_size assert x.use_triton_kernel == y.use_triton_kernel assert x.act_quant_kwargs == y.act_quant_kwargs assert x.dtype == torch.float32 diff --git a/torchao/prototype/mx_formats/mx_tensor.py b/torchao/prototype/mx_formats/mx_tensor.py index 67aa9d767a..9119ae4b24 100644 --- a/torchao/prototype/mx_formats/mx_tensor.py +++ b/torchao/prototype/mx_formats/mx_tensor.py @@ -490,7 +490,7 @@ class MXTensor(TorchAOBaseTensor): tensor_data_names = ["qdata", "scale"] tensor_attribute_names = [ "_elem_dtype", - "_block_size", + "block_size", "_orig_dtype", "_gemm_kernel_choice", "_pack_fp6", @@ -547,7 +547,7 @@ def __new__( self.qdata = qdata self.scale = scale_e8m0_bits self._elem_dtype = elem_dtype - self._block_size = block_size + self.block_size = block_size self._orig_dtype = orig_dtype self._gemm_kernel_choice = gemm_kernel_choice self._pack_fp6 = pack_fp6 @@ -560,7 +560,7 @@ def __repr__(self): return f"MXTensor: elem_dtype: {self._elem_dtype}, s_e8m0: {self.scale}, d: {self.qdata}, act_quant_kwargs: {self.act_quant_kwargs}, _is_swizzled_scales={self._is_swizzled_scales}" # noqa: E501 def _quantization_type(self): - return f"{self._elem_dtype=}, {self._block_size=}, {self._orig_dtype=}, {self._gemm_kernel_choice=}, {self.act_quant_kwargs=}" + return f"{self._elem_dtype=}, {self.block_size=}, {self._orig_dtype=}, {self._gemm_kernel_choice=}, {self.act_quant_kwargs=}" def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor: if output_dtype is None: @@ -575,9 +575,9 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor else: leading_dims, M, K = self.shape[:-2], self.shape[-2], self.shape[-1] scale = from_blocked( - scale, math.prod(leading_dims) * M, K // self._block_size + scale, math.prod(leading_dims) * M, K // self.block_size ) - scale = scale.view(*leading_dims, M, K // self._block_size) + scale = scale.view(*leading_dims, M, K // self.block_size) if is_transposed: scale = scale.transpose(-2, -1) @@ -585,7 +585,7 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor self.qdata, scale, self._elem_dtype, - self._block_size, + self.block_size, output_dtype, self._pack_fp6, ) @@ -699,19 +699,19 @@ def _addmm_mx_dispatch( M, K, N = a.shape[0], a.shape[1], b.shape[1] assert a.qdata.is_contiguous() assert b.qdata.t().is_contiguous() - assert a._block_size == 32, f"Invalid block size {a._block_size}" - assert b._block_size == 32, f"Invalid block size {b._block_size}" + assert a.block_size == 32, f"Invalid block size {a.block_size}" + assert b.block_size == 32, f"Invalid block size {b.block_size}" if a._is_swizzled_scales: a_scale_block = a.scale else: - a_scale = a.scale.view(M, K // a._block_size) + a_scale = a.scale.view(M, K // a.block_size) a_scale_block = to_blocked(a_scale) if b._is_swizzled_scales: b_scale_block = b.scale.t() else: - b_scale = b.scale.t().view(N, K // b._block_size) + b_scale = b.scale.t().view(N, K // b.block_size) b_scale_block = to_blocked(b_scale) if a._elem_dtype == torch.float8_e4m3fn: @@ -804,7 +804,7 @@ def mx_t(func, types, args, kwargs): old.qdata.t(), old.scale.t(), old._elem_dtype, - old._block_size, + old.block_size, old._orig_dtype, old._gemm_kernel_choice, old._pack_fp6, @@ -849,7 +849,7 @@ def mx_view_op(func, types, args, kwargs): new_data, args[0].scale, args[0]._elem_dtype, - args[0]._block_size, + args[0].block_size, args[0]._orig_dtype, args[0]._gemm_kernel_choice, args[0]._pack_fp6, @@ -875,7 +875,7 @@ def mx_slice(func, types, args, kwargs): sliced_data, sliced_scale, x._elem_dtype, - x._block_size, + x.block_size, x._orig_dtype, x._gemm_kernel_choice, x._pack_fp6, @@ -910,7 +910,7 @@ def mx_select(func, types, args, kwargs): old_mx_tensor.qdata[index], old_mx_tensor.scale[index], old_mx_tensor._elem_dtype, - old_mx_tensor._block_size, + old_mx_tensor.block_size, old_mx_tensor._orig_dtype, old_mx_tensor._gemm_kernel_choice, old_mx_tensor._pack_fp6, diff --git a/torchao/prototype/mx_formats/nvfp4_tensor.py b/torchao/prototype/mx_formats/nvfp4_tensor.py index 18f05290e5..26e48216ee 100644 --- a/torchao/prototype/mx_formats/nvfp4_tensor.py +++ b/torchao/prototype/mx_formats/nvfp4_tensor.py @@ -78,7 +78,7 @@ class NVFP4Tensor(TorchAOBaseTensor): scale: Blockwise scales in float8_e4m3fn format (may be swizzled) per_tensor_scale: Optional global per-tensor scale in float32 format act_per_tensor_scale: Optional global per-tensor scale in float32 format, for activation - _block_size (int): Block size for quantization (fixed at 16) + block_size (int): Block size for quantization (fixed at 16) _orig_dtype (torch.dtype): Original tensor dtype before quantization _is_swizzled_scales (bool): Whether scales are stored in swizzled (blocked) format use_triton_kernel (bool): Whether to use triton kernels @@ -86,7 +86,7 @@ class NVFP4Tensor(TorchAOBaseTensor): tensor_data_names = ["qdata", "scale"] tensor_attribute_names = [ - "_block_size", + "block_size", "_orig_dtype", ] optional_tensor_data_names = ["per_tensor_scale", "act_per_tensor_scale"] @@ -126,7 +126,7 @@ def __new__( self.qdata = qdata self.scale = scale - self._block_size = block_size + self.block_size = block_size self._orig_dtype = orig_dtype self.per_tensor_scale = per_tensor_scale self.act_per_tensor_scale = act_per_tensor_scale @@ -238,10 +238,10 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor data_f32 = f4_unpacked_to_f32(data_unpacked) data_f32 = data_f32.view( - *leading_dims, M, K // self._block_size, self._block_size + *leading_dims, M, K // self.block_size, self.block_size ) scale_e4m3_reshaped = self.get_hp_scales().view( - *leading_dims, M, K // self._block_size, 1 + *leading_dims, M, K // self.block_size, 1 ) data_scaled = data_f32 * scale_e4m3_reshaped.to(torch.float32) result = data_scaled.view(*leading_dims, M, K).to(output_dtype) @@ -267,7 +267,7 @@ def get_hp_scales(self) -> torch.Tensor: if self._is_swizzled_scales: scale_e4m3 = from_blocked( - scale_e4m3, math.prod(leading_dims) * M, K // self._block_size + scale_e4m3, math.prod(leading_dims) * M, K // self.block_size ) return ( @@ -297,7 +297,7 @@ def _same_metadata(cls, self: "NVFP4Tensor", src: "NVFP4Tensor") -> bool: return ( isinstance(self, NVFP4Tensor) and isinstance(src, NVFP4Tensor) - and self._block_size == src._block_size + and self.block_size == src.block_size and self._orig_dtype == src._orig_dtype and self._is_swizzled_scales == src._is_swizzled_scales and self.scale.shape == src.scale.shape @@ -341,7 +341,7 @@ def nvfp4_to_copy(func, types, args, kwargs): res = NVFP4Tensor( tensor.qdata, tensor.scale, - tensor._block_size, + tensor.block_size, dtype, tensor.per_tensor_scale, tensor.act_per_tensor_scale, @@ -399,7 +399,7 @@ def nvfp4_slice(func, types, args, kwargs): result = NVFP4Tensor( sliced_data, sliced_scale, - x._block_size, + x.block_size, x._orig_dtype, x.per_tensor_scale, x.act_per_tensor_scale, @@ -418,7 +418,7 @@ def nvfp4_t(func, types, args, kwargs): new = NVFP4Tensor( old.qdata.t(), old.scale.t(), - old._block_size, + old.block_size, old._orig_dtype, old.per_tensor_scale, old.act_per_tensor_scale, @@ -440,7 +440,7 @@ def nvfp4_transpose(func, types, args, kwargs): new = NVFP4Tensor( new_qdata, new_scale, - old._block_size, + old.block_size, old._orig_dtype, old.per_tensor_scale, old.act_per_tensor_scale, @@ -460,7 +460,7 @@ def nvfp4_view_op(func, types, args, kwargs): return NVFP4Tensor( new_data, args[0].scale, - args[0]._block_size, + args[0].block_size, args[0]._orig_dtype, args[0].per_tensor_scale, args[0].act_per_tensor_scale, @@ -478,7 +478,7 @@ def nvfp4_select(func, types, args, kwargs): new = old.__class__( old.qdata[index], old.scale[index], - old._block_size, + old.block_size, old._orig_dtype, old.per_tensor_scale, old.act_per_tensor_scale, @@ -500,8 +500,8 @@ def _addmm_nvfp4_dispatch( assert a.scale.is_contiguous() assert b.qdata.t().is_contiguous() assert b.scale.t().is_contiguous() - assert a._block_size == 16, f"NVFP4 requires block_size=16, got {a._block_size}" - assert b._block_size == 16, f"NVFP4 requires block_size=16, got {b._block_size}" + assert a.block_size == 16, f"NVFP4 requires block_size=16, got {a.block_size}" + assert b.block_size == 16, f"NVFP4 requires block_size=16, got {b.block_size}" assert len(a.shape) == 2 and len(b.shape) == 2 M, K = a.shape[0], a.shape[1] @@ -511,13 +511,13 @@ def _addmm_nvfp4_dispatch( if a._is_swizzled_scales: a_scale_blocked = a.scale # Already swizzled else: - a_scale = a.scale.view(M, K // a._block_size) + a_scale = a.scale.view(M, K // a.block_size) a_scale_blocked = to_blocked(a_scale) if b._is_swizzled_scales: b_scale_blocked = b.scale.t() # Already swizzled else: - b_scale = b.scale.t().view(N, K // b._block_size) + b_scale = b.scale.t().view(N, K // b.block_size) b_scale_blocked = to_blocked(b_scale) # Merge double quant scales into 1 scale for Scale_In^D diff --git a/torchao/prototype/mx_formats/utils.py b/torchao/prototype/mx_formats/utils.py index 78bfd48ab7..72d8a47b81 100644 --- a/torchao/prototype/mx_formats/utils.py +++ b/torchao/prototype/mx_formats/utils.py @@ -232,7 +232,7 @@ def _swizzle_aware_slice( if x._is_swizzled_scales: scale_rows = M - scale_cols = K // x._block_size + scale_cols = K // x.block_size n_row_blocks = ceil_div(scale_rows, 128) n_col_blocks = ceil_div(scale_cols, 4) elements_per_block = 32 * 16 # 512 elements @@ -351,7 +351,7 @@ def _swizzle_aware_slice( ) else: - scale_shaped = x.scale.view(M, K // x._block_size) + scale_shaped = x.scale.view(M, K // x.block_size) if dim == 0: sliced_scale = aten.slice.Tensor(scale_shaped, dim, start, end, step) @@ -359,16 +359,16 @@ def _swizzle_aware_slice( elif dim == 1: if start is not None: - assert start % x._block_size == 0, ( - f"Start index {start} must be a multiple of block_size {x._block_size}" + assert start % x.block_size == 0, ( + f"Start index {start} must be a multiple of block_size {x.block_size}" ) assert start % 2 == 0, ( f"Start index {start} must be even for FP4 packing" ) if end is not None and end != sys.maxsize: - assert end % x._block_size == 0, ( - f"End index {end} must be a multiple of block_size {x._block_size}" + assert end % x.block_size == 0, ( + f"End index {end} must be a multiple of block_size {x.block_size}" ) assert end % 2 == 0, f"End index {end} must be even for FP4 packing" @@ -382,8 +382,8 @@ def _swizzle_aware_slice( x.qdata, dim, packed_start, packed_end, step ) - start_block = 0 if start is None else start // x._block_size - end_block = None if end is None else end // x._block_size + start_block = 0 if start is None else start // x.block_size + end_block = None if end is None else end // x.block_size sliced_scale = aten.slice.Tensor( scale_shaped, 1, start_block, end_block, step ) @@ -398,12 +398,12 @@ def _swizzle_aware_slice( # multiply by 2 to convert from bytes to num_elements sliced_K = sliced_data.shape[1] * 2 if x._is_swizzled_scales: - if x._block_size == 16: + if x.block_size == 16: scale_M, scale_K = hp_data_dims_to_swizzled_scale_dims_nvfp4( sliced_M, sliced_K ) else: - assert x._block_size == 32, f"unexpected {x._block_size=}" + assert x.block_size == 32, f"unexpected {x.block_size=}" scale_M, scale_K = hp_data_dims_to_swizzled_scale_dims_mx( sliced_M, sliced_K ) @@ -413,7 +413,7 @@ def _swizzle_aware_slice( # mx: a 1x32 unpacked or 1x16 packed qdata tile corresponds to 1 # scale element scale_M = sliced_M - scale_K = sliced_K // x._block_size + scale_K = sliced_K // x.block_size sliced_scale = sliced_scale.view(scale_M, scale_K) return sliced_data, sliced_scale From 92236b71aaf87d6bb778cf37bbc671d1d9d8d042 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Mon, 24 Nov 2025 09:34:50 -0800 Subject: [PATCH 2/4] Update [ghstack-poisoned] --- test/prototype/mx_formats/test_kernels.py | 33 - test/prototype/mx_formats/test_mx_tensor.py | 30 +- .../mx_formats/inference_workflow.py | 2 - torchao/prototype/mx_formats/kernels.py | 665 ------------------ torchao/prototype/mx_formats/mx_tensor.py | 85 +-- torchao/prototype/mx_formats/utils.py | 2 - 6 files changed, 7 insertions(+), 810 deletions(-) diff --git a/test/prototype/mx_formats/test_kernels.py b/test/prototype/mx_formats/test_kernels.py index 14de3610b3..1729901933 100644 --- a/test/prototype/mx_formats/test_kernels.py +++ b/test/prototype/mx_formats/test_kernels.py @@ -34,9 +34,6 @@ f32_to_f6_e3m2_unpacked, get_bits, pack_uint4, - pack_uint6, - triton_f6_e2m3_to_bf16, - triton_f6_e3m2_to_bf16, triton_mxfp8_dequant_dim0, triton_to_mxfp8_dim0, triton_to_mxfp8_dim1, @@ -423,36 +420,6 @@ def test_fp6_e3m2_rounding(f32_val, f6_e3m2_enc, device): assert f6_e3m2_unpacked.item() == (f6_e3m2_enc | 0b100000) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.skipif(not has_triton(), reason="unsupported without triton") -def test_fp6_e2m3_pack_unpack(): - orig_vals = torch.Tensor([[0.0, 0.5, 7.5, -0.0], [-0.875, 1.0, -6.0, 0.125]]).to( - "cuda" - ) - orig_vals_f6_unpacked = f32_to_f6_e2m3_unpacked(orig_vals) - orig_vals_f6_packed = pack_uint6(orig_vals_f6_unpacked) - assert orig_vals_f6_packed.numel() == (3 * orig_vals.numel() // 4) - orig_vals_f6_packed_unpacked = triton_f6_e2m3_to_bf16(orig_vals_f6_packed).to( - torch.float32 - ) - assert torch.all(orig_vals_f6_packed_unpacked == orig_vals) - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.skipif(not has_triton(), reason="unsupported without triton") -def test_fp6_e3m2_pack_unpack(): - orig_vals = torch.Tensor([[0.0, 5.0, 28.0, -0.0], [-0.25, 0.1875, 0.0625, 8.0]]).to( - "cuda" - ) - orig_vals_f6_unpacked = f32_to_f6_e3m2_unpacked(orig_vals) - orig_vals_f6_packed = pack_uint6(orig_vals_f6_unpacked) - assert orig_vals_f6_packed.numel() == (3 * orig_vals.numel() // 4) - orig_vals_f6_packed_unpacked = triton_f6_e3m2_to_bf16(orig_vals_f6_packed).to( - torch.float32 - ) - assert torch.all(orig_vals_f6_packed_unpacked == orig_vals) - - def triton_to_mxfp8_dim0_reference( x_hp: torch.Tensor, block_size ) -> tuple[torch.Tensor, torch.Tensor]: diff --git a/test/prototype/mx_formats/test_mx_tensor.py b/test/prototype/mx_formats/test_mx_tensor.py index 1d46bff0d3..66f8998ea8 100644 --- a/test/prototype/mx_formats/test_mx_tensor.py +++ b/test/prototype/mx_formats/test_mx_tensor.py @@ -18,7 +18,7 @@ DTYPE_FP6_E3M2, SUPPORTED_ELEM_DTYPES, ) -from torchao.prototype.mx_formats.kernels import pack_uint4, pack_uint6 +from torchao.prototype.mx_formats.kernels import pack_uint4 from torchao.prototype.mx_formats.mx_tensor import ( MXTensor, ScaleCalculationMode, @@ -343,14 +343,10 @@ def test_exponent_nan_in(elem_dtype): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) -@pytest.mark.parametrize("pack_fp6", [False, True]) -def test_exponent_nan_out(elem_dtype, pack_fp6): +def test_exponent_nan_out(elem_dtype): """ If block exponent value is NaN, the MX tensor block value is NaN """ - if pack_fp6 and elem_dtype not in (DTYPE_FP6_E2M3, DTYPE_FP6_E3M2): - pytest.skip("invalid configuration") - scale_e8m0 = torch.tensor( [float("nan"), 1.0], dtype=torch.float8_e8m0fnu, device="cuda" ) @@ -365,9 +361,6 @@ def test_exponent_nan_out(elem_dtype, pack_fp6): data_bits = torch.tensor( [0, 1, 2, 3, 4, 5, 6, 7], dtype=torch.uint8, device="cuda" ) # noqa: E501 - if pack_fp6: - data_bits = data_bits.reshape(-1, block_size) - data_bits = pack_uint6(data_bits) elif elem_dtype == torch.float4_e2m1fn_x2: data_bits = torch.tensor( [0, 1, 2, 3, 4, 5, 6, 7], dtype=torch.uint8, device="cuda" @@ -383,7 +376,6 @@ def test_exponent_nan_out(elem_dtype, pack_fp6): block_size, torch.float, MXGemmKernelChoice.EMULATED, - pack_fp6, None, False, ) @@ -466,21 +458,6 @@ def test_clone(): ) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.parametrize("elem_dtype", [DTYPE_FP6_E2M3, DTYPE_FP6_E3M2]) -@pytest.mark.parametrize("pack_fp6", [False, True]) -def test_fp6_packing(elem_dtype, pack_fp6): - x = torch.randn(1, 2, 4, device="cuda") - block_size = 4 - x_mx = MXTensor.to_mx(x, elem_dtype, block_size, pack_fp6=pack_fp6) - if pack_fp6: - expected_packed_shape = torch.Size([*x.shape[:-1], 3 * x.shape[-1] // 4]) - else: - expected_packed_shape = x.shape - - assert x_mx.qdata.shape == expected_packed_shape - - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) @pytest.mark.parametrize("hp_dtype", [torch.float32, torch.bfloat16]) @@ -514,14 +491,12 @@ def test_to_mx_from_mx_compile_numerics(elem_dtype, hp_dtype, all_zeros): to_dtype_c = torch.compile(to_dtype, fullgraph=True) - pack_fp6 = False x_mx_dq = to_dtype( x_mx.qdata, x_mx.scale, x_mx._elem_dtype, x_mx.block_size, hp_dtype, # noqa: E501 - pack_fp6, ) x_mx_c_dq = to_dtype_c( x_mx_c.qdata, @@ -529,7 +504,6 @@ def test_to_mx_from_mx_compile_numerics(elem_dtype, hp_dtype, all_zeros): x_mx_c._elem_dtype, x_mx_c.block_size, hp_dtype, - pack_fp6, ) torch.testing.assert_close(x_mx_dq, x_mx_c_dq, atol=0, rtol=0) diff --git a/torchao/prototype/mx_formats/inference_workflow.py b/torchao/prototype/mx_formats/inference_workflow.py index 8725c33b44..2ff4eedf5f 100644 --- a/torchao/prototype/mx_formats/inference_workflow.py +++ b/torchao/prototype/mx_formats/inference_workflow.py @@ -110,7 +110,6 @@ def _mx_inference_linear_transform( elem_dtype=config.activation_dtype, block_size=config.block_size, gemm_kernel_choice=config.gemm_kernel_choice, - pack_fp6=False, is_swizzled_scales=True, ) @@ -120,7 +119,6 @@ def _mx_inference_linear_transform( config.weight_dtype, block_size=config.block_size, gemm_kernel_choice=config.gemm_kernel_choice, - pack_fp6=False, # TODO act_quant_kwargs=act_quant_kwargs, is_swizzled_scales=True, ) diff --git a/torchao/prototype/mx_formats/kernels.py b/torchao/prototype/mx_formats/kernels.py index 265d890847..b4cd192244 100644 --- a/torchao/prototype/mx_formats/kernels.py +++ b/torchao/prototype/mx_formats/kernels.py @@ -22,20 +22,6 @@ torch_version_at_least, ) -# TODO(future): if needed, make the below work on previous PyTorch versions, -# just need to hunt down the previous location of `libdevice`. An assert -# at the callsite prevents usage of this on unsupported versions. -if has_triton(): - from torch._inductor.runtime.triton_helpers import libdevice - -from torchao.prototype.mx_formats.constants import ( - E8M0_EXPONENT_BIAS, - E8M0_EXPONENT_NAN_VAL, - F6_E2M3_EXP_BIAS, - F6_E3M2_EXP_BIAS, - F32_EXP_BIAS, -) - logger = logging.getLogger(__name__) @@ -119,594 +105,6 @@ def f6_e3m2_unpacked_to_f32(x: torch.Tensor): return _floatx_unpacked_to_f32(x, EBITS_F6_E3M2, MBITS_F6_E3M2) -if has_triton(): - import triton - import triton.language as tl - - @triton.jit - def _fp4_packed_to_bf16( - x_packed, - sign_mask_f4, - mantissa_mask_f4, - mbits_f4_e2m1, - ebits_f4_e2m1, - f4_e2m1_exp_bias, - mbits_f32, - ebits_f32, - f32_exp_bias, - zero_bits_f32, - zero_point_five_bits_f32, - ): - """ - Input: a tensor of packed fp4 values - Output: a tensor of bfloat16 values - """ - - # high-bits: original location 0:3 - # low-bits: original location 4:7 - x_high_bits = x_packed >> 4 - x_low_bits = x_packed & 0xF - x = tl.interleave(x_low_bits, x_high_bits) - - # cast logic below - # output = x_unpacked.to(tl.float32) - - # save the sign - sign_f4 = x & sign_mask_f4 - - # set everything to positive, will add sign back at the end - x_pos = x ^ sign_f4 - - # Special case zero - zero_mask = x_pos == 0 - - # There is only one denormal value in fp4: s001, which is 0.5 in f32 - # Special case it. - # TODO(later): will it be faster to repeat this for all 8 positive - # values instead of the bit manipulations? - denormal_mask = x_pos == 1 - - # calculate the new exponent and shift it to bits 2:9 of the result - exp_biased_f4 = x_pos >> mbits_f4_e2m1 - exp_biased_f32 = exp_biased_f4 - f4_e2m1_exp_bias + f32_exp_bias - exp_biased_f32 = exp_biased_f32.to(tl.int32) << mbits_f32 - - # shift the mantissa to bits 10:32 of the result - mantissa_f4 = x_pos & mantissa_mask_f4 - mantissa_f32 = mantissa_f4.to(tl.int32) << (mbits_f32 - mbits_f4_e2m1) - output = mantissa_f32 - - # combine the pieces - result = exp_biased_f32 | mantissa_f32 - # result[zero_mask] = ZERO_BITS_F32 - result = tl.where(zero_mask, zero_bits_f32, result) - # result[denormal_mask] = ZERO_POINT_FIVE_BITS_F32 - result = tl.where(denormal_mask, zero_point_five_bits_f32, result) - - # add sign back - sign_f32 = sign_f4.to(tl.int32) << ( - mbits_f32 - mbits_f4_e2m1 + ebits_f32 - ebits_f4_e2m1 - ) - result = result | sign_f32 - - # The bit shifting above is for float32, so for now we - # bitcast to float32 and then regular cast to bfloat16 - # TODO(later): it should be pretty easy to cast directly to bf16, just - # need to adjust the mbits/ebits/special values. Perf impact is likely - # to be small as we would not be chaning memory access patterns. - output = result.to(tl.float32, bitcast=True) - output = output.to(tl.bfloat16) - return output - - @triton.jit - def _fp6_packed_to_bf16( - packed_4bits_a, - packed_4bits_b, - packed_2bits, - sign_mask_f6, - mbits_f6, - f6_exp_bias, - mbits_f32, - f32_exp_bias, - ): - """ - Input: a tensor of packed fp6 values - Output: a tensor of bfloat16 values - """ - - # L/R shift and combine back into uint8 with first 2 bits empty (i.e. unpacked) - x_0 = ((packed_4bits_a >> 2) & 0x3C) | ((packed_2bits & 0xC0) >> 6) - x_1 = ((packed_4bits_a << 2) & 0x3C) | ((packed_2bits & 0x30) >> 4) - x_2 = ((packed_4bits_b >> 2) & 0x3C) | ((packed_2bits & 0xC) >> 2) - x_3 = ((packed_4bits_b << 2) & 0x3C) | (packed_2bits & 0x3) - - # repeat_interleave not supported yet, see https://github.com/triton-lang/triton/issues/1426 - # instead we can interleave(interleave(4*i, 4*i+2), interleave(4*i+1, 4*i+3)) - # TODO: is there a more performant way? - # We could stack all 4, then transpose and ravel and do it that way? - x_02 = tl.interleave(x_0, x_2) # [x_0_0, x_2_0, x_0_1, x_2_1, ...] - x_13 = tl.interleave(x_1, x_3) # [x_1_0, x_3_0, x_1_1, x_3_1, ...] - x = tl.interleave(x_02, x_13) # [x_0_0, x_1_0, x_2_0, x_3_0, x_0_1, ...] - - # save the sign - sign_f6 = x & sign_mask_f6 - - # set everything to positive, will add sign back at the end - x_pos = x ^ sign_f6 - - # shift the exponent and mantissa - result = x_pos.to(tl.int32) << (mbits_f32 - mbits_f6) - - # add sign back - # left shift is always 26 regardless of fp6 variant - sign_f32 = sign_f6.to(tl.int32) << 26 - result = result | sign_f32 - - # The bit shifting above is for float32, so for now we - # bitcast to float32 and then regular cast to bfloat16 - # TODO(later): it should be pretty easy to cast directly to bf16, just - # need to adjust the mbits/ebits/special values. Perf impact is likely - # to be small as we would not be changing memory access patterns. - output = result.to(tl.float32, bitcast=True) - - # Scale the fp32 exponent afterwards, handles the denorms correctly - output *= 2.0 ** (f32_exp_bias - f6_exp_bias) - - output = output.to(tl.bfloat16) - return output - - @triton.autotune( - configs=[ - triton.Config({"BLOCK_SIZE_IN": 2}, num_warps=1), - triton.Config({"BLOCK_SIZE_IN": 4}, num_warps=1), - triton.Config({"BLOCK_SIZE_IN": 8}, num_warps=1), - triton.Config({"BLOCK_SIZE_IN": 16}, num_warps=1), - ], - key=["n_mx_blocks"], - ) - @triton.jit - def triton_f6_to_bf16_kernel( - x_ptr, - output_ptr, - n_mx_blocks, - mx_block_size: tl.constexpr, - packed_mx_block_size: tl.constexpr, - sign_mask_f6: tl.constexpr, - mbits_f6: tl.constexpr, - f6_exp_bias: tl.constexpr, - mbits_f32: tl.constexpr, - f32_exp_bias: tl.constexpr, - BLOCK_SIZE_IN: tl.constexpr, - ): - pid = tl.program_id(axis=0) - block_start = pid * BLOCK_SIZE_IN - - offsets_rows = block_start + tl.arange(0, BLOCK_SIZE_IN) - offsets_cols = tl.arange(0, packed_mx_block_size // 3) - mask_in = (offsets_rows[:, None] < n_mx_blocks) & ( - offsets_cols[None, :] < packed_mx_block_size // 3 - ) - offsets_in = ( - offsets_rows[:, None] * packed_mx_block_size + offsets_cols[None, :] - ) - - # packed 4 x fp6 into 3 x uint8 - packed_4bits_a = tl.load(x_ptr + offsets_in, mask=mask_in, other=0) - packed_4bits_b = tl.load( - x_ptr + offsets_in + (packed_mx_block_size // 3), mask=mask_in, other=0 - ) - packed_2bits = tl.load( - x_ptr + offsets_in + (2 * packed_mx_block_size // 3), mask=mask_in, other=0 - ) - - output = _fp6_packed_to_bf16( - packed_4bits_a, - packed_4bits_b, - packed_2bits, - sign_mask_f6, - mbits_f6, - f6_exp_bias, - mbits_f32, - f32_exp_bias, - ) - - # set up output offsets - offsets_rows_out = block_start + tl.arange(0, BLOCK_SIZE_IN) - offsets_cols_out = tl.arange(0, mx_block_size) - offsets_out = ( - offsets_rows_out[:, None] * mx_block_size + offsets_cols_out[None, :] - ) - mask_out = (offsets_rows_out[:, None] < n_mx_blocks) & ( - offsets_cols_out[None, :] < mx_block_size - ) - - tl.store(output_ptr + offsets_out, output, mask=mask_out) - - @triton.autotune( - configs=[ - triton.Config({"BLOCK_SIZE_IN": 2}, num_warps=1), - triton.Config({"BLOCK_SIZE_IN": 4}, num_warps=1), - triton.Config({"BLOCK_SIZE_IN": 8}, num_warps=1), - triton.Config({"BLOCK_SIZE_IN": 16}, num_warps=1), - ], - key=["n_mx_blocks"], - ) - @triton.jit - def triton_f6_to_scaled_bf16_kernel( - x_ptr, - s_ptr, - output_ptr, - n_mx_blocks, - mx_block_size: tl.constexpr, - packed_mx_block_size: tl.constexpr, - sign_mask_f6: tl.constexpr, - mbits_f6: tl.constexpr, - f6_exp_bias: tl.constexpr, - mbits_f32: tl.constexpr, - f32_exp_bias: tl.constexpr, - e8m0_exponent_bias: tl.constexpr, - e8m0_exponent_nan_val: tl.constexpr, - BLOCK_SIZE_IN: tl.constexpr, - ): - pid = tl.program_id(axis=0) - - block_start = pid * BLOCK_SIZE_IN - - offsets_rows = block_start + tl.arange(0, BLOCK_SIZE_IN) - offsets_cols = tl.arange(0, packed_mx_block_size // 3) - mask_in = (offsets_rows[:, None] < n_mx_blocks) & ( - offsets_cols[None, :] < packed_mx_block_size // 3 - ) - offsets_in = ( - offsets_rows[:, None] * packed_mx_block_size + offsets_cols[None, :] - ) - - # packed 4 x fp6 into 3 x uint8 - packed_4bits_a = tl.load(x_ptr + offsets_in, mask=mask_in, other=0) - packed_4bits_b = tl.load( - x_ptr + offsets_in + (packed_mx_block_size // 3), mask=mask_in, other=0 - ) - packed_2bits = tl.load( - x_ptr + offsets_in + (2 * packed_mx_block_size // 3), mask=mask_in, other=0 - ) - - output = _fp6_packed_to_bf16( - packed_4bits_a, - packed_4bits_b, - packed_2bits, - sign_mask_f6, - mbits_f6, - f6_exp_bias, - mbits_f32, - f32_exp_bias, - ) - - # load scale - offsets_s = block_start + tl.arange(0, BLOCK_SIZE_IN) - mask_s = offsets_s < n_mx_blocks - s = tl.load(s_ptr + offsets_s, mask=mask_s) - - # create the scale in bf16 - s_offset = s.to(tl.float32) - e8m0_exponent_bias - s_fp = libdevice.pow(2.0, s_offset).to(tl.bfloat16) - s_fp = tl.where(s != e8m0_exponent_nan_val, s_fp, float("nan")) - - # multiply output by scale - # TODO(later): see if manipulating the exponent instead of fp - # multiplication is going to give a significant speedup - output = tl.reshape(output, (BLOCK_SIZE_IN, mx_block_size)) # noqa: E501 - s_fp = tl.reshape(s_fp, (BLOCK_SIZE_IN // 1, 1)) - output = output * s_fp - output = tl.reshape(output, (BLOCK_SIZE_IN, mx_block_size)) - - # set up output offsets - offsets_rows_out = block_start + tl.arange(0, BLOCK_SIZE_IN) - offsets_cols_out = tl.arange(0, mx_block_size) - offsets_out = ( - offsets_rows_out[:, None] * mx_block_size + offsets_cols_out[None, :] - ) - mask_out = (offsets_rows_out[:, None] < n_mx_blocks) & ( - offsets_cols_out[None, :] < mx_block_size - ) - - tl.store(output_ptr + offsets_out, output, mask=mask_out) - - @triton.autotune( - configs=[ - triton.Config({"BLOCK_SIZE_IN": 2}, num_warps=1), - triton.Config({"BLOCK_SIZE_IN": 4}, num_warps=1), - triton.Config({"BLOCK_SIZE_IN": 8}, num_warps=1), - triton.Config({"BLOCK_SIZE_IN": 16}, num_warps=1), - ], - key=["n_mx_blocks"], - ) - @triton.jit - def triton_pack_uint6_kernel( - input_ptr, - output_ptr, - n_mx_blocks, - MX_BLOCK_SIZE: tl.constexpr, - PACKED_MX_BLOCK_SIZE: tl.constexpr, - BLOCK_SIZE_IN: tl.constexpr, - ): - pid = tl.program_id(axis=0) - block_start = pid * BLOCK_SIZE_IN - - # input_ptr is shape [n_mx_blocks, MX_BLOCK_SIZE] - # Load BLOCK_SIZE rows of input_ptr - offsets_rows = block_start + tl.arange(0, BLOCK_SIZE_IN) - offsets_cols = tl.arange(0, MX_BLOCK_SIZE // 4) - offsets = offsets_rows[:, None] * MX_BLOCK_SIZE + (4 * offsets_cols[None, :]) - mask = (offsets_rows[:, None] < n_mx_blocks) & ( - offsets_cols[None, :] < MX_BLOCK_SIZE // 4 - ) - - # x is shape [BLOCK_SIZE, MX_BLOCK_SIZE] - x_0 = tl.load(input_ptr + offsets, mask=mask) - x_1 = tl.load(input_ptr + offsets + 1, mask=mask) - x_2 = tl.load(input_ptr + offsets + 2, mask=mask) - x_3 = tl.load(input_ptr + offsets + 3, mask=mask) - - # OR between remainder 0/1, 2/3 elements to pack 2 x first-4-bit partial representations - # next to each other. These are the middle 4 bits of the uint8, so some gymnastics required. - # i.e. (00abcd00 >> 2) | (00wxyz00 << 2) = 0000abcd | wxyz0000 = wxyzabcd - bits_packed_4_a = (x_1 >> 2) | ((x_0 << 2) & 0xF0) - bits_packed_4_b = (x_3 >> 2) | ((x_2 << 2) & 0xF0) - # Similarly pack 4 remaining 2-bit partial representations into one uint8 - # e.g. 000000ab, 0000cd00, 00ef0000, gh000000 --> abcdefgh - bits_packed_2 = ( - (x_0 << 6) | ((x_1 << 4) & 0x30) | ((x_2 << 2) & 0xC) | (x_3 & 0x3) - ) - - # Store values in a uint8 tensor of length `3 * MX_BLOCK_SIZE / 4` - offsets_out_4_a = ( - offsets_rows[:, None] * PACKED_MX_BLOCK_SIZE + offsets_cols[None, :] - ) - offsets_out_4_b = ( - offsets_rows[:, None] * PACKED_MX_BLOCK_SIZE - + offsets_cols[None, :] - + (MX_BLOCK_SIZE // 4) - ) - offsets_out_2 = ( - offsets_rows[:, None] * PACKED_MX_BLOCK_SIZE - + offsets_cols[None, :] - + (MX_BLOCK_SIZE // 2) - ) - - # Store into output tensor - tl.store( - output_ptr + offsets_out_4_a, - bits_packed_4_a, - mask=mask, - ) - - tl.store( - output_ptr + offsets_out_4_b, - bits_packed_4_b, - mask=mask, - ) - - tl.store( - output_ptr + offsets_out_2, - bits_packed_2, - mask=mask, - ) - -else: - - def triton_f6_to_bf16_kernel( - x_ptr, - output_ptr, - n_elements_in, - sign_mask_f6, - mbits_f6, - f6_exp_bias, - mbits_f32, - f32_exp_bias, - BLOCK_SIZE_IN, - ): - raise AssertionError("unsupported without triton") - - def triton_f6_to_scaled_bf16_kernel( - x_ptr, - s_ptr, - output_ptr, - n_elements_in, - mx_block_size, - sign_mask_f6, - mbits_f6, - f6_exp_bias, - mbits_f32, - f32_exp_bias, - e8m0_exponent_bias, - e8m0_exponent_nan_val, - BLOCK_SIZE_IN, - ): - raise AssertionError("unsupported without triton") - - def triton_pack_uint6_kernel( - input_ptr, - output_ptr, - n_mx_blocks, - MX_BLOCK_SIZE, - PACKED_MX_BLOCK_SIZE, - BLOCK_SIZE, - ): - raise AssertionError("unsupported without triton") - - -def triton_f6_e2m3_to_bf16(x: torch.Tensor) -> torch.Tensor: - """ - Input: a tensor of packed fp6 values - Output: a tensor of bfloat16 values - - Note: this function is only used in testing, so we can test - the numerical correctness of the cast without the scaling. - """ - packed_mx_block_size = x.shape[-1] - mx_block_size = 4 * packed_mx_block_size // 3 - - x = x.view(-1, packed_mx_block_size) - new_shape = (x.shape[0], mx_block_size) - - output = torch.empty(*new_shape, device=x.device, dtype=torch.bfloat16) - - assert x.is_contiguous() - assert x.is_cuda and output.is_cuda - - n_mx_blocks = x.shape[0] - grid = lambda meta: (triton.cdiv(n_mx_blocks, meta["BLOCK_SIZE_IN"]),) - triton_f6_to_bf16_kernel[grid]( - x, - output, - n_mx_blocks, - mx_block_size, - packed_mx_block_size, - sign_mask_f6=SIGN_MASK_F6_E2M3, - mbits_f6=MBITS_F6_E2M3, - f6_exp_bias=F6_E2M3_EXP_BIAS, - mbits_f32=MBITS_F32, - f32_exp_bias=F32_EXP_BIAS, - ) - return output - - -def triton_f6_e3m2_to_bf16(x: torch.Tensor) -> torch.Tensor: - """ - Input: a tensor of packed fp6 values - Output: a tensor of bfloat16 values - - Note: this function is only used in testing, so we can test - the numerical correctness of the cast without the scaling. - """ - packed_mx_block_size = x.shape[-1] - mx_block_size = 4 * packed_mx_block_size // 3 - - x = x.view(-1, packed_mx_block_size) - new_shape = (x.numel() // packed_mx_block_size, mx_block_size) - - output = torch.empty(*new_shape, device=x.device, dtype=torch.bfloat16) - - assert x.is_contiguous() - assert x.is_cuda and output.is_cuda - - n_mx_blocks = x.shape[0] - grid = lambda meta: (triton.cdiv(n_mx_blocks, meta["BLOCK_SIZE_IN"]),) - triton_f6_to_bf16_kernel[grid]( - x, - output, - n_mx_blocks, - mx_block_size, - packed_mx_block_size, - sign_mask_f6=SIGN_MASK_F6_E3M2, - mbits_f6=MBITS_F6_E3M2, - f6_exp_bias=F6_E3M2_EXP_BIAS, - mbits_f32=MBITS_F32, - f32_exp_bias=F32_EXP_BIAS, - ) - return output - - -@torch.library.custom_op("ao::triton_f6_e2m3_to_scaled_bf16", mutates_args=()) -def triton_f6_e2m3_to_scaled_bf16( - x: torch.Tensor, - s_e8m0: torch.Tensor, - mx_block_size: int, -) -> torch.Tensor: - """ - Input: a tensor of packed fp6 values, and a scale in e8m0 format. The block - size is currently assumed to be 32. - Output: a tensor of bfloat16 values, multiplied by the encoded scale - """ - s_e8m0 = s_e8m0.view(torch.uint8) - - packed_mx_block_size = 3 * mx_block_size // 4 - - x = x.view(-1, packed_mx_block_size) - new_shape = (x.numel() // packed_mx_block_size, mx_block_size) - - output = torch.empty(*new_shape, device=x.device, dtype=torch.bfloat16) - - assert x.is_contiguous() - assert x.is_cuda and output.is_cuda - - n_mx_blocks = x.shape[0] - grid = lambda meta: (triton.cdiv(n_mx_blocks, meta["BLOCK_SIZE_IN"]),) - triton_f6_to_scaled_bf16_kernel[grid]( - x, - s_e8m0, - output, - n_mx_blocks, - mx_block_size, - packed_mx_block_size, - sign_mask_f6=SIGN_MASK_F6_E2M3, - mbits_f6=MBITS_F6_E2M3, - f6_exp_bias=F6_E2M3_EXP_BIAS, - mbits_f32=MBITS_F32, - f32_exp_bias=F32_EXP_BIAS, - e8m0_exponent_bias=E8M0_EXPONENT_BIAS, - e8m0_exponent_nan_val=E8M0_EXPONENT_NAN_VAL, - ) - return output - - -@torch.library.custom_op("ao::triton_f6_e3m2_to_scaled_bf16", mutates_args=()) -def triton_f6_e3m2_to_scaled_bf16( - x: torch.Tensor, - s_e8m0: torch.Tensor, - mx_block_size: int, -) -> torch.Tensor: - """ - Input: a tensor of packed fp6 values, and a scale in e8m0 format. The block - size is currently assumed to be 32. - Output: a tensor of bfloat16 values, multiplied by the encoded scale - """ - s_e8m0 = s_e8m0.view(torch.uint8) - - packed_mx_block_size = 3 * mx_block_size // 4 - - x = x.view(-1, packed_mx_block_size) - new_shape = (x.numel() // packed_mx_block_size, mx_block_size) - - output = torch.empty(*new_shape, device=x.device, dtype=torch.bfloat16) - - assert x.is_contiguous() - assert x.is_cuda and output.is_cuda - - n_mx_blocks = x.numel() // packed_mx_block_size - grid = lambda meta: (triton.cdiv(n_mx_blocks, meta["BLOCK_SIZE_IN"]),) - triton_f6_to_scaled_bf16_kernel[grid]( - x, - s_e8m0, - output, - n_mx_blocks, - mx_block_size, - packed_mx_block_size, - sign_mask_f6=SIGN_MASK_F6_E3M2, - mbits_f6=MBITS_F6_E3M2, - f6_exp_bias=F6_E3M2_EXP_BIAS, - mbits_f32=MBITS_F32, - f32_exp_bias=F32_EXP_BIAS, - e8m0_exponent_bias=E8M0_EXPONENT_BIAS, - e8m0_exponent_nan_val=E8M0_EXPONENT_NAN_VAL, - ) - return output - - -@triton_f6_e3m2_to_scaled_bf16.register_fake -def _(x, s_e8m0, mx_block_size): - _padded_mx_block_size = 3 * mx_block_size // 4 - out_shape = (x.numel() // _padded_mx_block_size, mx_block_size) - return torch.empty(*out_shape, device=x.device, dtype=torch.bfloat16) - - -@triton_f6_e2m3_to_scaled_bf16.register_fake -def _(x, s_e8m0, mx_block_size): - _padded_mx_block_size = 3 * mx_block_size // 4 - out_shape = (x.numel() // _padded_mx_block_size, mx_block_size) - return torch.empty(*out_shape, device=x.device, dtype=torch.bfloat16) - - # pack/unpack code copy-pasted from # https://github.com/pytorch-labs/ao/blob/main/torchao/dtypes/uint4.py @@ -761,69 +159,6 @@ def pack_uint4(uint8_data: torch.Tensor) -> torch.Tensor: return (uint8_data[::2] | uint8_data[1::2] << 4).view(down_size(shape)) -# PyTorch implementation of fp6 packing for reference purposes -def pack_uint6_pytorch(uint8_data: torch.Tensor) -> torch.Tensor: - # check shape is divisible by 4 along packing axis - shape = uint8_data.shape - assert shape[-1] % 4 == 0 - - packed_shape = [*shape[:-1], 3 * shape[-1] // 4] - - uint8_data = uint8_data.contiguous().view(-1) - - # pack 4 bits of each of 4 numbers into 2xuint8, remaining 2 bits into 1xuint8 - bits_packed_4_a = (uint8_data[1::4] >> 2) | ((uint8_data[::4] << 2) & 0xF0) - bits_packed_4_b = (uint8_data[2::4] >> 2) | ((uint8_data[3::4] << 2) & 0xF0) - bits_packed_2 = ( - (uint8_data[::4] << 6) - | ((uint8_data[1::4] << 4) & 0x30) - | ((uint8_data[3::4] << 2) & 0xC) - | (uint8_data[2::4] & 0x3) - ) - - return ( - torch.stack((bits_packed_4_a, bits_packed_4_b, bits_packed_2), dim=-1) - ).view(packed_shape) - - -@torch.library.custom_op("ao::pack_uint6", mutates_args=()) -def pack_uint6(uint8_data: torch.Tensor) -> torch.Tensor: - # ensure input data is contiguous before passing to kernel - assert uint8_data.is_contiguous() - - # tensor should already be of shape [..., mx_block_size] - mx_block_size = uint8_data.shape[-1] - assert mx_block_size % 4 == 0 - - # effective mx block size since we're packing 2 fp4 into 1 uint8 - packed_mx_block_size = 3 * mx_block_size // 4 - packed_shape = [*uint8_data.shape[:-1], packed_mx_block_size] - n_mx_blocks = uint8_data.numel() // mx_block_size - - grid = lambda meta: (triton.cdiv(n_mx_blocks, meta["BLOCK_SIZE_IN"]),) - - # contiguous uint8 container in which we can store the unpacked tensor - packed_uint8_data = torch.empty( - packed_shape, dtype=torch.uint8, device=uint8_data.device - ) - - triton_pack_uint6_kernel[grid]( - uint8_data, - packed_uint8_data, - n_mx_blocks, - MX_BLOCK_SIZE=mx_block_size, - PACKED_MX_BLOCK_SIZE=packed_mx_block_size, - ) - - return packed_uint8_data - - -@pack_uint6.register_fake -def _(uint8_data): - out_shape = (*uint8_data.shape[:-1], 3 * uint8_data.shape[-1] // 4) - return torch.empty(*out_shape, device=uint8_data.device, dtype=torch.uint8) - - if torch_version_at_least("2.7.0") and has_triton(): import triton import triton.language as tl diff --git a/torchao/prototype/mx_formats/mx_tensor.py b/torchao/prototype/mx_formats/mx_tensor.py index 9119ae4b24..7a1b5a160b 100644 --- a/torchao/prototype/mx_formats/mx_tensor.py +++ b/torchao/prototype/mx_formats/mx_tensor.py @@ -58,9 +58,6 @@ f32_to_f6_e2m3_unpacked, f32_to_f6_e3m2_unpacked, pack_uint4, - pack_uint6, - triton_f6_e2m3_to_scaled_bf16, - triton_f6_e3m2_to_scaled_bf16, unpack_uint4, ) from torchao.prototype.mx_formats.utils import ( @@ -91,7 +88,6 @@ class QuantizeTensorToMXKwargs(QuantizeTensorKwargs): block_size: int = 32 scaling_mode: ScaleCalculationMode = ScaleCalculationMode.FLOOR gemm_kernel_choice: MXGemmKernelChoice = MXGemmKernelChoice.EMULATED - pack_fp6: bool = False is_swizzled_scales: bool = False @@ -148,7 +144,6 @@ def to_mx( elem_dtype: Union[torch.dtype, str], block_size: int, scaling_mode: ScaleCalculationMode = ScaleCalculationMode.FLOOR, - pack_fp6: bool = False, is_swizzled_scales: bool = False, ): """ @@ -311,16 +306,10 @@ def to_mx( data_lp = data_lp.reshape(orig_shape) elif elem_dtype == DTYPE_FP6_E2M3: data_lp = f32_to_f6_e2m3_unpacked(data_lp) - if pack_fp6: - orig_shape = [*orig_shape[:-1], 3 * orig_shape[-1] // 4] - data_lp = pack_uint6(data_lp) # need to reshape at the end to help inductor fuse things data_lp = data_lp.reshape(orig_shape) elif elem_dtype == DTYPE_FP6_E3M2: data_lp = f32_to_f6_e3m2_unpacked(data_lp) - if pack_fp6: - orig_shape = [*orig_shape[:-1], 3 * orig_shape[-1] // 4] - data_lp = pack_uint6(data_lp) # need to reshape at the end to help inductor fuse things data_lp = data_lp.reshape(orig_shape) elif elem_dtype == torch.float4_e2m1fn_x2: @@ -370,7 +359,6 @@ def to_dtype( elem_dtype, block_size, target_dtype, - pack_fp6: bool = False, ): orig_shape = data_lp.shape is_transposed = not data_lp.is_contiguous() @@ -385,33 +373,11 @@ def to_dtype( if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2): data_hp = data_lp.to(target_dtype) elif elem_dtype == DTYPE_FP6_E2M3: - if pack_fp6: - orig_shape = (*orig_shape[:-1], 4 * orig_shape[-1] // 3) - data_hp_rescaled = triton_f6_e2m3_to_scaled_bf16( - data_lp, - scale_e8m0, - block_size, - ).reshape(orig_shape) - if is_transposed: - data_hp_rescaled = data_hp_rescaled.t() - return data_hp_rescaled.to(target_dtype) - else: - data_hp = f6_e2m3_unpacked_to_f32(data_lp) - data_hp = data_hp.to(target_dtype).reshape(orig_shape) + data_hp = f6_e2m3_unpacked_to_f32(data_lp) + data_hp = data_hp.to(target_dtype).reshape(orig_shape) elif elem_dtype == DTYPE_FP6_E3M2: - if pack_fp6: - orig_shape = (*orig_shape[:-1], 4 * orig_shape[-1] // 3) - data_hp_rescaled = triton_f6_e3m2_to_scaled_bf16( - data_lp, - scale_e8m0, - block_size, - ).reshape(orig_shape) - if is_transposed: - data_hp_rescaled = data_hp_rescaled.t() - return data_hp_rescaled.to(target_dtype) - else: - data_hp = f6_e3m2_unpacked_to_f32(data_lp) - data_hp = data_hp.to(target_dtype).reshape(orig_shape) + data_hp = f6_e3m2_unpacked_to_f32(data_lp) + data_hp = data_hp.to(target_dtype).reshape(orig_shape) elif elem_dtype == torch.float4_e2m1fn_x2: # fp4 f4_unpacked = unpack_uint4(data_lp) @@ -466,26 +432,6 @@ def tensor_size_fp4x2_to_hp(orig_size, is_contiguous): return new_size -# TODO(future PR): fix this function for rank 3 and add tests -def tensor_size_hpx3_to_fp6x4(orig_size, is_contiguous): - new_size = orig_size - if is_contiguous: - new_size = [*list(new_size[:-1]), 3 * new_size[-1] // 4] - else: - new_size = [3 * new_size[0] // 4, *list(new_size[1:])] - return new_size - - -# TODO(future PR): fix this function for rank 3 and add tests -def tensor_size_fp6x4_to_hpx3(orig_size, is_contiguous): - new_size = orig_size - if is_contiguous: - new_size = [*list(new_size[:-1]), 4 * new_size[-1] // 3] - else: - new_size = [4 * new_size[0] // 3, *list(new_size[1:])] - return new_size - - class MXTensor(TorchAOBaseTensor): tensor_data_names = ["qdata", "scale"] tensor_attribute_names = [ @@ -493,7 +439,6 @@ class MXTensor(TorchAOBaseTensor): "block_size", "_orig_dtype", "_gemm_kernel_choice", - "_pack_fp6", "act_quant_kwargs", "_is_swizzled_scales", ] @@ -506,7 +451,6 @@ def __new__( block_size, orig_dtype, gemm_kernel_choice, - pack_fp6, act_quant_kwargs, is_swizzled_scales, ): @@ -521,12 +465,6 @@ def __new__( new_size, qdata.is_contiguous(), ) - elif pack_fp6 and elem_dtype in [DTYPE_FP6_E2M3, DTYPE_FP6_E3M2]: - # set the tensor size to what it would be without fp6 packing - new_size = tensor_size_fp6x4_to_hpx3( - new_size, - qdata.is_contiguous(), - ) self = torch.Tensor._make_wrapper_subclass( cls, new_size, @@ -550,7 +488,6 @@ def __new__( self.block_size = block_size self._orig_dtype = orig_dtype self._gemm_kernel_choice = gemm_kernel_choice - self._pack_fp6 = pack_fp6 self.act_quant_kwargs = act_quant_kwargs self._is_swizzled_scales = is_swizzled_scales return self @@ -587,7 +524,6 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor self._elem_dtype, self.block_size, output_dtype, - self._pack_fp6, ) @staticmethod @@ -599,12 +535,11 @@ def to_mx( scaling_mode: ScaleCalculationMode = ScaleCalculationMode.FLOOR, # TODO(future PR): switch default gemm to cublas gemm_kernel_choice: MXGemmKernelChoice = MXGemmKernelChoice.EMULATED, - pack_fp6: bool = False, act_quant_kwargs: Optional[QuantizeTensorToMXKwargs] = None, is_swizzled_scales: bool = False, ): scale_e8m0_biased, data_lp = to_mx( - data_hp, elem_dtype, block_size, scaling_mode, pack_fp6, is_swizzled_scales + data_hp, elem_dtype, block_size, scaling_mode, is_swizzled_scales ) if isinstance(scale_e8m0_biased, DTensor): assert isinstance(data_lp, DTensor), "unsupported" @@ -617,7 +552,6 @@ def to_mx( block_size, data_hp.dtype, gemm_kernel_choice, - pack_fp6, act_quant_kwargs, is_swizzled_scales, ) @@ -636,7 +570,6 @@ def to_mx( block_size, data_hp.dtype, gemm_kernel_choice, - pack_fp6, act_quant_kwargs, is_swizzled_scales, ) @@ -688,7 +621,6 @@ def _addmm_mx_dispatch( k.block_size, k.scaling_mode, k.gemm_kernel_choice, - k.pack_fp6, k.is_swizzled_scales, ) @@ -807,7 +739,6 @@ def mx_t(func, types, args, kwargs): old.block_size, old._orig_dtype, old._gemm_kernel_choice, - old._pack_fp6, old.act_quant_kwargs, old._is_swizzled_scales, ) @@ -841,9 +772,6 @@ def mx_view_op(func, types, args, kwargs): if args[0]._elem_dtype == torch.float4_e2m1fn_x2: # special case fp4 as we pack two elements per byte new_size = tensor_size_hp_to_fp4x2(new_size, data.is_contiguous()) - elif args[0]._elem_dtype in [DTYPE_FP6_E3M2, DTYPE_FP6_E2M3] and args[0]._pack_fp6: - # special case fp6 as we pack 4 elements in 3 bytes - new_size = tensor_size_hpx3_to_fp6x4(new_size, data.is_contiguous()) new_data = func(data, new_size, *args[2:], **kwargs) return MXTensor( new_data, @@ -852,7 +780,6 @@ def mx_view_op(func, types, args, kwargs): args[0].block_size, args[0]._orig_dtype, args[0]._gemm_kernel_choice, - args[0]._pack_fp6, args[0].act_quant_kwargs, args[0]._is_swizzled_scales, ) @@ -878,7 +805,6 @@ def mx_slice(func, types, args, kwargs): x.block_size, x._orig_dtype, x._gemm_kernel_choice, - x._pack_fp6, x.act_quant_kwargs, x._is_swizzled_scales, ), @@ -913,7 +839,6 @@ def mx_select(func, types, args, kwargs): old_mx_tensor.block_size, old_mx_tensor._orig_dtype, old_mx_tensor._gemm_kernel_choice, - old_mx_tensor._pack_fp6, old_mx_tensor.act_quant_kwargs, old_mx_tensor._is_swizzled_scales, ) diff --git a/torchao/prototype/mx_formats/utils.py b/torchao/prototype/mx_formats/utils.py index 72d8a47b81..9b6e878f7d 100644 --- a/torchao/prototype/mx_formats/utils.py +++ b/torchao/prototype/mx_formats/utils.py @@ -188,7 +188,6 @@ def _to_mxfp8_dim1_kernel_wrapper( block_size, hp_dtype, gemm_kernel_choice, - False, None, is_swizzled_scales, ) @@ -208,7 +207,6 @@ def _to_mxfp8_dim1_kernel_wrapper( block_size, hp_dtype, gemm_kernel_choice, - False, None, is_swizzled_scales, ) From d4f3b174bcdacaf4e6ae0b29852670b1a762cdac Mon Sep 17 00:00:00 2001 From: vasiliy Date: Mon, 24 Nov 2025 09:34:54 -0800 Subject: [PATCH 3/4] Update [ghstack-poisoned] --- test/prototype/mx_formats/test_kernels.py | 18 +++++++++--------- test/prototype/mx_formats/test_mx_linear.py | 7 ++++--- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/test/prototype/mx_formats/test_kernels.py b/test/prototype/mx_formats/test_kernels.py index 1729901933..4b6586b385 100644 --- a/test/prototype/mx_formats/test_kernels.py +++ b/test/prototype/mx_formats/test_kernels.py @@ -442,8 +442,8 @@ def triton_to_mxfp8_dim0_reference( not is_sm_at_least_89(), reason="float8 in triton requires CUDA capability 8.9 or greater", ) -@pytest.mark.parametrize("M", (256, 2048)) -@pytest.mark.parametrize("K", (256, 2048)) +@pytest.mark.parametrize("M", (128, 256)) +@pytest.mark.parametrize("K", (128, 256)) def test_triton_mxfp8_dim1_randn(M, K): x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") x_mx_ref, x_s_ref = triton_to_mxfp8_dim1_reference(x, block_size=32) @@ -457,8 +457,8 @@ def test_triton_mxfp8_dim1_randn(M, K): not is_sm_at_least_100(), reason="mxfp8 requires CUDA capability 10.0 or greater", ) -@pytest.mark.parametrize("M", (256, 2048, 131072)) -@pytest.mark.parametrize("K", (256, 5120, 7168)) +@pytest.mark.parametrize("M", (128, 256)) +@pytest.mark.parametrize("K", (128, 256)) def test_triton_mxfp8_dim0_randn(M, K): x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") x_mx_ref, x_s_ref = triton_to_mxfp8_dim0_reference(x, block_size=32) @@ -473,7 +473,7 @@ def test_triton_mxfp8_dim0_randn(M, K): reason="mxfp8 requires CUDA capability 10.0 or greater", ) def test_triton_mxfp8_dim0_zeros(): - x = torch.zeros(8192, 5120, dtype=torch.bfloat16, device="cuda") + x = torch.zeros(128, 256, dtype=torch.bfloat16, device="cuda") x_mx_ref, x_s_ref = triton_to_mxfp8_dim0_reference(x, block_size=32) x_mx_t, x_s_t = triton_to_mxfp8_dim0(x, inner_block_size=32) assert not x_mx_t.isnan().any(), "quantized tensor should not contain NaNs" @@ -486,8 +486,8 @@ def test_triton_mxfp8_dim0_zeros(): not is_sm_at_least_100(), reason="mxfp8 requires CUDA capability 10.0 or greater", ) -@pytest.mark.parametrize("M", (256, 2048, 131072)) -@pytest.mark.parametrize("K", (256, 5120, 7168)) +@pytest.mark.parametrize("M", (128, 256)) +@pytest.mark.parametrize("K", (128, 256)) @pytest.mark.parametrize("orig_dtype", (torch.float32, torch.bfloat16)) def test_triton_mxfp8_dequant_dim0(M, K, orig_dtype): x = torch.zeros(M, K, dtype=orig_dtype, device="cuda") @@ -529,8 +529,8 @@ def test_rearrange(shape): not is_sm_at_least_100(), reason="MXFP8 requires CUDA capability 10.0 or greater", ) -@pytest.mark.parametrize("M", (32, 64, 2048)) -@pytest.mark.parametrize("K", (32, 64, 2048)) +@pytest.mark.parametrize("M", (32, 256)) +@pytest.mark.parametrize("K", (32, 256)) @pytest.mark.parametrize("input_dtype", (torch.float32, torch.bfloat16)) @pytest.mark.parametrize( "scaling_mode", (ScaleCalculationMode.FLOOR, ScaleCalculationMode.RCEIL) diff --git a/test/prototype/mx_formats/test_mx_linear.py b/test/prototype/mx_formats/test_mx_linear.py index c858657af6..49343c6608 100644 --- a/test/prototype/mx_formats/test_mx_linear.py +++ b/test/prototype/mx_formats/test_mx_linear.py @@ -238,9 +238,11 @@ def test_activation_checkpointing(): "recipe_name", [ "mxfp8_emulated", - "mxfp4_emulated", "mxfp8_cublas", - "mxfp4_cutlass", + # TODO(future PR): add mxfp4 back here, but ensure CI speed is not too + # slow + # "mxfp4_emulated", + # "mxfp4_cutlass", ], ) @pytest.mark.parametrize("bias", [False, True]) @@ -258,7 +260,6 @@ def test_activation_checkpointing(): "scale_calculation_mode", [ ScaleCalculationMode.FLOOR, - ScaleCalculationMode.CEIL, # even + compile does not work yet: # https://gist.github.com/vkuzo/1a04845cd503b1c75291aa1ea3bf79c4 # ScaleCalculationMode.EVEN, From 4d30fe98701da1c91597b613bcb64c2f337181ce Mon Sep 17 00:00:00 2001 From: vasiliy Date: Mon, 24 Nov 2025 11:48:20 -0800 Subject: [PATCH 4/4] Update [ghstack-poisoned] --- .../float8/float8_inference_roofline.py | 7 +-- .../mx_formats/test_inference_workflow.py | 16 ++--- .../mx_formats/test_mx_serialization.py | 6 +- test/prototype/mx_formats/test_mx_tensor.py | 4 +- .../moe_training/scaled_grouped_mm.py | 8 +-- torchao/prototype/mx_formats/README.md | 22 +++---- torchao/prototype/mx_formats/__init__.py | 2 - torchao/prototype/mx_formats/config.py | 62 +++++++------------ .../mx_formats/inference_workflow.py | 17 +++-- torchao/prototype/mx_formats/mx_linear.py | 28 ++++----- torchao/prototype/mx_formats/mx_tensor.py | 43 ++++++------- .../quantize_/common/kernel_preference.py | 8 +++ 12 files changed, 97 insertions(+), 126 deletions(-) diff --git a/benchmarks/float8/float8_inference_roofline.py b/benchmarks/float8/float8_inference_roofline.py index 675c7f166f..dc732dc77a 100644 --- a/benchmarks/float8/float8_inference_roofline.py +++ b/benchmarks/float8/float8_inference_roofline.py @@ -38,9 +38,6 @@ ) import torchao -from torchao.prototype.mx_formats.config import ( - MXGemmKernelChoice, -) from torchao.prototype.mx_formats.inference_workflow import ( MXFPInferenceConfig, NVFP4InferenceConfig, @@ -439,13 +436,13 @@ def run( config = MXFPInferenceConfig( activation_dtype=torch.float8_e4m3fn, weight_dtype=torch.float8_e4m3fn, - gemm_kernel_choice=MXGemmKernelChoice.CUBLAS, + kernel_preference=KernelPreference.AUTO, ) elif recipe_name == "mxfp4_cutlass": config = MXFPInferenceConfig( activation_dtype=torch.float4_e2m1fn_x2, weight_dtype=torch.float4_e2m1fn_x2, - gemm_kernel_choice=MXGemmKernelChoice.CUTLASS, + kernel_preference=KernelPreference.AUTO, ) elif recipe_name == "nvfp4": config = NVFP4InferenceConfig( diff --git a/test/prototype/mx_formats/test_inference_workflow.py b/test/prototype/mx_formats/test_inference_workflow.py index 2f6e411ff7..8dad950c4c 100644 --- a/test/prototype/mx_formats/test_inference_workflow.py +++ b/test/prototype/mx_formats/test_inference_workflow.py @@ -12,15 +12,13 @@ import torch.nn as nn from torch.profiler import ProfilerActivity, profile -from torchao.prototype.mx_formats.config import ( - MXGemmKernelChoice, -) from torchao.prototype.mx_formats.inference_workflow import ( MXFPInferenceConfig, NVFP4InferenceConfig, NVFP4MMConfig, ) from torchao.quantization import quantize_ +from torchao.quantization.quantize_.common import KernelPreference from torchao.quantization.utils import compute_error from torchao.testing.utils import TorchAOIntegrationTestCase, skip_if_rocm from torchao.utils import ( @@ -105,15 +103,13 @@ def test_inference_workflow_mx( m_mx = copy.deepcopy(m) if emulate: - kernel_choice = MXGemmKernelChoice.EMULATED - elif elem_dtype == torch.float4_e2m1fn_x2: - kernel_choice = MXGemmKernelChoice.CUTLASS + kernel_choice = KernelPreference.EMULATED else: - kernel_choice = MXGemmKernelChoice.CUBLAS + kernel_choice = KernelPreference.AUTO config = MXFPInferenceConfig( activation_dtype=elem_dtype, weight_dtype=elem_dtype, - gemm_kernel_choice=kernel_choice, + kernel_preference=kernel_choice, ) quantize_(m_mx, config=config) if compile: @@ -254,7 +250,7 @@ def test_slice_and_copy_similar_to_vllm(self): config = MXFPInferenceConfig( activation_dtype=torch.float8_e4m3fn, weight_dtype=torch.float8_e4m3fn, - gemm_kernel_choice=MXGemmKernelChoice.EMULATED, + kernel_preference=KernelPreference.EMULATED, ) self._test_slice_and_copy_similar_to_vllm(config) @@ -267,7 +263,7 @@ def test_narrow_similar_to_vllm(self): config = MXFPInferenceConfig( activation_dtype=torch.float8_e4m3fn, weight_dtype=torch.float8_e4m3fn, - gemm_kernel_choice=MXGemmKernelChoice.EMULATED, + kernel_preference=KernelPreference.EMULATED, ) self._test_narrow_similar_to_vllm(config) diff --git a/test/prototype/mx_formats/test_mx_serialization.py b/test/prototype/mx_formats/test_mx_serialization.py index d04d23f46c..930dc1dfaa 100644 --- a/test/prototype/mx_formats/test_mx_serialization.py +++ b/test/prototype/mx_formats/test_mx_serialization.py @@ -12,15 +12,13 @@ import torch import torch.nn as nn -from torchao.prototype.mx_formats.config import ( - MXGemmKernelChoice, -) from torchao.prototype.mx_formats.inference_workflow import ( MXFPInferenceConfig, NVFP4InferenceConfig, NVFP4MMConfig, ) from torchao.quantization import quantize_ +from torchao.quantization.quantize_.common import KernelPreference from torchao.utils import ( is_sm_at_least_100, torch_version_at_least, @@ -46,7 +44,7 @@ def test_serialization(recipe_name): config = MXFPInferenceConfig( activation_dtype=torch.float8_e4m3fn, weight_dtype=torch.float8_e4m3fn, - gemm_kernel_choice=MXGemmKernelChoice.EMULATED, + kernel_preference=KernelPreference.EMULATED, ) else: assert recipe_name == "nvfp4", "unsupported" diff --git a/test/prototype/mx_formats/test_mx_tensor.py b/test/prototype/mx_formats/test_mx_tensor.py index 66f8998ea8..2b8c72ff91 100644 --- a/test/prototype/mx_formats/test_mx_tensor.py +++ b/test/prototype/mx_formats/test_mx_tensor.py @@ -12,7 +12,6 @@ from torch._inductor.utils import run_and_get_code from torch.testing import FileCheck -from torchao.prototype.mx_formats.config import MXGemmKernelChoice from torchao.prototype.mx_formats.constants import ( DTYPE_FP6_E2M3, DTYPE_FP6_E3M2, @@ -25,6 +24,7 @@ to_dtype, ) from torchao.prototype.mx_formats.utils import from_blocked, to_blocked +from torchao.quantization.quantize_.common import KernelPreference from torchao.quantization.utils import compute_error from torchao.utils import ( is_sm_at_least_89, @@ -375,7 +375,7 @@ def test_exponent_nan_out(elem_dtype): elem_dtype, block_size, torch.float, - MXGemmKernelChoice.EMULATED, + KernelPreference.EMULATED, None, False, ) diff --git a/torchao/prototype/moe_training/scaled_grouped_mm.py b/torchao/prototype/moe_training/scaled_grouped_mm.py index c7705fec18..3a4ad43b4f 100644 --- a/torchao/prototype/moe_training/scaled_grouped_mm.py +++ b/torchao/prototype/moe_training/scaled_grouped_mm.py @@ -27,12 +27,12 @@ ) from torchao.prototype.mx_formats.config import ( MXFP8Dim1CastKernelChoice, - MXGemmKernelChoice, ScaleCalculationMode, ) from torchao.prototype.mx_formats.kernels import triton_to_mxfp8_dim0 from torchao.prototype.mx_formats.mx_tensor import to_mx from torchao.prototype.mx_formats.utils import _to_mxfp8_dim1_kernel_wrapper +from torchao.quantization.quantize_.common import KernelPreference logger: logging.Logger = logging.getLogger(__name__) @@ -412,7 +412,7 @@ def backward(ctx, grad_out: torch.Tensor): block_size, elem_dtype=torch.float8_e4m3fn, hp_dtype=grad_out.dtype, - gemm_kernel_choice=MXGemmKernelChoice.CUTLASS, # Not used + kernel_preference=KernelPreference.AUTO, # Not used cast_kernel_choice=MXFP8Dim1CastKernelChoice.CUDA, scale_calculation_mode=scale_calculation_mode, ) @@ -428,7 +428,7 @@ def backward(ctx, grad_out: torch.Tensor): block_size, elem_dtype=torch.float8_e4m3fn, hp_dtype=A.dtype, - gemm_kernel_choice=MXGemmKernelChoice.CUTLASS, # Not used + kernel_preference=KernelPreference.AUTO, # Not used cast_kernel_choice=MXFP8Dim1CastKernelChoice.CUDA, scale_calculation_mode=scale_calculation_mode, ) @@ -475,7 +475,7 @@ def _to_mxfp8_dim1_3d( block_size, elem_dtype=torch.float8_e4m3fn, hp_dtype=B_reshaped.dtype, - gemm_kernel_choice=MXGemmKernelChoice.CUTLASS, # Not used + kernel_preference=KernelPreference.AUTO, # Not used cast_kernel_choice=MXFP8Dim1CastKernelChoice.CUDA, scale_calculation_mode=scaling_mode, ) diff --git a/torchao/prototype/mx_formats/README.md b/torchao/prototype/mx_formats/README.md index 8922be949b..6c36c2eaed 100644 --- a/torchao/prototype/mx_formats/README.md +++ b/torchao/prototype/mx_formats/README.md @@ -74,13 +74,13 @@ Below is a toy training loop. For an example real training loop, see our torchti import torch from torchao.quantization import quantize_ import torchao.prototype.mx_formats -from torchao.prototype.mx_formats import MXLinearConfig, MXGemmKernelChoice, ScaleCalculationMode +from torchao.prototype.mx_formats import MXLinearConfig, ScaleCalculationMode +from torchao.quantization.quantize_.common import KernelPreference -# on NVIDIA Blackwell GPUs, you can use cuBLAS or CUTLASS mxfp8 kernels -gemm_kernel_choice = MXGemmKernelChoice.CUBLAS -# gemm_kernel_choice = MXGemmKernelChoice.CUTLASS -# on older NVIDIA gpus, you can run training with emulated MX gemm -# gemm_kernel_choice = MXGemmKernelChoice.EMULATED +# low precision gemm, requires CUDA capability 10.0+ +kernel_preference = KernelPreference.AUTO +# or, emulated gemm +# kernel_preference = KernelPreference.EMULATED scale_calculation_mode = ScaleCalculationMode.FLOOR # other supported modes: RCEIL, CEIL, EVEN @@ -89,7 +89,7 @@ m = torch.nn.Sequential(torch.nn.Linear(32, 32)).cuda() config = MXLinearConfig( elem_dtype=torch.float8_e4m3fn, block_size=32, - gemm_kernel_choice=gemm_kernel_choice, + kernel_preference=kernel_preference, scale_calculation_mode=scale_calculation_mode, ) quantize_(m, config) @@ -107,14 +107,12 @@ import torch import torch.nn as nn from torchao.quantization import quantize_ import torchao.prototype.mx_formats -from torchao.prototype.mx_formats.config import ( - MXGemmKernelChoice, -) from torchao.prototype.mx_formats.inference_workflow import ( MXFPInferenceConfig, NVFP4InferenceConfig, NVFP4MMConfig, ) +from torchao.quantization.quantize_.common import KernelPreference m = nn.Linear(32, 128, bias=False, dtype=torch.bfloat16, device="cuda") x = torch.randn(128, 32, device="cuda", dtype=torch.bfloat16) @@ -125,7 +123,7 @@ m_mxfp8 = copy.deepcopy(m) config = MXFPInferenceConfig( activation_dtype=torch.float8_e4m3fn, weight_dtype=torch.float8_e4m3fn, - gemm_kernel_choice=MXGemmKernelChoice.CUBLAS, + kernel_preference=KernelPreference.AUTO, ) quantize_(m_mxfp8, config=config) m_mxfp8 = torch.compile(m_mxfp8, fullgraph=True) @@ -137,7 +135,7 @@ m_mxfp4 = copy.deepcopy(m) config = MXFPInferenceConfig( activation_dtype=torch.float4_e2m1fn_x2, weight_dtype=torch.float4_e2m1fn_x2, - gemm_kernel_choice=MXGemmKernelChoice.CUTLASS, + kernel_preference=KernelPreference.AUTO, ) quantize_(m_mxfp4, config=config) m_mxfp4 = torch.compile(m_mxfp4, fullgraph=True) diff --git a/torchao/prototype/mx_formats/__init__.py b/torchao/prototype/mx_formats/__init__.py index c7a4c47f9d..8d1455d6f3 100644 --- a/torchao/prototype/mx_formats/__init__.py +++ b/torchao/prototype/mx_formats/__init__.py @@ -1,5 +1,4 @@ from torchao.prototype.mx_formats.config import ( - MXGemmKernelChoice, MXLinearConfig, MXLinearRecipeName, ) @@ -16,7 +15,6 @@ import torchao.prototype.mx_formats.mx_linear # noqa: F401 __all__ = [ - "MXGemmKernelChoice", "MXLinearConfig", "MXLinearRecipeName", "MXFPInferenceConfig", diff --git a/torchao/prototype/mx_formats/config.py b/torchao/prototype/mx_formats/config.py index 388af07874..d57b91b85f 100644 --- a/torchao/prototype/mx_formats/config.py +++ b/torchao/prototype/mx_formats/config.py @@ -15,20 +15,7 @@ DTYPE_TO_SHORT_STR, SUPPORTED_ELEM_DTYPES, ) - - -class MXGemmKernelChoice(Enum): - # always available - MX operands are dequantized and a high precision - # gemm is run - EMULATED = "emulated" - - # available only when CUDA capability is greater than or equal to 10.0 - CUTLASS = "cutlass" - - # available only when CUDA capability is greater than or equal to 10.0 - # available on recent versions of PyTorch nightly, with https://github.com/pytorch/pytorch/pull/147548 - # note: torch.compile does not work yet, see https://github.com/pytorch/pytorch/issues/147873 - CUBLAS = "cublas" +from torchao.quantization.quantize_.common.kernel_preference import KernelPreference class MXFP8Dim1CastKernelChoice(Enum): @@ -85,22 +72,17 @@ def _validate_elem_dtype(elem_dtype): ) -def _validate_gemm_kernel_choice(gemm_kernel_choice, block_size, elem_dtype): - if gemm_kernel_choice == MXGemmKernelChoice.CUTLASS: - assert block_size == 32, ( - f"block_size must be 32 to use the CUTLASS MX gemm kernels, got {block_size}" - ) - valid_dtypes = [torch.float8_e4m3fn, torch.float4_e2m1fn_x2] - assert elem_dtype in valid_dtypes, ( - f"elem_dtype must be one of {valid_dtypes} to use the CUTLASS MX gemm kernels, got {elem_dtype}" - ) - elif gemm_kernel_choice == MXGemmKernelChoice.CUBLAS: - assert block_size in [16, 32], ( - f"block_size must be in [16, 32] to use the cuBLAS MX gemm kernels, got {block_size}" - ) - valid_dtypes = [torch.float8_e4m3fn, torch.float4_e2m1fn_x2] - assert elem_dtype in valid_dtypes, ( - f"elem_dtype must be one of {valid_dtypes} to use the CUTLASS MX gemm kernels, got {elem_dtype}" +def _validate_kernel_preference(kernel_preference, block_size, elem_dtype): + if kernel_preference == KernelPreference.AUTO: + if elem_dtype in (torch.float8_e4m3fn, torch.float4_e2m1fn_x2): + assert block_size == 32, f"block_size must be 32, got {block_size}" + else: + raise AssertionError( + f"unsupported {kernel_preference=}, {block_size=}, {elem_dtype=}" + ) + else: + assert kernel_preference == KernelPreference.EMULATED, ( + f"unsupported {kernel_preference=}, {block_size=}, {elem_dtype=}" ) @@ -135,9 +117,9 @@ class MXLinearConfig(AOBaseConfig): elem_dtype_weight_override: Optional[Any] = None elem_dtype_grad_output_override: Optional[Any] = None - # defines the gemm kernel choice, if the chosen kernel is not supported + # defines the kernel preference, if the chosen kernel is not supported # on the given hardware an exception will be thrown - gemm_kernel_choice: MXGemmKernelChoice = MXGemmKernelChoice.EMULATED + kernel_preference: KernelPreference = KernelPreference.EMULATED # define which kernel to use for mxfp8 casting # TODO(1945): remove this config option once torch.compile gives us @@ -150,15 +132,15 @@ class MXLinearConfig(AOBaseConfig): def __post_init__(self): _validate_elem_dtype(self.elem_dtype) - _validate_gemm_kernel_choice( - self.gemm_kernel_choice, self.block_size, self.elem_dtype + _validate_kernel_preference( + self.kernel_preference, self.block_size, self.elem_dtype ) if self.elem_dtype_weight_override is not None: _validate_elem_dtype(self.elem_dtype_weight_override) - assert self.gemm_kernel_choice == MXGemmKernelChoice.EMULATED, "unsupported" + assert self.kernel_preference == KernelPreference.EMULATED, "unsupported" if self.elem_dtype_grad_output_override is not None: _validate_elem_dtype(self.elem_dtype_grad_output_override) - assert self.gemm_kernel_choice == MXGemmKernelChoice.EMULATED, "unsupported" + assert self.kernel_preference == KernelPreference.EMULATED, "unsupported" _validate_mxfp8_cast_kernel_choice( self.mxfp8_cast_kernel_choice, self.scale_calculation_mode ) @@ -182,12 +164,12 @@ def from_recipe_name( return MXLinearConfig() elif recipe_name is MXLinearRecipeName.MXFP8_CUBLAS: return MXLinearConfig( - gemm_kernel_choice=MXGemmKernelChoice.CUBLAS, + kernel_preference=KernelPreference.AUTO, mxfp8_cast_kernel_choice=MXFP8Dim1CastKernelChoice.CUDA, ) elif recipe_name is MXLinearRecipeName.MXFP8_CUBLAS_RCEIL: return MXLinearConfig( - gemm_kernel_choice=MXGemmKernelChoice.CUBLAS, + kernel_preference=KernelPreference.AUTO, mxfp8_cast_kernel_choice=MXFP8Dim1CastKernelChoice.CUDA, scale_calculation_mode=ScaleCalculationMode.RCEIL, ) @@ -196,7 +178,7 @@ def from_recipe_name( elif recipe_name is MXLinearRecipeName.MXFP4_CUTLASS: return MXLinearConfig( elem_dtype=torch.float4_e2m1fn_x2, - gemm_kernel_choice=MXGemmKernelChoice.CUTLASS, + kernel_preference=KernelPreference.AUTO, ) else: raise AssertionError(f"unknown recipe_name {recipe_name}") @@ -212,7 +194,7 @@ def short_str(self) -> str: ) if self.elem_dtype_grad_output_override is not None: s += f", lp_go_override={DTYPE_TO_SHORT_STR[self.elem_dtype_grad_output_override]}" - s += f", kernel={self.gemm_kernel_choice.value}" + s += f", kernel={self.kernel_preference.value}" s += f", mxfp8_cast_kernel_choice={self.mxfp8_cast_kernel_choice.value}" if self.scale_calculation_mode != ScaleCalculationMode.FLOOR: s += f", scale_calculation_mode={self.scale_calculation_mode}" diff --git a/torchao/prototype/mx_formats/inference_workflow.py b/torchao/prototype/mx_formats/inference_workflow.py index 2ff4eedf5f..5991d8557e 100644 --- a/torchao/prototype/mx_formats/inference_workflow.py +++ b/torchao/prototype/mx_formats/inference_workflow.py @@ -10,12 +10,9 @@ import torch from torchao.core.config import AOBaseConfig -from torchao.prototype.mx_formats import ( - MXGemmKernelChoice, -) from torchao.prototype.mx_formats.config import ( _validate_elem_dtype, - _validate_gemm_kernel_choice, + _validate_kernel_preference, ) from torchao.prototype.mx_formats.mx_tensor import ( MXTensor, @@ -29,6 +26,7 @@ per_tensor_amax_to_scale, ) from torchao.quantization.quant_api import _quantization_type +from torchao.quantization.quantize_.common.kernel_preference import KernelPreference from torchao.quantization.transform_module import ( register_quantize_module_handler, ) @@ -80,7 +78,7 @@ class MXFPInferenceConfig(AOBaseConfig): weight_dtype: torch.dtype = torch.float8_e4m3fn # Which kernel to run for mm - gemm_kernel_choice: MXGemmKernelChoice = MXGemmKernelChoice.CUBLAS + kernel_preference: KernelPreference = KernelPreference.AUTO def __post_init__(self): assert self.activation_dtype == self.weight_dtype, ( @@ -88,8 +86,8 @@ def __post_init__(self): ) _validate_elem_dtype(self.activation_dtype) _validate_elem_dtype(self.weight_dtype) - _validate_gemm_kernel_choice( - self.gemm_kernel_choice, self.block_size, self.weight_dtype + _validate_kernel_preference( + self.kernel_preference, self.block_size, self.weight_dtype ) @@ -109,7 +107,7 @@ def _mx_inference_linear_transform( act_quant_kwargs = QuantizeTensorToMXKwargs( elem_dtype=config.activation_dtype, block_size=config.block_size, - gemm_kernel_choice=config.gemm_kernel_choice, + kernel_preference=config.kernel_preference, is_swizzled_scales=True, ) @@ -118,7 +116,7 @@ def _mx_inference_linear_transform( weight, config.weight_dtype, block_size=config.block_size, - gemm_kernel_choice=config.gemm_kernel_choice, + kernel_preference=config.kernel_preference, act_quant_kwargs=act_quant_kwargs, is_swizzled_scales=True, ) @@ -211,7 +209,6 @@ def _nvfp4_inference_linear_transform( MXTensor, NVFP4Tensor, NVFP4MMConfig, - MXGemmKernelChoice, QuantizeTensorToMXKwargs, QuantizeTensorToNVFP4Kwargs, ScaleCalculationMode, diff --git a/torchao/prototype/mx_formats/mx_linear.py b/torchao/prototype/mx_formats/mx_linear.py index 19d658a6fc..8b9d1576c4 100644 --- a/torchao/prototype/mx_formats/mx_linear.py +++ b/torchao/prototype/mx_formats/mx_linear.py @@ -14,12 +14,12 @@ from torchao.prototype.mx_formats.config import ( MXFP8Dim1CastKernelChoice, - MXGemmKernelChoice, MXLinearConfig, ScaleCalculationMode, ) from torchao.prototype.mx_formats.mx_tensor import MXTensor from torchao.prototype.mx_formats.utils import _to_mxfp8_dim1_kernel_wrapper +from torchao.quantization.quantize_.common.kernel_preference import KernelPreference from torchao.quantization.transform_module import ( register_quantize_module_handler, ) @@ -44,7 +44,7 @@ def forward( w_elem_dtype: Any, grad_elem_dtype: Any, block_size: int, - gemm_kernel_choice: MXGemmKernelChoice, + kernel_preference: KernelPreference, mxfp8_cast_kernel_choice: MXFP8Dim1CastKernelChoice, scale_calculation_mode: ScaleCalculationMode, ): @@ -53,7 +53,7 @@ def forward( ctx.w_elem_dtype = w_elem_dtype ctx.grad_elem_dtype = grad_elem_dtype ctx.block_size = block_size - ctx.gemm_kernel_choice = gemm_kernel_choice + ctx.kernel_preference = kernel_preference ctx.mxfp8_cast_kernel_choice = mxfp8_cast_kernel_choice ctx.scale_calculation_mode = scale_calculation_mode @@ -65,14 +65,14 @@ def forward( input_hp_r, in_elem_dtype, block_size, - gemm_kernel_choice=gemm_kernel_choice, + kernel_preference=kernel_preference, scaling_mode=scale_calculation_mode, ) weight_mx_dim0 = MXTensor.to_mx( weight_hp, w_elem_dtype, block_size, - gemm_kernel_choice=gemm_kernel_choice, + kernel_preference=kernel_preference, scaling_mode=scale_calculation_mode, ) output = torch.mm(input_mx_r_dim0, weight_mx_dim0.t()) @@ -87,7 +87,7 @@ def backward(ctx, grad_output_hp: torch.Tensor): w_elem_dtype = ctx.w_elem_dtype grad_elem_dtype = ctx.grad_elem_dtype block_size = ctx.block_size - gemm_kernel_choice = ctx.gemm_kernel_choice + kernel_preference = ctx.kernel_preference mxfp8_cast_kernel_choice = ctx.mxfp8_cast_kernel_choice scale_calculation_mode = ctx.scale_calculation_mode @@ -102,7 +102,7 @@ def backward(ctx, grad_output_hp: torch.Tensor): grad_output_hp_r, grad_elem_dtype, block_size, - gemm_kernel_choice=gemm_kernel_choice, + kernel_preference=kernel_preference, scaling_mode=scale_calculation_mode, ) @@ -112,7 +112,7 @@ def backward(ctx, grad_output_hp: torch.Tensor): block_size, w_elem_dtype, weight_hp.dtype, - gemm_kernel_choice, + kernel_preference, mxfp8_cast_kernel_choice, scale_calculation_mode, ) @@ -122,7 +122,7 @@ def backward(ctx, grad_output_hp: torch.Tensor): weight_hp_t_c, w_elem_dtype, block_size, - gemm_kernel_choice=gemm_kernel_choice, + kernel_preference=kernel_preference, scaling_mode=scale_calculation_mode, ) grad_input = torch.mm(grad_output_mx_dim0, weight_mx_dim1.t()) @@ -137,7 +137,7 @@ def backward(ctx, grad_output_hp: torch.Tensor): block_size, grad_elem_dtype, grad_output_hp_r.dtype, - gemm_kernel_choice, + kernel_preference, mxfp8_cast_kernel_choice, scale_calculation_mode, ) @@ -146,7 +146,7 @@ def backward(ctx, grad_output_hp: torch.Tensor): grad_output_hp_r.t().contiguous(), grad_elem_dtype, block_size, - gemm_kernel_choice=gemm_kernel_choice, + kernel_preference=kernel_preference, scaling_mode=scale_calculation_mode, ) @@ -156,7 +156,7 @@ def backward(ctx, grad_output_hp: torch.Tensor): block_size, in_elem_dtype, input_hp_r.dtype, - gemm_kernel_choice, + kernel_preference, mxfp8_cast_kernel_choice, scale_calculation_mode, ) @@ -166,7 +166,7 @@ def backward(ctx, grad_output_hp: torch.Tensor): input_hp_r.t().contiguous(), in_elem_dtype, block_size, - gemm_kernel_choice=gemm_kernel_choice, + kernel_preference=kernel_preference, scaling_mode=scale_calculation_mode, ) input_t_mx_dim0 = input_t_mx_dim0_tmp.t() @@ -215,7 +215,7 @@ def forward(self, x): config.elem_dtype_weight_override or config.elem_dtype, config.elem_dtype_grad_output_override or config.elem_dtype, config.block_size, - config.gemm_kernel_choice, + config.kernel_preference, config.mxfp8_cast_kernel_choice, config.scale_calculation_mode, ) diff --git a/torchao/prototype/mx_formats/mx_tensor.py b/torchao/prototype/mx_formats/mx_tensor.py index 7a1b5a160b..74f37bc2df 100644 --- a/torchao/prototype/mx_formats/mx_tensor.py +++ b/torchao/prototype/mx_formats/mx_tensor.py @@ -29,7 +29,7 @@ from torch.utils._pytree import tree_map import torchao.ops -from torchao.prototype.mx_formats.config import MXGemmKernelChoice, ScaleCalculationMode +from torchao.prototype.mx_formats.config import ScaleCalculationMode from torchao.prototype.mx_formats.constants import ( BLOCK_SIZE_DEFAULT, DTYPE_FP6_E2M3, @@ -69,6 +69,7 @@ from torchao.quantization.quantize_.common import ( QuantizeTensorKwargs, ) +from torchao.quantization.quantize_.common.kernel_preference import KernelPreference from torchao.utils import TorchAOBaseTensor, fill_defaults aten = torch.ops.aten @@ -87,7 +88,7 @@ class QuantizeTensorToMXKwargs(QuantizeTensorKwargs): elem_dtype: Union[torch.dtype, str] = torch.float8_e4m3fn block_size: int = 32 scaling_mode: ScaleCalculationMode = ScaleCalculationMode.FLOOR - gemm_kernel_choice: MXGemmKernelChoice = MXGemmKernelChoice.EMULATED + kernel_preference: KernelPreference = KernelPreference.EMULATED is_swizzled_scales: bool = False @@ -438,7 +439,7 @@ class MXTensor(TorchAOBaseTensor): "_elem_dtype", "block_size", "_orig_dtype", - "_gemm_kernel_choice", + "kernel_preference", "act_quant_kwargs", "_is_swizzled_scales", ] @@ -450,7 +451,7 @@ def __new__( elem_dtype, block_size, orig_dtype, - gemm_kernel_choice, + kernel_preference, act_quant_kwargs, is_swizzled_scales, ): @@ -487,7 +488,7 @@ def __new__( self._elem_dtype = elem_dtype self.block_size = block_size self._orig_dtype = orig_dtype - self._gemm_kernel_choice = gemm_kernel_choice + self.kernel_preference = kernel_preference self.act_quant_kwargs = act_quant_kwargs self._is_swizzled_scales = is_swizzled_scales return self @@ -497,7 +498,7 @@ def __repr__(self): return f"MXTensor: elem_dtype: {self._elem_dtype}, s_e8m0: {self.scale}, d: {self.qdata}, act_quant_kwargs: {self.act_quant_kwargs}, _is_swizzled_scales={self._is_swizzled_scales}" # noqa: E501 def _quantization_type(self): - return f"{self._elem_dtype=}, {self.block_size=}, {self._orig_dtype=}, {self._gemm_kernel_choice=}, {self.act_quant_kwargs=}" + return f"{self._elem_dtype=}, {self.block_size=}, {self._orig_dtype=}, {self.kernel_preference=}, {self.act_quant_kwargs=}" def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor: if output_dtype is None: @@ -534,7 +535,7 @@ def to_mx( block_size: int = BLOCK_SIZE_DEFAULT, scaling_mode: ScaleCalculationMode = ScaleCalculationMode.FLOOR, # TODO(future PR): switch default gemm to cublas - gemm_kernel_choice: MXGemmKernelChoice = MXGemmKernelChoice.EMULATED, + kernel_preference: KernelPreference = KernelPreference.EMULATED, act_quant_kwargs: Optional[QuantizeTensorToMXKwargs] = None, is_swizzled_scales: bool = False, ): @@ -551,7 +552,7 @@ def to_mx( elem_dtype, block_size, data_hp.dtype, - gemm_kernel_choice, + kernel_preference, act_quant_kwargs, is_swizzled_scales, ) @@ -569,7 +570,7 @@ def to_mx( elem_dtype, block_size, data_hp.dtype, - gemm_kernel_choice, + kernel_preference, act_quant_kwargs, is_swizzled_scales, ) @@ -589,8 +590,8 @@ def _(func, types, args, kwargs): def _get_gemm_choice( - choice_a: Optional[MXGemmKernelChoice], choice_b: Optional[MXGemmKernelChoice] -) -> MXGemmKernelChoice: + choice_a: Optional[KernelPreference], choice_b: Optional[KernelPreference] +) -> KernelPreference: if choice_a is not None and choice_b is not None: assert choice_a == choice_b, ( "Both MXTensor inputs must have the same gemm config if specified" @@ -620,13 +621,13 @@ def _addmm_mx_dispatch( k.elem_dtype, k.block_size, k.scaling_mode, - k.gemm_kernel_choice, + k.kernel_preference, k.is_swizzled_scales, ) - gemm_choice = _get_gemm_choice(a._gemm_kernel_choice, b._gemm_kernel_choice) + gemm_choice = _get_gemm_choice(a.kernel_preference, b.kernel_preference) - if gemm_choice in (MXGemmKernelChoice.CUBLAS, MXGemmKernelChoice.CUTLASS): + if gemm_choice == KernelPreference.AUTO: # real MX gemm backed by torchao's CUTLASS kernels M, K, N = a.shape[0], a.shape[1], b.shape[1] assert a.qdata.is_contiguous() @@ -648,10 +649,6 @@ def _addmm_mx_dispatch( if a._elem_dtype == torch.float8_e4m3fn: assert b._elem_dtype == torch.float8_e4m3fn - assert gemm_choice is MXGemmKernelChoice.CUBLAS, ( - "CUBLAS is the only supported kernel choice for MX FP8 operations" - ) - res = torch._scaled_mm( a.qdata, b.qdata, @@ -663,7 +660,6 @@ def _addmm_mx_dispatch( else: assert a._elem_dtype == torch.float4_e2m1fn_x2 assert b._elem_dtype == torch.float4_e2m1fn_x2 - assert gemm_choice is MXGemmKernelChoice.CUTLASS, "unsupported" # FP4 operations res = torchao.ops.mx_fp4_bf16( a.qdata, b.qdata, a_scale_block, b_scale_block @@ -673,6 +669,7 @@ def _addmm_mx_dispatch( res = res + bias else: + assert gemm_choice == KernelPreference.EMULATED, "unimplemented" # emulated MX gemm a_hp = a.dequantize(a._orig_dtype) b_hp = b.dequantize(b._orig_dtype) @@ -738,7 +735,7 @@ def mx_t(func, types, args, kwargs): old._elem_dtype, old.block_size, old._orig_dtype, - old._gemm_kernel_choice, + old.kernel_preference, old.act_quant_kwargs, old._is_swizzled_scales, ) @@ -779,7 +776,7 @@ def mx_view_op(func, types, args, kwargs): args[0]._elem_dtype, args[0].block_size, args[0]._orig_dtype, - args[0]._gemm_kernel_choice, + args[0].kernel_preference, args[0].act_quant_kwargs, args[0]._is_swizzled_scales, ) @@ -804,7 +801,7 @@ def mx_slice(func, types, args, kwargs): x._elem_dtype, x.block_size, x._orig_dtype, - x._gemm_kernel_choice, + x.kernel_preference, x.act_quant_kwargs, x._is_swizzled_scales, ), @@ -838,7 +835,7 @@ def mx_select(func, types, args, kwargs): old_mx_tensor._elem_dtype, old_mx_tensor.block_size, old_mx_tensor._orig_dtype, - old_mx_tensor._gemm_kernel_choice, + old_mx_tensor.kernel_preference, old_mx_tensor.act_quant_kwargs, old_mx_tensor._is_swizzled_scales, ) diff --git a/torchao/quantization/quantize_/common/kernel_preference.py b/torchao/quantization/quantize_/common/kernel_preference.py index 8f53f55c6a..45ae4d2ab6 100644 --- a/torchao/quantization/quantize_/common/kernel_preference.py +++ b/torchao/quantization/quantize_/common/kernel_preference.py @@ -30,5 +30,13 @@ class KernelPreference(str, Enum): """ FBGEMM = "fbgemm" + """Emulates gemm_lowp(A, B) with gemm_fp32(A.dequantize(), B.dequantize()). + Intended use cases are: + 1. Running CI for product logic on hardware which does not support the + actual lowp gemm. + 2. Debugging kernel numerics issues. + """ + EMULATED = "emulated" + torch.serialization.add_safe_globals([KernelPreference])