Skip to content
Merged
6 changes: 3 additions & 3 deletions benchmarks/float8/float8_inference_roofline.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@

import torchao
from torchao.prototype.mx_formats.inference_workflow import (
MXFPInferenceConfig,
MXDynamicActivationMXWeightConfig,
NVFP4InferenceConfig,
NVFP4MMConfig,
)
Expand Down Expand Up @@ -433,13 +433,13 @@ def run(
kernel_preference=KernelPreference.TORCH,
)
elif recipe_name == "mxfp8_cublas":
config = MXFPInferenceConfig(
config = MXDynamicActivationMXWeightConfig(
activation_dtype=torch.float8_e4m3fn,
weight_dtype=torch.float8_e4m3fn,
kernel_preference=KernelPreference.AUTO,
)
elif recipe_name == "mxfp4_cutlass":
config = MXFPInferenceConfig(
config = MXDynamicActivationMXWeightConfig(
activation_dtype=torch.float4_e2m1fn_x2,
weight_dtype=torch.float4_e2m1fn_x2,
kernel_preference=KernelPreference.AUTO,
Expand Down
4 changes: 2 additions & 2 deletions test/integration/test_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig
from vllm import LLM, SamplingParams

from torchao.prototype.mx_formats import MXFPInferenceConfig
from torchao.prototype.mx_formats import MXDynamicActivationMXWeightConfig
from torchao.quantization.granularity import PerRow, PerTensor
from torchao.quantization.quant_api import (
CutlassInt4PackedLayout,
Expand Down Expand Up @@ -70,7 +70,7 @@ def get_tests() -> List[TorchAoConfig]:
Int8DynamicActivationInt4WeightConfig(layout=CutlassInt4PackedLayout())
)
]
SM100_TESTS = [TorchAoConfig(MXFPInferenceConfig())]
SM100_TESTS = [TorchAoConfig(MXDynamicActivationMXWeightConfig())]

# Check CUDA availability first
if not torch.cuda.is_available():
Expand Down
8 changes: 4 additions & 4 deletions test/prototype/mx_formats/test_inference_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from torch.profiler import ProfilerActivity, profile

from torchao.prototype.mx_formats.inference_workflow import (
MXFPInferenceConfig,
MXDynamicActivationMXWeightConfig,
NVFP4InferenceConfig,
NVFP4MMConfig,
)
Expand Down Expand Up @@ -106,7 +106,7 @@ def test_inference_workflow_mx(
kernel_choice = KernelPreference.EMULATED
else:
kernel_choice = KernelPreference.AUTO
config = MXFPInferenceConfig(
config = MXDynamicActivationMXWeightConfig(
activation_dtype=elem_dtype,
weight_dtype=elem_dtype,
kernel_preference=kernel_choice,
Expand Down Expand Up @@ -247,7 +247,7 @@ class VLLMIntegrationTestCase(TorchAOIntegrationTestCase):
reason="torch.compile requires PyTorch 2.8+",
)
def test_slice_and_copy_similar_to_vllm(self):
config = MXFPInferenceConfig(
config = MXDynamicActivationMXWeightConfig(
activation_dtype=torch.float8_e4m3fn,
weight_dtype=torch.float8_e4m3fn,
kernel_preference=KernelPreference.EMULATED,
Expand All @@ -260,7 +260,7 @@ def test_slice_and_copy_similar_to_vllm(self):
reason="torch.compile requires PyTorch 2.8+",
)
def test_narrow_similar_to_vllm(self):
config = MXFPInferenceConfig(
config = MXDynamicActivationMXWeightConfig(
activation_dtype=torch.float8_e4m3fn,
weight_dtype=torch.float8_e4m3fn,
kernel_preference=KernelPreference.EMULATED,
Expand Down
4 changes: 2 additions & 2 deletions test/prototype/mx_formats/test_mx_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import torch.nn as nn

from torchao.prototype.mx_formats.inference_workflow import (
MXFPInferenceConfig,
MXDynamicActivationMXWeightConfig,
NVFP4InferenceConfig,
NVFP4MMConfig,
)
Expand Down Expand Up @@ -41,7 +41,7 @@ def test_serialization(recipe_name):
fname = None
with tempfile.NamedTemporaryFile(delete=False, mode="w") as f:
if recipe_name == "mxfp8":
config = MXFPInferenceConfig(
config = MXDynamicActivationMXWeightConfig(
activation_dtype=torch.float8_e4m3fn,
weight_dtype=torch.float8_e4m3fn,
kernel_preference=KernelPreference.EMULATED,
Expand Down
6 changes: 3 additions & 3 deletions torchao/prototype/mx_formats/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ import torch.nn as nn
from torchao.quantization import quantize_
import torchao.prototype.mx_formats
from torchao.prototype.mx_formats.inference_workflow import (
MXFPInferenceConfig,
MXDynamicActivationMXWeightConfig,
NVFP4InferenceConfig,
NVFP4MMConfig,
)
Expand All @@ -120,7 +120,7 @@ x = torch.randn(128, 32, device="cuda", dtype=torch.bfloat16)
# mxfp8

m_mxfp8 = copy.deepcopy(m)
config = MXFPInferenceConfig(
config = MXDynamicActivationMXWeightConfig(
activation_dtype=torch.float8_e4m3fn,
weight_dtype=torch.float8_e4m3fn,
kernel_preference=KernelPreference.AUTO,
Expand All @@ -132,7 +132,7 @@ y_mxfp8 = m_mxfp8(x)
# mxfp4

m_mxfp4 = copy.deepcopy(m)
config = MXFPInferenceConfig(
config = MXDynamicActivationMXWeightConfig(
activation_dtype=torch.float4_e2m1fn_x2,
weight_dtype=torch.float4_e2m1fn_x2,
kernel_preference=KernelPreference.AUTO,
Expand Down
4 changes: 2 additions & 2 deletions torchao/prototype/mx_formats/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

# Note: Prototype and subject to change
from torchao.prototype.mx_formats.inference_workflow import (
MXFPInferenceConfig,
MXDynamicActivationMXWeightConfig,
NVFP4InferenceConfig,
NVFP4MMConfig,
)
Expand All @@ -17,7 +17,7 @@
__all__ = [
"MXLinearConfig",
"MXLinearRecipeName",
"MXFPInferenceConfig",
"MXDynamicActivationMXWeightConfig",
"NVFP4InferenceConfig",
"NVFP4MMConfig",
]
29 changes: 3 additions & 26 deletions torchao/prototype/mx_formats/inference_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,39 +36,16 @@
)


# TODO The naming for these configs is a little weird, rename before moving to public API
# Note: This API is extra prototype and will change in the future
@dataclass
class MXFPInferenceConfig(AOBaseConfig):
class MXDynamicActivationMXWeightConfig(AOBaseConfig):
"""
MX Format Inference Quantization

This module provides support for running inference with float8 quantization using MX formats.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: comment seem outdated, this supports both mxfp8 and mxfp4 right?

The quantization flow works as follows:

1. Weight Quantization:
- In _mx_inference_linear_transform(), the module's weight is converted to an MXTensor
- The weight is quantized to the specified dtype (float8_e4m3fn by default)
- This happens when quantize_() is called with an MXFPInferenceConfig

2. Activation Quantization:
- A callable (_input_activation_quant_func_mxfp) is defined that will quantize
activations during inference to the same dtype
- This function is passed to to_linear_activation_quantized() along with the
already-quantized weight

3. Runtime Flow:
- When the quantized module is called, the input goes through the LinearActivationQuantizedTensor
- The input (activation) is quantized just-in-time using the provided function
- The MX quantized activation and MX weight are used together in F.linear

Requirements:
- NVIDIA SM100+ hardware (Blackwell or newer) is required for execution
- PyTorch 2.5+ for proper serialization support

See also:
- LinearActivationQuantizedTensor in torchao.quantization.quant_api
- MXTensor in torchao.prototype.mx_formats.mx_tensor
"""

block_size: int = 32
Expand All @@ -95,9 +72,9 @@ def _linear_extra_repr(self):
return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight={_quantization_type(self.weight)}"


@register_quantize_module_handler(MXFPInferenceConfig)
@register_quantize_module_handler(MXDynamicActivationMXWeightConfig)
def _mx_inference_linear_transform(
module: torch.nn.Module, config: MXFPInferenceConfig
module: torch.nn.Module, config: MXDynamicActivationMXWeightConfig
):
weight = module.weight

Expand Down