-
Notifications
You must be signed in to change notification settings - Fork 88
Open
Description
Summary
Add a new ATen operator quantize_mx to support Microscaling (MX) format quantization in PyTorch. MX formats provide hardware-efficient quantization with shared exponents across blocks of values, enabling lower precision training and inference while maintaining model accuracy.
Motivation
Microscaling formats (MXFP8, MXFP4, MXINT8, NVFP4) are emerging quantization standards that:
- Provide better accuracy-performance tradeoffs than traditional INT8/FP8 quantization
- Use shared exponents per block (typically 32 elements)
- Are being adopted by hardware vendors (Nvidia, AWS TRN, AMD)
Currently, PyTorch lacks native support for quantizing into MX formats, implement a native aten would provide a cleaner option for users/libraries to enable quantization e.g. TorchAO to_mx functions
Proposed API
Core Operator
torch.quantize_mx(
input: Tensor,
dim: int,
block_size: int,
dtype: torch.dtype, # "fp8_e5m2", "fp8_e4m3", "fp6_e3m2", "fp6_e2m3", "fp4", "int8"
scale_calculation_mode: int #{0: Round Down, 1: Round Up, 2: Round to Nearest ...}
) -> Tensor, Tensor
"""
input: Tensor to be quantized
dim: Dimension on perform quantization on i.e. reduction dimension e.g. 0,-1, has to be the innermost/outermost dimension
dtype: data type of the quantized tensor
scale_calculation_mode: the scale calculated can be rounded up/down/nearest etc.
"""
Storage Format
Tensor Layout for Input:
- input: [B, S, H], can be 2D/3D tensor
Tensor Layout for Output:
- quantized_input: [B, S, H], same as input
- scales: [B, S, H//block_size], depends on dimension of quantization
ATen Operator Signature
// aten/src/ATen/native/native_functions.yaml
- func: quantize_mx(Tensor self, int dim, int block_size, ScalarType dtype, int scale_calculation_mode) -> (Tensor output, Tensor scale)
variants: function, method
dispatch:
CPU: quantize_mx_cpu
Meta: quantize_mx_meta
tags: [
core,
pt2_compliant_tag
]
Implementation Plan (High level summary)
- Will implement feature example in MXFP8 with CPU goldens and tests
- Integration : Will include DTensor, torch.compile and eager mode.
- Additional dtypes and backend/accelerators can be added by other contributors
Metadata
Metadata
Assignees
Labels
No labels