From f9a464d4765be810b7daf83ab837de3d2bf867d5 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 15 Apr 2025 17:27:31 -0700 Subject: [PATCH 1/6] [passes] Move inliner to common passes --- .../passes/common/inliner.py} | 27 ++++++++++--------- .../passes/common/inliner_test.py} | 0 onnxscript/optimizer/__init__.py | 8 +++++- 3 files changed, 22 insertions(+), 13 deletions(-) rename onnxscript/{optimizer/_inliner.py => ir/passes/common/inliner.py} (96%) rename onnxscript/{optimizer/_inliner_test.py => ir/passes/common/inliner_test.py} (100%) diff --git a/onnxscript/optimizer/_inliner.py b/onnxscript/ir/passes/common/inliner.py similarity index 96% rename from onnxscript/optimizer/_inliner.py rename to onnxscript/ir/passes/common/inliner.py index ac9bf71010..5535073c0d 100644 --- a/onnxscript/optimizer/_inliner.py +++ b/onnxscript/ir/passes/common/inliner.py @@ -3,6 +3,9 @@ """Implementation of an inliner for onnxscript.ir""" from __future__ import annotations +import dataclasses + +__all__ = ["InlinePass"] from collections import defaultdict from typing import Iterable, List, Sequence, Tuple @@ -188,6 +191,12 @@ def id_abbreviation(id: ir.OperatorIdentifier) -> str: return {id: id_abbreviation(id) for id in function_ids} +@dataclasses.dataclass +class InlinePassResult(ir.passes.PassResult): + id_count: dict[ir.OperatorIdentifier, int] + + + class InlinePass(ir.passes.InPlacePass): def __init__(self) -> None: super().__init__() @@ -208,9 +217,9 @@ def _reset(self, model: ir.Model) -> None: def call(self, model: ir.Model) -> ir.passes.PassResult: self._reset(model) - modified = self.inline_calls_in(model.graph) + id_count = self._inline_calls_in(model.graph) model.functions.clear() - return ir.passes.PassResult(model, modified) + return InlinePassResult(model, modified=bool(id_count), id_count=id_count) def _instantiate_call(self, node: ir.Node, call_site_id: CallSiteId) -> NodeReplacement: id = node.op_identifier() @@ -264,7 +273,7 @@ def _instantiate_call(self, node: ir.Node, call_site_id: CallSiteId) -> NodeRepl output_values = [value_map[output] for output in function.outputs] return nodes, output_values # type: ignore - def inline_calls_in(self, graph: ir.Graph) -> bool: + def _inline_calls_in(self, graph: ir.Graph) -> dict[ir.OperatorIdentifier, int]: for input in graph.inputs: if input.name is not None: self.used_value_names.add(input.name) @@ -313,14 +322,8 @@ def inline_calls_in(self, graph: ir.Graph) -> bool: if not isinstance(attr, ir.Attr): continue if attr.type == ir.AttributeType.GRAPH: - self.inline_calls_in(attr.as_graph()) + self._inline_calls_in(attr.as_graph()) elif attr.type == ir.AttributeType.GRAPHS: for graph in attr.as_graphs(): - self.inline_calls_in(graph) - return bool(id_count) - - -def inline(model: ir.Model) -> None: - """Inline all function calls (recursively) in the model.""" - if model.functions: - InlinePass()(model) + self._inline_calls_in(graph) + return id_count diff --git a/onnxscript/optimizer/_inliner_test.py b/onnxscript/ir/passes/common/inliner_test.py similarity index 100% rename from onnxscript/optimizer/_inliner_test.py rename to onnxscript/ir/passes/common/inliner_test.py diff --git a/onnxscript/optimizer/__init__.py b/onnxscript/optimizer/__init__.py index 3b25d2d3ee..3ea759db6c 100644 --- a/onnxscript/optimizer/__init__.py +++ b/onnxscript/optimizer/__init__.py @@ -15,11 +15,11 @@ import onnx import onnxscript.ir.passes.common.unused_removal +import onnxscript.ir.passes.common.inliner import onnxscript.optimizer._constant_folding as constant_folding import onnxscript.optimizer._legacy._optimizer as legacy_optimizer import onnxscript.optimizer._legacy.constant_folding as legacy_constant_folding from onnxscript import ir -from onnxscript.optimizer._inliner import inline from onnxscript.optimizer._optimizer import optimize_ir basic_constant_propagation = constant_folding.basic_constant_propagation @@ -35,6 +35,12 @@ def optimize(model: ir.Model, *args, **kwargs) -> ir.Model: return legacy_optimizer.optimize(model, *args, **kwargs) +def inline(model: ir.Model) -> None: + """Inline all function calls (recursively) in the model.""" + if model.functions: + onnxscript.ir.passes.common.inliner.InlinePass()(model) + + def fold_constants( model: ir.Model | onnx.ModelProto, *args, **kwargs ) -> constant_folding.FoldConstantsResult | bool: From 6f2116158966de102a1eb60fcdae6616fddf968e Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 15 Apr 2025 17:29:13 -0700 Subject: [PATCH 2/6] rename --- onnxscript/ir/passes/common/inliner.py | 2 +- onnxscript/ir/passes/common/inliner_test.py | 4 ++-- onnxscript/optimizer/__init__.py | 2 +- onnxscript/optimizer/_optimizer.py | 5 +++-- 4 files changed, 7 insertions(+), 6 deletions(-) diff --git a/onnxscript/ir/passes/common/inliner.py b/onnxscript/ir/passes/common/inliner.py index 5535073c0d..4c834abe38 100644 --- a/onnxscript/ir/passes/common/inliner.py +++ b/onnxscript/ir/passes/common/inliner.py @@ -3,6 +3,7 @@ """Implementation of an inliner for onnxscript.ir""" from __future__ import annotations + import dataclasses __all__ = ["InlinePass"] @@ -196,7 +197,6 @@ class InlinePassResult(ir.passes.PassResult): id_count: dict[ir.OperatorIdentifier, int] - class InlinePass(ir.passes.InPlacePass): def __init__(self) -> None: super().__init__() diff --git a/onnxscript/ir/passes/common/inliner_test.py b/onnxscript/ir/passes/common/inliner_test.py index e7e3bbadc1..8c4515fb7e 100644 --- a/onnxscript/ir/passes/common/inliner_test.py +++ b/onnxscript/ir/passes/common/inliner_test.py @@ -11,7 +11,7 @@ from onnx import parser from onnxscript import ir -from onnxscript.optimizer._inliner import inline +from onnxscript.ir.passes.common import inliner def _name_checker(renameable: Sequence[str] | None) -> Callable[[str, str], bool]: @@ -46,7 +46,7 @@ def _check( name_check = _name_checker(renameable) model_proto = parser.parse_model(input_model) model_ir = ir.serde.deserialize_model(model_proto) - inline(model_ir) + inliner.InlinePass()(model_ir) proto = ir.serde.serialize_model(model_ir) text = onnx.printer.to_text(proto) print(text) diff --git a/onnxscript/optimizer/__init__.py b/onnxscript/optimizer/__init__.py index 3ea759db6c..b073b3345e 100644 --- a/onnxscript/optimizer/__init__.py +++ b/onnxscript/optimizer/__init__.py @@ -14,8 +14,8 @@ import onnx -import onnxscript.ir.passes.common.unused_removal import onnxscript.ir.passes.common.inliner +import onnxscript.ir.passes.common.unused_removal import onnxscript.optimizer._constant_folding as constant_folding import onnxscript.optimizer._legacy._optimizer as legacy_optimizer import onnxscript.optimizer._legacy.constant_folding as legacy_constant_folding diff --git a/onnxscript/optimizer/_optimizer.py b/onnxscript/optimizer/_optimizer.py index 9dfeb53da3..60bee72b92 100644 --- a/onnxscript/optimizer/_optimizer.py +++ b/onnxscript/optimizer/_optimizer.py @@ -5,9 +5,10 @@ import logging import onnxscript.ir.passes.common.constant_manipulation +import onnxscript.ir.passes.common.inliner import onnxscript.ir.passes.common.unused_removal from onnxscript import ir, rewriter -from onnxscript.optimizer import _constant_folding, _inliner +from onnxscript.optimizer import _constant_folding logger = logging.getLogger(__name__) @@ -35,7 +36,7 @@ def optimize_ir( outer optimization loop if no change is detected in one iteration. """ optimizer_pass = ir.passes.Sequential( - _inliner.InlinePass(), + onnxscript.ir.passes.common.inliner.InlinePass(), ir.passes.PassManager( [ _constant_folding.FoldConstantsPass( From 887da54b315e507a6e900a1c77b8484c78b3d522 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 15 Apr 2025 18:15:51 -0700 Subject: [PATCH 3/6] fix test --- onnxscript/ir/passes/common/inliner_test.py | 2 +- onnxscript/version_converter/__init__.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/onnxscript/ir/passes/common/inliner_test.py b/onnxscript/ir/passes/common/inliner_test.py index 8c4515fb7e..7fef143e55 100644 --- a/onnxscript/ir/passes/common/inliner_test.py +++ b/onnxscript/ir/passes/common/inliner_test.py @@ -1,6 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -"""Tests for onnxscript.optimizer._inliner""" +"""Tests for the inliner pass.""" from __future__ import annotations diff --git a/onnxscript/version_converter/__init__.py b/onnxscript/version_converter/__init__.py index 299373f9c0..20b7d9c24b 100644 --- a/onnxscript/version_converter/__init__.py +++ b/onnxscript/version_converter/__init__.py @@ -7,8 +7,8 @@ "convert_version", ] +import onnxscript.optimizer from onnxscript import ir -from onnxscript.optimizer import _inliner from onnxscript.version_converter import _version_converter @@ -17,5 +17,5 @@ def convert_version(model: ir.Model, target_version: int) -> None: # In functions, we can have attribute-parameters, which means we don't know the value of the attribute. # Hence, we inline all the functions. - _inliner.inline(model) + onnxscript.optimizer.inline(model) _version_converter.convert_version(model, target_version) From 6e51c84bc1020e4453deff522332598d689d0d8a Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 15 Apr 2025 18:26:42 -0700 Subject: [PATCH 4/6] lint --- onnxscript/ir/passes/common/inliner.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxscript/ir/passes/common/inliner.py b/onnxscript/ir/passes/common/inliner.py index 4c834abe38..cb67567a6f 100644 --- a/onnxscript/ir/passes/common/inliner.py +++ b/onnxscript/ir/passes/common/inliner.py @@ -6,7 +6,7 @@ import dataclasses -__all__ = ["InlinePass"] +__all__ = ["InlinePass", "InlinePassResult"] from collections import defaultdict from typing import Iterable, List, Sequence, Tuple @@ -215,7 +215,7 @@ def _reset(self, model: ir.Model) -> None: self.used_node_names = set() self.node_context = {} - def call(self, model: ir.Model) -> ir.passes.PassResult: + def call(self, model: ir.Model) -> InlinePassResult: self._reset(model) id_count = self._inline_calls_in(model.graph) model.functions.clear() From 15d35ca908c1a5a5dc5ef51777548e9d99485796 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 16 Apr 2025 08:13:40 -0700 Subject: [PATCH 5/6] lint --- onnxscript/ir/passes/common/inliner.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/onnxscript/ir/passes/common/inliner.py b/onnxscript/ir/passes/common/inliner.py index cb67567a6f..66f4c7a74d 100644 --- a/onnxscript/ir/passes/common/inliner.py +++ b/onnxscript/ir/passes/common/inliner.py @@ -11,8 +11,8 @@ from collections import defaultdict from typing import Iterable, List, Sequence, Tuple -import onnxscript.ir as ir -import onnxscript.ir.convenience as ir_convenience +import onnxscript.ir.convenience as _ir_convenience +from onnxscript import ir # A replacement for a node specifies a list of nodes that replaces the original node, # and a list of values that replaces the original node's outputs. @@ -309,7 +309,7 @@ def _inline_calls_in(self, graph: ir.Graph) -> dict[ir.OperatorIdentifier, int]: self._function_id_abbreviations[id] + call_site_prefix ) nodes, values = self._instantiate_call(node, call_site) - ir_convenience.replace_nodes_and_values( + _ir_convenience.replace_nodes_and_values( graph, insertion_point=node, old_nodes=[node], From a4768b532bbf0399dd8942a30202e4df8785d110 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Wed, 16 Apr 2025 08:23:38 -0700 Subject: [PATCH 6/6] lint --- onnxscript/ir/passes/common/inliner.py | 8 ++++---- onnxscript/ir/passes/common/inliner_test.py | 5 +---- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/onnxscript/ir/passes/common/inliner.py b/onnxscript/ir/passes/common/inliner.py index 66f4c7a74d..5cefc94268 100644 --- a/onnxscript/ir/passes/common/inliner.py +++ b/onnxscript/ir/passes/common/inliner.py @@ -26,7 +26,7 @@ CallStack = List[CallSiteId] -def _make_unique_name(name: str, callstack: CallStack, used_names: set[str]) -> str: +def _make_unique_name(name: str, callstack: CallStack, used_names: set[str]) -> str: # pylint: disable=unused-argument """Generate a unique name from a name, calling-context, and set of used names. If there is a name clash, we add a numeric suffix to the name to make @@ -244,7 +244,7 @@ def _instantiate_call(self, node: ir.Node, call_site_id: CallSiteId) -> NodeRepl if default_attr_values: attributes = {**attributes, **default_attr_values} if any( - attr.type == ir.AttributeType.GRAPH or attr.type == ir.AttributeType.GRAPHS + attr.type in {ir.AttributeType.GRAPH, ir.AttributeType.GRAPHS} for attr in attributes.values() ): raise ValueError( @@ -324,6 +324,6 @@ def _inline_calls_in(self, graph: ir.Graph) -> dict[ir.OperatorIdentifier, int]: if attr.type == ir.AttributeType.GRAPH: self._inline_calls_in(attr.as_graph()) elif attr.type == ir.AttributeType.GRAPHS: - for graph in attr.as_graphs(): - self._inline_calls_in(graph) + for g in attr.as_graphs(): + self._inline_calls_in(g) return id_count diff --git a/onnxscript/ir/passes/common/inliner_test.py b/onnxscript/ir/passes/common/inliner_test.py index 7fef143e55..7a64a8d4b4 100644 --- a/onnxscript/ir/passes/common/inliner_test.py +++ b/onnxscript/ir/passes/common/inliner_test.py @@ -68,10 +68,7 @@ def _check( self.assertTrue(isinstance(value, ir.Attr)) self.assertTrue(isinstance(expected_value, ir.Attr)) self.assertEqual(value.type, expected_value.type) - if ( - value.type != ir.AttributeType.GRAPH - and value.type != ir.AttributeType.GRAPHS - ): + if value.type not in (ir.AttributeType.GRAPH, ir.AttributeType.GRAPHS): self.assertEqual(value.value, expected_value.value) else: self.fail("Graph attributes are not supported yet")