Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX: beginner/onnx/onnx_registry_tutorial.py fails against 2.4 RC binaries #2950

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .jenkins/validate_tutorials_built.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@
"intermediate_source/flask_rest_api_tutorial",
"intermediate_source/text_to_speech_with_torchaudio",
"intermediate_source/tensorboard_profiler_tutorial", # reenable after 2.0 release.
"beginner_source/onnx/onnx_registry_tutorial", # reenable after 2941 is fixed.
"intermediate_source/torch_export_tutorial" # reenable after 2940 is fixed.
]

Expand Down
Binary file modified _static/img/onnx/custom_aten_add_function.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file removed _static/img/onnx/custom_aten_gelu_function.png
Binary file not shown.
Binary file modified _static/img/onnx/custom_aten_gelu_model.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
31 changes: 11 additions & 20 deletions beginner_source/onnx/onnx_registry_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ def forward(self, input_x, input_y):
# NOTE: All attributes must be annotated with type hints.
@onnxscript.script(custom_aten)
def custom_aten_add(input_x, input_y, alpha: float = 1.0):
alpha = opset18.CastLike(alpha, input_y)
input_y = opset18.Mul(input_y, alpha)
return opset18.Add(input_x, input_y)

Expand Down Expand Up @@ -130,9 +129,9 @@ def custom_aten_add(input_x, input_y, alpha: float = 1.0):
# graph node name is the function name
assert onnx_program.model_proto.graph.node[0].op_type == "custom_aten_add"
# function node domain is empty because we use standard ONNX operators
assert onnx_program.model_proto.functions[0].node[3].domain == ""
assert {node.domain for node in onnx_program.model_proto.functions[0].node} == {""}
# function node name is the standard ONNX operator name
assert onnx_program.model_proto.functions[0].node[3].op_type == "Add"
assert {node.op_type for node in onnx_program.model_proto.functions[0].node} == {"Add", "Mul", "Constant"}


######################################################################
Expand Down Expand Up @@ -231,33 +230,25 @@ def custom_aten_gelu(input_x, approximate: str = "none"):


######################################################################
# Let's inspect the model and verify the model uses :func:`custom_aten_gelu` instead of
# :class:`aten::gelu`. Note the graph has one graph nodes for
# ``custom_aten_gelu``, and inside ``custom_aten_gelu``, there is a function
# node for ``Gelu`` with namespace ``com.microsoft``.
# Let's inspect the model and verify the model uses op_type ``Gelu``
# from namespace ``com.microsoft``.
#
# .. note::
# :func:`custom_aten_gelu` does not exist in the graph because
# functions with fewer than three operators are inlined automatically.
#

# graph node domain is the custom domain we registered
assert onnx_program.model_proto.graph.node[0].domain == "com.microsoft"
# graph node name is the function name
assert onnx_program.model_proto.graph.node[0].op_type == "custom_aten_gelu"
# function node domain is the custom domain we registered
assert onnx_program.model_proto.functions[0].node[0].domain == "com.microsoft"
# function node name is the node name used in the function
assert onnx_program.model_proto.functions[0].node[0].op_type == "Gelu"
assert onnx_program.model_proto.graph.node[0].op_type == "Gelu"


######################################################################
# The following diagram shows ``custom_aten_gelu_model`` ONNX graph using Netron:
# The following diagram shows ``custom_aten_gelu_model`` ONNX graph using Netron,
# we can see the ``Gelu`` node from module ``com.microsoft`` used in the function:
#
# .. image:: /_static/img/onnx/custom_aten_gelu_model.png
# :width: 70%
# :align: center
#
# Inside the ``custom_aten_gelu`` function, we can see the ``Gelu`` node from module
# ``com.microsoft`` used in the function:
#
# .. image:: /_static/img/onnx/custom_aten_gelu_function.png
#
# That is all we need to do. As an additional step, we can use ONNX Runtime to run the model,
# and compare the results with PyTorch.
Expand Down
Loading