Skip to content

Commit 52771de

Browse files
committed
Replaced SymbolicShapeInference with infer_shapes
Signed-off-by: gcunhase <4861122+gcunhase@users.noreply.github.com>
1 parent b442a28 commit 52771de

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

modelopt/onnx/quantization/graph_utils.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
from onnx_graphsurgeon.ir.node import Node
2828
from onnx_graphsurgeon.ir.tensor import Constant, Tensor, Variable
2929
from onnxruntime.quantization.calibrate import CalibrationDataReader
30-
from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference
3130

3231
from modelopt.onnx.logging_config import logger
3332
from modelopt.onnx.op_types import copy_ops, is_copy_op, is_linear_op
@@ -36,6 +35,7 @@
3635
find_lowest_common_ancestor,
3736
get_child_nodes,
3837
get_parent_nodes,
38+
infer_shapes,
3939
parse_shapes_spec,
4040
save_onnx,
4141
)
@@ -962,11 +962,10 @@ def find_nodes_from_matmul_to_exclude(
962962
logger.debug("No MatMul nodes found in the model")
963963
return []
964964

965-
nodes_to_exclude = []
966965
logger.debug(f"Found {len(matmul_nodes)} MatMul nodes to analyze")
967966

968967
if calibration_shapes:
969-
nodes_to_exclude = _exclude_matmuls_by_symbolic_inference(
968+
nodes_to_exclude = _exclude_matmuls_by_shape_inference(
970969
model, matmul_nodes, calibration_shapes
971970
)
972971
else:
@@ -1058,7 +1057,7 @@ def find_nodes_from_convs_to_exclude(graph: Graph, quantize_mode: str = "int8"):
10581057
return unsupported_conv_nodes
10591058

10601059

1061-
def _exclude_matmuls_by_symbolic_inference(
1060+
def _exclude_matmuls_by_shape_inference(
10621061
model: onnx.ModelProto, matmul_nodes: list, calibration_shapes: str | dict | None = None
10631062
) -> list[str]:
10641063
"""Use symbolic shape inference to find MatMuls with dimension 1."""
@@ -1088,7 +1087,7 @@ def _exclude_matmuls_by_symbolic_inference(
10881087
dim.dim_value = new_dim_value
10891088

10901089
model.graph.ClearField("value_info")
1091-
model = SymbolicShapeInference.infer_shapes(model)
1090+
model = infer_shapes(model)
10921091
value_info_map = {vi.name: vi for vi in model.graph.value_info}
10931092

10941093
nodes_to_exclude = []

0 commit comments

Comments
 (0)