From bfa1068cac0cc250898518543ad042c158dbfd4f Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Wed, 27 Aug 2025 17:05:56 -0700 Subject: [PATCH 1/6] Run lint Signed-off-by: Ganesan Ramalingam --- onnxscript/rewriter/onnx_fusions/_gqa.py | 114 ++++++++++++++++++ .../rewriter/onnx_fusions/_onnx_fusions.py | 3 +- 2 files changed, 116 insertions(+), 1 deletion(-) create mode 100644 onnxscript/rewriter/onnx_fusions/_gqa.py diff --git a/onnxscript/rewriter/onnx_fusions/_gqa.py b/onnxscript/rewriter/onnx_fusions/_gqa.py new file mode 100644 index 0000000000..b4dc3d203a --- /dev/null +++ b/onnxscript/rewriter/onnx_fusions/_gqa.py @@ -0,0 +1,114 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +from typing import Union + +import onnx_ir as ir + +import onnxscript.rewriter._fusion_utils as _fusion_utils +from onnxscript.rewriter import _basics, pattern + +Dim = Union[int, ir.SymbolicDim] + + +class OnnxGroupQueryAttention(pattern.RewriteRuleClassBase): + def __init__(self): + super().__init__("ONNXGQA", remove_nodes=False) + + def pattern( + self, + op, + query_BHSD, + key_BHkvSD, + value_BHkvSD, + past_key_BHkvSpD, + past_value_BHkvSpD, + mask, + ): + # Concatenate past_key cache and current key, expand across heads + # that share key/value. + + present_key_BHkvStD = op.Concat(past_key_BHkvSpD, key_BHkvSD, axis=-2) + present_key_BHkv1StD = op.Unsqueeze(present_key_BHkvStD, 2) + present_key_BHkvGStD = op.Expand(present_key_BHkv1StD, pattern.ANY_VALUE) + present_key_BHStD = op.Reshape( + present_key_BHkvGStD, pattern.ANY_VALUE, _outputs=["present_key_BHStD"] + ) + + # Concatenate past_value cache and current value, expand across heads + # that share key/value. + present_value_BHkvStD = op.Concat(past_value_BHkvSpD, value_BHkvSD, axis=-2) + present_value_BHkv1StD = op.Unsqueeze(present_value_BHkvStD, 2) + present_value_BHkvGStD = op.Expand(present_value_BHkv1StD, pattern.ANY_VALUE) + present_value_BHStD = op.Reshape( + present_value_BHkvGStD, pattern.ANY_VALUE, _outputs=["present_value_BHStD"] + ) + + attention_BHSDh = op.Attention( + query_BHSD, + present_key_BHStD, + present_value_BHStD, + mask, + _outputs=["attention_BHSDh"], + ) + + return attention_BHSDh + + def check( + self, + context: _basics.MatchContext, + query_BHSD, + key_BHkvSD, + value_BHkvSD, + past_key_BHkvSpD, + past_value_BHkvSpD, + present_key_BHStD, + present_value_BHStD, + **_, + ): + bindings: dict[str, Dim] = {} + # Check that inputs to new Attention node have expected shapes + _fusion_utils.check_shape(bindings, query_BHSD, ["B", "H", "S", "D"]) + _fusion_utils.check_shape(bindings, key_BHkvSD, ["B", "Hkv", "S", "D"]) + _fusion_utils.check_shape(bindings, value_BHkvSD, ["B", "Hkv", "S", "D"]) + _fusion_utils.check_shape(bindings, past_key_BHkvSpD, ["B", "Hkv", "P", "D"]) + _fusion_utils.check_shape(bindings, past_value_BHkvSpD, ["B", "Hkv", "P", "D"]) + # We need to check that the Expand/Reshape arguments are as expected. + # As a substitute, we check that the outputs of Expand=>Reshape have expected shapes. + # TODO (rama): May be better to check the actual Expand/Reshape arguments. + _fusion_utils.check_shape(bindings, present_key_BHStD, ["B", "H", "S+P", "D"]) + _fusion_utils.check_shape(bindings, present_value_BHStD, ["B", "H", "S+P", "D"]) + + return True + + def rewrite( + self, + op, + query_BHSD, + key_BHkvSD, + value_BHkvSD, + past_key_BHkvSpD, + past_value_BHkvSpD, + mask, + attention_BHSDh, + **_, + ): + original_attention_node = attention_BHSDh.producer() + original_attrs = original_attention_node.attributes + return op.Attention( + query_BHSD, + key_BHkvSD, + value_BHkvSD, + mask, + past_key_BHkvSpD, + past_value_BHkvSpD, + **original_attrs, + ) + + +_basic_gqa_rule = OnnxGroupQueryAttention.rule() + +gqa_rules = pattern.RewriteRuleSet([_basic_gqa_rule]) + +fuse_gqa = _fusion_utils.apply_fusion_rules(gqa_rules) diff --git a/onnxscript/rewriter/onnx_fusions/_onnx_fusions.py b/onnxscript/rewriter/onnx_fusions/_onnx_fusions.py index 0a45f3017c..c9297b3699 100644 --- a/onnxscript/rewriter/onnx_fusions/_onnx_fusions.py +++ b/onnxscript/rewriter/onnx_fusions/_onnx_fusions.py @@ -4,7 +4,7 @@ import onnx_ir as ir -from onnxscript.rewriter.onnx_fusions import _rms_normalization, _rotary_embedding +from onnxscript.rewriter.onnx_fusions import _gqa, _rms_normalization, _rotary_embedding def _get_onnx_opset_version(model: ir.Model) -> int | None: @@ -24,6 +24,7 @@ def _opset_23_fuse(model: ir.Model, *, debug: bool = False) -> dict[str, int]: counts: dict[str, int] = {} counts["RMSNormalization"] = _rms_normalization.fuse_rms_normalization(model, debug=debug) counts["RotaryEmbedding"] = _rotary_embedding.fuse_rotary_embedding(model, debug=debug) + counts["GQA"] = _gqa.fuse_gqa(model, debug=debug) return counts From 7cce2f5ffa78a66056489fdd7bc0a9df705d5fbc Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Thu, 28 Aug 2025 17:40:21 -0700 Subject: [PATCH 2/6] Add unit test Signed-off-by: Ganesan Ramalingam --- onnxscript/rewriter/onnx_fusions/_gqa.py | 3 +- onnxscript/rewriter/onnx_fusions/_gqa_test.py | 85 +++++++++++++++++++ onnxscript/rewriter/ort_fusions/_gqa_test.py | 0 onnxscript/rewriter/testing.py | 73 ++++++++++++---- 4 files changed, 142 insertions(+), 19 deletions(-) create mode 100644 onnxscript/rewriter/onnx_fusions/_gqa_test.py create mode 100644 onnxscript/rewriter/ort_fusions/_gqa_test.py diff --git a/onnxscript/rewriter/onnx_fusions/_gqa.py b/onnxscript/rewriter/onnx_fusions/_gqa.py index b4dc3d203a..8d6f156ed5 100644 --- a/onnxscript/rewriter/onnx_fusions/_gqa.py +++ b/onnxscript/rewriter/onnx_fusions/_gqa.py @@ -24,7 +24,6 @@ def pattern( value_BHkvSD, past_key_BHkvSpD, past_value_BHkvSpD, - mask, ): # Concatenate past_key cache and current key, expand across heads # that share key/value. @@ -49,7 +48,7 @@ def pattern( query_BHSD, present_key_BHStD, present_value_BHStD, - mask, + pattern.Var("mask", can_match_none=True), _outputs=["attention_BHSDh"], ) diff --git a/onnxscript/rewriter/onnx_fusions/_gqa_test.py b/onnxscript/rewriter/onnx_fusions/_gqa_test.py new file mode 100644 index 0000000000..6097a4de69 --- /dev/null +++ b/onnxscript/rewriter/onnx_fusions/_gqa_test.py @@ -0,0 +1,85 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import unittest + +import onnx_ir as ir + +import onnxscript +import onnxscript.optimizer +from onnxscript import FLOAT, script +import onnxscript.rewriter.testing +from onnxscript.rewriter.onnx_fusions._gqa import fuse_gqa + +op = onnxscript.values.Opset("", 23) + +H = [8] # Number of attention heads +Hkv = [4] # Number of key/value heads (H should be divisible by Hkv) +D = [64] # Head size +G = [2] # Number of groups + +@script(ir_version=10) +def _gqa_script( + query_BHSD: FLOAT[2, 8, 4, 64], # B=2, H=8, S=4, D=64 + key_BHkvSD: FLOAT[2, 4, 4, 64], # B=2, Hkv=4, S=4, D=64 + value_BHkvSD: FLOAT[2, 4, 4, 64], # B=2, Hkv=4, S=4, D=64 + past_key_BHkvPD: FLOAT[2, 4, 8, 64], # B=2, Hkv=4, P=8, D=64 + past_value_BHkvPD: FLOAT[2, 4, 8, 64], # B=2, Hkv=4, P=8, D=64 +) -> FLOAT[2, 8, 4, 64]: + """Basic GQA pattern that should be fused into an Attention op.""" + + # Concatenate past_key cache and current key + present_key_BHkvStD = op.Concat(past_key_BHkvPD, key_BHkvSD, axis=-2) # [B, Hkv, S+P, D] + + # Unsqueeze to add group dimension + present_key_BHkv1StD = op.Unsqueeze(present_key_BHkvStD, 2) # [B, Hkv, 1, S+P, D] + + # Calculate shapes dynamically + B = op.Shape(query_BHSD, start=0, end=1) # [B] + T = op.Shape(present_key_BHkvStD, start=2, end=3) # [S+P] + + # Create expand shape [B, Hkv, G, S+P, D] + expand_shape = op.Concat(B, Hkv, G, T, D, axis=0) + present_key_BHkvGStD = op.Expand(present_key_BHkv1StD, expand_shape) # [B, Hkv, G, S+P, D] + + # Create reshape shape [B, H, S+P, D] + reshape_shape = op.Concat(B, H, T, D, axis=0) + present_key_BHStD = op.Reshape(present_key_BHkvGStD, reshape_shape) # [B, H, S+P, D] + + # Same for value + present_value_BHkvStD = op.Concat(past_value_BHkvPD, value_BHkvSD, axis=-2) # [B, Hkv, S+P, D] + present_value_BHkv1StD = op.Unsqueeze(present_value_BHkvStD, 2) # [B, Hkv, 1, S+P, D] + present_value_BHkvGStD = op.Expand(present_value_BHkv1StD, expand_shape) # [B, Hkv, G, S+P, D] + present_value_BHStD = op.Reshape(present_value_BHkvGStD, reshape_shape) # [B, H, S+P, D] + + # Attention computation + attention_BHSDh = op.Attention( + query_BHSD, + present_key_BHStD, + present_value_BHStD, + ) + + return attention_BHSDh + + +class GQAFusionTest(unittest.TestCase): + def test_basic_gqa_fusion(self): + """Test basic GQA fusion pattern.""" + model_proto = _gqa_script.to_model_proto() + + # Apply GQA fusion + model = ir.serde.deserialize_model(model_proto) + onnxscript.optimizer.optimize(model) + count = fuse_gqa(model) + self.assertGreater(count, 0, "GQA fusion should have occurred") + + onnxscript.optimizer.remove_unused_nodes(model) + + rewritten_model_proto = ir.serde.serialize_model(model) + # onnxscript.rewriter.testing.assert_numerically_equal( + # model_proto, rewritten_model_proto, use_reference=True + # ) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/rewriter/ort_fusions/_gqa_test.py b/onnxscript/rewriter/ort_fusions/_gqa_test.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/onnxscript/rewriter/testing.py b/onnxscript/rewriter/testing.py index 591f9387c2..b4e8d5060a 100644 --- a/onnxscript/rewriter/testing.py +++ b/onnxscript/rewriter/testing.py @@ -6,6 +6,7 @@ import numpy as np import onnx +import onnx.reference import onnxruntime as ort from onnxscript import ir @@ -32,10 +33,11 @@ def generate_random_inputs(model: onnx.ModelProto) -> dict[str, Any]: def assert_numerically_equal( original_model_proto: onnx.ModelProto | ir.Model, rewritten_model_proto: onnx.ModelProto | ir.Model, - args: tuple[Any, ...] | dict[str, Any], + args: tuple[Any, ...] | dict[str, Any] | None = None, ort_optimization_level: ort.GraphOptimizationLevel = ort.GraphOptimizationLevel.ORT_ENABLE_ALL, rtol: float = 1, atol: float = 1e-3, + use_reference: bool = False, ): """Assert that the two models are numerically equal. @@ -46,6 +48,7 @@ def assert_numerically_equal( ort_optimization_level: Onnxruntime optimization level. rtol: Relative tolerance. atol: Absolute tolerance. + use_reference: If True, use ONNX reference implementation instead of ONNXRuntime. """ if isinstance(original_model_proto, ir.Model): @@ -53,7 +56,10 @@ def assert_numerically_equal( if isinstance(rewritten_model_proto, ir.Model): rewritten_model_proto = ir.serde.serialize_model(rewritten_model_proto) - if isinstance(args, dict): + if args is None: + original_proto_ort_inputs = generate_random_inputs(original_model_proto) + the_rewritten_proto_ort_inputs = original_proto_ort_inputs + elif isinstance(args, dict): original_proto_ort_inputs = args the_rewritten_proto_ort_inputs = args else: @@ -64,21 +70,39 @@ def assert_numerically_equal( k.name: v for k, v in zip(rewritten_model_proto.graph.input, args) } - original_proto_ort_inference_session = _ort_session_initializer( - original_model_proto.SerializeToString(), ort_optimization_level - ) - run_options = ort.RunOptions() - run_options.log_severity_level = 3 # 3: Error - original_outputs = original_proto_ort_inference_session.run( - None, original_proto_ort_inputs, run_options=run_options - ) - - the_rewritten_proto_ort_inference_session = _ort_session_initializer( - rewritten_model_proto.SerializeToString(), ort_optimization_level - ) - the_rewritten_outputs = the_rewritten_proto_ort_inference_session.run( - None, the_rewritten_proto_ort_inputs, run_options=run_options - ) + if use_reference: + # Use ONNX reference implementation + original_evaluator = _reference_session( + original_model_proto.SerializeToString(), ort_optimization_level + ) + original_outputs = original_evaluator.run(None, original_proto_ort_inputs) + + rewritten_evaluator = _reference_session( + rewritten_model_proto.SerializeToString(), ort_optimization_level + ) + the_rewritten_outputs = rewritten_evaluator.run(None, the_rewritten_proto_ort_inputs) + else: + # Use ONNXRuntime + original_proto_ort_inference_session = _ort_session_initializer( + original_model_proto.SerializeToString(), ort_optimization_level + ) + run_options = ort.RunOptions() + run_options.log_severity_level = 3 # 3: Error + original_outputs = original_proto_ort_inference_session.run( + None, original_proto_ort_inputs, run_options=run_options + ) + + the_rewritten_proto_ort_inference_session = _ort_session_initializer( + rewritten_model_proto.SerializeToString(), ort_optimization_level + ) + the_rewritten_outputs = the_rewritten_proto_ort_inference_session.run( + None, the_rewritten_proto_ort_inputs, run_options=run_options + ) + + for i, (orig, rewritten) in enumerate(zip(original_outputs, the_rewritten_outputs)): + print(f"==== Output {i} ====") + diff = np.abs(orig - rewritten) + print(diff) np.testing.assert_allclose( original_outputs, the_rewritten_outputs, rtol=rtol, atol=atol, equal_nan=True @@ -103,3 +127,18 @@ def _ort_session_initializer( provider for provider in possible_providers if provider in available_providers ] return ort.InferenceSession(model, providers=providers, sess_options=session_options) + + +def _reference_session( + model: str | bytes, ort_optimization_level: ort.GraphOptimizationLevel +) -> onnx.reference.ReferenceEvaluator: + """Initialize an ONNX reference evaluator with the specified model.""" + # Parse the model from bytes if needed + if isinstance(model, (str, bytes)): + model_proto = onnx.load_from_string(model) + else: + model_proto = model + + # Note: ort_optimization_level is ignored for reference implementation + # as it doesn't have equivalent optimization levels + return onnx.reference.ReferenceEvaluator(model_proto) From 033125109077d48653ade947dfad83d4fb7d7211 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Mon, 22 Sep 2025 10:30:40 -0700 Subject: [PATCH 3/6] Run lint Signed-off-by: Ganesan Ramalingam --- onnxscript/rewriter/rules/fusion/_gqa_test.py | 33 +++++++++++-------- 1 file changed, 19 insertions(+), 14 deletions(-) diff --git a/onnxscript/rewriter/rules/fusion/_gqa_test.py b/onnxscript/rewriter/rules/fusion/_gqa_test.py index 95ef34f380..c8386bee57 100644 --- a/onnxscript/rewriter/rules/fusion/_gqa_test.py +++ b/onnxscript/rewriter/rules/fusion/_gqa_test.py @@ -2,15 +2,15 @@ # Licensed under the MIT License. import unittest -import onnx -from packaging import version +import onnx import onnx_ir as ir +from packaging import version import onnxscript import onnxscript.optimizer -from onnxscript import FLOAT, script import onnxscript.rewriter.testing +from onnxscript import FLOAT, script from onnxscript.rewriter.rules.fusion._gqa import fuse_gqa op = onnxscript.values.Opset("", 23) @@ -20,6 +20,7 @@ D = [64] # Head size G = [2] # Number of groups + @script(ir_version=10) def _gqa_script( query_BHSD: FLOAT[2, 8, 4, 64], # B=2, H=8, S=4, D=64 @@ -29,38 +30,42 @@ def _gqa_script( past_value_BHkvPD: FLOAT[2, 4, 8, 64], # B=2, Hkv=4, P=8, D=64 ) -> FLOAT[2, 8, 4, 64]: """Basic GQA pattern that should be fused into an Attention op.""" - + # Concatenate past_key cache and current key present_key_BHkvStD = op.Concat(past_key_BHkvPD, key_BHkvSD, axis=-2) # [B, Hkv, S+P, D] - - # Unsqueeze to add group dimension + + # Unsqueeze to add group dimension present_key_BHkv1StD = op.Unsqueeze(present_key_BHkvStD, 2) # [B, Hkv, 1, S+P, D] # Calculate shapes dynamically B = op.Shape(query_BHSD, start=0, end=1) # [B] T = op.Shape(present_key_BHkvStD, start=2, end=3) # [S+P] - + # Create expand shape [B, Hkv, G, S+P, D] expand_shape = op.Concat(B, Hkv, G, T, D, axis=0) present_key_BHkvGStD = op.Expand(present_key_BHkv1StD, expand_shape) # [B, Hkv, G, S+P, D] - - # Create reshape shape [B, H, S+P, D] + + # Create reshape shape [B, H, S+P, D] reshape_shape = op.Concat(B, H, T, D, axis=0) present_key_BHStD = op.Reshape(present_key_BHkvGStD, reshape_shape) # [B, H, S+P, D] - + # Same for value - present_value_BHkvStD = op.Concat(past_value_BHkvPD, value_BHkvSD, axis=-2) # [B, Hkv, S+P, D] + present_value_BHkvStD = op.Concat( + past_value_BHkvPD, value_BHkvSD, axis=-2 + ) # [B, Hkv, S+P, D] present_value_BHkv1StD = op.Unsqueeze(present_value_BHkvStD, 2) # [B, Hkv, 1, S+P, D] - present_value_BHkvGStD = op.Expand(present_value_BHkv1StD, expand_shape) # [B, Hkv, G, S+P, D] + present_value_BHkvGStD = op.Expand( + present_value_BHkv1StD, expand_shape + ) # [B, Hkv, G, S+P, D] present_value_BHStD = op.Reshape(present_value_BHkvGStD, reshape_shape) # [B, H, S+P, D] - + # Attention computation attention_BHSDh = op.Attention( query_BHSD, present_key_BHStD, present_value_BHStD, ) - + return attention_BHSDh From 9280dd60618ee69ae2be10d6002bb6fcb824753f Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Mon, 22 Sep 2025 10:33:06 -0700 Subject: [PATCH 4/6] Remove empty file Signed-off-by: Ganesan Ramalingam --- onnxscript/rewriter/ort_fusions/_gqa_test.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 onnxscript/rewriter/ort_fusions/_gqa_test.py diff --git a/onnxscript/rewriter/ort_fusions/_gqa_test.py b/onnxscript/rewriter/ort_fusions/_gqa_test.py deleted file mode 100644 index e69de29bb2..0000000000 From ef13d354c0c1d25e43327f5e7e2c9d2e852f7665 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Mon, 22 Sep 2025 16:58:02 -0700 Subject: [PATCH 5/6] Fix version check Signed-off-by: Ganesan Ramalingam --- onnxscript/rewriter/rules/fusion/_gqa_test.py | 6 +++++- onnxscript/rewriter/testing.py | 7 ++++++- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/onnxscript/rewriter/rules/fusion/_gqa_test.py b/onnxscript/rewriter/rules/fusion/_gqa_test.py index c8386bee57..baf80c4b8c 100644 --- a/onnxscript/rewriter/rules/fusion/_gqa_test.py +++ b/onnxscript/rewriter/rules/fusion/_gqa_test.py @@ -81,7 +81,11 @@ def test_basic_gqa_fusion(self): self.assertGreater(count, 0, "GQA fusion should have occurred") # We can't yet test numerical equivalence because of a bug in the op spec/implementation. - if version.parse(onnx.__version__) >= version.parse("1.19.1"): + onnx_ver = version.parse(onnx.__version__) + if onnx_ver >= version.parse("1.19.1") and not ( + onnx_ver.is_prerelease or onnx_ver.is_devrelease + ): + # Only official releases >= 1.19.1 onnxscript.optimizer.remove_unused_nodes(model) rewritten_model_proto = ir.serde.serialize_model(model) onnxscript.rewriter.testing.assert_numerically_equal( diff --git a/onnxscript/rewriter/testing.py b/onnxscript/rewriter/testing.py index b4e8d5060a..911a3bf865 100644 --- a/onnxscript/rewriter/testing.py +++ b/onnxscript/rewriter/testing.py @@ -102,7 +102,12 @@ def assert_numerically_equal( for i, (orig, rewritten) in enumerate(zip(original_outputs, the_rewritten_outputs)): print(f"==== Output {i} ====") diff = np.abs(orig - rewritten) - print(diff) + for h in range(diff.shape[1]): + subarray = diff[:, h, :, :] # Select along H + if np.allclose(subarray, 0): + print(f"H={h}: all zeros") + else: + print(f"H={h}: not all zeros") np.testing.assert_allclose( original_outputs, the_rewritten_outputs, rtol=rtol, atol=atol, equal_nan=True From fab3cb888b987bf402fd2a6bdb51673d7e299b65 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Tue, 23 Sep 2025 08:35:11 -0700 Subject: [PATCH 6/6] Remove debugging code Signed-off-by: Ganesan Ramalingam --- onnxscript/rewriter/testing.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/onnxscript/rewriter/testing.py b/onnxscript/rewriter/testing.py index 911a3bf865..2a9d24ee01 100644 --- a/onnxscript/rewriter/testing.py +++ b/onnxscript/rewriter/testing.py @@ -99,16 +99,6 @@ def assert_numerically_equal( None, the_rewritten_proto_ort_inputs, run_options=run_options ) - for i, (orig, rewritten) in enumerate(zip(original_outputs, the_rewritten_outputs)): - print(f"==== Output {i} ====") - diff = np.abs(orig - rewritten) - for h in range(diff.shape[1]): - subarray = diff[:, h, :, :] # Select along H - if np.allclose(subarray, 0): - print(f"H={h}: all zeros") - else: - print(f"H={h}: not all zeros") - np.testing.assert_allclose( original_outputs, the_rewritten_outputs, rtol=rtol, atol=atol, equal_nan=True )