Skip to content

RFC: Add quantize_mx ATen Op for Microscaling Quantization #86

@avizon-aws

Description

@avizon-aws

Summary

Add a new ATen operator quantize_mx to support Microscaling (MX) format quantization in PyTorch. MX formats provide hardware-efficient quantization with shared exponents across blocks of values, enabling lower precision training and inference while maintaining model accuracy.

Motivation

Microscaling formats (MXFP8, MXFP4, MXINT8, NVFP4) are emerging quantization standards that:

  • Provide better accuracy-performance tradeoffs than traditional INT8/FP8 quantization
  • Use shared exponents per block (typically 32 elements)
  • Are being adopted by hardware vendors (Nvidia, AWS TRN, AMD)

Currently, PyTorch lacks native support for quantizing into MX formats, implement a native aten would provide a cleaner option for users/libraries to enable quantization e.g. TorchAO to_mx functions

Proposed API

Core Operator

torch.quantize_mx(
    input: Tensor,
    dim: int, 
    block_size: int, 
    dtype: torch.dtype,  # "fp8_e5m2", "fp8_e4m3", "fp6_e3m2", "fp6_e2m3", "fp4", "int8"
    scale_calculation_mode: int  #{0: Round Down, 1: Round Up, 2: Round to Nearest ...}
) -> Tensor, Tensor

"""

    input: Tensor to be quantized
    dim: Dimension on perform quantization on i.e. reduction dimension e.g. 0,-1, has to be the innermost/outermost dimension
    dtype: data type of the quantized tensor
    scale_calculation_mode: the scale calculated can be rounded up/down/nearest etc.
    
"""

Storage Format

Tensor Layout for Input:
- input: [B, S, H], can be 2D/3D tensor

Tensor Layout for Output:
- quantized_input: [B, S, H], same as input
- scales: [B, S, H//block_size], depends on dimension of quantization

ATen Operator Signature

// aten/src/ATen/native/native_functions.yaml

- func: quantize_mx(Tensor self, int dim, int block_size, ScalarType dtype, int scale_calculation_mode) -> (Tensor output, Tensor scale)
  variants: function, method
  dispatch:
    CPU: quantize_mx_cpu
    Meta: quantize_mx_meta
  tags: [
    core,                 
    pt2_compliant_tag    
  ]

Implementation Plan (High level summary)

  1. Will implement feature example in MXFP8 with CPU goldens and tests
  2. Integration : Will include DTensor, torch.compile and eager mode.
  3. Additional dtypes and backend/accelerators can be added by other contributors

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions