Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions onnxscript/rewriter/_basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,42 @@
import onnxscript.rewriter._rewrite_rule as _rewrite_rule


class MatchFailureInfo:
"""Encapsulates information about a pattern match failure."""

def __init__(
self,
reason: str = "",
*failure_source: ir.Node | ir.Value,
):
self.reason = reason
self.failure_sources: tuple[ir.Node | ir.Value, ...] = failure_source
assert all(isinstance(item, (ir.Node, ir.Value)) for item in failure_source), (
f"All items in failure_source must be ir.Node or ir.Value, got {[type(item) for item in failure_source]}"
)

def __str__(self):
return f"MatchFailureInfo(reason={self.reason!r}, failure_sources={self.failure_sources!r})"

Check warning on line 34 in onnxscript/rewriter/_basics.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/_basics.py#L34

Added line #L34 was not covered by tests


class MatchFailureError(MatchFailureInfo, Exception):
"""Exception raised when a pattern match fails.

This makes it easier to handle match failures in a compositional way,
for example, during the condition-checking phase of a pattern match.
It allows us to define utility functions without having to check for
and propagate match failures explicitly.
"""

def __init__(
self,
reason: str = "",
*failure_source: ir.Node | ir.Value,
):
MatchFailureInfo.__init__(self, reason, *failure_source)
Exception.__init__(self, reason)


class MatchResult:
"""The state object used by the pattern-matching algorithm.

Expand Down
19 changes: 19 additions & 0 deletions onnxscript/rewriter/_fusion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import onnxscript.ir as ir
import onnxscript.ir.passes.common as common_passes
from onnxscript.rewriter import pattern
from onnxscript.rewriter._basics import MatchFailureError

Dim = Union[int, ir.SymbolicDim]

Expand All @@ -24,6 +25,24 @@
return True


def check_shape(bindings: dict[str, Dim], val: ir.Value, shape: Sequence[str]):
if val.shape is None:
raise MatchFailureError(f"The shape of {val} is unknown.", val)

Check warning on line 30 in onnxscript/rewriter/_fusion_utils.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/_fusion_utils.py#L30

Added line #L30 was not covered by tests
if val.shape.rank() != len(shape):
raise MatchFailureError(

Check warning on line 32 in onnxscript/rewriter/_fusion_utils.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/_fusion_utils.py#L32

Added line #L32 was not covered by tests
f"The rank of {val} ({val.shape.rank()} does not match the expected rank {len(shape)}.",
val,
)
for i, (actual, expected) in enumerate(zip(val.shape, shape)):
if expected not in bindings:
bindings[expected] = actual # type: ignore[assignment]
elif actual != bindings[expected]:
raise MatchFailureError(
f"Dimension {i} of {val} ({actual}) does not have expected size ({bindings[expected]}).",
val,
)


def apply_fusion_rules(rules: pattern.RewriteRule | pattern.RewriteRuleSet) -> Callable:
"""
Apply the given fusion rules to the model and return the number of fusions applied.
Expand Down
6 changes: 5 additions & 1 deletion onnxscript/rewriter/_rewrite_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,11 @@ def try_rewrite(
if var.name is not None:
if var.name not in match.bindings:
match.bind(var.name, None)
check_match_result = self._condition_function(context, **match.bindings)
try:
check_match_result = self._condition_function(context, **match.bindings)
except _basics.MatchFailureError as e:
check_match_result = _basics.MatchResult()
check_match_result.fail(e.reason, list(e.failure_sources))
if not check_match_result:
# If check function was provided, but it failed, return the reason for failure to the tracer.
if isinstance(check_match_result, _basics.MatchResult):
Expand Down
13 changes: 13 additions & 0 deletions onnxscript/rewriter/ort_fusions/gqa_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __init__(self, *args, **kwargs):
"num_heads must be divisible by kv_num_heads"
)
self.num_groups = self.num_heads // self.kv_num_heads
self.total_seqlen = self.seqlen + self.past_seqlen

# Abbreviations
B = self.batchsize
Expand Down Expand Up @@ -311,12 +312,24 @@ def test_fusion(self):
onnx.TensorProto.FLOAT,
["B", self.seqlen, self.kv_num_heads, self.head_size],
)
key_transposed_value_info = onnx.helper.make_tensor_value_info(
"key_transposed",
onnx.TensorProto.FLOAT,
["B", self.num_heads, self.head_size, self.total_seqlen],
)
value_BHSDh_value_info = onnx.helper.make_tensor_value_info(
"value_BHSDh",
onnx.TensorProto.FLOAT,
["B", self.num_heads, self.total_seqlen, self.head_size],
)
source_model.graph.value_info.extend(
[
query_BHSDh_rope_value_info,
key_BHkvSDh_rope_value_info,
query_BSHDh_value_info,
key_BSHkvDh_value_info,
key_transposed_value_info,
value_BHSDh_value_info,
]
)

Expand Down
Loading
Loading