Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions onnxscript/optimizer/_constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import onnx.reference.ops

import onnxscript.ir as ir
import onnxscript.rewriter.pattern as orp
import onnxscript.ir._tape as _tape
import onnxscript.utils.utils as utils

DEFAULT_CONSTANT_FOLD_INPUT_SIZE_LIMIT = 1024
Expand Down Expand Up @@ -202,10 +202,9 @@ def get_shape_value(self, value: ir.Value | None) -> ir.Shape | None:
# the ir.Value or ir.Values to replace the output values of the node, when the new nodes
# can be inferred from the RewriterContext used to build the new nodes.

RewriterContext = _tape.Builder
ReturnValue = Union[Replacement, Sequence[ir.Value], ir.Value, None]
PartialEvaluatorFunction = Callable[
[ir.Node, orp.RewriterContext, OptimizerState], ReturnValue
]
PartialEvaluatorFunction = Callable[[ir.Node, RewriterContext, OptimizerState], ReturnValue]


@dataclasses.dataclass
Expand Down Expand Up @@ -991,7 +990,7 @@ def process_node(self, node: ir.Node) -> Replacement | None:
op_optimizers = registry.lookup_evaluators(node.domain, node.op_type, version)
for optimizer in op_optimizers:
assert optimizer
context = orp.RewriterContext()
context = RewriterContext()
output = optimizer(node, context, self._state)
if output is not None:
if isinstance(output, Replacement):
Expand Down
7 changes: 4 additions & 3 deletions onnxscript/version_converter/_version_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
import logging
from typing import Callable, Sequence, Union

import onnxscript.ir._tape as _tape
import onnxscript.ir.convenience as ir_convenience
import onnxscript.rewriter.pattern as orp
from onnxscript import ir

logger = logging.getLogger(__name__)
Expand All @@ -35,8 +35,9 @@ class Replacement:
# A version-adapter function takes a node, a RewriterContext and returns
# a Replacement for the node or None (if no replacement is needed).

RewriterContext = _tape.Builder
ReturnValue = Union[Sequence[ir.Value], ir.Value, None]
AdapterFunction = Callable[[ir.Node, orp.RewriterContext], ReturnValue]
AdapterFunction = Callable[[ir.Node, RewriterContext], ReturnValue]


def version_supported(model: ir.Model, target_version: int) -> bool:
Expand Down Expand Up @@ -236,7 +237,7 @@ def process_node(
)
if adapter is None:
return None
context = orp.RewriterContext()
context = RewriterContext()
output = adapter(node, context)
if output is not None:
if isinstance(output, ir.Value):
Expand Down
Loading