Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 23 additions & 10 deletions onnxscript/optimizer/_inliner.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class _CopyReplace:

def __init__(
self,
inliner: _Inliner,
inliner: InlinePass,
attr_map: dict[str, ir.Attr | ir.RefAttr],
value_map: dict[ir.Value, ir.Value | None],
metadata_props: dict[str, str],
Expand Down Expand Up @@ -188,15 +188,29 @@ def id_abbreviation(id: ir.OperatorIdentifier) -> str:
return {id: id_abbreviation(id) for id in function_ids}


class _Inliner:
def __init__(self, model: ir.Model) -> None:
self._functions = model.functions
self._function_id_abbreviations = _abbreviate(self._functions.keys())
self._opset_imports = model.opset_imports
class InlinePass(ir.passes.InPlacePass):
def __init__(self) -> None:
self._functions: dict[ir.OperatorIdentifier, ir.Function] = {}
self._function_id_abbreviations: dict[ir.OperatorIdentifier, str] = {}
self._opset_imports: dict[str, int] = {}
self.used_value_names: set[str] = set()
self.used_node_names: set[str] = set()
self.node_context: dict[ir.Node, CallStack] = {}

def _reset(self, model: ir.Model) -> None:
self._functions = model.functions
self._function_id_abbreviations = _abbreviate(self._functions.keys())
self._opset_imports = model.opset_imports
self.used_value_names = set()
self.used_node_names = set()
self.node_context = {}

def call(self, model: ir.Model) -> ir.passes.PassResult:
self._reset(model)
modified = self.inline_calls_in(model.graph)
model.functions.clear()
return ir.passes.PassResult(model, modified)

def _instantiate_call(self, node: ir.Node, call_site_id: CallSiteId) -> NodeReplacement:
id = node.op_identifier()
function = self._functions[id]
Expand Down Expand Up @@ -249,7 +263,7 @@ def _instantiate_call(self, node: ir.Node, call_site_id: CallSiteId) -> NodeRepl
output_values = [value_map[output] for output in function.outputs]
return nodes, output_values # type: ignore

def inline_calls_in(self, graph: ir.Graph) -> None:
def inline_calls_in(self, graph: ir.Graph) -> bool:
for input in graph.inputs:
if input.name is not None:
self.used_value_names.add(input.name)
Expand Down Expand Up @@ -302,11 +316,10 @@ def inline_calls_in(self, graph: ir.Graph) -> None:
elif attr.type == ir.AttributeType.GRAPHS:
for graph in attr.as_graphs():
self.inline_calls_in(graph)
return bool(id_count)


def inline(model: ir.Model) -> None:
"""Inline all function calls (recursively) in the model."""
if model.functions:
inliner = _Inliner(model)
inliner.inline_calls_in(model.graph)
model.functions.clear()
InlinePass()(model)
36 changes: 24 additions & 12 deletions onnxscript/optimizer/_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import logging

import onnxscript.ir.passes.common.unused_removal
import onnxscript.optimizer
from onnxscript import ir, rewriter
from onnxscript.optimizer import _constant_folding, _inliner
Expand Down Expand Up @@ -50,15 +51,26 @@ def optimize_ir(
stop_if_no_change: Not supported currently (has no effect). Meant to stop the
outer optimization loop if no change is detected in one iteration.
"""
del stop_if_no_change # Looks like rewriter doesn't support this yet.
# TODO(justinchuby): Update this to use a pass manager
_inliner.inline(model)
for _ in range(num_iterations):
_constant_folding.fold_constants(
model,
onnx_shape_inference=onnx_shape_inference,
input_size_limit=input_size_limit,
output_size_limit=output_size_limit,
)
rewriter.rewrite(model, pattern_rewrite_rules=_DEFAULT_REWRITE_RULES)
onnxscript.optimizer.remove_unused_nodes(model)
optimizer_pass = ir.passes.Sequential(
_inliner.InlinePass(),
ir.passes.PassManager(
[
_constant_folding.FoldConstantsPass(
external_data_folder="",
shape_inference=onnx_shape_inference,
input_size_limit=input_size_limit,
output_size_limit=output_size_limit,
),
rewriter.RewritePass(_DEFAULT_REWRITE_RULES),
onnxscript.ir.passes.common.unused_removal.RemoveUnusedNodesPass(),
onnxscript.ir.passes.common.unused_removal.RemoveUnusedFunctionsPass(),
onnxscript.ir.passes.common.unused_removal.RemoveUnusedOpsetsPass(),
],
steps=num_iterations,
early_stop=stop_if_no_change,
),
onnxscript.ir.passes.common.unused_removal.RemoveUnusedNodesPass(),
)
assert optimizer_pass.in_place
result = optimizer_pass(model)
assert result.model is model
35 changes: 24 additions & 11 deletions onnxscript/rewriter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,37 +17,50 @@
from onnxscript.ir.passes.common import unused_removal
from onnxscript.rewriter import pattern

RewriteRuleSet = pattern.RewriteRuleSet
PatternRewriteRule = pattern.RewriteRule

ModelProtoOrIr = TypeVar("ModelProtoOrIr", onnx.ModelProto, ir.Model)


class RewritePass(ir.passes.InPlacePass):
def __init__(
self,
pattern_rewrite_rules: Sequence[PatternRewriteRule] | pattern.RewriteRuleSet = (),
) -> None:
if pattern_rewrite_rules:
if not isinstance(pattern_rewrite_rules, pattern.RewriteRuleSet):
# Create a pattern rule-set using provided rules
pattern_rewrite_rules = pattern.RewriteRuleSet(pattern_rewrite_rules)
assert isinstance(pattern_rewrite_rules, pattern.RewriteRuleSet)
self.pattern_rewrite_rules: pattern.RewriteRuleSet = pattern_rewrite_rules

def call(self, model: ir.Model) -> ir.passes.PassResult:
count = self.pattern_rewrite_rules.apply_to_model(model)
if count:
print(f"Applied {count} of general pattern rewrite rules.")
return ir.passes.PassResult(model, bool(count))


def rewrite(
model: ModelProtoOrIr,
pattern_rewrite_rules: Union[Sequence[PatternRewriteRule], RewriteRuleSet] = (),
pattern_rewrite_rules: Union[Sequence[PatternRewriteRule], pattern.RewriteRuleSet] = (),
) -> ModelProtoOrIr:
if isinstance(model, onnx.ModelProto):
model_ir = ir.serde.deserialize_model(model)
proto = True
else:
model_ir = model
proto = False
if pattern_rewrite_rules:
if not isinstance(pattern_rewrite_rules, RewriteRuleSet):
# Create a pattern rule-set using provided rules
pattern_rewrite_rules = pattern.RewriteRuleSet(pattern_rewrite_rules)
count = pattern_rewrite_rules.apply_to_model(model_ir)
if count:
print(f"Applied {count} of general pattern rewrite rules.")
unused_remover = ir.passes.PassManager(

rewrite_pass = ir.passes.PassManager(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is causing a regression when no rewrite-rules are specified. I think we should retain an if pattern_rewrite_rules condition and avoid the RewritePass if no rules are specified.

(
RewritePass(pattern_rewrite_rules),
unused_removal.RemoveUnusedNodesPass(),
unused_removal.RemoveUnusedFunctionsPass(),
unused_removal.RemoveUnusedOpsetsPass(),
)
)
model_ir = unused_remover(model_ir).model
model_ir = rewrite_pass(model_ir).model
if proto:
return ir.serde.serialize_model(model_ir)
return model_ir # type: ignore[return-value]
Loading