Skip to content

Conversation

@justinchuby
Copy link
Collaborator

@justinchuby justinchuby commented Oct 9, 2025

This PR implements #2580 by combining all overloads in torchlib and remove the ability to register new ones. It is done in a BC compatible fashion and should work with released versions of PyTorch.

From now on all logic for a single aten OpOverload should be implemented by a single torchlib function to ensure 1-to-1 mapping.

Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
@justinchuby justinchuby added the module: torchlib Related to the torch/aten function lib in development label Oct 9, 2025
@justinchuby justinchuby added this to the 0.5.4 milestone Oct 9, 2025
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR consolidates overloaded functions in the torch_lib by removing the concept of private functions and preventing new overloads from being created for the same operation name. The changes focus on simplifying the registration system and merging boolean indexing operations with their regular counterparts.

  • Removes the private parameter and functionality from the registration system
  • Consolidates boolean and regular index operations into unified functions
  • Adds validation to prevent duplicate overload registrations

Reviewed Changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
onnxscript/function_libs/torch_lib/registration.py Removes private function support and adds overload duplication prevention
onnxscript/function_libs/torch_lib/ops/core.py Consolidates index and index_put functions to handle both boolean and integer indexing
tests/function_libs/torch_lib/ops_test_data.py Removes separate boolean index test entries and duplicates
onnxscript/function_libs/torch_lib/ops/nn.py Removes outdated comment about private functions
onnxscript/function_libs/tools/torch_lib/deduce_type_constraints_test.py Updates to exclude removed private functions from iteration

Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
@justinchuby justinchuby changed the title Consolidate all overloads and prevent new ones from being created [torchlib] Consolidate all overloads and prevent new ones from being created Oct 9, 2025
Copy link
Contributor

@titaiwangms titaiwangms left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CI is failing

@codecov
Copy link

codecov bot commented Oct 10, 2025

❌ 5 Tests Failed:

Tests completed Failed Passed Skipped
11120 5 11115 903
View the top 3 failed test(s) by shortest run time
tests.function_libs.torch_lib.ops_test.TestOutputConsistencyFullGraphCPU::test_output_match_opinfo__ops_aten_index_Tensor_bool_cpu_int64
Stack Traces | 1.36s run time
.../function_libs/torch_lib/ops_test_common.py:584: in _capture_graph_and_evaluate_torch_script_evaluator
    onnx.checker.check_model(model_proto, full_check=True)
..../test/lib/python3.11.../site-packages/onnx/checker.py:179: in check_model
    C.check_model(
E   onnx.onnx_cpp2py_export.shape_inference.InferenceError: [ShapeInferenceError] Inference error(s): (op_type:Transpose, node name: node_Transpose_4): [ShapeInferenceError] Inferred shape and existing shape differ in dimension 0: (5) vs (103)

The above exception was the direct cause of the following exception:
.../function_libs/torch_lib/ops_test.py:217: in run_test_output_match
    function_output = function_executor(test_name, reference_torch_outputs)(
.../function_libs/torch_lib/ops_test_common.py:586: in _capture_graph_and_evaluate_torch_script_evaluator
    raise AssertionError(f"ONNX model is invalid. Model:\n{onnx_model}") from e
E   AssertionError: ONNX model is invalid. Model:
E   <
E       ir_version=10,
E       opset_imports={'': 18, 'pkg.torch.onnx': 1, 'pkg.onnxscript.torch_lib.common': 1, 'pkg.onnxscript.torch_lib': 1},
E       producer_name='torch_test',
E       producer_version=None,
E       domain=None,
E       model_version=None,
E   >
E   graph(
E       name=main_graph,
E       inputs=(
E           %"input_0"<INT64,[5,5,5,5]>,
E           %"input_1_0"<BOOL,[5,5,5]>
E       ),
E       outputs=(
E           %"val_4"<INT64,[103,5]>
E       ),
E   ) {
E       0 |  # node_NonZero_0
E            %"val_0"<?,?> ⬅️ ::NonZero(%"input_1_0")
E       1 |  # node_Transpose_1
E            %"val_1"<?,?> ⬅️ ::Transpose(%"val_0") {perm=(1, 0)}
E       2 |  # node_GatherND_2
E            %"val_2"<?,?> ⬅️ ::GatherND(%"input_0", %"val_1") {batch_dims=0}
E       3 |  # node_Transpose_3
E            %"val_3"<?,?> ⬅️ ::Transpose(%"input_0") {perm=(1, 0)}
E       4 |  # node_Transpose_4
E            %"val_4"<INT64,[103,5]> ⬅️ ::Transpose(%"val_2") {perm=(1, 0)}
E       return %"val_4"<INT64,[103,5]>
E   }
tests.function_libs.torch_lib.ops_test.TestOutputConsistencyFullGraphCPU::test_output_match_opinfo__ops_aten_index_Tensor_bool_cpu_bool
Stack Traces | 1.43s run time
.../function_libs/torch_lib/ops_test_common.py:584: in _capture_graph_and_evaluate_torch_script_evaluator
    onnx.checker.check_model(model_proto, full_check=True)
..../test/lib/python3.11.../site-packages/onnx/checker.py:179: in check_model
    C.check_model(
E   onnx.onnx_cpp2py_export.shape_inference.InferenceError: [ShapeInferenceError] Inference error(s): (op_type:Transpose, node name: node_Transpose_4): [ShapeInferenceError] Inferred shape and existing shape differ in dimension 0: (5) vs (103)

The above exception was the direct cause of the following exception:
.../function_libs/torch_lib/ops_test.py:217: in run_test_output_match
    function_output = function_executor(test_name, reference_torch_outputs)(
.../function_libs/torch_lib/ops_test_common.py:586: in _capture_graph_and_evaluate_torch_script_evaluator
    raise AssertionError(f"ONNX model is invalid. Model:\n{onnx_model}") from e
E   AssertionError: ONNX model is invalid. Model:
E   <
E       ir_version=10,
E       opset_imports={'': 18, 'pkg.torch.onnx': 1, 'pkg.onnxscript.torch_lib.common': 1, 'pkg.onnxscript.torch_lib': 1},
E       producer_name='torch_test',
E       producer_version=None,
E       domain=None,
E       model_version=None,
E   >
E   graph(
E       name=main_graph,
E       inputs=(
E           %"input_0"<BOOL,[5,5,5,5]>,
E           %"input_1_0"<BOOL,[5,5,5]>
E       ),
E       outputs=(
E           %"val_4"<BOOL,[103,5]>
E       ),
E   ) {
E       0 |  # node_NonZero_0
E            %"val_0"<?,?> ⬅️ ::NonZero(%"input_1_0")
E       1 |  # node_Transpose_1
E            %"val_1"<?,?> ⬅️ ::Transpose(%"val_0") {perm=(1, 0)}
E       2 |  # node_GatherND_2
E            %"val_2"<?,?> ⬅️ ::GatherND(%"input_0", %"val_1") {batch_dims=0}
E       3 |  # node_Transpose_3
E            %"val_3"<?,?> ⬅️ ::Transpose(%"input_0") {perm=(1, 0)}
E       4 |  # node_Transpose_4
E            %"val_4"<BOOL,[103,5]> ⬅️ ::Transpose(%"val_2") {perm=(1, 0)}
E       return %"val_4"<BOOL,[103,5]>
E   }
tests.function_libs.torch_lib.ops_test.TestOutputConsistencyFullGraphCPU::test_output_match_opinfo__ops_aten_index_Tensor_bool_cpu_float16
Stack Traces | 1.45s run time
.../function_libs/torch_lib/ops_test_common.py:584: in _capture_graph_and_evaluate_torch_script_evaluator
    onnx.checker.check_model(model_proto, full_check=True)
..../test/lib/python3.11.../site-packages/onnx/checker.py:179: in check_model
    C.check_model(
E   onnx.onnx_cpp2py_export.shape_inference.InferenceError: [ShapeInferenceError] Inference error(s): (op_type:Transpose, node name: node_Transpose_4): [ShapeInferenceError] Inferred shape and existing shape differ in dimension 0: (5) vs (103)

The above exception was the direct cause of the following exception:
.../function_libs/torch_lib/ops_test.py:217: in run_test_output_match
    function_output = function_executor(test_name, reference_torch_outputs)(
.../function_libs/torch_lib/ops_test_common.py:586: in _capture_graph_and_evaluate_torch_script_evaluator
    raise AssertionError(f"ONNX model is invalid. Model:\n{onnx_model}") from e
E   AssertionError: ONNX model is invalid. Model:
E   <
E       ir_version=10,
E       opset_imports={'': 18, 'pkg.torch.onnx': 1, 'pkg.onnxscript.torch_lib.common': 1, 'pkg.onnxscript.torch_lib': 1},
E       producer_name='torch_test',
E       producer_version=None,
E       domain=None,
E       model_version=None,
E   >
E   graph(
E       name=main_graph,
E       inputs=(
E           %"input_0"<FLOAT16,[5,5,5,5]>,
E           %"input_1_0"<BOOL,[5,5,5]>
E       ),
E       outputs=(
E           %"val_4"<FLOAT16,[103,5]>
E       ),
E   ) {
E       0 |  # node_NonZero_0
E            %"val_0"<?,?> ⬅️ ::NonZero(%"input_1_0")
E       1 |  # node_Transpose_1
E            %"val_1"<?,?> ⬅️ ::Transpose(%"val_0") {perm=(1, 0)}
E       2 |  # node_GatherND_2
E            %"val_2"<?,?> ⬅️ ::GatherND(%"input_0", %"val_1") {batch_dims=0}
E       3 |  # node_Transpose_3
E            %"val_3"<?,?> ⬅️ ::Transpose(%"input_0") {perm=(1, 0)}
E       4 |  # node_Transpose_4
E            %"val_4"<FLOAT16,[103,5]> ⬅️ ::Transpose(%"val_2") {perm=(1, 0)}
E       return %"val_4"<FLOAT16,[103,5]>
E   }
tests.function_libs.torch_lib.ops_test.TestOutputConsistencyFullGraphCPU::test_output_match_opinfo__ops_aten_index_Tensor_bool_cpu_float32
Stack Traces | 1.58s run time
.../function_libs/torch_lib/ops_test_common.py:584: in _capture_graph_and_evaluate_torch_script_evaluator
    onnx.checker.check_model(model_proto, full_check=True)
..../test/lib/python3.11.../site-packages/onnx/checker.py:179: in check_model
    C.check_model(
E   onnx.onnx_cpp2py_export.shape_inference.InferenceError: [ShapeInferenceError] Inference error(s): (op_type:Transpose, node name: node_Transpose_4): [ShapeInferenceError] Inferred shape and existing shape differ in dimension 0: (5) vs (103)

The above exception was the direct cause of the following exception:
.../function_libs/torch_lib/ops_test.py:217: in run_test_output_match
    function_output = function_executor(test_name, reference_torch_outputs)(
.../function_libs/torch_lib/ops_test_common.py:586: in _capture_graph_and_evaluate_torch_script_evaluator
    raise AssertionError(f"ONNX model is invalid. Model:\n{onnx_model}") from e
E   AssertionError: ONNX model is invalid. Model:
E   <
E       ir_version=10,
E       opset_imports={'': 18, 'pkg.torch.onnx': 1, 'pkg.onnxscript.torch_lib.common': 1, 'pkg.onnxscript.torch_lib': 1},
E       producer_name='torch_test',
E       producer_version=None,
E       domain=None,
E       model_version=None,
E   >
E   graph(
E       name=main_graph,
E       inputs=(
E           %"input_0"<FLOAT,[5,5,5,5]>,
E           %"input_1_0"<BOOL,[5,5,5]>
E       ),
E       outputs=(
E           %"val_4"<FLOAT,[103,5]>
E       ),
E   ) {
E       0 |  # node_NonZero_0
E            %"val_0"<?,?> ⬅️ ::NonZero(%"input_1_0")
E       1 |  # node_Transpose_1
E            %"val_1"<?,?> ⬅️ ::Transpose(%"val_0") {perm=(1, 0)}
E       2 |  # node_GatherND_2
E            %"val_2"<?,?> ⬅️ ::GatherND(%"input_0", %"val_1") {batch_dims=0}
E       3 |  # node_Transpose_3
E            %"val_3"<?,?> ⬅️ ::Transpose(%"input_0") {perm=(1, 0)}
E       4 |  # node_Transpose_4
E            %"val_4"<FLOAT,[103,5]> ⬅️ ::Transpose(%"val_2") {perm=(1, 0)}
E       return %"val_4"<FLOAT,[103,5]>
E   }
tests.function_libs.torch_lib.ops_test.TestOutputConsistencyFullGraphCPU::test_output_match_opinfo__ops_aten_index_Tensor_bool_cpu_int32
Stack Traces | 1.69s run time
.../function_libs/torch_lib/ops_test_common.py:584: in _capture_graph_and_evaluate_torch_script_evaluator
    onnx.checker.check_model(model_proto, full_check=True)
..../test/lib/python3.11.../site-packages/onnx/checker.py:179: in check_model
    C.check_model(
E   onnx.onnx_cpp2py_export.shape_inference.InferenceError: [ShapeInferenceError] Inference error(s): (op_type:Transpose, node name: node_Transpose_4): [ShapeInferenceError] Inferred shape and existing shape differ in dimension 0: (5) vs (103)

The above exception was the direct cause of the following exception:
.../function_libs/torch_lib/ops_test.py:217: in run_test_output_match
    function_output = function_executor(test_name, reference_torch_outputs)(
.../function_libs/torch_lib/ops_test_common.py:586: in _capture_graph_and_evaluate_torch_script_evaluator
    raise AssertionError(f"ONNX model is invalid. Model:\n{onnx_model}") from e
E   AssertionError: ONNX model is invalid. Model:
E   <
E       ir_version=10,
E       opset_imports={'': 18, 'pkg.torch.onnx': 1, 'pkg.onnxscript.torch_lib.common': 1, 'pkg.onnxscript.torch_lib': 1},
E       producer_name='torch_test',
E       producer_version=None,
E       domain=None,
E       model_version=None,
E   >
E   graph(
E       name=main_graph,
E       inputs=(
E           %"input_0"<INT32,[5,5,5,5]>,
E           %"input_1_0"<BOOL,[5,5,5]>
E       ),
E       outputs=(
E           %"val_4"<INT32,[103,5]>
E       ),
E   ) {
E       0 |  # node_NonZero_0
E            %"val_0"<?,?> ⬅️ ::NonZero(%"input_1_0")
E       1 |  # node_Transpose_1
E            %"val_1"<?,?> ⬅️ ::Transpose(%"val_0") {perm=(1, 0)}
E       2 |  # node_GatherND_2
E            %"val_2"<?,?> ⬅️ ::GatherND(%"input_0", %"val_1") {batch_dims=0}
E       3 |  # node_Transpose_3
E            %"val_3"<?,?> ⬅️ ::Transpose(%"input_0") {perm=(1, 0)}
E       4 |  # node_Transpose_4
E            %"val_4"<INT32,[103,5]> ⬅️ ::Transpose(%"val_2") {perm=(1, 0)}
E       return %"val_4"<INT32,[103,5]>
E   }

To view more test analytics, go to the Test Analytics Dashboard
📋 Got 3 mins? Take this short survey to help us improve Test Analytics.

Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
@justinchuby justinchuby added the do not merge Do not merge this PR label Oct 10, 2025
@justinchuby justinchuby modified the milestones: 0.5.4, 0.5.5 Oct 14, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

do not merge Do not merge this PR module: torchlib Related to the torch/aten function lib in development

Projects

Development

Successfully merging this pull request may close these issues.

3 participants