diff --git a/onnxscript/rewriter/_pattern_ir.py b/onnxscript/rewriter/_pattern_ir.py index 1687897737..f64d3fca3c 100644 --- a/onnxscript/rewriter/_pattern_ir.py +++ b/onnxscript/rewriter/_pattern_ir.py @@ -76,13 +76,19 @@ def __str__(self) -> str: class AttrPattern(Pattern[ir.Attr]): """Base class for an attribute pattern. Matches any attribute value by default.""" - def __init__(self, name: str | None): + def __init__(self, name: str | None, *, can_match_none: bool = False): self._name = name + self._can_match_none = can_match_none @property def name(self) -> str | None: return self._name + @property + def can_match_none(self) -> bool: + """Indicates whether this pattern can match a None attribute.""" + return self._can_match_none + def matches(self, attr: ir.Attr) -> bool: return True @@ -90,6 +96,13 @@ def __str__(self) -> str: return self._name if self._name is not None else "anonymous:" + str(id(self)) +class AttrVar(AttrPattern): + """Represents a pattern variable used to match against attribute values.""" + + def __init__(self, name: str | None, *, can_match_none: bool = False): + super().__init__(name, can_match_none=can_match_none) + + # TODO: Support tensors. Align with usage elsewhere. SupportedAttrTypes = Union[ int, @@ -129,11 +142,11 @@ def _to_attr_pattern(value: AttrPattern | ValuePattern | SupportedAttrTypes) -> # annotations to distinguish between ValuePattern and AttrPattern, but forces users to # use these type annotations. # TODO: check for misuse at rule-creation time. (Currently will be caught by matcher at match-time.) - if value.can_match_none or value.check_method is not None: + if value.check_method is not None: raise ValueError( - "Pattern variables used in attributes must not have can_match_none or check_method set." + "Pattern variables used in attributes must not have check_method set." ) - return AttrPattern(value.name) + return AttrVar(value.name, can_match_none=value.can_match_none) if isinstance(value, (int, float, str)): return AttrConstantPattern(value) if isinstance(value, Sequence): @@ -493,8 +506,9 @@ def matches(self, node: ir.Node, match: _basics.MatchResult) -> _basics.MatchRes for name, attr_pattern in self.attributes.items(): attr_value = node.attributes.get(name) if attr_value is None: - return match.fail(f"Attribute {name} not found in node.", node) - if not attr_pattern.matches(attr_value): + if not attr_pattern.can_match_none: + return match.fail(f"Attribute {name} not found in node.", node) + elif not attr_pattern.matches(attr_value): return match.fail( f"Attribute {name} mismatch: expected {attr_pattern}, got {attr_value}.", node, diff --git a/onnxscript/rewriter/ort_fusions/_core.py b/onnxscript/rewriter/ort_fusions/_core.py index 5657f1d30a..ed33807db9 100644 --- a/onnxscript/rewriter/ort_fusions/_core.py +++ b/onnxscript/rewriter/ort_fusions/_core.py @@ -22,6 +22,7 @@ from onnxscript.rewriter.ort_fusions.gqa_packed_qkv import fuse_qkv_gqa from onnxscript.rewriter.ort_fusions.mha import fuse_mha1, fuse_mha2 from onnxscript.rewriter.ort_fusions.mha_bias import fuse_mha_bias +from onnxscript.rewriter.ort_fusions.mha_scale import fuse_mha_scale from onnxscript.rewriter.ort_fusions.rms_normalization import fuse_rms_normalization from onnxscript.rewriter.ort_fusions.rotary_embedding import ( fuse_partial_rotary_embedding, @@ -82,6 +83,7 @@ def fuse(func, **kwargs): fusion_count["skip_rms_normalization"] = fuse(fuse_skip_rms_normalization) fusion_count["rotary_embedding"] = fuse(fuse_rotary_embedding) fusion_count["cos_sin_cache"] = fuse(fuse_cos_sin_cache) + common_passes.CommonSubexpressionEliminationPass()(model) fusion_count["partial_rotary_embedding"] = fuse(fuse_partial_rotary_embedding) # We apply shape inference after the SDPA fusion as new nodes are added @@ -90,9 +92,9 @@ def fuse(func, **kwargs): fusion_count["gqa"] = fuse(fuse_gqa) fusion_count["packed_qkv_for_gqa"] = fuse(fuse_qkv_gqa) - fusion_count["mha1"] = fuse(fuse_mha1) fusion_count["mha2"] = fuse(fuse_mha2) + fusion_count["mha_scale"] = fuse(fuse_mha_scale) if (fusion_count["mha1"] == 0) and (fusion_count["mha2"] == 0): fusion_count["mha_bias"] = 0 fusion_count["attention"] = 0 diff --git a/onnxscript/rewriter/ort_fusions/attention.py b/onnxscript/rewriter/ort_fusions/attention.py index ffbe131233..4a4cd0ad8e 100644 --- a/onnxscript/rewriter/ort_fusions/attention.py +++ b/onnxscript/rewriter/ort_fusions/attention.py @@ -111,7 +111,7 @@ def pattern( num_heads=num_heads, # scale=scale, _domain="com.microsoft", - _outputs=3, + _outputs=["mha_output", "present_key", "present_value"], ) # Concat present_key and present_value to form present present_key = op.Unsqueeze(present_key, [0]) @@ -132,7 +132,7 @@ def pattern( num_heads=num_heads, # scale=scale, _domain="com.microsoft", - _outputs=1, + _outputs=["mha_output"], ) return attention @@ -260,6 +260,7 @@ def rewrite( attention_bias, num_heads, # scale, + mha_output, q_mul=None, k_mul=None, v_mul=None, @@ -274,6 +275,8 @@ def rewrite( if self._no_slice: qkv_weight = op.Concat(q_mul, k_mul, v_mul, axis=1) + scale = mha_output.producer().attributes.get_float("scale", None) + if self._has_past: attention, present = op.Attention( input, @@ -285,7 +288,7 @@ def rewrite( # past_sequence_length num_heads=num_heads, qkv_hidden_sizes=qkv_hidden_sizes, - # scale=scale, + scale=scale, _domain="com.microsoft", _outputs=2, ) @@ -302,7 +305,7 @@ def rewrite( None, # past_sequence_length num_heads=num_heads, qkv_hidden_sizes=qkv_hidden_sizes, - # scale=scale, + scale=scale, _domain="com.microsoft", _outputs=1, ) diff --git a/onnxscript/rewriter/ort_fusions/attention_test.py b/onnxscript/rewriter/ort_fusions/attention_test.py index d4e485428b..4559bc205c 100644 --- a/onnxscript/rewriter/ort_fusions/attention_test.py +++ b/onnxscript/rewriter/ort_fusions/attention_test.py @@ -176,6 +176,8 @@ def test_whisper_encoder(self): mha_count = xformers.fuse_mha1(model) mha_count += xformers.fuse_mha2(model) self.assertGreater(mha_count, 0) + mha_scale_count = xformers.fuse_mha_scale(model) + self.assertGreater(mha_scale_count, 0) fused_mha_bias_count = xformers.fuse_mha_bias(model) self.assertGreater(fused_mha_bias_count, 0) # TODO: Enable once source of discrepancy is found diff --git a/onnxscript/rewriter/ort_fusions/mha.py b/onnxscript/rewriter/ort_fusions/mha.py index e9f752acca..433c10e504 100644 --- a/onnxscript/rewriter/ort_fusions/mha.py +++ b/onnxscript/rewriter/ort_fusions/mha.py @@ -38,16 +38,12 @@ def __init__( name, *, double_transpose: bool, - transpose_4d: bool, - pre_scale_q: bool, is_rotary: bool, has_past_present: bool, is_cross_attention: bool, ): super().__init__(name) self._double_transpose = double_transpose - self._transpose_4d = transpose_4d - self._pre_scale_q = pre_scale_q self._is_rotary = is_rotary self._has_past_present = has_past_present self._is_cross_attention = is_cross_attention @@ -63,12 +59,9 @@ def pattern( position_ids, cos, sin, - q_scale, ): # First, query, key, and value are reshaped+transposed from (B, S, D) to (B, H, S, D/H) - if self._pre_scale_q: - query_BSD = op.Mul(query_BSD, q_scale) # Reshape from (B, S, D) to (B, S, H, D/H) query_BSHDh = op.Reshape(query_BSD, pattern.ANY_VALUE, _outputs=["query_BSHDh"]) # Transpose from (B, S, H, D/H) to (B, H, S, D/H) @@ -93,24 +86,12 @@ def pattern( value_BHSDh = value if self._is_rotary: - # This is workaround for examples where there is a duplication of Unsqueeze op - # to generate a 2D positions-ids from a 1D position-ids. This can be eliminated - # if we have CSE-optimization to eliminate the duplicate Unsqueeze ops. - # For now, same flag (transpose_4d) controls this variation. A different flag - # can be added if we see instances that mix the two. - if self._transpose_4d: - position_ids_q = op.Unsqueeze(position_ids, [0]) - position_ids_k = op.Unsqueeze(position_ids, [0]) - else: - position_ids_q = position_ids - position_ids_k = position_ids - query_BHSDh_emb = op.RotaryEmbedding( - query_BHSDh, position_ids_q, cos, sin, _domain="com.microsoft" + query_BHSDh, position_ids, cos, sin, _domain="com.microsoft" ) if not self._is_cross_attention: key_BHSDh_emb = op.RotaryEmbedding( - key, position_ids_k, cos, sin, _domain="com.microsoft" + key, position_ids, cos, sin, _domain="com.microsoft" ) else: key_BHSDh_emb = key @@ -289,6 +270,7 @@ def no_match(val: ir.Value, dims: Sequence[str]) -> bool: else: self._use_mask_broadcast = False + self._scale = sdpa_node.attributes.get_float("scale", None) # TODO: verify Reshapes: # eg.: verify bindings["B"] * bindings["H"] == bindings["B*H"]: # and bindings["H"] * bindings["Dh"] == bindings["H*Dh"]: @@ -307,20 +289,14 @@ def rewrite( position_ids, cos, sin, - q_scale=None, **_, ): - scale = _ir_utils.get_singleton_value(q_scale) num_heads = _ir_utils.get_dim(query_BSHDh, 2) if not isinstance(num_heads, int): return None # TODO: forward other attributes - if self._transpose_4d: - zero_1d = op.Constant(value_ints=[0]) - position_ids = op.Unsqueeze(position_ids, zero_1d) - if self._is_rotary: query_BSD_emb = op.RotaryEmbedding( query_BSD, position_ids, cos, sin, _domain="com.microsoft" @@ -360,9 +336,9 @@ def rewrite( past_key, past_value, num_heads=num_heads, - scale=scale, _domain="com.microsoft", _outputs=num_outputs, + scale=self._scale, ) @@ -370,17 +346,11 @@ def _make_rule_set(has_past_present: bool): parameter_combinations = [ { "double_transpose": double_transpose, - "transpose_4d": transpose_4d, - "pre_scale_q": pre_scale_q, "is_rotary": is_rotary, "has_past_present": has_past_present, "is_cross_attention": is_cross_attention, } for double_transpose in [False, True] - for transpose_4d in ( - [False, True] if double_transpose else [False] - ) # Only generate patterns when double_transpose is True - for pre_scale_q in [True, False] for is_rotary in [False, True] for is_cross_attention in ([False] if has_past_present else [False, True]) ] @@ -389,9 +359,8 @@ def _make_rule_set(has_past_present: bool): mha_rules = pattern.RewriteRuleSet( [ MultiHeadAttention.rule( - f"MHA_{'4D' if params['transpose_4d'] else '3D'}_Transpose" + f"MHA" f"{'_Twice' if params['double_transpose'] else ''}" - f"{'_PreScaleQ' if params['pre_scale_q'] else ''}" f"{'_Rotary' if params['is_rotary'] else ''}" f"{'_Past' if params['has_past_present'] else ''}" f"{'_CrossAttention' if params['is_cross_attention'] else ''}", diff --git a/onnxscript/rewriter/ort_fusions/mha_bias.py b/onnxscript/rewriter/ort_fusions/mha_bias.py index 775386484f..28b9646ddc 100644 --- a/onnxscript/rewriter/ort_fusions/mha_bias.py +++ b/onnxscript/rewriter/ort_fusions/mha_bias.py @@ -28,7 +28,6 @@ def pattern( past_key, past_value, num_heads, - # scale, ): query_BSD = pattern.OrValue( [op.Add(query_matmul, q_bias), query_matmul], @@ -56,7 +55,7 @@ def pattern( pattern.Var("past_key", can_match_none=True), pattern.Var("past_value", can_match_none=True), num_heads=num_heads, - # scale=scale, + scale=pattern.AttrVar("scale", can_match_none=True), _domain="com.microsoft", ) @@ -132,7 +131,7 @@ def rewrite( past_key, past_value, num_heads, - # scale, + scale, **_, ): if q_bias is None: @@ -158,7 +157,7 @@ def rewrite( past_key, past_value, num_heads=num_heads, - # scale=scale, + scale=scale, _domain="com.microsoft", ) diff --git a/onnxscript/rewriter/ort_fusions/mha_scale.py b/onnxscript/rewriter/ort_fusions/mha_scale.py new file mode 100644 index 0000000000..e02e6c49e3 --- /dev/null +++ b/onnxscript/rewriter/ort_fusions/mha_scale.py @@ -0,0 +1,68 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import math + +from onnxscript.rewriter import _fusion_utils, _ir_utils, pattern + +""" +Multi-Head Attention (MHA) pre-scaling fusion patterns. + +This module contains rewrite rules for fusing scale operations that occur before +Multi-Head Attention operations. The fusion optimizes patterns where a query tensor +is scaled before being passed to MHA by incorporating the scaling directly into +the MHA operation. + +Example pattern: + query -> Mul(scale) -> MultiHeadAttention -> output + +Gets rewritten to: + query -> MultiHeadAttention(with integrated scaling) -> output +""" + + +class FuseMHAScale(pattern.RewriteRuleClassBase): + def pattern(self, op, query, scale): + scaled_query = op.Mul(query, scale) + mha_output = op.MultiHeadAttention( + scaled_query, + _allow_other_inputs=True, + _domain="com.microsoft", + _outputs=["mha_output"], + ) + return mha_output + + def check(self, context, scale, **_): + scale_value = _ir_utils.get_singleton_value(scale) + if scale_value is None or not isinstance(scale_value, (int, float)): + return pattern.MatchResult().fail("Scale must be a constant numeric value.", scale) + self._scale = scale_value + return True + + def rewrite(self, op, query, mha_output, **_): + # Integrate the scale into the MHA operation + mha_node = mha_output.producer() + assert mha_node is not None + # Compute original scale factor for MHA: + attributes = mha_node.attributes + original_scale = attributes.get_float("scale", None) + if original_scale is None: + num_heads = attributes.get_int("num_heads", None) + if num_heads is None: + return None + head_size = query.shape[-1] // num_heads + original_scale = 1.0 / math.sqrt(head_size) + self._scale *= original_scale + inputs = list(mha_node.inputs) + inputs[0] = query + attributes = dict(attributes) + attributes["scale"] = self._scale + return op.MultiHeadAttention( + *inputs, **attributes, _domain="com.microsoft", _outputs=1 + ) + + +_mha_scale_rules = pattern.RewriteRuleSet([FuseMHAScale.rule()]) + +fuse_mha_scale = _fusion_utils.apply_fusion_rules(_mha_scale_rules) diff --git a/onnxscript/rewriter/pattern.py b/onnxscript/rewriter/pattern.py index 68c1654f5c..c4fd6e9161 100644 --- a/onnxscript/rewriter/pattern.py +++ b/onnxscript/rewriter/pattern.py @@ -7,6 +7,7 @@ from onnxscript.rewriter._matcher import PatternMatcher, SimplePatternMatcher from onnxscript.rewriter._pattern_ir import ( ANY_VALUE, + AttrVar, Constant, OpsetPatternBuilder, OrValue, @@ -26,6 +27,7 @@ __all__ = [ "ANY_VALUE", + "AttrVar", "OrValue", "Constant", "OpsetPatternBuilder",