Skip to content

Commit 4334f60

Browse files
committed
5/x mx cleanup: rename mx inf config to MXDynamicActivationMXWeightConfig
Summary: Renames the MX inference config to `MXDynamicActivationMXWeightConfig` to match naming of other AO inference workflow configs. Test Plan: ``` pytest test/prototype/mx_formats/ -s -x ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: ffe82cc ghstack-comment-id: 3572506339 Pull-Request: #3386
1 parent d9aa3fe commit 4334f60

File tree

7 files changed

+19
-42
lines changed

7 files changed

+19
-42
lines changed

benchmarks/float8/float8_inference_roofline.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939

4040
import torchao
4141
from torchao.prototype.mx_formats.inference_workflow import (
42-
MXFPInferenceConfig,
42+
MXDynamicActivationMXWeightConfig,
4343
NVFP4InferenceConfig,
4444
NVFP4MMConfig,
4545
)
@@ -433,13 +433,13 @@ def run(
433433
kernel_preference=KernelPreference.TORCH,
434434
)
435435
elif recipe_name == "mxfp8_cublas":
436-
config = MXFPInferenceConfig(
436+
config = MXDynamicActivationMXWeightConfig(
437437
activation_dtype=torch.float8_e4m3fn,
438438
weight_dtype=torch.float8_e4m3fn,
439439
kernel_preference=KernelPreference.AUTO,
440440
)
441441
elif recipe_name == "mxfp4_cutlass":
442-
config = MXFPInferenceConfig(
442+
config = MXDynamicActivationMXWeightConfig(
443443
activation_dtype=torch.float4_e2m1fn_x2,
444444
weight_dtype=torch.float4_e2m1fn_x2,
445445
kernel_preference=KernelPreference.AUTO,

test/integration/test_vllm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig
4242
from vllm import LLM, SamplingParams
4343

44-
from torchao.prototype.mx_formats import MXFPInferenceConfig
44+
from torchao.prototype.mx_formats import MXDynamicActivationMXWeightConfig
4545
from torchao.quantization.granularity import PerRow, PerTensor
4646
from torchao.quantization.quant_api import (
4747
CutlassInt4PackedLayout,
@@ -70,7 +70,7 @@ def get_tests() -> List[TorchAoConfig]:
7070
Int8DynamicActivationInt4WeightConfig(layout=CutlassInt4PackedLayout())
7171
)
7272
]
73-
SM100_TESTS = [TorchAoConfig(MXFPInferenceConfig())]
73+
SM100_TESTS = [TorchAoConfig(MXDynamicActivationMXWeightConfig())]
7474

7575
# Check CUDA availability first
7676
if not torch.cuda.is_available():

test/prototype/mx_formats/test_inference_workflow.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from torch.profiler import ProfilerActivity, profile
1414

1515
from torchao.prototype.mx_formats.inference_workflow import (
16-
MXFPInferenceConfig,
16+
MXDynamicActivationMXWeightConfig,
1717
NVFP4InferenceConfig,
1818
NVFP4MMConfig,
1919
)
@@ -106,7 +106,7 @@ def test_inference_workflow_mx(
106106
kernel_choice = KernelPreference.EMULATED
107107
else:
108108
kernel_choice = KernelPreference.AUTO
109-
config = MXFPInferenceConfig(
109+
config = MXDynamicActivationMXWeightConfig(
110110
activation_dtype=elem_dtype,
111111
weight_dtype=elem_dtype,
112112
kernel_preference=kernel_choice,
@@ -247,7 +247,7 @@ class VLLMIntegrationTestCase(TorchAOIntegrationTestCase):
247247
reason="torch.compile requires PyTorch 2.8+",
248248
)
249249
def test_slice_and_copy_similar_to_vllm(self):
250-
config = MXFPInferenceConfig(
250+
config = MXDynamicActivationMXWeightConfig(
251251
activation_dtype=torch.float8_e4m3fn,
252252
weight_dtype=torch.float8_e4m3fn,
253253
kernel_preference=KernelPreference.EMULATED,
@@ -260,7 +260,7 @@ def test_slice_and_copy_similar_to_vllm(self):
260260
reason="torch.compile requires PyTorch 2.8+",
261261
)
262262
def test_narrow_similar_to_vllm(self):
263-
config = MXFPInferenceConfig(
263+
config = MXDynamicActivationMXWeightConfig(
264264
activation_dtype=torch.float8_e4m3fn,
265265
weight_dtype=torch.float8_e4m3fn,
266266
kernel_preference=KernelPreference.EMULATED,

test/prototype/mx_formats/test_mx_serialization.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import torch.nn as nn
1414

1515
from torchao.prototype.mx_formats.inference_workflow import (
16-
MXFPInferenceConfig,
16+
MXDynamicActivationMXWeightConfig,
1717
NVFP4InferenceConfig,
1818
NVFP4MMConfig,
1919
)
@@ -41,7 +41,7 @@ def test_serialization(recipe_name):
4141
fname = None
4242
with tempfile.NamedTemporaryFile(delete=False, mode="w") as f:
4343
if recipe_name == "mxfp8":
44-
config = MXFPInferenceConfig(
44+
config = MXDynamicActivationMXWeightConfig(
4545
activation_dtype=torch.float8_e4m3fn,
4646
weight_dtype=torch.float8_e4m3fn,
4747
kernel_preference=KernelPreference.EMULATED,

torchao/prototype/mx_formats/README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ import torch.nn as nn
108108
from torchao.quantization import quantize_
109109
import torchao.prototype.mx_formats
110110
from torchao.prototype.mx_formats.inference_workflow import (
111-
MXFPInferenceConfig,
111+
MXDynamicActivationMXWeightConfig,
112112
NVFP4InferenceConfig,
113113
NVFP4MMConfig,
114114
)
@@ -120,7 +120,7 @@ x = torch.randn(128, 32, device="cuda", dtype=torch.bfloat16)
120120
# mxfp8
121121

122122
m_mxfp8 = copy.deepcopy(m)
123-
config = MXFPInferenceConfig(
123+
config = MXDynamicActivationMXWeightConfig(
124124
activation_dtype=torch.float8_e4m3fn,
125125
weight_dtype=torch.float8_e4m3fn,
126126
kernel_preference=KernelPreference.AUTO,
@@ -132,7 +132,7 @@ y_mxfp8 = m_mxfp8(x)
132132
# mxfp4
133133

134134
m_mxfp4 = copy.deepcopy(m)
135-
config = MXFPInferenceConfig(
135+
config = MXDynamicActivationMXWeightConfig(
136136
activation_dtype=torch.float4_e2m1fn_x2,
137137
weight_dtype=torch.float4_e2m1fn_x2,
138138
kernel_preference=KernelPreference.AUTO,

torchao/prototype/mx_formats/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
# Note: Prototype and subject to change
77
from torchao.prototype.mx_formats.inference_workflow import (
8-
MXFPInferenceConfig,
8+
MXDynamicActivationMXWeightConfig,
99
NVFP4InferenceConfig,
1010
NVFP4MMConfig,
1111
)
@@ -17,7 +17,7 @@
1717
__all__ = [
1818
"MXLinearConfig",
1919
"MXLinearRecipeName",
20-
"MXFPInferenceConfig",
20+
"MXDynamicActivationMXWeightConfig",
2121
"NVFP4InferenceConfig",
2222
"NVFP4MMConfig",
2323
]

torchao/prototype/mx_formats/inference_workflow.py

Lines changed: 3 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -36,39 +36,16 @@
3636
)
3737

3838

39-
# TODO The naming for these configs is a little weird, rename before moving to public API
40-
# Note: This API is extra prototype and will change in the future
4139
@dataclass
42-
class MXFPInferenceConfig(AOBaseConfig):
40+
class MXDynamicActivationMXWeightConfig(AOBaseConfig):
4341
"""
4442
MX Format Inference Quantization
4543
4644
This module provides support for running inference with float8 quantization using MX formats.
47-
The quantization flow works as follows:
48-
49-
1. Weight Quantization:
50-
- In _mx_inference_linear_transform(), the module's weight is converted to an MXTensor
51-
- The weight is quantized to the specified dtype (float8_e4m3fn by default)
52-
- This happens when quantize_() is called with an MXFPInferenceConfig
53-
54-
2. Activation Quantization:
55-
- A callable (_input_activation_quant_func_mxfp) is defined that will quantize
56-
activations during inference to the same dtype
57-
- This function is passed to to_linear_activation_quantized() along with the
58-
already-quantized weight
59-
60-
3. Runtime Flow:
61-
- When the quantized module is called, the input goes through the LinearActivationQuantizedTensor
62-
- The input (activation) is quantized just-in-time using the provided function
63-
- The MX quantized activation and MX weight are used together in F.linear
6445
6546
Requirements:
6647
- NVIDIA SM100+ hardware (Blackwell or newer) is required for execution
6748
- PyTorch 2.5+ for proper serialization support
68-
69-
See also:
70-
- LinearActivationQuantizedTensor in torchao.quantization.quant_api
71-
- MXTensor in torchao.prototype.mx_formats.mx_tensor
7249
"""
7350

7451
block_size: int = 32
@@ -95,9 +72,9 @@ def _linear_extra_repr(self):
9572
return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight={_quantization_type(self.weight)}"
9673

9774

98-
@register_quantize_module_handler(MXFPInferenceConfig)
75+
@register_quantize_module_handler(MXDynamicActivationMXWeightConfig)
9976
def _mx_inference_linear_transform(
100-
module: torch.nn.Module, config: MXFPInferenceConfig
77+
module: torch.nn.Module, config: MXDynamicActivationMXWeightConfig
10178
):
10279
weight = module.weight
10380

0 commit comments

Comments
 (0)