From 8d37ebd895c62f8e2915dd357c8c996b0e1d225c Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 14 Apr 2025 11:38:12 -0700 Subject: [PATCH 01/15] [passes] Implement checker pass and refactor shape inference --- onnxscript/ir/passes/common/_c_api_utils.py | 81 ++++++++++++ onnxscript/ir/passes/common/debugging.py | 30 +++++ .../ir/passes/common/shape_inference.py | 118 +++++++----------- 3 files changed, 157 insertions(+), 72 deletions(-) create mode 100644 onnxscript/ir/passes/common/_c_api_utils.py create mode 100644 onnxscript/ir/passes/common/debugging.py diff --git a/onnxscript/ir/passes/common/_c_api_utils.py b/onnxscript/ir/passes/common/_c_api_utils.py new file mode 100644 index 0000000000..2bf5ef77d8 --- /dev/null +++ b/onnxscript/ir/passes/common/_c_api_utils.py @@ -0,0 +1,81 @@ +"""Utilities for interfacing with onnx C APIs.""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Callable + +from onnxscript import ir + +if TYPE_CHECKING: + import onnx + + +logger = logging.getLogger(__name__) +# Temporarily remove initializers larger than this size to keep model size down +# for the onnx.shape_inference call because it needs to serialize the model +_BIG_TENSOR_SIZE_LIMIT = 1000 # 1KB + + +def call_onnx_api( + func: Callable[[onnx.ModelProto], onnx.ModelProto], + model: ir.Model, + merge_func: Callable[[ir.Model, onnx.ModelProto], tuple[ir.Model, bool]], +) -> tuple[ir.Model, bool]: + """Call an ONNX C API function by temporarily removing initializers. + This is necessary because the ONNX C API does not support large models + with initializers that have large tensor values. + + Args: + func: Partially applied function that takes a model proto and returns a model proto. + model: The IR model to pass to the API function. + merge_func: Function that merges IR model with information from the model proto. + + Returns: + A tuple containing the modified model and a boolean indicating whether the model was modified. + """ + + # Store the original initializer values so they can be restored + initializer_values = tuple(model.graph.initializers.values()) + tensors = {v.name: v.const_value for v in initializer_values} + original_inputs_len = len(model.graph.inputs) + initializer_names = {v.name for v in initializer_values} + + # Turn the initializers into inputs and clear the initializers + # to limit the model size + for initializer in initializer_values: + # Make sure the initializer has its shape/type set + assert initializer.const_value is not None + if initializer.shape is None: + initializer.shape = initializer.const_value.shape # type: ignore[assignment] + if initializer.dtype is None: + initializer.dtype = initializer.const_value.dtype + if initializer not in model.graph.inputs: + model.graph.inputs.append(initializer) + if initializer.const_value.nbytes > _BIG_TENSOR_SIZE_LIMIT: + # Temporarily remove the initializer value to reduce model size + # for onnx.shape_inference + initializer.const_value = None + assert initializer.name is not None + model.graph.initializers.pop(initializer.name) + + try: + proto = ir.serde.serialize_model(model) + result_proto = func(proto) + except Exception: # pylint: disable=broad-exception-caught + logger.warning("Call to %s failed. The model is not modified", func, exc_info=True) + return (model, False) + finally: + # Restore the original initializer values so the model is unchanged + for initializer in initializer_values: + if initializer.name in initializer_names: + initializer.const_value = tensors[initializer.name] + model.graph.register_initializer(initializer) + + # Restore the original inputs + inputs = model.graph.inputs[:original_inputs_len] + model.graph.inputs.clear() + model.graph.inputs.extend(inputs) + + # Merge the result with the original model + return merge_func(model, result_proto) diff --git a/onnxscript/ir/passes/common/debugging.py b/onnxscript/ir/passes/common/debugging.py new file mode 100644 index 0000000000..4c1e5b1212 --- /dev/null +++ b/onnxscript/ir/passes/common/debugging.py @@ -0,0 +1,30 @@ +"""Passes for debugging purposes.""" + +from __future__ import annotations + +__all__ = [ + "LiftConstantsToInitializersPass", +] + +import logging + +from onnxscript import ir + +logger = logging.getLogger(__name__) + + +class CheckerPass(ir.passes.PassBase): + """Run onnx checker on the model.""" + + @property + def in_place(self) -> bool: + return True + + @property + def changes_input(self) -> bool: + return False + + def __init__(self, lift_all_constants: bool = False, size_limit: int = 16): + super().__init__() + self.lift_all_constants = lift_all_constants + self.size_limit = size_limit diff --git a/onnxscript/ir/passes/common/shape_inference.py b/onnxscript/ir/passes/common/shape_inference.py index f6d88584e7..580079ec56 100644 --- a/onnxscript/ir/passes/common/shape_inference.py +++ b/onnxscript/ir/passes/common/shape_inference.py @@ -9,20 +9,49 @@ "infer_shapes", ] +import functools import logging import onnx from onnxscript import ir +from onnxscript.ir.passes.common import _c_api_utils logger = logging.getLogger(__name__) -# Temporarily remove initializers larger than this size to keep model size down -# for the onnx.shape_inference call because it needs to serialize the model -_BIG_TENSOR_SIZE_LIMIT = 1000 # 1KB +def _merge_func(model: ir.Model, inferred_proto: onnx.ModelProto) -> tuple[ir.Model, bool]: + """Merge the inferred model with the original model. -class ShapeInferencePass(ir.passes.FunctionalPass): + Args: + model: The original IR model. + inferred_proto: The inferred ONNX model. + + Returns: + A tuple containing the modified model and a boolean indicating whether the model was modified. + """ + inferred_model = ir.serde.deserialize_model(inferred_proto) + modified = False + for original_graph, inferred_graph in (model.graphs(), inferred_model.graphs()): + original_values = ir.convenience.create_value_mapping(original_graph) + inferred_values = ir.convenience.create_value_mapping(inferred_graph) + for name, value in original_values.items(): + if name in inferred_values: + inferred_value = inferred_values[name] + if value.shape != inferred_value.shape and inferred_value.shape is not None: + value.shape = inferred_value.shape + modified = True + if value.dtype != inferred_value.dtype and inferred_value.dtype is not None: + value.dtype = inferred_value.dtype + modified = True + else: + logger.warning( + "Value %s not found in inferred graph %s", name, inferred_graph.name + ) + return model, modified + + +class ShapeInferencePass(ir.passes.InPlacePass): """This pass performs shape inference on the graph.""" def __init__( @@ -41,76 +70,21 @@ def __init__( self.data_prop = data_prop def call(self, model: ir.Model) -> ir.passes.PassResult: - # Store the original initializer values so they can be restored - initializer_values = tuple(model.graph.initializers.values()) - tensors = {v.name: v.const_value for v in initializer_values} - original_inputs_len = len(model.graph.inputs) - initializer_names = {v.name for v in initializer_values} - - # Turn the initializers into inputs and clear the initializers - # to limit the model size - for initializer in initializer_values: - # Make sure the initializer has its shape/type set - assert initializer.const_value is not None - if initializer.shape is None: - initializer.shape = initializer.const_value.shape # type: ignore[assignment] - if initializer.dtype is None: - initializer.dtype = initializer.const_value.dtype - if initializer not in model.graph.inputs: - model.graph.inputs.append(initializer) - if initializer.const_value.nbytes > _BIG_TENSOR_SIZE_LIMIT: - # Temporarily remove the initializer value to reduce model size - # for onnx.shape_inference - initializer.const_value = None - assert initializer.name is not None - model.graph.initializers.pop(initializer.name) - - # Perform shape inference - try: - proto = ir.serde.serialize_model(model) - value_infos = {info.name: info for info in proto.graph.value_info} - inferred_proto = onnx.shape_inference.infer_shapes( - proto, - check_type=self.check_type, - strict_mode=self.strict_mode, - data_prop=self.data_prop, - ) - inferred_value_infos = { - info.name: info for info in inferred_proto.graph.value_info - } - inferred_model = ir.serde.deserialize_model(inferred_proto) - - except Exception: # pylint: disable=broad-exception-caught - logger.warning("Shape inference failed. The model is not modified", exc_info=True) - return ir.passes.PassResult(model, modified=False) - finally: - # Restore the original initializer values so the model is unchanged - for initializer in initializer_values: - if initializer.name in initializer_names: - initializer.const_value = tensors[initializer.name] - model.graph.register_initializer(initializer) - - # Restore the original inputs - inputs = model.graph.inputs[:original_inputs_len] - model.graph.inputs.clear() - model.graph.inputs.extend(inputs) - - # Add the original initializer tensors to the new (inferred) model - for new_input in inferred_model.graph.inputs: - # Assign the tensors back to the initializers - if new_input.name in initializer_names: - new_input.const_value = tensors[new_input.name] - inferred_model.graph.register_initializer(new_input) - - # Remove the inputs that were added - new_inputs = inferred_model.graph.inputs[:original_inputs_len] - inferred_model.graph.inputs.clear() - inferred_model.graph.inputs.extend(new_inputs) - - return ir.passes.PassResult( - inferred_model, modified=value_infos != inferred_value_infos + onnx_infer_shapes = functools.partial( + onnx.shape_inference.infer_shapes, + check_type=self.check_type, + strict_mode=self.strict_mode, + data_prop=self.data_prop, ) + inferred_model, modified = _c_api_utils.call_onnx_api( + onnx_infer_shapes, + model, + merge_func=_merge_func, + ) + + return ir.passes.PassResult(inferred_model, modified=modified) + def infer_shapes( model: ir.Model, From 88ccecb43ef18151c98678815cff529828fae84b Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 14 Apr 2025 11:59:26 -0700 Subject: [PATCH 02/15] checker pass --- onnxscript/ir/passes/common/debugging.py | 39 +++++++++++++++++++++--- 1 file changed, 35 insertions(+), 4 deletions(-) diff --git a/onnxscript/ir/passes/common/debugging.py b/onnxscript/ir/passes/common/debugging.py index 4c1e5b1212..41762ca4c2 100644 --- a/onnxscript/ir/passes/common/debugging.py +++ b/onnxscript/ir/passes/common/debugging.py @@ -3,12 +3,15 @@ from __future__ import annotations __all__ = [ - "LiftConstantsToInitializersPass", + "CheckerPass", ] import logging +import onnx + from onnxscript import ir +from onnxscript.ir.passes.common import _c_api_utils logger = logging.getLogger(__name__) @@ -24,7 +27,35 @@ def in_place(self) -> bool: def changes_input(self) -> bool: return False - def __init__(self, lift_all_constants: bool = False, size_limit: int = 16): + def __init__( + self, + full_check: bool = False, + skip_opset_compatibility_check: bool = False, + check_custom_domain: bool = False, + ): super().__init__() - self.lift_all_constants = lift_all_constants - self.size_limit = size_limit + self.full_check = full_check + self.skip_opset_compatibility_check = skip_opset_compatibility_check + self.check_custom_domain = check_custom_domain + + def call(self, model: ir.Model) -> ir.Model: + """Run the onnx checker on the model.""" + + def _partial_check_model(proto: onnx.ModelProto) -> onnx.ModelProto: + """Partial function to check the model.""" + onnx.checker.check_model( + proto, + full_check=self.full_check, + skip_opset_compatibility_check=self.skip_opset_compatibility_check, + check_custom_domain=self.check_custom_domain, + ) + return proto + + _c_api_utils.call_onnx_api( + func=_partial_check_model, + model=model, + # Since we do not modify the model. merge_func is not used but provided for completeness + merge_func=lambda m, proto: (m, False), + ) + # The model is not modified + return model From cf0a793f413e39021b67edb2aed2cb77112f48bb Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 14 Apr 2025 11:59:52 -0700 Subject: [PATCH 03/15] return --- onnxscript/ir/passes/common/debugging.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxscript/ir/passes/common/debugging.py b/onnxscript/ir/passes/common/debugging.py index 41762ca4c2..70a1787a4d 100644 --- a/onnxscript/ir/passes/common/debugging.py +++ b/onnxscript/ir/passes/common/debugging.py @@ -38,7 +38,7 @@ def __init__( self.skip_opset_compatibility_check = skip_opset_compatibility_check self.check_custom_domain = check_custom_domain - def call(self, model: ir.Model) -> ir.Model: + def call(self, model: ir.Model) -> ir.passes.PassResult: """Run the onnx checker on the model.""" def _partial_check_model(proto: onnx.ModelProto) -> onnx.ModelProto: @@ -58,4 +58,4 @@ def _partial_check_model(proto: onnx.ModelProto) -> onnx.ModelProto: merge_func=lambda m, proto: (m, False), ) # The model is not modified - return model + return ir.passes.PassResult(model, False) From 934f4c59815eca23a90d53ab64607d76cdbcdb4a Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 14 Apr 2025 12:00:16 -0700 Subject: [PATCH 04/15] header --- onnxscript/ir/passes/common/_c_api_utils.py | 2 ++ onnxscript/ir/passes/common/debugging.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/onnxscript/ir/passes/common/_c_api_utils.py b/onnxscript/ir/passes/common/_c_api_utils.py index 2bf5ef77d8..e76731ac90 100644 --- a/onnxscript/ir/passes/common/_c_api_utils.py +++ b/onnxscript/ir/passes/common/_c_api_utils.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. """Utilities for interfacing with onnx C APIs.""" from __future__ import annotations diff --git a/onnxscript/ir/passes/common/debugging.py b/onnxscript/ir/passes/common/debugging.py index 70a1787a4d..d3761cec0b 100644 --- a/onnxscript/ir/passes/common/debugging.py +++ b/onnxscript/ir/passes/common/debugging.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. """Passes for debugging purposes.""" from __future__ import annotations From ab362ccff6bb5612deb9d4b155eb56eb7a8501cb Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 14 Apr 2025 12:15:49 -0700 Subject: [PATCH 05/15] lint --- onnxscript/ir/passes/common/debugging.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/onnxscript/ir/passes/common/debugging.py b/onnxscript/ir/passes/common/debugging.py index d3761cec0b..24c65db623 100644 --- a/onnxscript/ir/passes/common/debugging.py +++ b/onnxscript/ir/passes/common/debugging.py @@ -8,15 +8,11 @@ "CheckerPass", ] -import logging - import onnx from onnxscript import ir from onnxscript.ir.passes.common import _c_api_utils -logger = logging.getLogger(__name__) - class CheckerPass(ir.passes.PassBase): """Run onnx checker on the model.""" From 22f5877605789766248330b0ad4755b5008f82f4 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Mon, 14 Apr 2025 19:05:01 -0700 Subject: [PATCH 06/15] Update onnxscript/ir/passes/common/shape_inference.py --- onnxscript/ir/passes/common/shape_inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/ir/passes/common/shape_inference.py b/onnxscript/ir/passes/common/shape_inference.py index 580079ec56..65c409f2f3 100644 --- a/onnxscript/ir/passes/common/shape_inference.py +++ b/onnxscript/ir/passes/common/shape_inference.py @@ -32,7 +32,7 @@ def _merge_func(model: ir.Model, inferred_proto: onnx.ModelProto) -> tuple[ir.Mo """ inferred_model = ir.serde.deserialize_model(inferred_proto) modified = False - for original_graph, inferred_graph in (model.graphs(), inferred_model.graphs()): + for original_graph, inferred_graph in zip(model.graphs(), inferred_model.graphs()): original_values = ir.convenience.create_value_mapping(original_graph) inferred_values = ir.convenience.create_value_mapping(inferred_graph) for name, value in original_values.items(): From 318025b593cb73c1f53a74e33abcc0167c80bcd6 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 15 Apr 2025 10:55:17 -0700 Subject: [PATCH 07/15] Fix test --- .../ir/passes/common/shape_inference_test.py | 23 ++++--------------- 1 file changed, 5 insertions(+), 18 deletions(-) diff --git a/onnxscript/ir/passes/common/shape_inference_test.py b/onnxscript/ir/passes/common/shape_inference_test.py index da67b4c1a7..5a2f02c64e 100644 --- a/onnxscript/ir/passes/common/shape_inference_test.py +++ b/onnxscript/ir/passes/common/shape_inference_test.py @@ -7,10 +7,13 @@ import numpy as np from onnxscript import ir -from onnxscript.ir.passes.common import shape_inference +from onnxscript.ir.passes.common import _c_api_utils, shape_inference class TestShapeInferencePass(unittest.TestCase): + def test_pass_is_in_place(self): + self.assertTrue(shape_inference.ShapeInferencePass().in_place) + def test_pass(self): # Create a simple ONNX model with shape inference # Define the model @@ -51,7 +54,7 @@ def test_pass_with_initializers(self): # _BIG_TENSOR_SIZE_LIMIT is in bytes, but we create big_dim as size # of a tensor. This is fine as we just need to create a big tensor whose size # passes _BIG_TENSOR_SIZE_LIMIT - big_dim = shape_inference._BIG_TENSOR_SIZE_LIMIT * 2 # pylint: disable=protected-access + big_dim = _c_api_utils._BIG_TENSOR_SIZE_LIMIT * 2 # pylint: disable=protected-access inputs = [ ir.Value( name="input_a", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((1, 2)) @@ -129,22 +132,6 @@ def test_pass_with_initializers(self): ir.DataType.FLOAT, ) - # Check that the original model is not modified - self.assertIsNone(val_add.shape) - self.assertIsNone(val_add.dtype) - self.assertIsNone(val_mul.shape) - self.assertIsNone(val_mul.dtype) - self.assertEqual(len(model.graph.inputs), 2) - self.assertEqual(len(model.graph.initializers), 2) - self.assertIs(model.graph.initializers["input_b"].const_value, inputs[1].const_value) - self.assertEqual(len(model.graph.outputs), 1) - self.assertEqual(model.graph.outputs[0].shape, None) - self.assertEqual(model.graph.outputs[0].dtype, None) - # Check that the initializer is not modified - self.assertIs( - model.graph.initializers["initializer"].const_value, initializer.const_value - ) - if __name__ == "__main__": unittest.main() From f1cfcef17bff5bd84c5280cdfec30153c9f8bca0 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 16 Apr 2025 17:01:09 -0700 Subject: [PATCH 08/15] Update --- onnxscript/ir/passes/common/_c_api_utils.py | 25 +++++++----------- .../common/{debugging.py => onnx_checker.py} | 7 +---- .../ir/passes/common/shape_inference.py | 26 ++++++++++--------- 3 files changed, 24 insertions(+), 34 deletions(-) rename onnxscript/ir/passes/common/{debugging.py => onnx_checker.py} (84%) diff --git a/onnxscript/ir/passes/common/_c_api_utils.py b/onnxscript/ir/passes/common/_c_api_utils.py index e76731ac90..c45b8ab561 100644 --- a/onnxscript/ir/passes/common/_c_api_utils.py +++ b/onnxscript/ir/passes/common/_c_api_utils.py @@ -20,28 +20,26 @@ def call_onnx_api( - func: Callable[[onnx.ModelProto], onnx.ModelProto], - model: ir.Model, - merge_func: Callable[[ir.Model, onnx.ModelProto], tuple[ir.Model, bool]], -) -> tuple[ir.Model, bool]: + func: Callable[[onnx.ModelProto], onnx.ModelProto], model: ir.Model +) -> onnx.ModelProto: """Call an ONNX C API function by temporarily removing initializers. + This is necessary because the ONNX C API does not support large models - with initializers that have large tensor values. + with initializers that have large tensor values. The input model is left + unchanged no matter the call succeeds or not. Args: func: Partially applied function that takes a model proto and returns a model proto. model: The IR model to pass to the API function. - merge_func: Function that merges IR model with information from the model proto. Returns: - A tuple containing the modified model and a boolean indicating whether the model was modified. + The resulting ModelProto that contains the result of the API call. """ # Store the original initializer values so they can be restored initializer_values = tuple(model.graph.initializers.values()) tensors = {v.name: v.const_value for v in initializer_values} original_inputs_len = len(model.graph.inputs) - initializer_names = {v.name for v in initializer_values} # Turn the initializers into inputs and clear the initializers # to limit the model size @@ -64,20 +62,15 @@ def call_onnx_api( try: proto = ir.serde.serialize_model(model) result_proto = func(proto) - except Exception: # pylint: disable=broad-exception-caught - logger.warning("Call to %s failed. The model is not modified", func, exc_info=True) - return (model, False) finally: # Restore the original initializer values so the model is unchanged for initializer in initializer_values: - if initializer.name in initializer_names: - initializer.const_value = tensors[initializer.name] - model.graph.register_initializer(initializer) + initializer.const_value = tensors[initializer.name] + model.graph.register_initializer(initializer) # Restore the original inputs inputs = model.graph.inputs[:original_inputs_len] model.graph.inputs.clear() model.graph.inputs.extend(inputs) - # Merge the result with the original model - return merge_func(model, result_proto) + return result_proto diff --git a/onnxscript/ir/passes/common/debugging.py b/onnxscript/ir/passes/common/onnx_checker.py similarity index 84% rename from onnxscript/ir/passes/common/debugging.py rename to onnxscript/ir/passes/common/onnx_checker.py index 24c65db623..b90aa586cc 100644 --- a/onnxscript/ir/passes/common/debugging.py +++ b/onnxscript/ir/passes/common/onnx_checker.py @@ -49,11 +49,6 @@ def _partial_check_model(proto: onnx.ModelProto) -> onnx.ModelProto: ) return proto - _c_api_utils.call_onnx_api( - func=_partial_check_model, - model=model, - # Since we do not modify the model. merge_func is not used but provided for completeness - merge_func=lambda m, proto: (m, False), - ) + _c_api_utils.call_onnx_api(func=_partial_check_model, model=model) # The model is not modified return ir.passes.PassResult(model, False) diff --git a/onnxscript/ir/passes/common/shape_inference.py b/onnxscript/ir/passes/common/shape_inference.py index 65c409f2f3..0308a1c1ed 100644 --- a/onnxscript/ir/passes/common/shape_inference.py +++ b/onnxscript/ir/passes/common/shape_inference.py @@ -20,12 +20,12 @@ logger = logging.getLogger(__name__) -def _merge_func(model: ir.Model, inferred_proto: onnx.ModelProto) -> tuple[ir.Model, bool]: - """Merge the inferred model with the original model. +def _merge_func(model: ir.Model, inferred_proto: onnx.ModelProto) -> bool: + """Merge the shape inferred model with the original model. Args: model: The original IR model. - inferred_proto: The inferred ONNX model. + inferred_proto: The ONNX model with shapes and types inferred. Returns: A tuple containing the modified model and a boolean indicating whether the model was modified. @@ -48,7 +48,7 @@ def _merge_func(model: ir.Model, inferred_proto: onnx.ModelProto) -> tuple[ir.Mo logger.warning( "Value %s not found in inferred graph %s", name, inferred_graph.name ) - return model, modified + return modified class ShapeInferencePass(ir.passes.InPlacePass): @@ -59,6 +59,8 @@ def __init__( ) -> None: """Initialize the shape inference pass. + If inference fails, the model is left unchanged. + Args: check_type: If True, check the types of the inputs and outputs. strict_mode: If True, use strict mode for shape inference. @@ -76,14 +78,14 @@ def call(self, model: ir.Model) -> ir.passes.PassResult: strict_mode=self.strict_mode, data_prop=self.data_prop, ) - - inferred_model, modified = _c_api_utils.call_onnx_api( - onnx_infer_shapes, - model, - merge_func=_merge_func, - ) - - return ir.passes.PassResult(inferred_model, modified=modified) + try: + inferred_model_proto = _c_api_utils.call_onnx_api(onnx_infer_shapes, model) + except Exception as e: + logger.warning("Shape inference failed: %s. Model is left unchanged", exc_info=e) + return ir.passes.PassResult(model, False) + + modified = _merge_func(model, inferred_model_proto) + return ir.passes.PassResult(model, modified=modified) def infer_shapes( From aad848c4da53f05c61957f24c22ad3956f2fd3a3 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 16 Apr 2025 17:07:53 -0700 Subject: [PATCH 09/15] Add tests --- .../ir/passes/common/onnx_checker_test.py | 79 +++++++++++++++++++ 1 file changed, 79 insertions(+) create mode 100644 onnxscript/ir/passes/common/onnx_checker_test.py diff --git a/onnxscript/ir/passes/common/onnx_checker_test.py b/onnxscript/ir/passes/common/onnx_checker_test.py new file mode 100644 index 0000000000..313e687f87 --- /dev/null +++ b/onnxscript/ir/passes/common/onnx_checker_test.py @@ -0,0 +1,79 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import unittest + +import numpy as np + +from onnxscript import ir +from onnxscript.ir.passes.common import onnx_checker + + +class TestCheckerPass(unittest.TestCase): + def test_pass_is_no_op(self): + checker_pass = onnx_checker.CheckerPass() + self.assertTrue(checker_pass.in_place) + self.assertFalse(checker_pass.changes_input) + + def test_check_simple_model(self): + inputs = [ + ir.Value( + name="input_a", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((1, 2)) + ), + ir.Value( + name="input_b", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((1, 2)) + ), + ] + + tape = ir.tape.Tape() + + output = tape.op("Add", inputs=inputs) + output.shape = ir.Shape((1, 2)) + output.dtype = ir.DataType.FLOAT + + model = ir.Model( + ir.Graph( + inputs=inputs, + outputs=[output], + nodes=tape.nodes, + opset_imports={"": 20}, + name="test_model", + ), + ir_version=10, + ) + # No exception should be raised + onnx_checker.CheckerPass()(model) + + def test_check_invalid_model(self): + inputs = [ + ir.Value( + name="input_a", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((1, 2)) + ), + ir.Value( + name="input_b", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((1, 2)) + ), + ] + + tape = ir.tape.Tape() + + output = tape.op("Add", inputs=inputs) + output.shape = ir.Shape((1, 2)) + output.dtype = ir.DataType.FLOAT + + model = ir.Model( + ir.Graph( + inputs=inputs, + outputs=[output], + nodes=tape.nodes, + opset_imports={"": 20}, + ), + ir_version=10, + ) + + with self.assertRaisesRegex(Exception, "Field 'name' of 'graph' is required to be non-empty"): + onnx_checker.CheckerPass()(model) + + +if __name__ == "__main__": + unittest.main() From 17022bb8f568772f6423906d7c30355afc768fbd Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 16 Apr 2025 17:08:10 -0700 Subject: [PATCH 10/15] lint --- onnxscript/ir/passes/common/onnx_checker_test.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/onnxscript/ir/passes/common/onnx_checker_test.py b/onnxscript/ir/passes/common/onnx_checker_test.py index 313e687f87..5a88b67825 100644 --- a/onnxscript/ir/passes/common/onnx_checker_test.py +++ b/onnxscript/ir/passes/common/onnx_checker_test.py @@ -4,8 +4,6 @@ import unittest -import numpy as np - from onnxscript import ir from onnxscript.ir.passes.common import onnx_checker From a056869a005ee91bb53afeafd785e61180ee1dc5 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 16 Apr 2025 17:12:11 -0700 Subject: [PATCH 11/15] def --- .../ir/passes/common/onnx_checker_test.py | 4 +++- onnxscript/ir/passes/common/shape_inference.py | 17 +++++++++-------- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/onnxscript/ir/passes/common/onnx_checker_test.py b/onnxscript/ir/passes/common/onnx_checker_test.py index 5a88b67825..144225416d 100644 --- a/onnxscript/ir/passes/common/onnx_checker_test.py +++ b/onnxscript/ir/passes/common/onnx_checker_test.py @@ -69,7 +69,9 @@ def test_check_invalid_model(self): ir_version=10, ) - with self.assertRaisesRegex(Exception, "Field 'name' of 'graph' is required to be non-empty"): + with self.assertRaisesRegex( + Exception, "Field 'name' of 'graph' is required to be non-empty" + ): onnx_checker.CheckerPass()(model) diff --git a/onnxscript/ir/passes/common/shape_inference.py b/onnxscript/ir/passes/common/shape_inference.py index 0308a1c1ed..f8ab148097 100644 --- a/onnxscript/ir/passes/common/shape_inference.py +++ b/onnxscript/ir/passes/common/shape_inference.py @@ -9,7 +9,6 @@ "infer_shapes", ] -import functools import logging import onnx @@ -72,14 +71,16 @@ def __init__( self.data_prop = data_prop def call(self, model: ir.Model) -> ir.passes.PassResult: - onnx_infer_shapes = functools.partial( - onnx.shape_inference.infer_shapes, - check_type=self.check_type, - strict_mode=self.strict_mode, - data_prop=self.data_prop, - ) + def partial_infer_shapes(proto: onnx.ModelProto) -> onnx.ModelProto: + onnx.shape_inference.infer_shapes( + proto, + check_type=self.check_type, + strict_mode=self.strict_mode, + data_prop=self.data_prop, + ) + try: - inferred_model_proto = _c_api_utils.call_onnx_api(onnx_infer_shapes, model) + inferred_model_proto = _c_api_utils.call_onnx_api(partial_infer_shapes, model) except Exception as e: logger.warning("Shape inference failed: %s. Model is left unchanged", exc_info=e) return ir.passes.PassResult(model, False) From 6a3cb3c6abbbc268b43f1d65e9c45abb5cce977c Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 16 Apr 2025 17:12:26 -0700 Subject: [PATCH 12/15] return --- onnxscript/ir/passes/common/shape_inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/ir/passes/common/shape_inference.py b/onnxscript/ir/passes/common/shape_inference.py index f8ab148097..b3ac4b64ce 100644 --- a/onnxscript/ir/passes/common/shape_inference.py +++ b/onnxscript/ir/passes/common/shape_inference.py @@ -72,7 +72,7 @@ def __init__( def call(self, model: ir.Model) -> ir.passes.PassResult: def partial_infer_shapes(proto: onnx.ModelProto) -> onnx.ModelProto: - onnx.shape_inference.infer_shapes( + return onnx.shape_inference.infer_shapes( proto, check_type=self.check_type, strict_mode=self.strict_mode, From a38ddf41be155afa054c041894e6a6ced787449b Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 16 Apr 2025 18:27:26 -0700 Subject: [PATCH 13/15] update typing --- onnxscript/ir/passes/common/_c_api_utils.py | 13 ++++++------- onnxscript/ir/passes/common/onnx_checker.py | 3 +-- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/onnxscript/ir/passes/common/_c_api_utils.py b/onnxscript/ir/passes/common/_c_api_utils.py index c45b8ab561..3d5994eb99 100644 --- a/onnxscript/ir/passes/common/_c_api_utils.py +++ b/onnxscript/ir/passes/common/_c_api_utils.py @@ -5,7 +5,7 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Callable +from typing import TYPE_CHECKING, Callable, TypeVar from onnxscript import ir @@ -17,11 +17,10 @@ # Temporarily remove initializers larger than this size to keep model size down # for the onnx.shape_inference call because it needs to serialize the model _BIG_TENSOR_SIZE_LIMIT = 1000 # 1KB +_R = TypeVar("_R") -def call_onnx_api( - func: Callable[[onnx.ModelProto], onnx.ModelProto], model: ir.Model -) -> onnx.ModelProto: +def call_onnx_api(func: Callable[[onnx.ModelProto], _R], model: ir.Model) -> _R: """Call an ONNX C API function by temporarily removing initializers. This is necessary because the ONNX C API does not support large models @@ -29,7 +28,7 @@ def call_onnx_api( unchanged no matter the call succeeds or not. Args: - func: Partially applied function that takes a model proto and returns a model proto. + func: Partially applied function that takes a model proto and returns anything. model: The IR model to pass to the API function. Returns: @@ -61,7 +60,7 @@ def call_onnx_api( try: proto = ir.serde.serialize_model(model) - result_proto = func(proto) + result = func(proto) finally: # Restore the original initializer values so the model is unchanged for initializer in initializer_values: @@ -73,4 +72,4 @@ def call_onnx_api( model.graph.inputs.clear() model.graph.inputs.extend(inputs) - return result_proto + return result diff --git a/onnxscript/ir/passes/common/onnx_checker.py b/onnxscript/ir/passes/common/onnx_checker.py index b90aa586cc..18a5c03c5e 100644 --- a/onnxscript/ir/passes/common/onnx_checker.py +++ b/onnxscript/ir/passes/common/onnx_checker.py @@ -39,7 +39,7 @@ def __init__( def call(self, model: ir.Model) -> ir.passes.PassResult: """Run the onnx checker on the model.""" - def _partial_check_model(proto: onnx.ModelProto) -> onnx.ModelProto: + def _partial_check_model(proto: onnx.ModelProto) -> None: """Partial function to check the model.""" onnx.checker.check_model( proto, @@ -47,7 +47,6 @@ def _partial_check_model(proto: onnx.ModelProto) -> onnx.ModelProto: skip_opset_compatibility_check=self.skip_opset_compatibility_check, check_custom_domain=self.check_custom_domain, ) - return proto _c_api_utils.call_onnx_api(func=_partial_check_model, model=model) # The model is not modified From 076233cf2b415010820be4595c13f17b629ec1a1 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 17 Apr 2025 14:15:42 -0700 Subject: [PATCH 14/15] lint --- onnxscript/ir/passes/common/shape_inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/ir/passes/common/shape_inference.py b/onnxscript/ir/passes/common/shape_inference.py index b3ac4b64ce..586fa5b417 100644 --- a/onnxscript/ir/passes/common/shape_inference.py +++ b/onnxscript/ir/passes/common/shape_inference.py @@ -81,7 +81,7 @@ def partial_infer_shapes(proto: onnx.ModelProto) -> onnx.ModelProto: try: inferred_model_proto = _c_api_utils.call_onnx_api(partial_infer_shapes, model) - except Exception as e: + except Exception as e: # pylint: disable=broad-exception-caught logger.warning("Shape inference failed: %s. Model is left unchanged", exc_info=e) return ir.passes.PassResult(model, False) From a19f0ada1e1dadcf0f17e7b62d0046c30f8258e1 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 17 Apr 2025 14:22:05 -0700 Subject: [PATCH 15/15] update --- onnxscript/ir/passes/common/_c_api_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/onnxscript/ir/passes/common/_c_api_utils.py b/onnxscript/ir/passes/common/_c_api_utils.py index 3d5994eb99..bb2715c75c 100644 --- a/onnxscript/ir/passes/common/_c_api_utils.py +++ b/onnxscript/ir/passes/common/_c_api_utils.py @@ -58,8 +58,10 @@ def call_onnx_api(func: Callable[[onnx.ModelProto], _R], model: ir.Model) -> _R: assert initializer.name is not None model.graph.initializers.pop(initializer.name) + proto = ir.serde.serialize_model(model) + try: - proto = ir.serde.serialize_model(model) + # Call the ONNX C API function result = func(proto) finally: # Restore the original initializer values so the model is unchanged