2929from modelopt .onnx .utils import save_onnx
3030
3131
32- def _assert_nodes_quantization (nodes , should_be_quantized = True ):
32+ def _assert_nodes_are_quantized (nodes ):
3333 for node in nodes :
3434 for inp_idx , inp in enumerate (node .inputs ):
3535 if isinstance (inp , gs .Variable ):
36- if should_be_quantized :
37- assert node .i (inp_idx ).op == "DequantizeLinear" , (
38- f"Input '{ inp .name } ' of node '{ node .name } ' is not quantized but should be!"
39- )
40- else :
41- assert node .i (inp_idx ).op != "DequantizeLinear" , (
42- f"Input '{ inp .name } ' of node '{ node .name } ' is quantized but should not be!"
43- )
36+ assert node .i (inp_idx ).op == "DequantizeLinear" , (
37+ f"Input '{ inp .name } ' of node '{ node .name } ' is not quantized but should be!"
38+ )
4439 return True
4540
4641
@@ -64,7 +59,7 @@ def test_int8(tmp_path, high_precision_dtype):
6459
6560 # Check that all MatMul nodes are quantized
6661 mm_nodes = [n for n in graph .nodes if n .op == "MatMul" ]
67- assert _assert_nodes_quantization (mm_nodes )
62+ assert _assert_nodes_are_quantized (mm_nodes )
6863
6964
7065def test_convtranspose_conv_residual_int8 (tmp_path ):
@@ -85,7 +80,7 @@ def test_convtranspose_conv_residual_int8(tmp_path):
8580
8681 # Check that Conv and ConvTransposed are quantized
8782 conv_nodes = [n for n in graph .nodes if "Conv" in n .op ]
88- assert _assert_nodes_quantization (conv_nodes )
83+ assert _assert_nodes_are_quantized (conv_nodes )
8984
9085 # Check that only 1 input of Add is quantized
9186 add_nodes = [n for n in graph .nodes if n .op == "Add" ]
0 commit comments