diff --git a/modelopt/onnx/quantization/quantize.py b/modelopt/onnx/quantization/quantize.py index 9bc025e33..948bc28d1 100755 --- a/modelopt/onnx/quantization/quantize.py +++ b/modelopt/onnx/quantization/quantize.py @@ -133,16 +133,16 @@ def _preprocess_onnx( if simplify: logger.info("Attempting to simplify model") try: - import onnxsim + import onnxslim except ModuleNotFoundError as e: logger.warning( - "onnxsim is not installed. Please install it with 'pip install onnxsim'." + "onnxslim is not installed. Please install it with 'pip install onnxslim'." ) raise e try: - model_simp, check = onnxsim.simplify(onnx_model) - if check: + model_simp = onnxslim.slim(onnx_model) + if model_simp: onnx_model = model_simp onnx_path = os.path.join(output_dir, f"{model_name}_simp.onnx") save_onnx(onnx_model, onnx_path, use_external_data_format) diff --git a/setup.py b/setup.py index 67bf114ae..ee3488374 100644 --- a/setup.py +++ b/setup.py @@ -52,8 +52,8 @@ "onnxruntime-gpu~=1.22.0 ; platform_machine != 'aarch64' and platform_system != 'Darwin' and platform_system != 'Windows'", # noqa: E501 "onnxruntime-directml==1.20.0; platform_system == 'Windows'", "onnxscript", # For test_onnx_dynamo_export unit test - "onnxsim ; python_version < '3.12' and platform_machine != 'aarch64'", "polygraphy>=0.49.22", + "onnxslim", ], "hf": [ "accelerate>=1.0.0", diff --git a/tests/gpu/onnx/test_simplify.py b/tests/gpu/onnx/test_simplify.py index 689de27c9..765d8c5d2 100644 --- a/tests/gpu/onnx/test_simplify.py +++ b/tests/gpu/onnx/test_simplify.py @@ -61,10 +61,10 @@ def test_onnx_simplification(tmp_path): graph = gs.import_onnx(onnx.load(simplified_onnx_path)) identity_nodes = [n for n in graph.nodes if n.op == "Identity"] assert not identity_nodes, "Simplified ONNX model contains Identity nodes but it shouldn't." - assert len(graph.nodes) == 3, ( - f"Number of nodes doesn't match the expected: {len(graph.nodes)} vs 3." + assert len(graph.nodes) == 2, ( + f"Number of nodes doesn't match the expected: {len(graph.nodes)} vs 2." ) - assert all(n.op in ["Conv", "BatchNormalization", "Relu"] for n in graph.nodes), ( + assert all(n.op in ["Conv", "Relu"] for n in graph.nodes), ( "Graph contains more ops than expected." )