Skip to content

Conversation

gramalingam
Copy link
Collaborator

@gramalingam gramalingam commented Aug 8, 2025

  • Cleanup MHA fusion rules by eliminating some redundant rule-variations
  • Fix handling of scale attribute in MHA fusion
  • Introduce can_match_none attribute-variables in patterns
  • Introduce fusion rule for MHA and scale

Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>
Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>
Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>
Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>
Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>
Copy link

codecov bot commented Aug 8, 2025

❌ 9 Tests Failed:

Tests completed Failed Passed Skipped
14395 9 14386 1901
View the top 3 failed test(s) by shortest run time
onnxscript.backend.onnx_export_test.TestOnnxBackEnd::test_export2python_produces_correct_onnx_script_model_0301_test_cosh_example
Stack Traces | 0.003s run time
onnxscript\backend\onnx_export_test.py:132: in extract_functions
    mod = importlib.import_module(import_name)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
C:\hostedtoolcache\windows\Python\3.11.9\x64\Lib\importlib\__init__.py:126: in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E   ModuleNotFoundError: No module named 'tests.onnx_backend_test_code.test_cosh_example'

The above exception was the direct cause of the following exception:
.nox\test_ort_nightly\Lib\site-packages\parameterized\parameterized.py:620: in standalone_func
    return func(*(a + p.args), **p.kwargs, **kw)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
onnxscript\backend\onnx_export_test.py:266: in test_export2python_produces_correct_onnx_script_model
    functions = extract_functions(backend_test.name, code, self.test_folder)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
onnxscript\backend\onnx_export_test.py:134: in extract_functions
    raise AssertionError(
E   AssertionError: Unable to import 'tests.onnx_backend_test_code.test_cosh_example' (e=No module named 'tests.onnx_backend_test_code.test_cosh_example') (file: 'D:\\a\\onnxscript\\onnxscript\\tests\\onnx_backend_test_code\\test_cosh_example.py', absolute path: 'D:\\a\\onnxscript\\onnxscript\\tests\\onnx_backend_test_code\\test_cosh_example.py', current folder: D:\a\onnxscript\onnxscript
E   ---- CONTENT --
E   import numpy as np
E   from onnx import TensorProto
E   from onnx.helper import make_tensor
E   from onnxscript import script, external_tensor
E   from onnxscript.values import Opset
E   from onnxscript.onnx_types import FLOAT
E   from onnxscript.onnx_opset import opset22
E   
E   @script()
E   def bck_test_cosh_example(x: FLOAT[3]) -> (FLOAT[3]):
E       y = opset22.Cosh(x)
E       return y
onnxscript.backend.onnx_export_test.TestOnnxBackEnd::test_export2python_produces_correct_onnx_script_model_1102_test_shape_example
Stack Traces | 0.003s run time
onnxscript\backend\onnx_export_test.py:132: in extract_functions
    mod = importlib.import_module(import_name)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
C:\hostedtoolcache\windows\Python\3.11.9\x64\Lib\importlib\__init__.py:126: in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E   ModuleNotFoundError: No module named 'tests.onnx_backend_test_code.test_shape_example'

The above exception was the direct cause of the following exception:
.nox\test_ort_nightly\Lib\site-packages\parameterized\parameterized.py:620: in standalone_func
    return func(*(a + p.args), **p.kwargs, **kw)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
onnxscript\backend\onnx_export_test.py:266: in test_export2python_produces_correct_onnx_script_model
    functions = extract_functions(backend_test.name, code, self.test_folder)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
onnxscript\backend\onnx_export_test.py:134: in extract_functions
    raise AssertionError(
E   AssertionError: Unable to import 'tests.onnx_backend_test_code.test_shape_example' (e=No module named 'tests.onnx_backend_test_code.test_shape_example') (file: 'D:\\a\\onnxscript\\onnxscript\\tests\\onnx_backend_test_code\\test_shape_example.py', absolute path: 'D:\\a\\onnxscript\\onnxscript\\tests\\onnx_backend_test_code\\test_shape_example.py', current folder: D:\a\onnxscript\onnxscript
E   ---- CONTENT --
E   import numpy as np
E   from onnx import TensorProto
E   from onnx.helper import make_tensor
E   from onnxscript import script, external_tensor
E   from onnxscript.values import Opset
E   from onnxscript.onnx_types import FLOAT, INT64
E   from onnxscript.onnx_opset import opset21
E   
E   @script()
E   def bck_test_shape_example(x: FLOAT[2,3]) -> (INT64[2]):
E       y = opset21.Shape(x)
E       return y
onnxscript.backend.onnx_export_test.TestOnnxBackEnd::test_export2python_produces_correct_onnx_script_model_0052_test_argmin_default_axis_example
Stack Traces | 0.004s run time
onnxscript\backend\onnx_export_test.py:132: in extract_functions
    mod = importlib.import_module(import_name)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
C:\hostedtoolcache\windows\Python\3.11.9\x64\Lib\importlib\__init__.py:126: in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E   ModuleNotFoundError: No module named 'tests.onnx_backend_test_code.test_argmin_default_axis_example'

The above exception was the direct cause of the following exception:
.nox\test_ort_nightly\Lib\site-packages\parameterized\parameterized.py:620: in standalone_func
    return func(*(a + p.args), **p.kwargs, **kw)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
onnxscript\backend\onnx_export_test.py:266: in test_export2python_produces_correct_onnx_script_model
    functions = extract_functions(backend_test.name, code, self.test_folder)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
onnxscript\backend\onnx_export_test.py:134: in extract_functions
    raise AssertionError(
E   AssertionError: Unable to import 'tests.onnx_backend_test_code.test_argmin_default_axis_example' (e=No module named 'tests.onnx_backend_test_code.test_argmin_default_axis_example') (file: 'D:\\a\\onnxscript\\onnxscript\\tests\\onnx_backend_test_code\\test_argmin_default_axis_example.py', absolute path: 'D:\\a\\onnxscript\\onnxscript\\tests\\onnx_backend_test_code\\test_argmin_default_axis_example.py', current folder: D:\a\onnxscript\onnxscript
E   ---- CONTENT --
E   import numpy as np
E   from onnx import TensorProto
E   from onnx.helper import make_tensor
E   from onnxscript import script, external_tensor
E   from onnxscript.values import Opset
E   from onnxscript.onnx_types import FLOAT, INT64
E   from onnxscript.onnx_opset import opset13
E   
E   @script()
E   def bck_test_argmin_default_axis_example(data: FLOAT[2,2]) -> (INT64[1,2]):
E       result = opset13.ArgMin(data, keepdims=1)
E       return result

To view more test analytics, go to the Test Analytics Dashboard
📋 Got 3 mins? Take this short survey to help us improve Test Analytics.

Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>
@gramalingam gramalingam changed the title [DRAFT] MHA fusion cleanup MHA fusion cleanup Aug 8, 2025
@gramalingam gramalingam marked this pull request as ready for review August 8, 2025 19:14
@justinchuby justinchuby requested a review from Copilot August 8, 2025 19:25
Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR cleans up MHA (Multi-Head Attention) fusion rules by simplifying pattern variations and improving scale attribute handling. The main improvements include removing redundant rule variations, introducing proper scale attribute handling through the new AttrVar pattern, and adding common subexpression elimination.

Key changes:

  • Eliminates redundant MHA fusion rule variations by removing transpose_4d and pre_scale_q parameters
  • Introduces AttrVar pattern with can_match_none support for optional attributes like scale
  • Fixes scale attribute handling across MHA fusion patterns

Reviewed Changes

Copilot reviewed 6 out of 6 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
onnxscript/rewriter/pattern.py Exports new AttrVar pattern for use in fusion rules
onnxscript/rewriter/_pattern_ir.py Implements AttrVar pattern with can_match_none support for optional attributes
onnxscript/rewriter/ort_fusions/mha.py Removes redundant rule variations and simplifies MHA pattern generation
onnxscript/rewriter/ort_fusions/mha_bias.py Updates to use AttrVar for scale attribute handling
onnxscript/rewriter/ort_fusions/attention.py Extracts scale from matched nodes and uses named outputs
onnxscript/rewriter/ort_fusions/_core.py Adds common subexpression elimination pass

Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>
@gramalingam gramalingam enabled auto-merge (squash) August 11, 2025 02:50
@gramalingam gramalingam merged commit c219dce into main Aug 11, 2025
23 of 30 checks passed
@gramalingam gramalingam deleted the rama/mha-cleanup branch August 11, 2025 15:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Development

Successfully merging this pull request may close these issues.

2 participants