diff --git a/onnxscript/optimizer/_inliner.py b/onnxscript/optimizer/_inliner.py index 8936a8adbf..ac9bf71010 100644 --- a/onnxscript/optimizer/_inliner.py +++ b/onnxscript/optimizer/_inliner.py @@ -190,6 +190,7 @@ def id_abbreviation(id: ir.OperatorIdentifier) -> str: class InlinePass(ir.passes.InPlacePass): def __init__(self) -> None: + super().__init__() self._functions: dict[ir.OperatorIdentifier, ir.Function] = {} self._function_id_abbreviations: dict[ir.OperatorIdentifier, str] = {} self._opset_imports: dict[str, int] = {} diff --git a/onnxscript/optimizer/_legacy/_optimizer.py b/onnxscript/optimizer/_legacy/_optimizer.py index eef56bdd33..829eb9c25f 100644 --- a/onnxscript/optimizer/_legacy/_optimizer.py +++ b/onnxscript/optimizer/_legacy/_optimizer.py @@ -15,7 +15,6 @@ inline_simple_functions, ) from onnxscript.optimizer._legacy.constant_folding import fold_constants -from onnxscript.optimizer._optimizer import _DEFAULT_REWRITE_RULES logger = logging.getLogger(__name__) @@ -75,7 +74,7 @@ def optimize( onnxscript.optimizer.remove_unused_functions(model) inline_functions_with_unused_outputs(model) # NOTE: This is general rewrite rules - model = rewriter.rewrite(model, pattern_rewrite_rules=_DEFAULT_REWRITE_RULES) + model = rewriter.rewrite(model) if stop_if_no_change and not modified: logger.debug("Stopping after %d iterations.", _) break diff --git a/onnxscript/optimizer/_optimizer.py b/onnxscript/optimizer/_optimizer.py index d3784ce40b..4b2ab2223f 100644 --- a/onnxscript/optimizer/_optimizer.py +++ b/onnxscript/optimizer/_optimizer.py @@ -5,29 +5,11 @@ import logging import onnxscript.ir.passes.common.unused_removal -import onnxscript.optimizer from onnxscript import ir, rewriter from onnxscript.optimizer import _constant_folding, _inliner -from onnxscript.rewriter import ( - broadcast_to_matmul, - cast_constant_of_shape, - collapse_slices, - gemm_to_matmul_add, - llama_rule_sets, - no_op, -) logger = logging.getLogger(__name__) -_DEFAULT_REWRITE_RULES: tuple[rewriter.pattern.RewriteRule, ...] = ( - *no_op.rules.rules, # TODO: merge this rule into constant folding? - *broadcast_to_matmul.rules.rules, - gemm_to_matmul_add.rule, # type: ignore[has-type] - *cast_constant_of_shape.rules.rules, - *collapse_slices.rules.rules, - *llama_rule_sets.llama_p0_rule_set().rules, -) - def optimize_ir( model: ir.Model, @@ -61,7 +43,7 @@ def optimize_ir( input_size_limit=input_size_limit, output_size_limit=output_size_limit, ), - rewriter.RewritePass(_DEFAULT_REWRITE_RULES), + rewriter.RewritePass(rewriter._DEFAULT_REWRITE_RULES), onnxscript.ir.passes.common.unused_removal.RemoveUnusedNodesPass(), onnxscript.ir.passes.common.unused_removal.RemoveUnusedFunctionsPass(), onnxscript.ir.passes.common.unused_removal.RemoveUnusedOpsetsPass(), diff --git a/onnxscript/rewriter/__init__.py b/onnxscript/rewriter/__init__.py index c43b3d875e..5efaf784b0 100644 --- a/onnxscript/rewriter/__init__.py +++ b/onnxscript/rewriter/__init__.py @@ -5,46 +5,81 @@ from typing import Sequence, TypeVar, Union __all__ = [ - # Modules "pattern", - # Functions "rewrite", + "RewritePass", ] import onnx from onnxscript import ir from onnxscript.ir.passes.common import unused_removal -from onnxscript.rewriter import pattern +from onnxscript.rewriter import ( + broadcast_to_matmul, + cast_constant_of_shape, + collapse_slices, + gemm_to_matmul_add, + llama_rule_sets, + no_op, + pattern, +) -PatternRewriteRule = pattern.RewriteRule - -ModelProtoOrIr = TypeVar("ModelProtoOrIr", onnx.ModelProto, ir.Model) +_ModelProtoOrIr = TypeVar("_ModelProtoOrIr", onnx.ModelProto, ir.Model) +_DEFAULT_REWRITE_RULES: tuple[pattern.RewriteRule, ...] = ( + *no_op.rules.rules, # TODO: merge this rule into constant folding? + *broadcast_to_matmul.rules.rules, + gemm_to_matmul_add.rule, # type: ignore[has-type] + *cast_constant_of_shape.rules.rules, + *collapse_slices.rules.rules, + *llama_rule_sets.llama_p0_rule_set().rules, +) class RewritePass(ir.passes.InPlacePass): def __init__( self, - pattern_rewrite_rules: Sequence[PatternRewriteRule] | pattern.RewriteRuleSet = (), + rules: Sequence[pattern.RewriteRule] | pattern.RewriteRuleSet, + /, ) -> None: - if pattern_rewrite_rules: - if not isinstance(pattern_rewrite_rules, pattern.RewriteRuleSet): - # Create a pattern rule-set using provided rules - pattern_rewrite_rules = pattern.RewriteRuleSet(pattern_rewrite_rules) - assert isinstance(pattern_rewrite_rules, pattern.RewriteRuleSet) - self.pattern_rewrite_rules: pattern.RewriteRuleSet = pattern_rewrite_rules + super().__init__() + if isinstance(rules, Sequence): + if not rules: + raise ValueError("rules must not be empty") + # Create a pattern rule-set using provided rules + rules = pattern.RewriteRuleSet(rules) + assert isinstance(rules, pattern.RewriteRuleSet) + self.rules: pattern.RewriteRuleSet = rules def call(self, model: ir.Model) -> ir.passes.PassResult: - count = self.pattern_rewrite_rules.apply_to_model(model) + count = self.rules.apply_to_model(model) if count: print(f"Applied {count} of general pattern rewrite rules.") return ir.passes.PassResult(model, bool(count)) def rewrite( - model: ModelProtoOrIr, - pattern_rewrite_rules: Union[Sequence[PatternRewriteRule], pattern.RewriteRuleSet] = (), -) -> ModelProtoOrIr: + model: _ModelProtoOrIr, + pattern_rewrite_rules: Union[Sequence[pattern.RewriteRule], pattern.RewriteRuleSet] + | None = None, +) -> _ModelProtoOrIr: + """Rewrite the model using the provided pattern rewrite rules. + + Unused nodes, functions, and opsets will be removed after the rewrite. + + Args: + model: The model to be rewritten. Can be an ONNX ModelProto or an ir.Model. + pattern_rewrite_rules: A sequence of pattern rewrite rules or a RewriteRuleSet. + If not provided, default rules will be applied. If empty, no rules will be applied + and the original model will be returned. + + Returns: + The rewritten model as the same type as the input model. + """ + if pattern_rewrite_rules is None: + pattern_rewrite_rules = _DEFAULT_REWRITE_RULES + elif not pattern_rewrite_rules: + return model + if isinstance(model, onnx.ModelProto): model_ir = ir.serde.deserialize_model(model) proto = True diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index 793675b4ab..907ebd0b88 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -1664,6 +1664,8 @@ def _get_new_overload(model: ir.Model, domain: str, name: str) -> str: class RewriteRuleSet: def __init__(self, rules: Sequence[RewriteRule], *, commute: bool = False) -> None: + if not rules: + raise ValueError("rules must contain at least one rule") if commute: rules = list(itertools.chain.from_iterable([rule.commute() for rule in rules])) self.rules = rules @@ -1671,6 +1673,9 @@ def __init__(self, rules: Sequence[RewriteRule], *, commute: bool = False) -> No # NOT remove nodes (immediately when it is applied) self.remove_unused_nodes = any(not rule.remove_nodes for rule in rules) + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.rules})" + def _apply_to_graph_or_function( self, model: ir.Model,