diff --git a/onnxscript/ir/passes/common/_c_api_utils.py b/onnxscript/ir/passes/common/_c_api_utils.py new file mode 100644 index 0000000000..bb2715c75c --- /dev/null +++ b/onnxscript/ir/passes/common/_c_api_utils.py @@ -0,0 +1,77 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Utilities for interfacing with onnx C APIs.""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Callable, TypeVar + +from onnxscript import ir + +if TYPE_CHECKING: + import onnx + + +logger = logging.getLogger(__name__) +# Temporarily remove initializers larger than this size to keep model size down +# for the onnx.shape_inference call because it needs to serialize the model +_BIG_TENSOR_SIZE_LIMIT = 1000 # 1KB +_R = TypeVar("_R") + + +def call_onnx_api(func: Callable[[onnx.ModelProto], _R], model: ir.Model) -> _R: + """Call an ONNX C API function by temporarily removing initializers. + + This is necessary because the ONNX C API does not support large models + with initializers that have large tensor values. The input model is left + unchanged no matter the call succeeds or not. + + Args: + func: Partially applied function that takes a model proto and returns anything. + model: The IR model to pass to the API function. + + Returns: + The resulting ModelProto that contains the result of the API call. + """ + + # Store the original initializer values so they can be restored + initializer_values = tuple(model.graph.initializers.values()) + tensors = {v.name: v.const_value for v in initializer_values} + original_inputs_len = len(model.graph.inputs) + + # Turn the initializers into inputs and clear the initializers + # to limit the model size + for initializer in initializer_values: + # Make sure the initializer has its shape/type set + assert initializer.const_value is not None + if initializer.shape is None: + initializer.shape = initializer.const_value.shape # type: ignore[assignment] + if initializer.dtype is None: + initializer.dtype = initializer.const_value.dtype + if initializer not in model.graph.inputs: + model.graph.inputs.append(initializer) + if initializer.const_value.nbytes > _BIG_TENSOR_SIZE_LIMIT: + # Temporarily remove the initializer value to reduce model size + # for onnx.shape_inference + initializer.const_value = None + assert initializer.name is not None + model.graph.initializers.pop(initializer.name) + + proto = ir.serde.serialize_model(model) + + try: + # Call the ONNX C API function + result = func(proto) + finally: + # Restore the original initializer values so the model is unchanged + for initializer in initializer_values: + initializer.const_value = tensors[initializer.name] + model.graph.register_initializer(initializer) + + # Restore the original inputs + inputs = model.graph.inputs[:original_inputs_len] + model.graph.inputs.clear() + model.graph.inputs.extend(inputs) + + return result diff --git a/onnxscript/ir/passes/common/onnx_checker.py b/onnxscript/ir/passes/common/onnx_checker.py new file mode 100644 index 0000000000..18a5c03c5e --- /dev/null +++ b/onnxscript/ir/passes/common/onnx_checker.py @@ -0,0 +1,53 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Passes for debugging purposes.""" + +from __future__ import annotations + +__all__ = [ + "CheckerPass", +] + +import onnx + +from onnxscript import ir +from onnxscript.ir.passes.common import _c_api_utils + + +class CheckerPass(ir.passes.PassBase): + """Run onnx checker on the model.""" + + @property + def in_place(self) -> bool: + return True + + @property + def changes_input(self) -> bool: + return False + + def __init__( + self, + full_check: bool = False, + skip_opset_compatibility_check: bool = False, + check_custom_domain: bool = False, + ): + super().__init__() + self.full_check = full_check + self.skip_opset_compatibility_check = skip_opset_compatibility_check + self.check_custom_domain = check_custom_domain + + def call(self, model: ir.Model) -> ir.passes.PassResult: + """Run the onnx checker on the model.""" + + def _partial_check_model(proto: onnx.ModelProto) -> None: + """Partial function to check the model.""" + onnx.checker.check_model( + proto, + full_check=self.full_check, + skip_opset_compatibility_check=self.skip_opset_compatibility_check, + check_custom_domain=self.check_custom_domain, + ) + + _c_api_utils.call_onnx_api(func=_partial_check_model, model=model) + # The model is not modified + return ir.passes.PassResult(model, False) diff --git a/onnxscript/ir/passes/common/onnx_checker_test.py b/onnxscript/ir/passes/common/onnx_checker_test.py new file mode 100644 index 0000000000..144225416d --- /dev/null +++ b/onnxscript/ir/passes/common/onnx_checker_test.py @@ -0,0 +1,79 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import unittest + +from onnxscript import ir +from onnxscript.ir.passes.common import onnx_checker + + +class TestCheckerPass(unittest.TestCase): + def test_pass_is_no_op(self): + checker_pass = onnx_checker.CheckerPass() + self.assertTrue(checker_pass.in_place) + self.assertFalse(checker_pass.changes_input) + + def test_check_simple_model(self): + inputs = [ + ir.Value( + name="input_a", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((1, 2)) + ), + ir.Value( + name="input_b", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((1, 2)) + ), + ] + + tape = ir.tape.Tape() + + output = tape.op("Add", inputs=inputs) + output.shape = ir.Shape((1, 2)) + output.dtype = ir.DataType.FLOAT + + model = ir.Model( + ir.Graph( + inputs=inputs, + outputs=[output], + nodes=tape.nodes, + opset_imports={"": 20}, + name="test_model", + ), + ir_version=10, + ) + # No exception should be raised + onnx_checker.CheckerPass()(model) + + def test_check_invalid_model(self): + inputs = [ + ir.Value( + name="input_a", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((1, 2)) + ), + ir.Value( + name="input_b", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((1, 2)) + ), + ] + + tape = ir.tape.Tape() + + output = tape.op("Add", inputs=inputs) + output.shape = ir.Shape((1, 2)) + output.dtype = ir.DataType.FLOAT + + model = ir.Model( + ir.Graph( + inputs=inputs, + outputs=[output], + nodes=tape.nodes, + opset_imports={"": 20}, + ), + ir_version=10, + ) + + with self.assertRaisesRegex( + Exception, "Field 'name' of 'graph' is required to be non-empty" + ): + onnx_checker.CheckerPass()(model) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/ir/passes/common/shape_inference.py b/onnxscript/ir/passes/common/shape_inference.py index f6d88584e7..586fa5b417 100644 --- a/onnxscript/ir/passes/common/shape_inference.py +++ b/onnxscript/ir/passes/common/shape_inference.py @@ -14,15 +14,43 @@ import onnx from onnxscript import ir +from onnxscript.ir.passes.common import _c_api_utils logger = logging.getLogger(__name__) -# Temporarily remove initializers larger than this size to keep model size down -# for the onnx.shape_inference call because it needs to serialize the model -_BIG_TENSOR_SIZE_LIMIT = 1000 # 1KB +def _merge_func(model: ir.Model, inferred_proto: onnx.ModelProto) -> bool: + """Merge the shape inferred model with the original model. -class ShapeInferencePass(ir.passes.FunctionalPass): + Args: + model: The original IR model. + inferred_proto: The ONNX model with shapes and types inferred. + + Returns: + A tuple containing the modified model and a boolean indicating whether the model was modified. + """ + inferred_model = ir.serde.deserialize_model(inferred_proto) + modified = False + for original_graph, inferred_graph in zip(model.graphs(), inferred_model.graphs()): + original_values = ir.convenience.create_value_mapping(original_graph) + inferred_values = ir.convenience.create_value_mapping(inferred_graph) + for name, value in original_values.items(): + if name in inferred_values: + inferred_value = inferred_values[name] + if value.shape != inferred_value.shape and inferred_value.shape is not None: + value.shape = inferred_value.shape + modified = True + if value.dtype != inferred_value.dtype and inferred_value.dtype is not None: + value.dtype = inferred_value.dtype + modified = True + else: + logger.warning( + "Value %s not found in inferred graph %s", name, inferred_graph.name + ) + return modified + + +class ShapeInferencePass(ir.passes.InPlacePass): """This pass performs shape inference on the graph.""" def __init__( @@ -30,6 +58,8 @@ def __init__( ) -> None: """Initialize the shape inference pass. + If inference fails, the model is left unchanged. + Args: check_type: If True, check the types of the inputs and outputs. strict_mode: If True, use strict mode for shape inference. @@ -41,75 +71,22 @@ def __init__( self.data_prop = data_prop def call(self, model: ir.Model) -> ir.passes.PassResult: - # Store the original initializer values so they can be restored - initializer_values = tuple(model.graph.initializers.values()) - tensors = {v.name: v.const_value for v in initializer_values} - original_inputs_len = len(model.graph.inputs) - initializer_names = {v.name for v in initializer_values} - - # Turn the initializers into inputs and clear the initializers - # to limit the model size - for initializer in initializer_values: - # Make sure the initializer has its shape/type set - assert initializer.const_value is not None - if initializer.shape is None: - initializer.shape = initializer.const_value.shape # type: ignore[assignment] - if initializer.dtype is None: - initializer.dtype = initializer.const_value.dtype - if initializer not in model.graph.inputs: - model.graph.inputs.append(initializer) - if initializer.const_value.nbytes > _BIG_TENSOR_SIZE_LIMIT: - # Temporarily remove the initializer value to reduce model size - # for onnx.shape_inference - initializer.const_value = None - assert initializer.name is not None - model.graph.initializers.pop(initializer.name) - - # Perform shape inference - try: - proto = ir.serde.serialize_model(model) - value_infos = {info.name: info for info in proto.graph.value_info} - inferred_proto = onnx.shape_inference.infer_shapes( + def partial_infer_shapes(proto: onnx.ModelProto) -> onnx.ModelProto: + return onnx.shape_inference.infer_shapes( proto, check_type=self.check_type, strict_mode=self.strict_mode, data_prop=self.data_prop, ) - inferred_value_infos = { - info.name: info for info in inferred_proto.graph.value_info - } - inferred_model = ir.serde.deserialize_model(inferred_proto) - - except Exception: # pylint: disable=broad-exception-caught - logger.warning("Shape inference failed. The model is not modified", exc_info=True) - return ir.passes.PassResult(model, modified=False) - finally: - # Restore the original initializer values so the model is unchanged - for initializer in initializer_values: - if initializer.name in initializer_names: - initializer.const_value = tensors[initializer.name] - model.graph.register_initializer(initializer) - - # Restore the original inputs - inputs = model.graph.inputs[:original_inputs_len] - model.graph.inputs.clear() - model.graph.inputs.extend(inputs) - - # Add the original initializer tensors to the new (inferred) model - for new_input in inferred_model.graph.inputs: - # Assign the tensors back to the initializers - if new_input.name in initializer_names: - new_input.const_value = tensors[new_input.name] - inferred_model.graph.register_initializer(new_input) - - # Remove the inputs that were added - new_inputs = inferred_model.graph.inputs[:original_inputs_len] - inferred_model.graph.inputs.clear() - inferred_model.graph.inputs.extend(new_inputs) - - return ir.passes.PassResult( - inferred_model, modified=value_infos != inferred_value_infos - ) + + try: + inferred_model_proto = _c_api_utils.call_onnx_api(partial_infer_shapes, model) + except Exception as e: # pylint: disable=broad-exception-caught + logger.warning("Shape inference failed: %s. Model is left unchanged", exc_info=e) + return ir.passes.PassResult(model, False) + + modified = _merge_func(model, inferred_model_proto) + return ir.passes.PassResult(model, modified=modified) def infer_shapes( diff --git a/onnxscript/ir/passes/common/shape_inference_test.py b/onnxscript/ir/passes/common/shape_inference_test.py index da67b4c1a7..5a2f02c64e 100644 --- a/onnxscript/ir/passes/common/shape_inference_test.py +++ b/onnxscript/ir/passes/common/shape_inference_test.py @@ -7,10 +7,13 @@ import numpy as np from onnxscript import ir -from onnxscript.ir.passes.common import shape_inference +from onnxscript.ir.passes.common import _c_api_utils, shape_inference class TestShapeInferencePass(unittest.TestCase): + def test_pass_is_in_place(self): + self.assertTrue(shape_inference.ShapeInferencePass().in_place) + def test_pass(self): # Create a simple ONNX model with shape inference # Define the model @@ -51,7 +54,7 @@ def test_pass_with_initializers(self): # _BIG_TENSOR_SIZE_LIMIT is in bytes, but we create big_dim as size # of a tensor. This is fine as we just need to create a big tensor whose size # passes _BIG_TENSOR_SIZE_LIMIT - big_dim = shape_inference._BIG_TENSOR_SIZE_LIMIT * 2 # pylint: disable=protected-access + big_dim = _c_api_utils._BIG_TENSOR_SIZE_LIMIT * 2 # pylint: disable=protected-access inputs = [ ir.Value( name="input_a", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((1, 2)) @@ -129,22 +132,6 @@ def test_pass_with_initializers(self): ir.DataType.FLOAT, ) - # Check that the original model is not modified - self.assertIsNone(val_add.shape) - self.assertIsNone(val_add.dtype) - self.assertIsNone(val_mul.shape) - self.assertIsNone(val_mul.dtype) - self.assertEqual(len(model.graph.inputs), 2) - self.assertEqual(len(model.graph.initializers), 2) - self.assertIs(model.graph.initializers["input_b"].const_value, inputs[1].const_value) - self.assertEqual(len(model.graph.outputs), 1) - self.assertEqual(model.graph.outputs[0].shape, None) - self.assertEqual(model.graph.outputs[0].dtype, None) - # Check that the initializer is not modified - self.assertIs( - model.graph.initializers["initializer"].const_value, initializer.const_value - ) - if __name__ == "__main__": unittest.main()