|
27 | 27 | from onnx_graphsurgeon.ir.node import Node |
28 | 28 | from onnx_graphsurgeon.ir.tensor import Constant, Tensor, Variable |
29 | 29 | from onnxruntime.quantization.calibrate import CalibrationDataReader |
30 | | -from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference |
31 | 30 |
|
32 | 31 | from modelopt.onnx.logging_config import logger |
33 | 32 | from modelopt.onnx.op_types import copy_ops, is_copy_op, is_linear_op |
|
36 | 35 | find_lowest_common_ancestor, |
37 | 36 | get_child_nodes, |
38 | 37 | get_parent_nodes, |
| 38 | + infer_shapes, |
39 | 39 | parse_shapes_spec, |
40 | 40 | save_onnx, |
41 | 41 | ) |
@@ -962,11 +962,10 @@ def find_nodes_from_matmul_to_exclude( |
962 | 962 | logger.debug("No MatMul nodes found in the model") |
963 | 963 | return [] |
964 | 964 |
|
965 | | - nodes_to_exclude = [] |
966 | 965 | logger.debug(f"Found {len(matmul_nodes)} MatMul nodes to analyze") |
967 | 966 |
|
968 | 967 | if calibration_shapes: |
969 | | - nodes_to_exclude = _exclude_matmuls_by_symbolic_inference( |
| 968 | + nodes_to_exclude = _exclude_matmuls_by_shape_inference( |
970 | 969 | model, matmul_nodes, calibration_shapes |
971 | 970 | ) |
972 | 971 | else: |
@@ -1058,7 +1057,7 @@ def find_nodes_from_convs_to_exclude(graph: Graph, quantize_mode: str = "int8"): |
1058 | 1057 | return unsupported_conv_nodes |
1059 | 1058 |
|
1060 | 1059 |
|
1061 | | -def _exclude_matmuls_by_symbolic_inference( |
| 1060 | +def _exclude_matmuls_by_shape_inference( |
1062 | 1061 | model: onnx.ModelProto, matmul_nodes: list, calibration_shapes: str | dict | None = None |
1063 | 1062 | ) -> list[str]: |
1064 | 1063 | """Use symbolic shape inference to find MatMuls with dimension 1.""" |
@@ -1088,7 +1087,7 @@ def _exclude_matmuls_by_symbolic_inference( |
1088 | 1087 | dim.dim_value = new_dim_value |
1089 | 1088 |
|
1090 | 1089 | model.graph.ClearField("value_info") |
1091 | | - model = SymbolicShapeInference.infer_shapes(model) |
| 1090 | + model = infer_shapes(model) |
1092 | 1091 | value_info_map = {vi.name: vi for vi in model.graph.value_info} |
1093 | 1092 |
|
1094 | 1093 | nodes_to_exclude = [] |
|
0 commit comments