Skip to content

Commit 42e1345

Browse files
authored
mx_formats: move inference to the quantize_ API (#1971)
* Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned]
1 parent ce759d5 commit 42e1345

File tree

5 files changed

+109
-104
lines changed

5 files changed

+109
-104
lines changed

test/prototype/mx_formats/test_mx_linear.py

+9-8
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import torch.nn as nn
1212

1313
from torchao.prototype.mx_formats.config import (
14+
MXInferenceLinearConfig,
1415
MXLinearConfig,
1516
MXLinearRecipeName,
1617
)
@@ -23,7 +24,6 @@
2324
from torchao.prototype.mx_formats.mx_linear import (
2425
MXInferenceLinear,
2526
MXLinear,
26-
swap_linear_with_mx_inference_linear,
2727
)
2828
from torchao.quantization import quantize_
2929
from torchao.quantization.utils import compute_error
@@ -294,8 +294,8 @@ def test_inference_linear(elem_dtype, bias, input_shape):
294294
m = nn.Sequential(nn.Linear(4, 8, bias=bias, dtype=torch.bfloat16))
295295
m = m.cuda()
296296
m_mx = copy.deepcopy(m)
297-
config = MXLinearConfig(block_size=4, elem_dtype=elem_dtype)
298-
swap_linear_with_mx_inference_linear(m_mx, config=config)
297+
config = MXInferenceLinearConfig(block_size=4, elem_dtype=elem_dtype)
298+
quantize_(m_mx, config=config)
299299

300300
x = torch.randn(*input_shape, device="cuda", dtype=torch.bfloat16)
301301
y_ref = m(x)
@@ -319,8 +319,8 @@ def test_inference_compile_simple(elem_dtype):
319319
m = nn.Sequential(nn.Linear(4, 8, bias=False, dtype=torch.bfloat16))
320320
m = m.cuda()
321321
m_mx = copy.deepcopy(m)
322-
config = MXLinearConfig(block_size=4, elem_dtype=elem_dtype)
323-
swap_linear_with_mx_inference_linear(m_mx, config=config)
322+
config = MXInferenceLinearConfig(block_size=4, elem_dtype=elem_dtype)
323+
quantize_(m_mx, config=config)
324324
m_mx = torch.compile(m_mx, fullgraph="true")
325325

326326
x = torch.randn(2, 4, device="cuda", dtype=torch.bfloat16)
@@ -346,7 +346,8 @@ def test_filter_fn():
346346
assert type(m1[0]) == MXLinear
347347
assert type(m1[1]) == torch.nn.Linear
348348

349-
swap_linear_with_mx_inference_linear(m2, config=config, filter_fn=filter_fn) # noqa: E501
349+
config2 = MXInferenceLinearConfig(block_size=32)
350+
quantize_(m2, config=config2, filter_fn=filter_fn) # noqa: E501
350351
assert type(m2[0]) == MXInferenceLinear
351352
assert type(m2[1]) == torch.nn.Linear
352353

@@ -362,8 +363,8 @@ def test_training_print_str():
362363

363364
def test_inference_print_str():
364365
m = nn.Sequential(nn.Linear(32, 32))
365-
config = MXLinearConfig()
366-
swap_linear_with_mx_inference_linear(m, config=config)
366+
config = MXInferenceLinearConfig()
367+
quantize_(m, config=config)
367368
s = str(m)
368369
assert "bl_sz=32" in s
369370
assert "kernel=emulated" in s

torchao/prototype/mx_formats/README.md

+9-4
Original file line numberDiff line numberDiff line change
@@ -68,12 +68,17 @@ This is a module to do MX inference, weights are in MX and matmul is in high pre
6868

6969
```python
7070
import torch
71-
from torchao.prototype.mx_formats.mx_linear import swap_linear_with_mx_inference_linear
72-
from torchao.prototype.mx_formats.config import MXLinearConfig
71+
from torchao.quantization import quantize_
72+
from torchao.prototype.mx_formats import MXInferenceLinearConfig, MXGemmKernelChoice
7373

7474
m = torch.nn.Sequential(torch.nn.Linear(32, 32)).cuda()
75-
config = MXLinearConfig(elem_dtype=torch.float8_e4m3fn, block_size=32)
76-
swap_linear_with_mx_inference_linear(m, config=config)
75+
gemm_kernel_choice = MXGemmKernelChoice.CUBLAS
76+
config = MXInferenceLinearConfig(
77+
elem_dtype=torch.float8_e4m3fn,
78+
block_size=32,
79+
gemm_kernel_choice=gemm_kernel_choice,
80+
)
81+
quantize_(m, config=config)
7782

7883
# do inference (not shown)
7984
```

torchao/prototype/mx_formats/__init__.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from torchao.prototype.mx_formats.config import (
22
MXGemmKernelChoice,
3+
MXInferenceLinearConfig,
34
MXLinearConfig,
45
MXLinearRecipeName,
56
)
@@ -9,7 +10,8 @@
910
import torchao.prototype.mx_formats.mx_linear # noqa: F401
1011

1112
__all__ = [
12-
"MXLinearConfig",
1313
"MXGemmKernelChoice",
14+
"MXInferenceLinearConfig",
15+
"MXLinearConfig",
1416
"MXLinearRecipeName",
1517
]

torchao/prototype/mx_formats/config.py

+78-45
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
from torchao.core.config import AOBaseConfig
1414
from torchao.prototype.mx_formats.constants import (
1515
DTYPE_FP4,
16+
DTYPE_FP6_E2M3,
17+
DTYPE_FP6_E3M2,
1618
DTYPE_TO_SHORT_STR,
1719
SUPPORTED_ELEM_DTYPES,
1820
)
@@ -41,6 +43,31 @@ class MXLinearRecipeName(Enum):
4143
MXFP4_CUTLASS = "mxfp4_cutlass"
4244

4345

46+
def _validate_elem_dtype(elem_dtype):
47+
assert (
48+
elem_dtype in SUPPORTED_ELEM_DTYPES
49+
), f"elem_dtype: expected one of {SUPPORTED_ELEM_DTYPES}, got {elem_dtype}"
50+
51+
52+
def _validate_gemm_kernel_choice(gemm_kernel_choice, block_size, elem_dtype):
53+
if gemm_kernel_choice == MXGemmKernelChoice.CUTLASS:
54+
assert (
55+
block_size == 32
56+
), f"block_size must be 32 to use the CUTLASS MX gemm kernels, got {block_size}"
57+
valid_dtypes = [torch.float8_e4m3fn, DTYPE_FP4]
58+
assert (
59+
elem_dtype in valid_dtypes
60+
), f"elem_dtype must be one of {valid_dtypes} to use the CUTLASS MX gemm kernels, got {elem_dtype}"
61+
elif gemm_kernel_choice == MXGemmKernelChoice.CUBLAS:
62+
assert (
63+
block_size == 32
64+
), f"block_size must be 32 to use the cuBLAS MX gemm kernels, got {block_size}"
65+
valid_dtypes = [torch.float8_e4m3fn]
66+
assert (
67+
elem_dtype in valid_dtypes
68+
), f"elem_dtype must be one of {valid_dtypes} to use the CUTLASS MX gemm kernels, got {elem_dtype}"
69+
70+
4471
@dataclass
4572
class MXLinearConfig(AOBaseConfig):
4673
# block size for scaling, default is 32 to match
@@ -68,53 +95,17 @@ class MXLinearConfig(AOBaseConfig):
6895
# If True, uses a custom triton kernel for fp4 dequantize
6996
use_fp4_custom_triton_dequant_kernel: bool = False
7097

71-
# If True, packs 4xFP6 into 3xuint8 containers for inference, using custom triton
72-
# kernels (fused unpack/dequantize). Training not currently supported.
73-
pack_fp6 = True if hasattr(torch.library, "custom_op") else False
74-
7598
def __post_init__(self):
76-
# validate elem_dtype and its overrides
77-
assert (
78-
self.elem_dtype in SUPPORTED_ELEM_DTYPES
79-
), f"elem_dtype: expected one of {SUPPORTED_ELEM_DTYPES}, got {self.elem_dtype}"
99+
_validate_elem_dtype(self.elem_dtype)
100+
_validate_gemm_kernel_choice(
101+
self.gemm_kernel_choice, self.block_size, self.elem_dtype
102+
)
80103
if self.elem_dtype_weight_override is not None:
81-
assert (
82-
self.elem_dtype_weight_override in SUPPORTED_ELEM_DTYPES
83-
), f"elem_dtype_weight_override: expected one of {SUPPORTED_ELEM_DTYPES}, got {self.elem_dtype}"
104+
_validate_elem_dtype(self.elem_dtype_weight_override)
105+
assert self.gemm_kernel_choice == MXGemmKernelChoice.EMULATED, "unsupported"
84106
if self.elem_dtype_grad_output_override is not None:
85-
assert (
86-
self.elem_dtype_grad_output_override in SUPPORTED_ELEM_DTYPES
87-
), f"elem_dtype_grad_output_override: expected one of {SUPPORTED_ELEM_DTYPES}, got {self.elem_dtype}"
88-
89-
# validate that block size and elem_dtype matches kernel choice
90-
if self.gemm_kernel_choice == MXGemmKernelChoice.CUTLASS:
91-
assert (
92-
self.block_size == 32
93-
), f"block_size must be 32 to use the CUTLASS MX gemm kernels, got {self.block_size}"
94-
valid_dtypes = [torch.float8_e4m3fn, DTYPE_FP4]
95-
assert (
96-
self.elem_dtype in valid_dtypes
97-
), f"elem_dtype must be one of {valid_dtypes} to use the CUTLASS MX gemm kernels, got {self.elem_dtype}"
98-
assert (
99-
self.elem_dtype_weight_override is None
100-
), "elem_dtype_weight_override not supported for CUTLASS MX gemm kernels"
101-
assert (
102-
self.elem_dtype_grad_output_override is None
103-
), "elem_dtype_grad_output_override not supported for CUTLASS MX gemm kernels"
104-
elif self.gemm_kernel_choice == MXGemmKernelChoice.CUBLAS:
105-
assert (
106-
self.block_size == 32
107-
), f"block_size must be 32 to use the cuBLAS MX gemm kernels, got {self.block_size}"
108-
valid_dtypes = [torch.float8_e4m3fn]
109-
assert (
110-
self.elem_dtype in valid_dtypes
111-
), f"elem_dtype must be one of {valid_dtypes} to use the CUTLASS MX gemm kernels, got {self.elem_dtype}"
112-
assert (
113-
self.elem_dtype_weight_override is None
114-
), "elem_dtype_weight_override not supported for CUTLASS MX gemm kernels"
115-
assert (
116-
self.elem_dtype_grad_output_override is None
117-
), "elem_dtype_grad_output_override not supported for CUTLASS MX gemm kernels"
107+
_validate_elem_dtype(self.elem_dtype_grad_output_override)
108+
assert self.gemm_kernel_choice == MXGemmKernelChoice.EMULATED, "unsupported"
118109

119110
@staticmethod
120111
def from_recipe_name(
@@ -162,5 +153,47 @@ def short_str(self) -> str:
162153
s += ", use_fp8_dim1_cast_triton_kernel=True"
163154
if self.use_fp4_custom_triton_dequant_kernel:
164155
s += ", use_fp4_custom_triton_dequant_kernel=True"
165-
# TODO(future PR): split training from inference and add fp6 here
166156
return s
157+
158+
159+
@dataclass
160+
class MXInferenceLinearConfig(AOBaseConfig):
161+
# block size for scaling, default is 32 to match
162+
# https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf,
163+
# section 5.2
164+
block_size: int = 32
165+
166+
# element dtype, used for activations, weights and gradients
167+
elem_dtype: Any = torch.float8_e4m3fn
168+
# TODO(future PR): support different elem_dtype for activations vs weights
169+
170+
# defines the gemm kernel choice, if the chosen kernel is not supported
171+
# on the given hardware an exception will be thrown
172+
gemm_kernel_choice: MXGemmKernelChoice = MXGemmKernelChoice.EMULATED
173+
174+
# If True, uses a custom triton kernel for fp4 dequantize
175+
use_fp4_custom_triton_dequant_kernel: bool = False
176+
177+
# If True, packs 4xFP6 into 3xuint8 containers for inference, using custom triton
178+
# kernels (fused unpack/dequantize).
179+
pack_fp6: bool = True
180+
181+
def __post_init__(self):
182+
_validate_elem_dtype(self.elem_dtype)
183+
_validate_gemm_kernel_choice(
184+
self.gemm_kernel_choice, self.block_size, self.elem_dtype
185+
)
186+
187+
def short_str(self) -> str:
188+
"""
189+
Returns a concise representation of the current config.
190+
"""
191+
s = f"bl_sz={self.block_size}, lp_dtype={DTYPE_TO_SHORT_STR[self.elem_dtype]}"
192+
s += f", kernel={self.gemm_kernel_choice.value}"
193+
if self.use_fp4_custom_triton_dequant_kernel:
194+
s += ", use_fp4_custom_triton_dequant_kernel=True"
195+
if self.elem_dtype in (DTYPE_FP6_E2M3, DTYPE_FP6_E3M2) and self.pack_fp6:
196+
s += ", pack_fp6=True"
197+
return s
198+
199+
# TODO(future PR): add a recipe to config API for inference

torchao/prototype/mx_formats/mx_linear.py

+10-46
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,11 @@
1313
import torch
1414
import torch.nn.functional as F
1515

16-
from torchao.prototype.mx_formats.config import MXGemmKernelChoice, MXLinearConfig
16+
from torchao.prototype.mx_formats.config import (
17+
MXGemmKernelChoice,
18+
MXInferenceLinearConfig,
19+
MXLinearConfig,
20+
)
1721
from torchao.prototype.mx_formats.custom_cast import triton_to_mxfp8_dim1
1822
from torchao.prototype.mx_formats.mx_tensor import MXTensor
1923
from torchao.quantization.transform_module import (
@@ -234,7 +238,7 @@ class MXInferenceLinear(torch.nn.Linear):
234238
def from_float(
235239
cls,
236240
mod,
237-
config: Optional[MXLinearConfig] = MXLinearConfig(),
241+
config: Optional[MXInferenceLinearConfig] = MXInferenceLinearConfig(),
238242
):
239243
with torch.device("meta"):
240244
super_kwargs = {
@@ -267,53 +271,13 @@ def extra_repr(self):
267271
return s
268272

269273

270-
def replace_with_custom_fn_if_matches_filter(
271-
model, replacement_fn, filter_fn, cur_fqn=""
272-
) -> None:
273-
"""
274-
For each `child` in `model`, replaces it with `replacement_fn(child)`
275-
if `filter_fn(child)` is `True`
276-
"""
277-
name_to_child = dict(model.named_children())
278-
for name, child in name_to_child.items():
279-
if cur_fqn == "":
280-
new_fqn = name
281-
else:
282-
new_fqn = f"{cur_fqn}.{name}"
283-
if filter_fn(child, new_fqn):
284-
new_child = replacement_fn(child)
285-
setattr(model, name, new_child)
286-
else:
287-
replace_with_custom_fn_if_matches_filter(
288-
child, replacement_fn, filter_fn, new_fqn
289-
)
290-
291-
292-
def _is_linear(mod, fqn):
293-
return isinstance(mod, torch.nn.Linear)
294-
295-
296274
@register_quantize_module_handler(MXLinearConfig)
297275
def _mx_linear_transform(module: torch.nn.Module, config: MXLinearConfig):
298276
return MXLinear.from_float(module, config=config)
299277

300278

301-
def swap_linear_with_mx_inference_linear(
302-
model,
303-
*,
304-
config: Optional[MXLinearConfig] = None,
305-
filter_fn=None,
279+
@register_quantize_module_handler(MXInferenceLinearConfig)
280+
def _mx_inference_linear_transform(
281+
module: torch.nn.Module, config: MXInferenceLinearConfig
306282
):
307-
if filter_fn is None:
308-
combined_filter_fn = _is_linear
309-
else:
310-
311-
def __fn(mod, fqn):
312-
return _is_linear(mod, fqn) and filter_fn(mod, fqn)
313-
314-
combined_filter_fn = __fn
315-
replace_with_custom_fn_if_matches_filter(
316-
model,
317-
lambda mod: MXInferenceLinear.from_float(mod, config=config),
318-
combined_filter_fn,
319-
)
283+
return MXInferenceLinear.from_float(module, config=config)

0 commit comments

Comments
 (0)