Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 20 additions & 6 deletions onnxscript/rewriter/_pattern_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,20 +76,33 @@ def __str__(self) -> str:
class AttrPattern(Pattern[ir.Attr]):
"""Base class for an attribute pattern. Matches any attribute value by default."""

def __init__(self, name: str | None):
def __init__(self, name: str | None, *, can_match_none: bool = False):
self._name = name
self._can_match_none = can_match_none

@property
def name(self) -> str | None:
return self._name

@property
def can_match_none(self) -> bool:
"""Indicates whether this pattern can match a None attribute."""
return self._can_match_none

def matches(self, attr: ir.Attr) -> bool:
return True

def __str__(self) -> str:
return self._name if self._name is not None else "anonymous:" + str(id(self))


class AttrVar(AttrPattern):
"""Represents a pattern variable used to match against attribute values."""

def __init__(self, name: str | None, *, can_match_none: bool = False):
super().__init__(name, can_match_none=can_match_none)


# TODO: Support tensors. Align with usage elsewhere.
SupportedAttrTypes = Union[
int,
Expand Down Expand Up @@ -129,11 +142,11 @@ def _to_attr_pattern(value: AttrPattern | ValuePattern | SupportedAttrTypes) ->
# annotations to distinguish between ValuePattern and AttrPattern, but forces users to
# use these type annotations.
# TODO: check for misuse at rule-creation time. (Currently will be caught by matcher at match-time.)
if value.can_match_none or value.check_method is not None:
if value.check_method is not None:
raise ValueError(
"Pattern variables used in attributes must not have can_match_none or check_method set."
"Pattern variables used in attributes must not have check_method set."
)
return AttrPattern(value.name)
return AttrVar(value.name, can_match_none=value.can_match_none)
if isinstance(value, (int, float, str)):
return AttrConstantPattern(value)
if isinstance(value, Sequence):
Expand Down Expand Up @@ -493,8 +506,9 @@ def matches(self, node: ir.Node, match: _basics.MatchResult) -> _basics.MatchRes
for name, attr_pattern in self.attributes.items():
attr_value = node.attributes.get(name)
if attr_value is None:
return match.fail(f"Attribute {name} not found in node.", node)
if not attr_pattern.matches(attr_value):
if not attr_pattern.can_match_none:
return match.fail(f"Attribute {name} not found in node.", node)
elif not attr_pattern.matches(attr_value):
return match.fail(
f"Attribute {name} mismatch: expected {attr_pattern}, got {attr_value}.",
node,
Expand Down
4 changes: 3 additions & 1 deletion onnxscript/rewriter/ort_fusions/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from onnxscript.rewriter.ort_fusions.gqa_packed_qkv import fuse_qkv_gqa
from onnxscript.rewriter.ort_fusions.mha import fuse_mha1, fuse_mha2
from onnxscript.rewriter.ort_fusions.mha_bias import fuse_mha_bias
from onnxscript.rewriter.ort_fusions.mha_scale import fuse_mha_scale
from onnxscript.rewriter.ort_fusions.rms_normalization import fuse_rms_normalization
from onnxscript.rewriter.ort_fusions.rotary_embedding import (
fuse_partial_rotary_embedding,
Expand Down Expand Up @@ -82,6 +83,7 @@ def fuse(func, **kwargs):
fusion_count["skip_rms_normalization"] = fuse(fuse_skip_rms_normalization)
fusion_count["rotary_embedding"] = fuse(fuse_rotary_embedding)
fusion_count["cos_sin_cache"] = fuse(fuse_cos_sin_cache)
common_passes.CommonSubexpressionEliminationPass()(model)
fusion_count["partial_rotary_embedding"] = fuse(fuse_partial_rotary_embedding)

# We apply shape inference after the SDPA fusion as new nodes are added
Expand All @@ -90,9 +92,9 @@ def fuse(func, **kwargs):

fusion_count["gqa"] = fuse(fuse_gqa)
fusion_count["packed_qkv_for_gqa"] = fuse(fuse_qkv_gqa)

fusion_count["mha1"] = fuse(fuse_mha1)
fusion_count["mha2"] = fuse(fuse_mha2)
fusion_count["mha_scale"] = fuse(fuse_mha_scale)
if (fusion_count["mha1"] == 0) and (fusion_count["mha2"] == 0):
fusion_count["mha_bias"] = 0
fusion_count["attention"] = 0
Expand Down
11 changes: 7 additions & 4 deletions onnxscript/rewriter/ort_fusions/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def pattern(
num_heads=num_heads,
# scale=scale,
_domain="com.microsoft",
_outputs=3,
_outputs=["mha_output", "present_key", "present_value"],
)
# Concat present_key and present_value to form present
present_key = op.Unsqueeze(present_key, [0])
Expand All @@ -132,7 +132,7 @@ def pattern(
num_heads=num_heads,
# scale=scale,
_domain="com.microsoft",
_outputs=1,
_outputs=["mha_output"],
)
return attention

Expand Down Expand Up @@ -260,6 +260,7 @@ def rewrite(
attention_bias,
num_heads,
# scale,
mha_output,
q_mul=None,
k_mul=None,
v_mul=None,
Expand All @@ -274,6 +275,8 @@ def rewrite(
if self._no_slice:
qkv_weight = op.Concat(q_mul, k_mul, v_mul, axis=1)

scale = mha_output.producer().attributes.get_float("scale", None)

if self._has_past:
attention, present = op.Attention(
input,
Expand All @@ -285,7 +288,7 @@ def rewrite(
# past_sequence_length
num_heads=num_heads,
qkv_hidden_sizes=qkv_hidden_sizes,
# scale=scale,
scale=scale,
_domain="com.microsoft",
_outputs=2,
)
Expand All @@ -302,7 +305,7 @@ def rewrite(
None, # past_sequence_length
num_heads=num_heads,
qkv_hidden_sizes=qkv_hidden_sizes,
# scale=scale,
scale=scale,
_domain="com.microsoft",
_outputs=1,
)
Expand Down
2 changes: 2 additions & 0 deletions onnxscript/rewriter/ort_fusions/attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,8 @@ def test_whisper_encoder(self):
mha_count = xformers.fuse_mha1(model)
mha_count += xformers.fuse_mha2(model)
self.assertGreater(mha_count, 0)
mha_scale_count = xformers.fuse_mha_scale(model)
self.assertGreater(mha_scale_count, 0)
fused_mha_bias_count = xformers.fuse_mha_bias(model)
self.assertGreater(fused_mha_bias_count, 0)
# TODO: Enable once source of discrepancy is found
Expand Down
41 changes: 5 additions & 36 deletions onnxscript/rewriter/ort_fusions/mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,12 @@ def __init__(
name,
*,
double_transpose: bool,
transpose_4d: bool,
pre_scale_q: bool,
is_rotary: bool,
has_past_present: bool,
is_cross_attention: bool,
):
super().__init__(name)
self._double_transpose = double_transpose
self._transpose_4d = transpose_4d
self._pre_scale_q = pre_scale_q
self._is_rotary = is_rotary
self._has_past_present = has_past_present
self._is_cross_attention = is_cross_attention
Expand All @@ -63,12 +59,9 @@ def pattern(
position_ids,
cos,
sin,
q_scale,
):
# First, query, key, and value are reshaped+transposed from (B, S, D) to (B, H, S, D/H)

if self._pre_scale_q:
query_BSD = op.Mul(query_BSD, q_scale)
# Reshape from (B, S, D) to (B, S, H, D/H)
query_BSHDh = op.Reshape(query_BSD, pattern.ANY_VALUE, _outputs=["query_BSHDh"])
# Transpose from (B, S, H, D/H) to (B, H, S, D/H)
Expand All @@ -93,24 +86,12 @@ def pattern(
value_BHSDh = value

if self._is_rotary:
# This is workaround for examples where there is a duplication of Unsqueeze op
# to generate a 2D positions-ids from a 1D position-ids. This can be eliminated
# if we have CSE-optimization to eliminate the duplicate Unsqueeze ops.
# For now, same flag (transpose_4d) controls this variation. A different flag
# can be added if we see instances that mix the two.
if self._transpose_4d:
position_ids_q = op.Unsqueeze(position_ids, [0])
position_ids_k = op.Unsqueeze(position_ids, [0])
else:
position_ids_q = position_ids
position_ids_k = position_ids

query_BHSDh_emb = op.RotaryEmbedding(
query_BHSDh, position_ids_q, cos, sin, _domain="com.microsoft"
query_BHSDh, position_ids, cos, sin, _domain="com.microsoft"
)
if not self._is_cross_attention:
key_BHSDh_emb = op.RotaryEmbedding(
key, position_ids_k, cos, sin, _domain="com.microsoft"
key, position_ids, cos, sin, _domain="com.microsoft"
)
else:
key_BHSDh_emb = key
Expand Down Expand Up @@ -289,6 +270,7 @@ def no_match(val: ir.Value, dims: Sequence[str]) -> bool:
else:
self._use_mask_broadcast = False

self._scale = sdpa_node.attributes.get_float("scale", None)
# TODO: verify Reshapes:
# eg.: verify bindings["B"] * bindings["H"] == bindings["B*H"]:
# and bindings["H"] * bindings["Dh"] == bindings["H*Dh"]:
Expand All @@ -307,20 +289,14 @@ def rewrite(
position_ids,
cos,
sin,
q_scale=None,
**_,
):
scale = _ir_utils.get_singleton_value(q_scale)
num_heads = _ir_utils.get_dim(query_BSHDh, 2)
if not isinstance(num_heads, int):
return None

# TODO: forward other attributes

if self._transpose_4d:
zero_1d = op.Constant(value_ints=[0])
position_ids = op.Unsqueeze(position_ids, zero_1d)

if self._is_rotary:
query_BSD_emb = op.RotaryEmbedding(
query_BSD, position_ids, cos, sin, _domain="com.microsoft"
Expand Down Expand Up @@ -360,27 +336,21 @@ def rewrite(
past_key,
past_value,
num_heads=num_heads,
scale=scale,
_domain="com.microsoft",
_outputs=num_outputs,
scale=self._scale,
)


def _make_rule_set(has_past_present: bool):
parameter_combinations = [
{
"double_transpose": double_transpose,
"transpose_4d": transpose_4d,
"pre_scale_q": pre_scale_q,
"is_rotary": is_rotary,
"has_past_present": has_past_present,
"is_cross_attention": is_cross_attention,
}
for double_transpose in [False, True]
for transpose_4d in (
[False, True] if double_transpose else [False]
) # Only generate patterns when double_transpose is True
for pre_scale_q in [True, False]
for is_rotary in [False, True]
for is_cross_attention in ([False] if has_past_present else [False, True])
]
Expand All @@ -389,9 +359,8 @@ def _make_rule_set(has_past_present: bool):
mha_rules = pattern.RewriteRuleSet(
[
MultiHeadAttention.rule(
f"MHA_{'4D' if params['transpose_4d'] else '3D'}_Transpose"
f"MHA"
f"{'_Twice' if params['double_transpose'] else ''}"
f"{'_PreScaleQ' if params['pre_scale_q'] else ''}"
f"{'_Rotary' if params['is_rotary'] else ''}"
f"{'_Past' if params['has_past_present'] else ''}"
f"{'_CrossAttention' if params['is_cross_attention'] else ''}",
Expand Down
7 changes: 3 additions & 4 deletions onnxscript/rewriter/ort_fusions/mha_bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ def pattern(
past_key,
past_value,
num_heads,
# scale,
):
query_BSD = pattern.OrValue(
[op.Add(query_matmul, q_bias), query_matmul],
Expand Down Expand Up @@ -56,7 +55,7 @@ def pattern(
pattern.Var("past_key", can_match_none=True),
pattern.Var("past_value", can_match_none=True),
num_heads=num_heads,
# scale=scale,
scale=pattern.AttrVar("scale", can_match_none=True),
_domain="com.microsoft",
)

Expand Down Expand Up @@ -132,7 +131,7 @@ def rewrite(
past_key,
past_value,
num_heads,
# scale,
scale,
**_,
):
if q_bias is None:
Expand All @@ -158,7 +157,7 @@ def rewrite(
past_key,
past_value,
num_heads=num_heads,
# scale=scale,
scale=scale,
_domain="com.microsoft",
)

Expand Down
68 changes: 68 additions & 0 deletions onnxscript/rewriter/ort_fusions/mha_scale.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations

import math

from onnxscript.rewriter import _fusion_utils, _ir_utils, pattern

"""
Multi-Head Attention (MHA) pre-scaling fusion patterns.

This module contains rewrite rules for fusing scale operations that occur before
Multi-Head Attention operations. The fusion optimizes patterns where a query tensor
is scaled before being passed to MHA by incorporating the scaling directly into
the MHA operation.

Example pattern:
query -> Mul(scale) -> MultiHeadAttention -> output

Gets rewritten to:
query -> MultiHeadAttention(with integrated scaling) -> output
"""


class FuseMHAScale(pattern.RewriteRuleClassBase):
def pattern(self, op, query, scale):
scaled_query = op.Mul(query, scale)
mha_output = op.MultiHeadAttention(
scaled_query,
_allow_other_inputs=True,
_domain="com.microsoft",
_outputs=["mha_output"],
)
return mha_output

def check(self, context, scale, **_):
scale_value = _ir_utils.get_singleton_value(scale)
if scale_value is None or not isinstance(scale_value, (int, float)):
return pattern.MatchResult().fail("Scale must be a constant numeric value.", scale)
self._scale = scale_value
return True

def rewrite(self, op, query, mha_output, **_):
# Integrate the scale into the MHA operation
mha_node = mha_output.producer()
assert mha_node is not None
# Compute original scale factor for MHA:
attributes = mha_node.attributes
original_scale = attributes.get_float("scale", None)
if original_scale is None:
num_heads = attributes.get_int("num_heads", None)
if num_heads is None:
return None
head_size = query.shape[-1] // num_heads
original_scale = 1.0 / math.sqrt(head_size)
self._scale *= original_scale
inputs = list(mha_node.inputs)
inputs[0] = query
attributes = dict(attributes)
attributes["scale"] = self._scale
return op.MultiHeadAttention(
*inputs, **attributes, _domain="com.microsoft", _outputs=1
)


_mha_scale_rules = pattern.RewriteRuleSet([FuseMHAScale.rule()])

fuse_mha_scale = _fusion_utils.apply_fusion_rules(_mha_scale_rules)
2 changes: 2 additions & 0 deletions onnxscript/rewriter/pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from onnxscript.rewriter._matcher import PatternMatcher, SimplePatternMatcher
from onnxscript.rewriter._pattern_ir import (
ANY_VALUE,
AttrVar,
Constant,
OpsetPatternBuilder,
OrValue,
Expand All @@ -26,6 +27,7 @@

__all__ = [
"ANY_VALUE",
"AttrVar",
"OrValue",
"Constant",
"OpsetPatternBuilder",
Expand Down
Loading