Skip to content

Conversation

gramalingam
Copy link
Collaborator

The generation of the causal mask's shape (produced by the translation of scalar_dot_product_attention) interferes with the subsequent fusion optimizations (because it makes use of the shape of the intermediate matmul value).

This PR introduces a very specific fusion/rewrite to eliminate this redundant computation of the "sequence length" dimension.

Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>
Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>
Copy link

codecov bot commented May 22, 2025

❌ 8 Tests Failed:

Tests completed Failed Passed Skipped
15997 8 15989 1883
View the top 3 failed test(s) by shortest run time
onnxscript.backend.onnx_export_test.TestOnnxBackEnd::test_export2python_produces_correct_onnx_script_model_0271_test_concat_3d_axis_2
Stack Traces | 0.004s run time
onnxscript\backend\onnx_export_test.py:137: in extract_functions
    mod = importlib.import_module(import_name)
C:\hostedtoolcache\windows\Python\3.10.11\x64\lib\importlib\__init__.py:126: in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
E   ModuleNotFoundError: No module named 'tests.onnx_backend_test_code.test_concat_3d_axis_2'

The above exception was the direct cause of the following exception:
.nox\test\lib\site-packages\parameterized\parameterized.py:620: in standalone_func
    return func(*(a + p.args), **p.kwargs, **kw)
onnxscript\backend\onnx_export_test.py:271: in test_export2python_produces_correct_onnx_script_model
    functions = extract_functions(backend_test.name, code, self.test_folder)
onnxscript\backend\onnx_export_test.py:139: in extract_functions
    raise AssertionError(
E   AssertionError: Unable to import 'tests.onnx_backend_test_code.test_concat_3d_axis_2' (e=No module named 'tests.onnx_backend_test_code.test_concat_3d_axis_2') (file: 'D:\\a\\onnxscript\\onnxscript\\tests\\onnx_backend_test_code\\test_concat_3d_axis_2.py', absolute path: 'D:\\a\\onnxscript\\onnxscript\\tests\\onnx_backend_test_code\\test_concat_3d_axis_2.py', current folder: D:\a\onnxscript\onnxscript
E   ---- CONTENT --
E   import numpy
E   from onnx import TensorProto
E   from onnx.helper import make_tensor
E   from onnxscript import script, external_tensor
E   from onnxscript.values import Opset
E   from onnxscript.onnx_types import FLOAT
E   from onnxscript.onnx_opset import opset13
E   
E   @script()
E   def bck_test_concat_3d_axis_2(value0: FLOAT[2,2,2], value1: FLOAT[2,2,2]) -> (FLOAT[2,2,4]):
E       output = opset13.Concat(value0, value1, axis=2)
E       return output
onnxscript.backend.onnx_export_test.TestOnnxBackEnd::test_export2python_produces_correct_onnx_script_model_1274_test_unsqueeze_two_axes
Stack Traces | 0.004s run time
onnxscript\backend\onnx_export_test.py:137: in extract_functions
    mod = importlib.import_module(import_name)
C:\hostedtoolcache\windows\Python\3.10.11\x64\lib\importlib\__init__.py:126: in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
E   ModuleNotFoundError: No module named 'tests.onnx_backend_test_code.test_unsqueeze_two_axes'

The above exception was the direct cause of the following exception:
.nox\test\lib\site-packages\parameterized\parameterized.py:620: in standalone_func
    return func(*(a + p.args), **p.kwargs, **kw)
onnxscript\backend\onnx_export_test.py:271: in test_export2python_produces_correct_onnx_script_model
    functions = extract_functions(backend_test.name, code, self.test_folder)
onnxscript\backend\onnx_export_test.py:139: in extract_functions
    raise AssertionError(
E   AssertionError: Unable to import 'tests.onnx_backend_test_code.test_unsqueeze_two_axes' (e=No module named 'tests.onnx_backend_test_code.test_unsqueeze_two_axes') (file: 'D:\\a\\onnxscript\\onnxscript\\tests\\onnx_backend_test_code\\test_unsqueeze_two_axes.py', absolute path: 'D:\\a\\onnxscript\\onnxscript\\tests\\onnx_backend_test_code\\test_unsqueeze_two_axes.py', current folder: D:\a\onnxscript\onnxscript
E   ---- CONTENT --
E   import numpy
E   from onnx import TensorProto
E   from onnx.helper import make_tensor
E   from onnxscript import script, external_tensor
E   from onnxscript.values import Opset
E   from onnxscript.onnx_types import FLOAT, INT64
E   from onnxscript.onnx_opset import opset21
E   
E   @script()
E   def bck_test_unsqueeze_two_axes(x: FLOAT[3,4,5], axes: INT64[2]) -> (FLOAT[3,1,4,5,1]):
E       y = opset21.Unsqueeze(x, axes)
E       return y
onnxscript.backend.onnx_export_test.TestOnnxBackEnd::test_export2python_produces_correct_onnx_script_model_1139_test_softmax_axis_2_expanded_ver18
Stack Traces | 0.005s run time
onnxscript\backend\onnx_export_test.py:137: in extract_functions
    mod = importlib.import_module(import_name)
C:\hostedtoolcache\windows\Python\3.10.11\x64\lib\importlib\__init__.py:126: in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
E   ModuleNotFoundError: No module named 'tests.onnx_backend_test_code.test_softmax_axis_2_expanded_ver18'

The above exception was the direct cause of the following exception:
.nox\test\lib\site-packages\parameterized\parameterized.py:620: in standalone_func
    return func(*(a + p.args), **p.kwargs, **kw)
onnxscript\backend\onnx_export_test.py:271: in test_export2python_produces_correct_onnx_script_model
    functions = extract_functions(backend_test.name, code, self.test_folder)
onnxscript\backend\onnx_export_test.py:139: in extract_functions
    raise AssertionError(
E   AssertionError: Unable to import 'tests.onnx_backend_test_code.test_softmax_axis_2_expanded_ver18' (e=No module named 'tests.onnx_backend_test_code.test_softmax_axis_2_expanded_ver18') (file: 'D:\\a\\onnxscript\\onnxscript\\tests\\onnx_backend_test_code\\test_softmax_axis_2_expanded_ver18.py', absolute path: 'D:\\a\\onnxscript\\onnxscript\\tests\\onnx_backend_test_code\\test_softmax_axis_2_expanded_ver18.py', current folder: D:\a\onnxscript\onnxscript
E   ---- CONTENT --
E   import numpy
E   from onnx import TensorProto
E   from onnx.helper import make_tensor
E   from onnxscript import script, external_tensor
E   from onnxscript.values import Opset
E   from onnxscript.onnx_types import FLOAT
E   from onnxscript.onnx_opset import opset18
E   
E   @script()
E   def bck_test_softmax_axis_2_expanded_ver18(x: FLOAT[3,4,5]) -> (FLOAT[3,4,5]):
E       Softmax_test_softmax_axis_2_expanded_function_axes = opset18.Constant(value=make_tensor("value", 7, dims=[1], vals=[2]))
E       Softmax_test_softmax_axis_2_expanded_function_X_ReduceMax = opset18.ReduceMax(x, Softmax_test_softmax_axis_2_expanded_function_axes, keepdims=1)
E       Softmax_test_softmax_axis_2_expanded_function_X_Sub = opset18.Sub(x, Softmax_test_softmax_axis_2_expanded_function_X_ReduceMax)
E       Softmax_test_softmax_axis_2_expanded_function_X_Exp = opset18.Exp(Softmax_test_softmax_axis_2_expanded_function_X_Sub)
E       Softmax_test_softmax_axis_2_expanded_function_X_ReduceSum = opset18.ReduceSum(Softmax_test_softmax_axis_2_expanded_function_X_Exp, Softmax_test_softmax_axis_2_expanded_function_axes, keepdims=1)
E       y = opset18.Div(Softmax_test_softmax_axis_2_expanded_function_X_Exp, Softmax_test_softmax_axis_2_expanded_function_X_ReduceSum)
E       return y

To view more test analytics, go to the Test Analytics Dashboard
📋 Got 3 mins? Take this short survey to help us improve Test Analytics.

justinchuby pushed a commit that referenced this pull request May 22, 2025
The MHA-Bias rules can be simplified using pattern-disjunction.

(This _may_ help with Whisper ... that was my original motivation, but
not sure, after I fixed another issue in PR #2325, which may be the
primary issue ). But the cleanup is useful anyway, and it makes fusion
more efficient.)

Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>
Copy link
Collaborator

@justinchuby justinchuby left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So constant folding doesn’t get this properly?

bmehta001 pushed a commit to bmehta001/onnxscript that referenced this pull request May 22, 2025
The MHA-Bias rules can be simplified using pattern-disjunction.

(This _may_ help with Whisper ... that was my original motivation, but
not sure, after I fixed another issue in PR microsoft#2325, which may be the
primary issue ). But the cleanup is useful anyway, and it makes fusion
more efficient.)

Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>
@gramalingam
Copy link
Collaborator Author

So constant folding doesn’t get this properly?

Good question (though it is the "optimizer", though we call it constant-folding, since it goes beyond pure constant folding). I think not. It does the necessary analysis for shape-inference. May be worth checking. I thought we might need a more generic optimization pass, but perhaps not.

@gramalingam gramalingam enabled auto-merge (squash) May 22, 2025 18:45
@gramalingam gramalingam merged commit b34cd9c into main May 22, 2025
25 of 29 checks passed
@gramalingam gramalingam deleted the rama/causal-mask-shape-opt branch May 22, 2025 19:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Development

Successfully merging this pull request may close these issues.

2 participants