Skip to content

Conversation

justinchuby
Copy link
Collaborator

Stop decomposing gemm to matmul add by default because it is a more compact representation. Move the ort fusion rules so it keeps functioning for ort.

Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR moves the gemm_to_matmul_add rewrite rule out of the default optimization pipeline and into the ORT-specific fusion rules, ensuring it’s only applied when running optimize_for_ort.

  • Removed gemm_to_matmul_add from the default rewrites in onnxscript/rewriter/__init__.py.
  • Imported and applied gemm_to_matmul_add.rule in optimize_for_ort within onnxscript/rewriter/ort_fusions/_core.py.
  • Cleaned up the optimize tutorial docs, simplified the table format, and removed the default rule list.

Reviewed Changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 1 comment.

File Description
onnxscript/rewriter/ort_fusions/_core.py Imported gemm_to_matmul_add, applied its rule in optimize_for_ort.
onnxscript/rewriter/init.py Removed gemm_to_matmul_add from default rewrite rules.
docs/tutorial/optimizer/optimize.md Reformatted optimization table and removed the list of default patterns.
Comments suppressed due to low confidence (1)

docs/tutorial/optimizer/optimize.md:28

  • [nitpick] The docs no longer mention that gemm_to_matmul_add has been removed from the default optimization pipeline. Consider adding a note under the API section to explain that this rule now only runs in ORT-specific optimizations.
| Optimization | Description |

Copy link

codecov bot commented Jun 18, 2025

❌ 10 Tests Failed:

Tests completed Failed Passed Skipped
16446 10 16436 2363
View the top 3 failed test(s) by shortest run time
onnxscript.backend.onnx_export_test.TestOnnxBackEnd::test_export2python_produces_correct_onnx_script_model_0049_test_argmax_no_keepdims_example_select_last_index
Stack Traces | 0.004s run time
onnxscript\backend\onnx_export_test.py:137: in extract_functions
    mod = importlib.import_module(import_name)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
C:\hostedtoolcache\windows\Python\3.11.9\x64\Lib\importlib\__init__.py:126: in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E   ModuleNotFoundError: No module named 'tests.onnx_backend_test_code.test_argmax_no_keepdims_example_select_last_index'

The above exception was the direct cause of the following exception:
.nox\test_torch_nightly\Lib\site-packages\parameterized\parameterized.py:620: in standalone_func
    return func(*(a + p.args), **p.kwargs, **kw)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
onnxscript\backend\onnx_export_test.py:271: in test_export2python_produces_correct_onnx_script_model
    functions = extract_functions(backend_test.name, code, self.test_folder)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
onnxscript\backend\onnx_export_test.py:139: in extract_functions
    raise AssertionError(
E   AssertionError: Unable to import 'tests.onnx_backend_test_code.test_argmax_no_keepdims_example_select_last_index' (e=No module named 'tests.onnx_backend_test_code.test_argmax_no_keepdims_example_select_last_index') (file: 'C:\\a\\onnxscript\\onnxscript\\tests\\onnx_backend_test_code\\test_argmax_no_keepdims_example_select_last_index.py', absolute path: 'C:\\a\\onnxscript\\onnxscript\\tests\\onnx_backend_test_code\\test_argmax_no_keepdims_example_select_last_index.py', current folder: C:\a\onnxscript\onnxscript
E   ---- CONTENT --
E   import numpy
E   from onnx import TensorProto
E   from onnx.helper import make_tensor
E   from onnxscript import script, external_tensor
E   from onnxscript.values import Opset
E   from onnxscript.onnx_types import FLOAT, INT64
E   from onnxscript.onnx_opset import opset13
E   
E   @script()
E   def bck_test_argmax_no_keepdims_example_select_last_index(data: FLOAT[2,2]) -> (INT64[2]):
E       result = opset13.ArgMax(data, axis=1, keepdims=0, select_last_index=1)
E       return result
onnxscript.backend.onnx_export_test.TestOnnxBackEnd::test_export2python_produces_correct_onnx_script_model_1043_test_regex_full_match_email_domain
Stack Traces | 0.004s run time
onnxscript\backend\onnx_export_test.py:137: in extract_functions
    mod = importlib.import_module(import_name)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
C:\hostedtoolcache\windows\Python\3.11.9\x64\Lib\importlib\__init__.py:126: in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E   ModuleNotFoundError: No module named 'tests.onnx_backend_test_code.test_regex_full_match_email_domain'

The above exception was the direct cause of the following exception:
.nox\test_torch_nightly\Lib\site-packages\parameterized\parameterized.py:620: in standalone_func
    return func(*(a + p.args), **p.kwargs, **kw)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
onnxscript\backend\onnx_export_test.py:271: in test_export2python_produces_correct_onnx_script_model
    functions = extract_functions(backend_test.name, code, self.test_folder)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
onnxscript\backend\onnx_export_test.py:139: in extract_functions
    raise AssertionError(
E   AssertionError: Unable to import 'tests.onnx_backend_test_code.test_regex_full_match_email_domain' (e=No module named 'tests.onnx_backend_test_code.test_regex_full_match_email_domain') (file: 'C:\\a\\onnxscript\\onnxscript\\tests\\onnx_backend_test_code\\test_regex_full_match_email_domain.py', absolute path: 'C:\\a\\onnxscript\\onnxscript\\tests\\onnx_backend_test_code\\test_regex_full_match_email_domain.py', current folder: C:\a\onnxscript\onnxscript
E   ---- CONTENT --
E   import numpy
E   from onnx import TensorProto
E   from onnx.helper import make_tensor
E   from onnxscript import script, external_tensor
E   from onnxscript.values import Opset
E   from onnxscript.onnx_types import BOOL, STRING
E   from onnxscript.onnx_opset import opset20
E   
E   @script()
E   def bck_test_regex_full_match_email_domain(X: STRING[2,2]) -> (BOOL[2,2]):
E       Y = opset20.RegexFullMatch(X, pattern='(\\W|^)[\\w.\\-]{0,25}@(yahoo|gmail)\\.com(\\W|$)')
E       return Y
onnxscript.backend.onnx_export_test.TestOnnxBackEnd::test_export2python_produces_correct_onnx_script_model_1160_test_scatternd_add
Stack Traces | 0.004s run time
onnxscript\backend\onnx_export_test.py:137: in extract_functions
    mod = importlib.import_module(import_name)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
C:\hostedtoolcache\windows\Python\3.11.9\x64\Lib\importlib\__init__.py:126: in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E   ModuleNotFoundError: No module named 'tests.onnx_backend_test_code.test_scatternd_add'

The above exception was the direct cause of the following exception:
.nox\test_torch_nightly\Lib\site-packages\parameterized\parameterized.py:620: in standalone_func
    return func(*(a + p.args), **p.kwargs, **kw)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
onnxscript\backend\onnx_export_test.py:271: in test_export2python_produces_correct_onnx_script_model
    functions = extract_functions(backend_test.name, code, self.test_folder)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
onnxscript\backend\onnx_export_test.py:139: in extract_functions
    raise AssertionError(
E   AssertionError: Unable to import 'tests.onnx_backend_test_code.test_scatternd_add' (e=No module named 'tests.onnx_backend_test_code.test_scatternd_add') (file: 'C:\\a\\onnxscript\\onnxscript\\tests\\onnx_backend_test_code\\test_scatternd_add.py', absolute path: 'C:\\a\\onnxscript\\onnxscript\\tests\\onnx_backend_test_code\\test_scatternd_add.py', current folder: C:\a\onnxscript\onnxscript
E   ---- CONTENT --
E   import numpy
E   from onnx import TensorProto
E   from onnx.helper import make_tensor
E   from onnxscript import script, external_tensor
E   from onnxscript.values import Opset
E   from onnxscript.onnx_types import FLOAT, INT64
E   from onnxscript.onnx_opset import opset18
E   
E   @script()
E   def bck_test_scatternd_add(data: FLOAT[4,4,4], indices: INT64[2,1], updates: FLOAT[2,4,4]) -> (FLOAT[4,4,4]):
E       y = opset18.ScatterND(data, indices, updates, reduction='add')
E       return y

To view more test analytics, go to the Test Analytics Dashboard
📋 Got 3 mins? Take this short survey to help us improve Test Analytics.

@justinchuby justinchuby merged commit e71c889 into main Jun 19, 2025
25 of 30 checks passed
@justinchuby justinchuby deleted the justinchu/gemm-rule branch June 19, 2025 04:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Development

Successfully merging this pull request may close these issues.

2 participants