Skip to content

Commit fd59bdf

Browse files
committed
Revert modification of nodes_are_quantized function
Signed-off-by: gcunhase <4861122+gcunhase@users.noreply.github.com>
1 parent 45aad6d commit fd59bdf

File tree

1 file changed

+6
-11
lines changed

1 file changed

+6
-11
lines changed

tests/unit/onnx/test_quantize_int8.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,18 +29,13 @@
2929
from modelopt.onnx.utils import save_onnx
3030

3131

32-
def _assert_nodes_quantization(nodes, should_be_quantized=True):
32+
def _assert_nodes_are_quantized(nodes):
3333
for node in nodes:
3434
for inp_idx, inp in enumerate(node.inputs):
3535
if isinstance(inp, gs.Variable):
36-
if should_be_quantized:
37-
assert node.i(inp_idx).op == "DequantizeLinear", (
38-
f"Input '{inp.name}' of node '{node.name}' is not quantized but should be!"
39-
)
40-
else:
41-
assert node.i(inp_idx).op != "DequantizeLinear", (
42-
f"Input '{inp.name}' of node '{node.name}' is quantized but should not be!"
43-
)
36+
assert node.i(inp_idx).op == "DequantizeLinear", (
37+
f"Input '{inp.name}' of node '{node.name}' is not quantized but should be!"
38+
)
4439
return True
4540

4641

@@ -64,7 +59,7 @@ def test_int8(tmp_path, high_precision_dtype):
6459

6560
# Check that all MatMul nodes are quantized
6661
mm_nodes = [n for n in graph.nodes if n.op == "MatMul"]
67-
assert _assert_nodes_quantization(mm_nodes)
62+
assert _assert_nodes_are_quantized(mm_nodes)
6863

6964

7065
def test_convtranspose_conv_residual_int8(tmp_path):
@@ -85,7 +80,7 @@ def test_convtranspose_conv_residual_int8(tmp_path):
8580

8681
# Check that Conv and ConvTransposed are quantized
8782
conv_nodes = [n for n in graph.nodes if "Conv" in n.op]
88-
assert _assert_nodes_quantization(conv_nodes)
83+
assert _assert_nodes_are_quantized(conv_nodes)
8984

9085
# Check that only 1 input of Add is quantized
9186
add_nodes = [n for n in graph.nodes if n.op == "Add"]

0 commit comments

Comments
 (0)