Skip to content

Commit f54cf47

Browse files
Add GQA fusion to ONNX fusions (#2524)
Add GQA fusion to ONNX fusions. TODO: * Test cases. (Fusion seems to work on Gemma3, but more to be done.) --------- Signed-off-by: Ganesan Ramalingam <grama@microsoft.com> Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
1 parent 27c7f09 commit f54cf47

File tree

4 files changed

+263
-18
lines changed

4 files changed

+263
-18
lines changed

onnxscript/rewriter/onnx_fusions/_onnx_fusions.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import onnx_ir as ir
66

7-
from onnxscript.rewriter.rules.fusion import _rms_normalization, _rotary_embedding
7+
from onnxscript.rewriter.rules.fusion import _gqa, _rms_normalization, _rotary_embedding
88

99

1010
def _get_onnx_opset_version(model: ir.Model) -> int | None:
@@ -24,6 +24,7 @@ def _opset_23_fuse(model: ir.Model, *, debug: bool = False) -> dict[str, int]:
2424
counts: dict[str, int] = {}
2525
counts["RMSNormalization"] = _rms_normalization.fuse_rms_normalization(model, debug=debug)
2626
counts["RotaryEmbedding"] = _rotary_embedding.fuse_rotary_embedding(model, debug=debug)
27+
counts["GQA"] = _gqa.fuse_gqa(model, debug=debug)
2728
return counts
2829

2930

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
from __future__ import annotations
4+
5+
from typing import Union
6+
7+
import onnx_ir as ir
8+
9+
import onnxscript.rewriter._fusion_utils as _fusion_utils
10+
from onnxscript.rewriter import _basics, pattern
11+
12+
Dim = Union[int, ir.SymbolicDim]
13+
14+
15+
class OnnxGroupQueryAttention(pattern.RewriteRuleClassBase):
16+
def __init__(self):
17+
super().__init__("ONNXGQA", remove_nodes=False)
18+
19+
def pattern(
20+
self,
21+
op,
22+
query_BHSD,
23+
key_BHkvSD,
24+
value_BHkvSD,
25+
past_key_BHkvSpD,
26+
past_value_BHkvSpD,
27+
):
28+
# Concatenate past_key cache and current key, expand across heads
29+
# that share key/value.
30+
31+
present_key_BHkvStD = op.Concat(past_key_BHkvSpD, key_BHkvSD, axis=-2)
32+
present_key_BHkv1StD = op.Unsqueeze(present_key_BHkvStD, 2)
33+
present_key_BHkvGStD = op.Expand(present_key_BHkv1StD, pattern.ANY_VALUE)
34+
present_key_BHStD = op.Reshape(
35+
present_key_BHkvGStD, pattern.ANY_VALUE, _outputs=["present_key_BHStD"]
36+
)
37+
38+
# Concatenate past_value cache and current value, expand across heads
39+
# that share key/value.
40+
present_value_BHkvStD = op.Concat(past_value_BHkvSpD, value_BHkvSD, axis=-2)
41+
present_value_BHkv1StD = op.Unsqueeze(present_value_BHkvStD, 2)
42+
present_value_BHkvGStD = op.Expand(present_value_BHkv1StD, pattern.ANY_VALUE)
43+
present_value_BHStD = op.Reshape(
44+
present_value_BHkvGStD, pattern.ANY_VALUE, _outputs=["present_value_BHStD"]
45+
)
46+
47+
attention_BHSDh = op.Attention(
48+
query_BHSD,
49+
present_key_BHStD,
50+
present_value_BHStD,
51+
pattern.Var("mask", can_match_none=True),
52+
_outputs=["attention_BHSDh"],
53+
)
54+
55+
return attention_BHSDh
56+
57+
def check(
58+
self,
59+
context: _basics.MatchContext,
60+
query_BHSD,
61+
key_BHkvSD,
62+
value_BHkvSD,
63+
past_key_BHkvSpD,
64+
past_value_BHkvSpD,
65+
present_key_BHStD,
66+
present_value_BHStD,
67+
**_,
68+
):
69+
bindings: dict[str, Dim] = {}
70+
# Check that inputs to new Attention node have expected shapes
71+
_fusion_utils.check_shape(bindings, query_BHSD, ["B", "H", "S", "D"])
72+
_fusion_utils.check_shape(bindings, key_BHkvSD, ["B", "Hkv", "S", "D"])
73+
_fusion_utils.check_shape(bindings, value_BHkvSD, ["B", "Hkv", "S", "D"])
74+
_fusion_utils.check_shape(bindings, past_key_BHkvSpD, ["B", "Hkv", "P", "D"])
75+
_fusion_utils.check_shape(bindings, past_value_BHkvSpD, ["B", "Hkv", "P", "D"])
76+
# We need to check that the Expand/Reshape arguments are as expected.
77+
# As a substitute, we check that the outputs of Expand=>Reshape have expected shapes.
78+
# TODO (rama): May be better to check the actual Expand/Reshape arguments.
79+
_fusion_utils.check_shape(bindings, present_key_BHStD, ["B", "H", "S+P", "D"])
80+
_fusion_utils.check_shape(bindings, present_value_BHStD, ["B", "H", "S+P", "D"])
81+
82+
return True
83+
84+
def rewrite(
85+
self,
86+
op,
87+
query_BHSD,
88+
key_BHkvSD,
89+
value_BHkvSD,
90+
past_key_BHkvSpD,
91+
past_value_BHkvSpD,
92+
mask,
93+
attention_BHSDh,
94+
**_,
95+
):
96+
original_attention_node = attention_BHSDh.producer()
97+
original_attrs = original_attention_node.attributes
98+
return op.Attention(
99+
query_BHSD,
100+
key_BHkvSD,
101+
value_BHkvSD,
102+
mask,
103+
past_key_BHkvSpD,
104+
past_value_BHkvSpD,
105+
**original_attrs,
106+
)
107+
108+
109+
_basic_gqa_rule = OnnxGroupQueryAttention.rule()
110+
111+
gqa_rules = pattern.RewriteRuleSet([_basic_gqa_rule])
112+
113+
fuse_gqa = _fusion_utils.apply_fusion_rules(gqa_rules)
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
4+
import unittest
5+
6+
import onnx
7+
import onnx_ir as ir
8+
from packaging import version
9+
10+
import onnxscript
11+
import onnxscript.optimizer
12+
import onnxscript.rewriter.testing
13+
from onnxscript import FLOAT, script
14+
from onnxscript.rewriter.rules.fusion._gqa import fuse_gqa
15+
16+
op = onnxscript.values.Opset("", 23)
17+
18+
H = [8] # Number of attention heads
19+
Hkv = [4] # Number of key/value heads (H should be divisible by Hkv)
20+
D = [64] # Head size
21+
G = [2] # Number of groups
22+
23+
24+
@script(ir_version=10)
25+
def _gqa_script(
26+
query_BHSD: FLOAT[2, 8, 4, 64], # B=2, H=8, S=4, D=64
27+
key_BHkvSD: FLOAT[2, 4, 4, 64], # B=2, Hkv=4, S=4, D=64
28+
value_BHkvSD: FLOAT[2, 4, 4, 64], # B=2, Hkv=4, S=4, D=64
29+
past_key_BHkvPD: FLOAT[2, 4, 8, 64], # B=2, Hkv=4, P=8, D=64
30+
past_value_BHkvPD: FLOAT[2, 4, 8, 64], # B=2, Hkv=4, P=8, D=64
31+
) -> FLOAT[2, 8, 4, 64]:
32+
"""Basic GQA pattern that should be fused into an Attention op."""
33+
34+
# Concatenate past_key cache and current key
35+
present_key_BHkvStD = op.Concat(past_key_BHkvPD, key_BHkvSD, axis=-2) # [B, Hkv, S+P, D]
36+
37+
# Unsqueeze to add group dimension
38+
present_key_BHkv1StD = op.Unsqueeze(present_key_BHkvStD, 2) # [B, Hkv, 1, S+P, D]
39+
40+
# Calculate shapes dynamically
41+
B = op.Shape(query_BHSD, start=0, end=1) # [B]
42+
T = op.Shape(present_key_BHkvStD, start=2, end=3) # [S+P]
43+
44+
# Create expand shape [B, Hkv, G, S+P, D]
45+
expand_shape = op.Concat(B, Hkv, G, T, D, axis=0)
46+
present_key_BHkvGStD = op.Expand(present_key_BHkv1StD, expand_shape) # [B, Hkv, G, S+P, D]
47+
48+
# Create reshape shape [B, H, S+P, D]
49+
reshape_shape = op.Concat(B, H, T, D, axis=0)
50+
present_key_BHStD = op.Reshape(present_key_BHkvGStD, reshape_shape) # [B, H, S+P, D]
51+
52+
# Same for value
53+
present_value_BHkvStD = op.Concat(
54+
past_value_BHkvPD, value_BHkvSD, axis=-2
55+
) # [B, Hkv, S+P, D]
56+
present_value_BHkv1StD = op.Unsqueeze(present_value_BHkvStD, 2) # [B, Hkv, 1, S+P, D]
57+
present_value_BHkvGStD = op.Expand(
58+
present_value_BHkv1StD, expand_shape
59+
) # [B, Hkv, G, S+P, D]
60+
present_value_BHStD = op.Reshape(present_value_BHkvGStD, reshape_shape) # [B, H, S+P, D]
61+
62+
# Attention computation
63+
attention_BHSDh = op.Attention(
64+
query_BHSD,
65+
present_key_BHStD,
66+
present_value_BHStD,
67+
)
68+
69+
return attention_BHSDh
70+
71+
72+
class GQAFusionTest(unittest.TestCase):
73+
def test_basic_gqa_fusion(self):
74+
"""Test basic GQA fusion pattern."""
75+
model_proto = _gqa_script.to_model_proto()
76+
77+
# Apply GQA fusion
78+
model = ir.serde.deserialize_model(model_proto)
79+
onnxscript.optimizer.optimize(model)
80+
count = fuse_gqa(model)
81+
self.assertGreater(count, 0, "GQA fusion should have occurred")
82+
83+
# We can't yet test numerical equivalence because of a bug in the op spec/implementation.
84+
onnx_ver = version.parse(onnx.__version__)
85+
if onnx_ver >= version.parse("1.19.1") and not (
86+
onnx_ver.is_prerelease or onnx_ver.is_devrelease
87+
):
88+
# Only official releases >= 1.19.1
89+
onnxscript.optimizer.remove_unused_nodes(model)
90+
rewritten_model_proto = ir.serde.serialize_model(model)
91+
onnxscript.rewriter.testing.assert_numerically_equal(
92+
model_proto, rewritten_model_proto, use_reference=True
93+
)
94+
95+
96+
if __name__ == "__main__":
97+
unittest.main()

onnxscript/rewriter/testing.py

Lines changed: 51 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import numpy as np
88
import onnx
9+
import onnx.reference
910
import onnxruntime as ort
1011

1112
from onnxscript import ir
@@ -32,10 +33,11 @@ def generate_random_inputs(model: onnx.ModelProto) -> dict[str, Any]:
3233
def assert_numerically_equal(
3334
original_model_proto: onnx.ModelProto | ir.Model,
3435
rewritten_model_proto: onnx.ModelProto | ir.Model,
35-
args: tuple[Any, ...] | dict[str, Any],
36+
args: tuple[Any, ...] | dict[str, Any] | None = None,
3637
ort_optimization_level: ort.GraphOptimizationLevel = ort.GraphOptimizationLevel.ORT_ENABLE_ALL,
3738
rtol: float = 1,
3839
atol: float = 1e-3,
40+
use_reference: bool = False,
3941
):
4042
"""Assert that the two models are numerically equal.
4143
@@ -46,14 +48,18 @@ def assert_numerically_equal(
4648
ort_optimization_level: Onnxruntime optimization level.
4749
rtol: Relative tolerance.
4850
atol: Absolute tolerance.
51+
use_reference: If True, use ONNX reference implementation instead of ONNXRuntime.
4952
"""
5053

5154
if isinstance(original_model_proto, ir.Model):
5255
original_model_proto = ir.serde.serialize_model(original_model_proto)
5356
if isinstance(rewritten_model_proto, ir.Model):
5457
rewritten_model_proto = ir.serde.serialize_model(rewritten_model_proto)
5558

56-
if isinstance(args, dict):
59+
if args is None:
60+
original_proto_ort_inputs = generate_random_inputs(original_model_proto)
61+
the_rewritten_proto_ort_inputs = original_proto_ort_inputs
62+
elif isinstance(args, dict):
5763
original_proto_ort_inputs = args
5864
the_rewritten_proto_ort_inputs = args
5965
else:
@@ -64,21 +70,34 @@ def assert_numerically_equal(
6470
k.name: v for k, v in zip(rewritten_model_proto.graph.input, args)
6571
}
6672

67-
original_proto_ort_inference_session = _ort_session_initializer(
68-
original_model_proto.SerializeToString(), ort_optimization_level
69-
)
70-
run_options = ort.RunOptions()
71-
run_options.log_severity_level = 3 # 3: Error
72-
original_outputs = original_proto_ort_inference_session.run(
73-
None, original_proto_ort_inputs, run_options=run_options
74-
)
75-
76-
the_rewritten_proto_ort_inference_session = _ort_session_initializer(
77-
rewritten_model_proto.SerializeToString(), ort_optimization_level
78-
)
79-
the_rewritten_outputs = the_rewritten_proto_ort_inference_session.run(
80-
None, the_rewritten_proto_ort_inputs, run_options=run_options
81-
)
73+
if use_reference:
74+
# Use ONNX reference implementation
75+
original_evaluator = _reference_session(
76+
original_model_proto.SerializeToString(), ort_optimization_level
77+
)
78+
original_outputs = original_evaluator.run(None, original_proto_ort_inputs)
79+
80+
rewritten_evaluator = _reference_session(
81+
rewritten_model_proto.SerializeToString(), ort_optimization_level
82+
)
83+
the_rewritten_outputs = rewritten_evaluator.run(None, the_rewritten_proto_ort_inputs)
84+
else:
85+
# Use ONNXRuntime
86+
original_proto_ort_inference_session = _ort_session_initializer(
87+
original_model_proto.SerializeToString(), ort_optimization_level
88+
)
89+
run_options = ort.RunOptions()
90+
run_options.log_severity_level = 3 # 3: Error
91+
original_outputs = original_proto_ort_inference_session.run(
92+
None, original_proto_ort_inputs, run_options=run_options
93+
)
94+
95+
the_rewritten_proto_ort_inference_session = _ort_session_initializer(
96+
rewritten_model_proto.SerializeToString(), ort_optimization_level
97+
)
98+
the_rewritten_outputs = the_rewritten_proto_ort_inference_session.run(
99+
None, the_rewritten_proto_ort_inputs, run_options=run_options
100+
)
82101

83102
np.testing.assert_allclose(
84103
original_outputs, the_rewritten_outputs, rtol=rtol, atol=atol, equal_nan=True
@@ -103,3 +122,18 @@ def _ort_session_initializer(
103122
provider for provider in possible_providers if provider in available_providers
104123
]
105124
return ort.InferenceSession(model, providers=providers, sess_options=session_options)
125+
126+
127+
def _reference_session(
128+
model: str | bytes, ort_optimization_level: ort.GraphOptimizationLevel
129+
) -> onnx.reference.ReferenceEvaluator:
130+
"""Initialize an ONNX reference evaluator with the specified model."""
131+
# Parse the model from bytes if needed
132+
if isinstance(model, (str, bytes)):
133+
model_proto = onnx.load_from_string(model)
134+
else:
135+
model_proto = model
136+
137+
# Note: ort_optimization_level is ignored for reference implementation
138+
# as it doesn't have equivalent optimization levels
139+
return onnx.reference.ReferenceEvaluator(model_proto)

0 commit comments

Comments
 (0)