Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,15 @@

from __future__ import annotations

import dataclasses

__all__ = ["InlinePass", "InlinePassResult"]

from collections import defaultdict
from typing import Iterable, List, Sequence, Tuple

import onnxscript.ir as ir
import onnxscript.ir.convenience as ir_convenience
import onnxscript.ir.convenience as _ir_convenience
from onnxscript import ir

# A replacement for a node specifies a list of nodes that replaces the original node,
# and a list of values that replaces the original node's outputs.
Expand All @@ -22,7 +26,7 @@
CallStack = List[CallSiteId]


def _make_unique_name(name: str, callstack: CallStack, used_names: set[str]) -> str:
def _make_unique_name(name: str, callstack: CallStack, used_names: set[str]) -> str: # pylint: disable=unused-argument
"""Generate a unique name from a name, calling-context, and set of used names.

If there is a name clash, we add a numeric suffix to the name to make
Expand Down Expand Up @@ -188,6 +192,11 @@ def id_abbreviation(id: ir.OperatorIdentifier) -> str:
return {id: id_abbreviation(id) for id in function_ids}


@dataclasses.dataclass
class InlinePassResult(ir.passes.PassResult):
id_count: dict[ir.OperatorIdentifier, int]


class InlinePass(ir.passes.InPlacePass):
def __init__(self) -> None:
super().__init__()
Expand All @@ -206,11 +215,11 @@ def _reset(self, model: ir.Model) -> None:
self.used_node_names = set()
self.node_context = {}

def call(self, model: ir.Model) -> ir.passes.PassResult:
def call(self, model: ir.Model) -> InlinePassResult:
self._reset(model)
modified = self.inline_calls_in(model.graph)
id_count = self._inline_calls_in(model.graph)
model.functions.clear()
return ir.passes.PassResult(model, modified)
return InlinePassResult(model, modified=bool(id_count), id_count=id_count)

def _instantiate_call(self, node: ir.Node, call_site_id: CallSiteId) -> NodeReplacement:
id = node.op_identifier()
Expand All @@ -235,7 +244,7 @@ def _instantiate_call(self, node: ir.Node, call_site_id: CallSiteId) -> NodeRepl
if default_attr_values:
attributes = {**attributes, **default_attr_values}
if any(
attr.type == ir.AttributeType.GRAPH or attr.type == ir.AttributeType.GRAPHS
attr.type in {ir.AttributeType.GRAPH, ir.AttributeType.GRAPHS}
for attr in attributes.values()
):
raise ValueError(
Expand Down Expand Up @@ -264,7 +273,7 @@ def _instantiate_call(self, node: ir.Node, call_site_id: CallSiteId) -> NodeRepl
output_values = [value_map[output] for output in function.outputs]
return nodes, output_values # type: ignore

def inline_calls_in(self, graph: ir.Graph) -> bool:
def _inline_calls_in(self, graph: ir.Graph) -> dict[ir.OperatorIdentifier, int]:
for input in graph.inputs:
if input.name is not None:
self.used_value_names.add(input.name)
Expand Down Expand Up @@ -300,7 +309,7 @@ def inline_calls_in(self, graph: ir.Graph) -> bool:
self._function_id_abbreviations[id] + call_site_prefix
)
nodes, values = self._instantiate_call(node, call_site)
ir_convenience.replace_nodes_and_values(
_ir_convenience.replace_nodes_and_values(
graph,
insertion_point=node,
old_nodes=[node],
Expand All @@ -313,14 +322,8 @@ def inline_calls_in(self, graph: ir.Graph) -> bool:
if not isinstance(attr, ir.Attr):
continue
if attr.type == ir.AttributeType.GRAPH:
self.inline_calls_in(attr.as_graph())
self._inline_calls_in(attr.as_graph())
elif attr.type == ir.AttributeType.GRAPHS:
for graph in attr.as_graphs():
self.inline_calls_in(graph)
return bool(id_count)


def inline(model: ir.Model) -> None:
"""Inline all function calls (recursively) in the model."""
if model.functions:
InlinePass()(model)
for g in attr.as_graphs():
self._inline_calls_in(g)
return id_count
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Tests for onnxscript.optimizer._inliner"""
"""Tests for the inliner pass."""

from __future__ import annotations

Expand All @@ -11,7 +11,7 @@
from onnx import parser

from onnxscript import ir
from onnxscript.optimizer._inliner import inline
from onnxscript.ir.passes.common import inliner


def _name_checker(renameable: Sequence[str] | None) -> Callable[[str, str], bool]:
Expand Down Expand Up @@ -46,7 +46,7 @@ def _check(
name_check = _name_checker(renameable)
model_proto = parser.parse_model(input_model)
model_ir = ir.serde.deserialize_model(model_proto)
inline(model_ir)
inliner.InlinePass()(model_ir)
proto = ir.serde.serialize_model(model_ir)
text = onnx.printer.to_text(proto)
print(text)
Expand All @@ -68,10 +68,7 @@ def _check(
self.assertTrue(isinstance(value, ir.Attr))
self.assertTrue(isinstance(expected_value, ir.Attr))
self.assertEqual(value.type, expected_value.type)
if (
value.type != ir.AttributeType.GRAPH
and value.type != ir.AttributeType.GRAPHS
):
if value.type not in (ir.AttributeType.GRAPH, ir.AttributeType.GRAPHS):
self.assertEqual(value.value, expected_value.value)
else:
self.fail("Graph attributes are not supported yet")
Expand Down
8 changes: 7 additions & 1 deletion onnxscript/optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@

import onnx

import onnxscript.ir.passes.common.inliner
import onnxscript.ir.passes.common.unused_removal
import onnxscript.optimizer._constant_folding as constant_folding
import onnxscript.optimizer._legacy._optimizer as legacy_optimizer
import onnxscript.optimizer._legacy.constant_folding as legacy_constant_folding
from onnxscript import ir
from onnxscript.optimizer._inliner import inline
from onnxscript.optimizer._optimizer import optimize_ir

basic_constant_propagation = constant_folding.basic_constant_propagation
Expand All @@ -35,6 +35,12 @@ def optimize(model: ir.Model, *args, **kwargs) -> ir.Model:
return legacy_optimizer.optimize(model, *args, **kwargs)


def inline(model: ir.Model) -> None:
"""Inline all function calls (recursively) in the model."""
if model.functions:
onnxscript.ir.passes.common.inliner.InlinePass()(model)


def fold_constants(
model: ir.Model | onnx.ModelProto, *args, **kwargs
) -> constant_folding.FoldConstantsResult | bool:
Expand Down
5 changes: 3 additions & 2 deletions onnxscript/optimizer/_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
import logging

import onnxscript.ir.passes.common.constant_manipulation
import onnxscript.ir.passes.common.inliner
import onnxscript.ir.passes.common.unused_removal
from onnxscript import ir, rewriter
from onnxscript.optimizer import _constant_folding, _inliner
from onnxscript.optimizer import _constant_folding

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -35,7 +36,7 @@ def optimize_ir(
outer optimization loop if no change is detected in one iteration.
"""
optimizer_pass = ir.passes.Sequential(
_inliner.InlinePass(),
onnxscript.ir.passes.common.inliner.InlinePass(),
ir.passes.PassManager(
[
_constant_folding.FoldConstantsPass(
Expand Down
4 changes: 2 additions & 2 deletions onnxscript/version_converter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
"convert_version",
]

import onnxscript.optimizer
from onnxscript import ir
from onnxscript.optimizer import _inliner
from onnxscript.version_converter import _version_converter


Expand All @@ -17,5 +17,5 @@ def convert_version(model: ir.Model, target_version: int) -> None:

# In functions, we can have attribute-parameters, which means we don't know the value of the attribute.
# Hence, we inline all the functions.
_inliner.inline(model)
onnxscript.optimizer.inline(model)
_version_converter.convert_version(model, target_version)
Loading