Skip to content

Commit 93b28d0

Browse files
authored
[OMNIML-2244] Create the MXFP8 quant exporter (#634)
## What does this PR do? **Type of change:** New feature **Overview:** ? - Implemented functions for the MXFP8 quant exporter - Integrated autocast for converting model to fp16 - deprecated quantize_weights_to_mxfp8 - Updated tests ## Usage ``` python torch_quant_to_onnx.py --quantize_mode=mxfp8 --onnx_save_path=vit_base_patch16_224.mxfp8.onnx --calibration_data_size 64 --batch_size 128 ``` ## Testing ``` python evaluate.py --onnx_path=vit_base_patch16_224.mxfp8.onnx --model_name=vit_base_patch16_224 --results_path=./results.txt --batch_size 128 ``` Accuracy and latency results ``` The top1 accuracy of the model is 85.07% The top5 accuracy of the model is 97.558% Inference latency of the model is 6.65451 ms ``` ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: No - deprecated quantize_weights_to_mxfp8 - **Did you write any new necessary tests?**: No - **Did you add or update any necessary documentation?**: No - **Did you update [Changelog](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/CHANGELOG.rst)?**: No <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> --------- Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
1 parent 3ef9e39 commit 93b28d0

File tree

4 files changed

+178
-123
lines changed

4 files changed

+178
-123
lines changed

modelopt/onnx/export/mxfp8_exporter.py

Lines changed: 162 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,27 +15,185 @@
1515

1616
"""MXFP8 quantization exporter."""
1717

18+
import numpy as np
1819
import onnx
20+
from onnx import numpy_helper
21+
22+
from modelopt.onnx.logging_config import logger
23+
from modelopt.onnx.quantization.graph_utils import get_tensor_producer_nodes
24+
from modelopt.onnx.quantization.qdq_utils import _cast_fp8, onnx_dtype_map
25+
from modelopt.onnx.quantization.quant_utils import compute_e8m0, get_amax
26+
from modelopt.onnx.utils import get_attribute, has_attribute
1927

2028
from .base_exporter import ONNXQuantExporter
2129

30+
E8_M0_BIAS = 127
31+
DEFAULT_BLOCK_SIZE = 32
32+
DEFAULT_QUANT_AXIS = -1
33+
34+
35+
def _get_weight_dq_nodes(graph: onnx.GraphProto) -> list[onnx.NodeProto]:
36+
"""Get weight DequantizeLinear nodes from the graph."""
37+
return [
38+
node
39+
for node in graph.node
40+
if node.op_type == "TRT_MXFP8DequantizeLinear"
41+
and any(".weight" in inp for inp in node.input)
42+
]
43+
44+
45+
def _get_quant_params(node: onnx.NodeProto) -> tuple[int, int]:
46+
"""Extract quantization axis and block size from a node."""
47+
if has_attribute(node, "axis"):
48+
quant_axis = int(get_attribute(node, "axis"))
49+
else:
50+
quant_axis = DEFAULT_QUANT_AXIS
51+
logger.warning(
52+
"axis attribute not found for MXFP8DequantizeLinear node. Setting axis to -1"
53+
)
54+
55+
if has_attribute(node, "block_size"):
56+
block_size = int(get_attribute(node, "block_size"))
57+
else:
58+
block_size = DEFAULT_BLOCK_SIZE
59+
logger.warning(
60+
"block_size attribute not found for MXFP8DequantizeLinear node. "
61+
"Setting block_size to 32"
62+
)
63+
64+
return quant_axis, block_size
65+
2266

23-
# TODO: Implement the MXFP8QuantExporter
2467
class MXFP8QuantExporter(ONNXQuantExporter):
2568
"""Exporter for MXFP8 quantization."""
2669

2770
@staticmethod
2871
def pre_process(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
2972
"""Pre-processes the ONNX model for MXFP8 quantization."""
73+
return onnx_model
3074

3175
@staticmethod
3276
def compute_scales(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
33-
"""Computes the scales for the weights in the ONNX model for MXFP8 quantization."""
77+
"""Computes the e8m0 scales for weights in the ONNX model for MXFP8 quantization."""
78+
logger.info("Computing MXFP8 scales for weights")
79+
graph = onnx_model.graph
80+
initializer_map = {init.name: init for init in graph.initializer}
81+
tensor_producer_map = get_tensor_producer_nodes(graph)
82+
83+
for node in _get_weight_dq_nodes(graph):
84+
weight_name = node.input[0]
85+
logger.debug(f"Computing MXFP8 scale for weight {weight_name}")
86+
87+
weight = numpy_helper.to_array(initializer_map[weight_name])
88+
quant_axis, block_size = _get_quant_params(node)
89+
90+
# Compute scales
91+
amax = get_amax(weight, quant_axis, block_size)
92+
se8m0_fp32 = compute_e8m0(amax, weight.shape, quant_axis, block_size)
93+
se8m0 = se8m0_fp32.astype(np.uint8)
94+
95+
# Remove scale producer if it's a Constant node
96+
scale_name = node.input[1]
97+
scale_producer = tensor_producer_map[scale_name]
98+
if scale_producer.op_type == "Constant":
99+
graph.node.remove(scale_producer)
100+
101+
# Create and add new scale tensor
102+
scale_name_new = scale_name.replace("Constant_output_0", "scale")
103+
scale_tensor = onnx.numpy_helper.from_array(se8m0, scale_name_new)
104+
graph.initializer.append(scale_tensor)
105+
node.input[1] = scale_name_new
106+
107+
return onnx_model
34108

35109
@staticmethod
36110
def compress_weights(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
37-
"""Compresses the weights in the ONNX model for MXFP8 quantization."""
111+
"""Compresses the weights in the ONNX model to FP8 format for MXFP8 quantization."""
112+
logger.info("Compressing weights to MXFP8 format")
113+
graph = onnx_model.graph
114+
initializer_map = {init.name: init for init in graph.initializer}
115+
116+
for node in _get_weight_dq_nodes(graph):
117+
weight_name = node.input[0]
118+
scale_name = node.input[1]
119+
logger.debug(f"Compressing weight {weight_name} to MXFP8")
120+
121+
weight = numpy_helper.to_array(initializer_map[weight_name])
122+
quant_axis, block_size = _get_quant_params(node)
123+
124+
# Get scale and convert back to fp32 for computation
125+
se8m0 = numpy_helper.to_array(initializer_map[scale_name])
126+
se8m0_fp32 = se8m0.astype(np.float32)
127+
128+
# Expand block array so that it can be broadcasted with weight
129+
se8m0_fp32_expanded = np.repeat(se8m0_fp32, block_size, axis=quant_axis)
130+
scaled_weight = weight / np.exp2(se8m0_fp32_expanded - E8_M0_BIAS)
131+
132+
# Create FP8 weight tensor
133+
weights_e4m3 = onnx.helper.make_tensor(
134+
name=weight_name,
135+
data_type=onnx_dtype_map["Float8"],
136+
dims=[*scaled_weight.shape],
137+
vals=_cast_fp8(scaled_weight).tobytes(),
138+
raw=True,
139+
)
140+
initializer_map[weight_name].CopyFrom(weights_e4m3)
141+
logger.debug(f"Converted {weight_name} to MXFP8")
142+
143+
return onnx_model
38144

39145
@staticmethod
40146
def post_process(onnx_model: onnx.ModelProto) -> onnx.ModelProto:
41-
"""Post-processes the ONNX model for MXFP8 quantization."""
147+
"""Post-processes the ONNX model for MXFP8 quantization.
148+
149+
Sets DQ output type to FP16 and updates GELU nodes to use tanh approximation.
150+
"""
151+
logger.info("Post-processing MXFP8 quantized model")
152+
graph = onnx_model.graph
153+
154+
# Set output type of DQ to FP16
155+
for node in graph.node:
156+
if node.op_type == "TRT_MXFP8DequantizeLinear":
157+
for attr in node.attribute:
158+
if attr.name == "output_dtype":
159+
attr.i = onnx_dtype_map["Half"]
160+
161+
# Currently only tanh approximation is supported for Gelu
162+
for node in graph.node:
163+
if node.op_type == "Gelu":
164+
for attr in node.attribute:
165+
if attr.name == "approximate":
166+
attr.s = b"tanh"
167+
logger.debug(f"Updated GELU node {node.name} to use tanh approximation")
168+
169+
# Insert cast to fp16 after Sqrt nodes
170+
cast_nodes_to_insert = []
171+
for idx, node in enumerate(graph.node):
172+
if node.op_type == "Sqrt":
173+
sqrt_output = node.output[0]
174+
cast_output = f"{sqrt_output}_cast_fp16"
175+
176+
# Create Cast node
177+
cast_node = onnx.helper.make_node(
178+
"Cast",
179+
inputs=[sqrt_output],
180+
outputs=[cast_output],
181+
to=onnx_dtype_map["Half"],
182+
name=f"{node.name}_cast_fp16",
183+
)
184+
cast_nodes_to_insert.append((idx + 1, cast_node))
185+
186+
# Update consumers to use cast output
187+
for consumer in graph.node:
188+
if consumer == node:
189+
continue
190+
for i, inp in enumerate(consumer.input):
191+
if inp == sqrt_output:
192+
consumer.input[i] = cast_output
193+
194+
# Insert Cast nodes in reverse order to preserve indices
195+
for offset, (pos, cast_node) in enumerate(cast_nodes_to_insert):
196+
graph.node.insert(pos + offset, cast_node)
197+
logger.debug(f"Inserted Cast to FP16 after {cast_node.input[0]}")
198+
199+
return onnx_model

modelopt/onnx/quantization/qdq_utils.py

Lines changed: 1 addition & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,7 @@
3131
get_tensor_producer_nodes,
3232
remove_redundant_cast_nodes,
3333
)
34-
from modelopt.onnx.quantization.quant_utils import compute_e8m0, get_amax, get_num_bits
35-
from modelopt.onnx.utils import get_attribute, has_attribute
34+
from modelopt.onnx.quantization.quant_utils import get_num_bits
3635

3736
QUANTIZE_NODE_NAME = "QuantizeLinear"
3837
DEQUANTIZE_NODE_NAME = "DequantizeLinear"
@@ -1036,101 +1035,3 @@ def cast_initializer_to_dtype(
10361035
input_onnx = onnx.numpy_helper.from_array(input, input_name)
10371036
input_onnx.data_type = onnx_dtype_map[dtype]
10381037
initializer_map[input_name].CopyFrom(input_onnx)
1039-
1040-
1041-
def quantize_weights_to_mxfp8(
1042-
onnx_model: onnx.ModelProto,
1043-
) -> onnx.ModelProto:
1044-
"""Converts the weights to FP8 precision using MXFP8 quantization.
1045-
1046-
For TRT_MXFP8DynamicQuantize, we update the output type to FP8.
1047-
For TRT_MXFP8DequantizeLinear, we compute the scales in e8m0 format and saves them as a new initializer.
1048-
We then expand the scale to the same shape as the weight and divide the weight by the scale to get the FP8 weights.
1049-
1050-
Args:
1051-
graph: ONNX model protobuf.
1052-
1053-
Returns:
1054-
ONNX model protobuf with weights quantized to FP8 precision using MXFP8 quantization.
1055-
"""
1056-
logger.info("Converting weights to MXFP8 precision")
1057-
graph = onnx_model.graph
1058-
initializer_map = {initializer.name: initializer for initializer in graph.initializer}
1059-
tensor_producer_map = get_tensor_producer_nodes(graph)
1060-
e8_m0_bias = 127
1061-
weight_dq_nodes = [
1062-
node
1063-
for node in graph.node
1064-
if node.op_type == "TRT_MXFP8DequantizeLinear"
1065-
and any(".weight" in input for input in node.input)
1066-
]
1067-
gelu_nodes = [node for node in graph.node if node.op_type == "Gelu"]
1068-
logger.debug(f"Found {len(weight_dq_nodes)} weight DQ nodes and {len(gelu_nodes)} GELU nodes")
1069-
1070-
for node in weight_dq_nodes:
1071-
# Get weights and node attributes
1072-
weight_name = node.input[0]
1073-
logger.debug(f"Processing MXFP8 conversion for weight {weight_name}")
1074-
weight = numpy_helper.to_array(initializer_map[weight_name])
1075-
if has_attribute(node, "axis"):
1076-
quant_axis = int(get_attribute(node, "axis"))
1077-
else:
1078-
quant_axis = -1
1079-
logger.warning(
1080-
"axis attribute not found for MXFP8DequantizeLinear node. Setting axis to -1"
1081-
)
1082-
1083-
if has_attribute(node, "block_size"):
1084-
block_size = int(get_attribute(node, "block_size"))
1085-
else:
1086-
block_size = 32
1087-
logger.warning(
1088-
"block_size attribute not found for MXFP8DequantizeLinear node. Setting block_size to 32"
1089-
)
1090-
1091-
# Compute and save scales as uint8
1092-
amax = get_amax(weight, quant_axis, block_size)
1093-
se8m0_fp32 = compute_e8m0(amax, weight.shape, quant_axis, block_size)
1094-
se8m0 = se8m0_fp32.astype(np.uint8)
1095-
1096-
# Remove scale producer if it's a Constant node
1097-
scale_name = node.input[1]
1098-
scale_producer = tensor_producer_map[scale_name]
1099-
if scale_producer.op_type == "Constant":
1100-
graph.node.remove(scale_producer)
1101-
1102-
# Create a new scale tensor
1103-
scale_name = scale_name.replace("Constant_output_0", "scale")
1104-
scale_tensor = onnx.numpy_helper.from_array(se8m0, scale_name)
1105-
graph.initializer.append(scale_tensor)
1106-
node.input[1] = scale_name
1107-
1108-
# Convert weights to FP8
1109-
# Expand block array so that it can be broadcasted with weight
1110-
se8m0_fp32 = np.repeat(se8m0_fp32, block_size, axis=quant_axis)
1111-
scaled_weight = weight / np.exp2(se8m0_fp32 - e8_m0_bias)
1112-
weights_e4m3 = onnx.helper.make_tensor(
1113-
name=weight_name,
1114-
data_type=onnx_dtype_map["Float8"],
1115-
dims=[*scaled_weight.shape],
1116-
vals=_cast_fp8(scaled_weight).tobytes(),
1117-
raw=True,
1118-
)
1119-
initializer_map[weight_name].CopyFrom(weights_e4m3)
1120-
logger.debug(f"Converted {weight_name} to MXFP8")
1121-
1122-
# set output type of DQ to FP16
1123-
for node in graph.node:
1124-
if node.op_type in ["TRT_MXFP8DequantizeLinear"]:
1125-
for attr in node.attribute:
1126-
if attr.name == "output_dtype":
1127-
attr.i = onnx_dtype_map["Half"]
1128-
1129-
# Currently only tanh approximation is supported for Gelu
1130-
for node in gelu_nodes:
1131-
for attr in node.attribute:
1132-
if attr.name == "approximate":
1133-
attr.s = b"tanh"
1134-
logger.debug(f"Updated GELU node {node.name} to use tanh approximation")
1135-
1136-
return onnx_model

modelopt/torch/_deploy/utils/torch_onnx.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,7 @@
4040
NVFP4QuantExporter,
4141
ONNXQuantExporter,
4242
)
43-
from modelopt.onnx.quantization.qdq_utils import (
44-
qdq_to_dq,
45-
quantize_weights_to_mxfp8,
46-
replace_zero_scale_with_smallest_nonzero,
47-
)
43+
from modelopt.onnx.quantization.qdq_utils import qdq_to_dq, replace_zero_scale_with_smallest_nonzero
4844
from modelopt.onnx.utils import (
4945
get_input_names,
5046
get_input_shapes,
@@ -364,6 +360,11 @@ def is_fp8_quantized(model: nn.Module) -> bool:
364360
and hasattr(module, "input_quantizer")
365361
and module.weight_quantizer._num_bits == (4, 3)
366362
and module.input_quantizer._num_bits == (4, 3)
363+
# Exclude MXFP8 which also uses (4,3) but has block_sizes with scale_bits
364+
and not (
365+
module.input_quantizer.block_sizes
366+
and module.input_quantizer.block_sizes.get("scale_bits", None) == (8, 0)
367+
)
367368
):
368369
return True
369370
return False
@@ -560,11 +561,8 @@ def get_onnx_bytes_and_metadata(
560561

561562
# Convert dummy TRT_FP4QDQ nodes to 2DQ format if the model is quantized in FP4 mode
562563
# Or convert weights to MXFP8 format if the model is quantized in MXFP8 mode
563-
if is_int4_quantized(model) or is_fp4_quantized(model):
564+
if is_int4_quantized(model) or is_fp4_quantized(model) or is_mxfp8_quantized(model):
564565
onnx_opt_graph = quantize_weights(model, onnx_opt_graph)
565-
elif is_mxfp8_quantized(model):
566-
# TODO: Implement the MXFP8QuantExporter
567-
onnx_opt_graph = quantize_weights_to_mxfp8(onnx_opt_graph)
568566

569567
if dq_only:
570568
onnx_opt_graph = qdq_to_dq(onnx_opt_graph)
@@ -575,7 +573,7 @@ def get_onnx_bytes_and_metadata(
575573
except StopIteration:
576574
param_dtype = torch.float32
577575
if weights_dtype in ["fp16", "bf16"] and param_dtype == torch.float32:
578-
if is_mxfp8_quantized(model) or is_int4_quantized(model):
576+
if is_int4_quantized(model) or is_mxfp8_quantized(model):
579577
assert weights_dtype == "fp16", "BF16 + MXFP8/INT4 mixed precision is not supported yet"
580578
onnx_opt_graph = convert_float_to_float16(
581579
onnx_opt_graph,

0 commit comments

Comments
 (0)