Skip to content

Commit a51b00d

Browse files
committed
fix qdq utils issues and remove global cast replacements
Signed-off-by: Luxiao Zheng <luxiaoz@nvidia.com>
1 parent ed58324 commit a51b00d

File tree

2 files changed

+323
-55
lines changed

2 files changed

+323
-55
lines changed

modelopt/onnx/quantization/qdq_utils.py

Lines changed: 12 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1111,17 +1111,20 @@ def quantize_weights_to_int4(
11111111
scale_shape = [*weight_shape[:-1], weight_shape[-1] // block_size]
11121112
scale = scale.reshape(scale_shape)
11131113
reshape_child_nodes = [n for n in graph.node if reshape_node.output[0] in n.input]
1114-
assert len(reshape_child_nodes) == 1, f"Expected exactly one transpose node for {node.name}"
1114+
assert len(reshape_child_nodes) == 1, f"Expected exactly one child node for {node.name}"
11151115

1116-
# Remove unnecessary Cast node
1117-
cast_node = reshape_child_nodes[0]
1118-
assert cast_node.op_type == "Cast", f"Expected Cast node for {node.name}"
1119-
nodes_to_remove.append(cast_node.name)
1120-
cast_child_nodes = [n for n in graph.node if cast_node.output[0] in n.input]
1116+
# Check if there's an optional Cast node between Reshape and Transpose/MatMul/Gemm
1117+
next_node = reshape_child_nodes[0]
1118+
if next_node.op_type == "Cast":
1119+
# Remove unnecessary Cast node
1120+
cast_node = next_node
1121+
nodes_to_remove.append(cast_node.name)
1122+
cast_child_nodes = [n for n in graph.node if cast_node.output[0] in n.input]
1123+
next_node = cast_child_nodes[0]
11211124

11221125
# Transpose weights and scales if present
1123-
if cast_child_nodes[0].op_type == "Transpose":
1124-
transpose_node = cast_child_nodes[0]
1126+
if next_node.op_type == "Transpose":
1127+
transpose_node = next_node
11251128
nodes_to_remove.append(transpose_node.name)
11261129
assert transpose_node.op_type == "Transpose", f"Expected Transpose node for {node.name}"
11271130
perm = None
@@ -1138,7 +1141,7 @@ def quantize_weights_to_int4(
11381141
)
11391142
matmul_node = transpose_child_nodes[0]
11401143
else:
1141-
matmul_node = cast_child_nodes[0]
1144+
matmul_node = next_node
11421145
assert matmul_node.op_type in ["MatMul", "Gemm"], (
11431146
f"Expected MatMul or Gemm node for {node.name}"
11441147
)
@@ -1189,21 +1192,6 @@ def is_pre_quant_scale_node(node: onnx.NodeProto) -> bool:
11891192
del graph.node[:]
11901193
graph.node.extend(new_nodes)
11911194

1192-
def is_fp32_cast(node: onnx.NodeProto) -> bool:
1193-
return any(
1194-
attr.name == "to" and attr.i == onnx.TensorProto.FLOAT for attr in node.attribute
1195-
)
1196-
1197-
# Change all Cast nodes that cast to float32 (TensorProto.FLOAT) to cast to float16 (TensorProto.FLOAT16)
1198-
for node in graph.node:
1199-
if node.op_type == "Cast":
1200-
# Skip Cast nodes that are part of normalization layers and outputs
1201-
if "norm/Cast" in node.name and is_fp32_cast(node):
1202-
continue
1203-
for attr in node.attribute:
1204-
if attr.name == "to" and attr.i == onnx.TensorProto.FLOAT:
1205-
attr.i = onnx.TensorProto.FLOAT16
1206-
12071195
# Cast bias to float16
12081196
for node in graph.node:
12091197
if node.op_type == "Add" and "proj/Add" in node.name:
@@ -1310,13 +1298,6 @@ def quantize_weights_to_mxfp8(
13101298
if attr.name == "output_dtype":
13111299
attr.i = onnx_dtype_map["Half"]
13121300

1313-
# set Cast to FP16
1314-
for node in graph.node:
1315-
if node.op_type == "Cast":
1316-
for attr in node.attribute:
1317-
if attr.name == "to" and attr.i == onnx.TensorProto.FLOAT:
1318-
attr.i = onnx_dtype_map["Half"]
1319-
13201301
# Currently only tanh approximation is supported for Gelu
13211302
for node in gelu_nodes:
13221303
for attr in node.attribute:

0 commit comments

Comments
 (0)