Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions onnxscript/ir/passes/common/_c_api_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Utilities for interfacing with onnx C APIs."""

from __future__ import annotations

import logging
from typing import TYPE_CHECKING, Callable, TypeVar

from onnxscript import ir

if TYPE_CHECKING:
import onnx

Check warning on line 13 in onnxscript/ir/passes/common/_c_api_utils.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/passes/common/_c_api_utils.py#L13

Added line #L13 was not covered by tests


logger = logging.getLogger(__name__)
# Temporarily remove initializers larger than this size to keep model size down
# for the onnx.shape_inference call because it needs to serialize the model
_BIG_TENSOR_SIZE_LIMIT = 1000 # 1KB
_R = TypeVar("_R")


def call_onnx_api(func: Callable[[onnx.ModelProto], _R], model: ir.Model) -> _R:
"""Call an ONNX C API function by temporarily removing initializers.

This is necessary because the ONNX C API does not support large models
with initializers that have large tensor values. The input model is left
unchanged no matter the call succeeds or not.

Args:
func: Partially applied function that takes a model proto and returns anything.
model: The IR model to pass to the API function.

Returns:
The resulting ModelProto that contains the result of the API call.
"""

# Store the original initializer values so they can be restored
initializer_values = tuple(model.graph.initializers.values())
tensors = {v.name: v.const_value for v in initializer_values}
original_inputs_len = len(model.graph.inputs)

# Turn the initializers into inputs and clear the initializers
# to limit the model size
for initializer in initializer_values:
# Make sure the initializer has its shape/type set
assert initializer.const_value is not None
if initializer.shape is None:
initializer.shape = initializer.const_value.shape # type: ignore[assignment]
if initializer.dtype is None:
initializer.dtype = initializer.const_value.dtype
if initializer not in model.graph.inputs:
model.graph.inputs.append(initializer)
if initializer.const_value.nbytes > _BIG_TENSOR_SIZE_LIMIT:
# Temporarily remove the initializer value to reduce model size
# for onnx.shape_inference
initializer.const_value = None
assert initializer.name is not None
model.graph.initializers.pop(initializer.name)

proto = ir.serde.serialize_model(model)

try:
# Call the ONNX C API function
result = func(proto)
finally:
# Restore the original initializer values so the model is unchanged
for initializer in initializer_values:
initializer.const_value = tensors[initializer.name]
model.graph.register_initializer(initializer)

# Restore the original inputs
inputs = model.graph.inputs[:original_inputs_len]
model.graph.inputs.clear()
model.graph.inputs.extend(inputs)

return result
53 changes: 53 additions & 0 deletions onnxscript/ir/passes/common/onnx_checker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Passes for debugging purposes."""

from __future__ import annotations

__all__ = [
"CheckerPass",
]

import onnx

from onnxscript import ir
from onnxscript.ir.passes.common import _c_api_utils


class CheckerPass(ir.passes.PassBase):
"""Run onnx checker on the model."""

@property
def in_place(self) -> bool:
return True

@property
def changes_input(self) -> bool:
return False

def __init__(
self,
full_check: bool = False,
skip_opset_compatibility_check: bool = False,
check_custom_domain: bool = False,
):
super().__init__()
self.full_check = full_check
self.skip_opset_compatibility_check = skip_opset_compatibility_check
self.check_custom_domain = check_custom_domain

def call(self, model: ir.Model) -> ir.passes.PassResult:
"""Run the onnx checker on the model."""

def _partial_check_model(proto: onnx.ModelProto) -> None:
"""Partial function to check the model."""
onnx.checker.check_model(
proto,
full_check=self.full_check,
skip_opset_compatibility_check=self.skip_opset_compatibility_check,
check_custom_domain=self.check_custom_domain,
)

_c_api_utils.call_onnx_api(func=_partial_check_model, model=model)
# The model is not modified
return ir.passes.PassResult(model, False)
79 changes: 79 additions & 0 deletions onnxscript/ir/passes/common/onnx_checker_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations

import unittest

from onnxscript import ir
from onnxscript.ir.passes.common import onnx_checker


class TestCheckerPass(unittest.TestCase):
def test_pass_is_no_op(self):
checker_pass = onnx_checker.CheckerPass()
self.assertTrue(checker_pass.in_place)
self.assertFalse(checker_pass.changes_input)

def test_check_simple_model(self):
inputs = [
ir.Value(
name="input_a", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((1, 2))
),
ir.Value(
name="input_b", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((1, 2))
),
]

tape = ir.tape.Tape()

output = tape.op("Add", inputs=inputs)
output.shape = ir.Shape((1, 2))
output.dtype = ir.DataType.FLOAT

model = ir.Model(
ir.Graph(
inputs=inputs,
outputs=[output],
nodes=tape.nodes,
opset_imports={"": 20},
name="test_model",
),
ir_version=10,
)
# No exception should be raised
onnx_checker.CheckerPass()(model)

def test_check_invalid_model(self):
inputs = [
ir.Value(
name="input_a", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((1, 2))
),
ir.Value(
name="input_b", type=ir.TensorType(ir.DataType.FLOAT), shape=ir.Shape((1, 2))
),
]

tape = ir.tape.Tape()

output = tape.op("Add", inputs=inputs)
output.shape = ir.Shape((1, 2))
output.dtype = ir.DataType.FLOAT

model = ir.Model(
ir.Graph(
inputs=inputs,
outputs=[output],
nodes=tape.nodes,
opset_imports={"": 20},
),
ir_version=10,
)

with self.assertRaisesRegex(
Exception, "Field 'name' of 'graph' is required to be non-empty"
):
onnx_checker.CheckerPass()(model)


if __name__ == "__main__":
unittest.main()

Check warning on line 79 in onnxscript/ir/passes/common/onnx_checker_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/passes/common/onnx_checker_test.py#L79

Added line #L79 was not covered by tests
113 changes: 45 additions & 68 deletions onnxscript/ir/passes/common/shape_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,52 @@
import onnx

from onnxscript import ir
from onnxscript.ir.passes.common import _c_api_utils

logger = logging.getLogger(__name__)

# Temporarily remove initializers larger than this size to keep model size down
# for the onnx.shape_inference call because it needs to serialize the model
_BIG_TENSOR_SIZE_LIMIT = 1000 # 1KB

def _merge_func(model: ir.Model, inferred_proto: onnx.ModelProto) -> bool:
"""Merge the shape inferred model with the original model.

class ShapeInferencePass(ir.passes.FunctionalPass):
Args:
model: The original IR model.
inferred_proto: The ONNX model with shapes and types inferred.

Returns:
A tuple containing the modified model and a boolean indicating whether the model was modified.
"""
inferred_model = ir.serde.deserialize_model(inferred_proto)
modified = False
for original_graph, inferred_graph in zip(model.graphs(), inferred_model.graphs()):
original_values = ir.convenience.create_value_mapping(original_graph)
inferred_values = ir.convenience.create_value_mapping(inferred_graph)
for name, value in original_values.items():
if name in inferred_values:
inferred_value = inferred_values[name]
if value.shape != inferred_value.shape and inferred_value.shape is not None:
value.shape = inferred_value.shape
modified = True
if value.dtype != inferred_value.dtype and inferred_value.dtype is not None:
value.dtype = inferred_value.dtype
modified = True
else:
logger.warning(

Check warning on line 47 in onnxscript/ir/passes/common/shape_inference.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/passes/common/shape_inference.py#L47

Added line #L47 was not covered by tests
"Value %s not found in inferred graph %s", name, inferred_graph.name
)
return modified


class ShapeInferencePass(ir.passes.InPlacePass):
"""This pass performs shape inference on the graph."""

def __init__(
self, check_type: bool = True, strict_mode: bool = True, data_prop: bool = True
) -> None:
"""Initialize the shape inference pass.

If inference fails, the model is left unchanged.

Args:
check_type: If True, check the types of the inputs and outputs.
strict_mode: If True, use strict mode for shape inference.
Expand All @@ -41,75 +71,22 @@
self.data_prop = data_prop

def call(self, model: ir.Model) -> ir.passes.PassResult:
# Store the original initializer values so they can be restored
initializer_values = tuple(model.graph.initializers.values())
tensors = {v.name: v.const_value for v in initializer_values}
original_inputs_len = len(model.graph.inputs)
initializer_names = {v.name for v in initializer_values}

# Turn the initializers into inputs and clear the initializers
# to limit the model size
for initializer in initializer_values:
# Make sure the initializer has its shape/type set
assert initializer.const_value is not None
if initializer.shape is None:
initializer.shape = initializer.const_value.shape # type: ignore[assignment]
if initializer.dtype is None:
initializer.dtype = initializer.const_value.dtype
if initializer not in model.graph.inputs:
model.graph.inputs.append(initializer)
if initializer.const_value.nbytes > _BIG_TENSOR_SIZE_LIMIT:
# Temporarily remove the initializer value to reduce model size
# for onnx.shape_inference
initializer.const_value = None
assert initializer.name is not None
model.graph.initializers.pop(initializer.name)

# Perform shape inference
try:
proto = ir.serde.serialize_model(model)
value_infos = {info.name: info for info in proto.graph.value_info}
inferred_proto = onnx.shape_inference.infer_shapes(
def partial_infer_shapes(proto: onnx.ModelProto) -> onnx.ModelProto:
return onnx.shape_inference.infer_shapes(
proto,
check_type=self.check_type,
strict_mode=self.strict_mode,
data_prop=self.data_prop,
)
inferred_value_infos = {
info.name: info for info in inferred_proto.graph.value_info
}
inferred_model = ir.serde.deserialize_model(inferred_proto)

except Exception: # pylint: disable=broad-exception-caught
logger.warning("Shape inference failed. The model is not modified", exc_info=True)
return ir.passes.PassResult(model, modified=False)
finally:
# Restore the original initializer values so the model is unchanged
for initializer in initializer_values:
if initializer.name in initializer_names:
initializer.const_value = tensors[initializer.name]
model.graph.register_initializer(initializer)

# Restore the original inputs
inputs = model.graph.inputs[:original_inputs_len]
model.graph.inputs.clear()
model.graph.inputs.extend(inputs)

# Add the original initializer tensors to the new (inferred) model
for new_input in inferred_model.graph.inputs:
# Assign the tensors back to the initializers
if new_input.name in initializer_names:
new_input.const_value = tensors[new_input.name]
inferred_model.graph.register_initializer(new_input)

# Remove the inputs that were added
new_inputs = inferred_model.graph.inputs[:original_inputs_len]
inferred_model.graph.inputs.clear()
inferred_model.graph.inputs.extend(new_inputs)

return ir.passes.PassResult(
inferred_model, modified=value_infos != inferred_value_infos
)

try:
inferred_model_proto = _c_api_utils.call_onnx_api(partial_infer_shapes, model)
except Exception as e: # pylint: disable=broad-exception-caught
logger.warning("Shape inference failed: %s. Model is left unchanged", exc_info=e)
return ir.passes.PassResult(model, False)

Check warning on line 86 in onnxscript/ir/passes/common/shape_inference.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/passes/common/shape_inference.py#L84-L86

Added lines #L84 - L86 were not covered by tests

modified = _merge_func(model, inferred_model_proto)
return ir.passes.PassResult(model, modified=modified)


def infer_shapes(
Expand Down
Loading
Loading