-
Notifications
You must be signed in to change notification settings - Fork 85
Add GQA fusion to ONNX fusions #2524
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
bfa1068
Run lint
gramalingam 7cce2f5
Add unit test
gramalingam 1ae7270
Merge branch 'main' into rama/onnxgqa2
justinchuby 738c869
Merge with main
gramalingam 0331251
Run lint
gramalingam 9280dd6
Remove empty file
gramalingam ef13d35
Fix version check
gramalingam fab3cb8
Remove debugging code
gramalingam d839bc1
Merge branch 'main' into rama/onnxgqa2
gramalingam File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
from __future__ import annotations | ||
|
||
from typing import Union | ||
|
||
import onnx_ir as ir | ||
|
||
import onnxscript.rewriter._fusion_utils as _fusion_utils | ||
from onnxscript.rewriter import _basics, pattern | ||
|
||
Dim = Union[int, ir.SymbolicDim] | ||
|
||
|
||
class OnnxGroupQueryAttention(pattern.RewriteRuleClassBase): | ||
def __init__(self): | ||
super().__init__("ONNXGQA", remove_nodes=False) | ||
|
||
def pattern( | ||
self, | ||
op, | ||
query_BHSD, | ||
key_BHkvSD, | ||
value_BHkvSD, | ||
past_key_BHkvSpD, | ||
past_value_BHkvSpD, | ||
): | ||
# Concatenate past_key cache and current key, expand across heads | ||
# that share key/value. | ||
|
||
present_key_BHkvStD = op.Concat(past_key_BHkvSpD, key_BHkvSD, axis=-2) | ||
present_key_BHkv1StD = op.Unsqueeze(present_key_BHkvStD, 2) | ||
present_key_BHkvGStD = op.Expand(present_key_BHkv1StD, pattern.ANY_VALUE) | ||
present_key_BHStD = op.Reshape( | ||
present_key_BHkvGStD, pattern.ANY_VALUE, _outputs=["present_key_BHStD"] | ||
) | ||
|
||
# Concatenate past_value cache and current value, expand across heads | ||
# that share key/value. | ||
present_value_BHkvStD = op.Concat(past_value_BHkvSpD, value_BHkvSD, axis=-2) | ||
present_value_BHkv1StD = op.Unsqueeze(present_value_BHkvStD, 2) | ||
present_value_BHkvGStD = op.Expand(present_value_BHkv1StD, pattern.ANY_VALUE) | ||
present_value_BHStD = op.Reshape( | ||
present_value_BHkvGStD, pattern.ANY_VALUE, _outputs=["present_value_BHStD"] | ||
) | ||
titaiwangms marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
attention_BHSDh = op.Attention( | ||
query_BHSD, | ||
present_key_BHStD, | ||
present_value_BHStD, | ||
pattern.Var("mask", can_match_none=True), | ||
_outputs=["attention_BHSDh"], | ||
) | ||
|
||
return attention_BHSDh | ||
|
||
def check( | ||
self, | ||
context: _basics.MatchContext, | ||
query_BHSD, | ||
key_BHkvSD, | ||
value_BHkvSD, | ||
past_key_BHkvSpD, | ||
past_value_BHkvSpD, | ||
present_key_BHStD, | ||
present_value_BHStD, | ||
**_, | ||
): | ||
bindings: dict[str, Dim] = {} | ||
# Check that inputs to new Attention node have expected shapes | ||
_fusion_utils.check_shape(bindings, query_BHSD, ["B", "H", "S", "D"]) | ||
gramalingam marked this conversation as resolved.
Show resolved
Hide resolved
|
||
_fusion_utils.check_shape(bindings, key_BHkvSD, ["B", "Hkv", "S", "D"]) | ||
_fusion_utils.check_shape(bindings, value_BHkvSD, ["B", "Hkv", "S", "D"]) | ||
_fusion_utils.check_shape(bindings, past_key_BHkvSpD, ["B", "Hkv", "P", "D"]) | ||
_fusion_utils.check_shape(bindings, past_value_BHkvSpD, ["B", "Hkv", "P", "D"]) | ||
# We need to check that the Expand/Reshape arguments are as expected. | ||
# As a substitute, we check that the outputs of Expand=>Reshape have expected shapes. | ||
# TODO (rama): May be better to check the actual Expand/Reshape arguments. | ||
_fusion_utils.check_shape(bindings, present_key_BHStD, ["B", "H", "S+P", "D"]) | ||
_fusion_utils.check_shape(bindings, present_value_BHStD, ["B", "H", "S+P", "D"]) | ||
|
||
return True | ||
|
||
def rewrite( | ||
self, | ||
op, | ||
query_BHSD, | ||
key_BHkvSD, | ||
value_BHkvSD, | ||
past_key_BHkvSpD, | ||
past_value_BHkvSpD, | ||
mask, | ||
attention_BHSDh, | ||
**_, | ||
): | ||
original_attention_node = attention_BHSDh.producer() | ||
original_attrs = original_attention_node.attributes | ||
return op.Attention( | ||
query_BHSD, | ||
key_BHkvSD, | ||
value_BHkvSD, | ||
mask, | ||
past_key_BHkvSpD, | ||
past_value_BHkvSpD, | ||
**original_attrs, | ||
) | ||
|
||
|
||
_basic_gqa_rule = OnnxGroupQueryAttention.rule() | ||
|
||
gqa_rules = pattern.RewriteRuleSet([_basic_gqa_rule]) | ||
|
||
fuse_gqa = _fusion_utils.apply_fusion_rules(gqa_rules) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
|
||
# Licensed under the MIT License. | ||
|
||
import unittest | ||
|
||
|
||
import onnx | ||
import onnx_ir as ir | ||
from packaging import version | ||
|
||
import onnxscript | ||
import onnxscript.optimizer | ||
import onnxscript.rewriter.testing | ||
from onnxscript import FLOAT, script | ||
from onnxscript.rewriter.rules.fusion._gqa import fuse_gqa | ||
|
||
op = onnxscript.values.Opset("", 23) | ||
|
||
H = [8] # Number of attention heads | ||
Hkv = [4] # Number of key/value heads (H should be divisible by Hkv) | ||
D = [64] # Head size | ||
G = [2] # Number of groups | ||
|
||
|
||
@script(ir_version=10) | ||
def _gqa_script( | ||
query_BHSD: FLOAT[2, 8, 4, 64], # B=2, H=8, S=4, D=64 | ||
key_BHkvSD: FLOAT[2, 4, 4, 64], # B=2, Hkv=4, S=4, D=64 | ||
value_BHkvSD: FLOAT[2, 4, 4, 64], # B=2, Hkv=4, S=4, D=64 | ||
past_key_BHkvPD: FLOAT[2, 4, 8, 64], # B=2, Hkv=4, P=8, D=64 | ||
past_value_BHkvPD: FLOAT[2, 4, 8, 64], # B=2, Hkv=4, P=8, D=64 | ||
) -> FLOAT[2, 8, 4, 64]: | ||
"""Basic GQA pattern that should be fused into an Attention op.""" | ||
|
||
# Concatenate past_key cache and current key | ||
present_key_BHkvStD = op.Concat(past_key_BHkvPD, key_BHkvSD, axis=-2) # [B, Hkv, S+P, D] | ||
|
||
# Unsqueeze to add group dimension | ||
present_key_BHkv1StD = op.Unsqueeze(present_key_BHkvStD, 2) # [B, Hkv, 1, S+P, D] | ||
|
||
# Calculate shapes dynamically | ||
B = op.Shape(query_BHSD, start=0, end=1) # [B] | ||
T = op.Shape(present_key_BHkvStD, start=2, end=3) # [S+P] | ||
|
||
# Create expand shape [B, Hkv, G, S+P, D] | ||
expand_shape = op.Concat(B, Hkv, G, T, D, axis=0) | ||
present_key_BHkvGStD = op.Expand(present_key_BHkv1StD, expand_shape) # [B, Hkv, G, S+P, D] | ||
|
||
# Create reshape shape [B, H, S+P, D] | ||
reshape_shape = op.Concat(B, H, T, D, axis=0) | ||
present_key_BHStD = op.Reshape(present_key_BHkvGStD, reshape_shape) # [B, H, S+P, D] | ||
|
||
# Same for value | ||
present_value_BHkvStD = op.Concat( | ||
past_value_BHkvPD, value_BHkvSD, axis=-2 | ||
) # [B, Hkv, S+P, D] | ||
present_value_BHkv1StD = op.Unsqueeze(present_value_BHkvStD, 2) # [B, Hkv, 1, S+P, D] | ||
present_value_BHkvGStD = op.Expand( | ||
present_value_BHkv1StD, expand_shape | ||
) # [B, Hkv, G, S+P, D] | ||
present_value_BHStD = op.Reshape(present_value_BHkvGStD, reshape_shape) # [B, H, S+P, D] | ||
|
||
# Attention computation | ||
attention_BHSDh = op.Attention( | ||
query_BHSD, | ||
present_key_BHStD, | ||
present_value_BHStD, | ||
) | ||
|
||
return attention_BHSDh | ||
|
||
|
||
class GQAFusionTest(unittest.TestCase): | ||
def test_basic_gqa_fusion(self): | ||
"""Test basic GQA fusion pattern.""" | ||
model_proto = _gqa_script.to_model_proto() | ||
|
||
# Apply GQA fusion | ||
model = ir.serde.deserialize_model(model_proto) | ||
onnxscript.optimizer.optimize(model) | ||
count = fuse_gqa(model) | ||
self.assertGreater(count, 0, "GQA fusion should have occurred") | ||
|
||
# We can't yet test numerical equivalence because of a bug in the op spec/implementation. | ||
onnx_ver = version.parse(onnx.__version__) | ||
if onnx_ver >= version.parse("1.19.1") and not ( | ||
onnx_ver.is_prerelease or onnx_ver.is_devrelease | ||
): | ||
# Only official releases >= 1.19.1 | ||
onnxscript.optimizer.remove_unused_nodes(model) | ||
rewritten_model_proto = ir.serde.serialize_model(model) | ||
onnxscript.rewriter.testing.assert_numerically_equal( | ||
model_proto, rewritten_model_proto, use_reference=True | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.