Skip to content

Conversation

gramalingam
Copy link
Collaborator

Add GQA fusion to ONNX fusions.

TODO:

  • Test cases. (Fusion seems to work on Gemma3, but more to be done.)

Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>
Copy link

codecov bot commented Aug 28, 2025

Codecov Report

❌ Patch coverage is 69.23077% with 32 lines in your changes missing coverage. Please review.
✅ Project coverage is 70.09%. Comparing base (27c7f09) to head (d839bc1).
⚠️ Report is 1 commits behind head on main.
✅ All tests successful. No failed tests found.

Files with missing lines Patch % Lines
onnxscript/rewriter/rules/fusion/_gqa_test.py 54.54% 18 Missing and 2 partials ⚠️
onnxscript/rewriter/testing.py 42.85% 10 Missing and 2 partials ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2524      +/-   ##
==========================================
- Coverage   70.09%   70.09%   -0.01%     
==========================================
  Files         220      222       +2     
  Lines       26086    26182      +96     
  Branches     2575     2580       +5     
==========================================
+ Hits        18285    18351      +66     
- Misses       6904     6931      +27     
- Partials      897      900       +3     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>
@gramalingam
Copy link
Collaborator Author

Putting this on a hold for now:

  • This can't be tested with released versions of onnxruntime, since don't support opset 23 (Attention) yet
  • Testing with onnx's reference implementation reveals a potential issue with the definition in the ONNX op (the ordering convention for groups seems different from what's used in practice).

Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>
Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>
Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>
@justinchuby
Copy link
Collaborator

FAILED onnxscript/rewriter/rules/fusion/_gqa_test.py::GQAFusionTest::test_basic_gqa_fusion - AssertionError: 
Not equal to tolerance rtol=1, atol=0.001

Mismatched elements: 1977 / 4096 (48.3%)
Max absolute difference among violations: 2.456497
Max relative difference among violations: 956.41864
 ACTUAL: array([[[[[-3.091334e-01,  5.472850e-01, -1.052010e-01, ...,
            1.951529e-01,  7.563670e-01, -1.576316e-01],
          [-1.378765e-02, -2.707998e-01, -3.788515e-01, ...,...
 DESIRED: array([[[[[-0.309133,  0.547285, -0.105201, ...,  0.195153,  0.756367,
           -0.157632],
          [-0.013788, -0.2708  , -0.378851, ..., -0.312114,  0.220327,...

Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>
Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>
@gramalingam gramalingam enabled auto-merge (squash) September 23, 2025 15:38
@gramalingam gramalingam merged commit f54cf47 into main Sep 23, 2025
32 checks passed
@gramalingam gramalingam deleted the rama/onnxgqa2 branch September 23, 2025 17:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Development

Successfully merging this pull request may close these issues.

4 participants