Skip to content

Conversation

gramalingam
Copy link
Collaborator

Fix a seeming bug in handling of cross-attention in MHA (to be verified): In MHA fusion, we start with an input graph where attention is applied to 4D query/key/value, and it is transformed into a MHA op on 3D query/key/value.

In the case of cross-attention (with no rotary-embedding): the fusion seems to convert just query to 3D, and seems to leave key and value as 4D, which seems wrong.

This PR adds the necessary 4D=>3D conversion for key/value before MHA.

Note: This is a quick fix for the relevant case (that shows up). Other combinations may be worth checking out separately.

Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>
Copy link

codecov bot commented May 23, 2025

❌ 3 Tests Failed:

Tests completed Failed Passed Skipped
15414 3 15411 2360
View the top 3 failed test(s) by shortest run time
onnxscript.backend.onnx_export_test.TestOnnxBackEnd::test_export2python_produces_correct_onnx_script_model_0129_test_bitwise_xor_i32_2d
Stack Traces | 0.004s run time
onnxscript\backend\onnx_export_test.py:137: in extract_functions
    mod = importlib.import_module(import_name)
C:\hostedtoolcache\windows\Python\3.12.10\x64\Lib\importlib\__init__.py:90: in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
E   ModuleNotFoundError: No module named 'tests.onnx_backend_test_code.test_bitwise_xor_i32_2d'

The above exception was the direct cause of the following exception:
.nox\test\Lib\site-packages\parameterized\parameterized.py:620: in standalone_func
    return func(*(a + p.args), **p.kwargs, **kw)
onnxscript\backend\onnx_export_test.py:271: in test_export2python_produces_correct_onnx_script_model
    functions = extract_functions(backend_test.name, code, self.test_folder)
onnxscript\backend\onnx_export_test.py:139: in extract_functions
    raise AssertionError(
E   AssertionError: Unable to import 'tests.onnx_backend_test_code.test_bitwise_xor_i32_2d' (e=No module named 'tests.onnx_backend_test_code.test_bitwise_xor_i32_2d') (file: 'D:\\a\\onnxscript\\onnxscript\\tests\\onnx_backend_test_code\\test_bitwise_xor_i32_2d.py', absolute path: 'D:\\a\\onnxscript\\onnxscript\\tests\\onnx_backend_test_code\\test_bitwise_xor_i32_2d.py', current folder: D:\a\onnxscript\onnxscript
E   ---- CONTENT --
E   import numpy
E   from onnx import TensorProto
E   from onnx.helper import make_tensor
E   from onnxscript import script, external_tensor
E   from onnxscript.values import Opset
E   from onnxscript.onnx_types import INT32
E   from onnxscript.onnx_opset import opset18
E   
E   @script()
E   def bck_test_bitwise_xor_i32_2d(x: INT32[3,4], y: INT32[3,4]) -> (INT32[3,4]):
E       bitwisexor = opset18.BitwiseXor(x, y)
E       return bitwisexor
onnxscript.backend.onnx_export_test.TestOnnxBackEnd::test_export2python_produces_correct_onnx_script_model_0534_test_layer_normalization_4d_axis3
Stack Traces | 0.004s run time
onnxscript\backend\onnx_export_test.py:137: in extract_functions
    mod = importlib.import_module(import_name)
C:\hostedtoolcache\windows\Python\3.12.10\x64\Lib\importlib\__init__.py:90: in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
E   ModuleNotFoundError: No module named 'tests.onnx_backend_test_code.test_layer_normalization_4d_axis3'

The above exception was the direct cause of the following exception:
.nox\test\Lib\site-packages\parameterized\parameterized.py:620: in standalone_func
    return func(*(a + p.args), **p.kwargs, **kw)
onnxscript\backend\onnx_export_test.py:271: in test_export2python_produces_correct_onnx_script_model
    functions = extract_functions(backend_test.name, code, self.test_folder)
onnxscript\backend\onnx_export_test.py:139: in extract_functions
    raise AssertionError(
E   AssertionError: Unable to import 'tests.onnx_backend_test_code.test_layer_normalization_4d_axis3' (e=No module named 'tests.onnx_backend_test_code.test_layer_normalization_4d_axis3') (file: 'D:\\a\\onnxscript\\onnxscript\\tests\\onnx_backend_test_code\\test_layer_normalization_4d_axis3.py', absolute path: 'D:\\a\\onnxscript\\onnxscript\\tests\\onnx_backend_test_code\\test_layer_normalization_4d_axis3.py', current folder: D:\a\onnxscript\onnxscript
E   ---- CONTENT --
E   import numpy
E   from onnx import TensorProto
E   from onnx.helper import make_tensor
E   from onnxscript import script, external_tensor
E   from onnxscript.values import Opset
E   from onnxscript.onnx_types import FLOAT
E   from onnxscript.onnx_opset import opset17
E   
E   @script()
E   def bck_test_layer_normalization_4d_axis3(X: FLOAT[2,3,4,5], W: FLOAT[5], B: FLOAT[5]) -> (FLOAT[2,3,4,5], FLOAT[2,3,4,1], FLOAT[2,3,4,1]):
E       Y, Mean, InvStdDev = opset17.LayerNormalization(X, W, B, axis=3)
E       return Y, Mean, InvStdDev
onnxscript.backend.onnx_export_test.TestOnnxBackEnd::test_export2python_produces_correct_onnx_script_model_0938_test_reshape_allowzero_reordered
Stack Traces | 0.004s run time
onnxscript\backend\onnx_export_test.py:137: in extract_functions
    mod = importlib.import_module(import_name)
C:\hostedtoolcache\windows\Python\3.12.10\x64\Lib\importlib\__init__.py:90: in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
E   ModuleNotFoundError: No module named 'tests.onnx_backend_test_code.test_reshape_allowzero_reordered'

The above exception was the direct cause of the following exception:
.nox\test\Lib\site-packages\parameterized\parameterized.py:620: in standalone_func
    return func(*(a + p.args), **p.kwargs, **kw)
onnxscript\backend\onnx_export_test.py:271: in test_export2python_produces_correct_onnx_script_model
    functions = extract_functions(backend_test.name, code, self.test_folder)
onnxscript\backend\onnx_export_test.py:139: in extract_functions
    raise AssertionError(
E   AssertionError: Unable to import 'tests.onnx_backend_test_code.test_reshape_allowzero_reordered' (e=No module named 'tests.onnx_backend_test_code.test_reshape_allowzero_reordered') (file: 'D:\\a\\onnxscript\\onnxscript\\tests\\onnx_backend_test_code\\test_reshape_allowzero_reordered.py', absolute path: 'D:\\a\\onnxscript\\onnxscript\\tests\\onnx_backend_test_code\\test_reshape_allowzero_reordered.py', current folder: D:\a\onnxscript\onnxscript
E   ---- CONTENT --
E   import numpy
E   from onnx import TensorProto
E   from onnx.helper import make_tensor
E   from onnxscript import script, external_tensor
E   from onnxscript.values import Opset
E   from onnxscript.onnx_types import FLOAT, INT64
E   from onnxscript.onnx_opset import opset21
E   
E   @script()
E   def bck_test_reshape_allowzero_reordered(data: FLOAT[0,3,4], shape: INT64[3]) -> (FLOAT[3,4,0]):
E       reshaped = opset21.Reshape(data, shape, allowzero=1)
E       return reshaped

To view more test analytics, go to the Test Analytics Dashboard
📋 Got 3 mins? Take this short survey to help us improve Test Analytics.

gramalingam and others added 2 commits May 27, 2025 12:07
Co-authored-by: Justin Chu <justinchuby@users.noreply.github.com>
@gramalingam gramalingam enabled auto-merge (squash) May 28, 2025 00:43
@gramalingam gramalingam merged commit 5a8b9e6 into main May 28, 2025
22 of 29 checks passed
@gramalingam gramalingam deleted the rama/mha_cross branch May 28, 2025 01:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Development

Successfully merging this pull request may close these issues.

2 participants