Skip to content

Commit

Permalink
fix build error
Browse files Browse the repository at this point in the history
  • Loading branch information
titaiwangms committed Jun 20, 2024
1 parent 0740801 commit 50cc591
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 20 deletions.
Binary file modified _static/img/onnx/custom_aten_add_function.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file removed _static/img/onnx/custom_aten_gelu_function.png
Binary file not shown.
Binary file modified _static/img/onnx/custom_aten_gelu_model.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
30 changes: 10 additions & 20 deletions beginner_source/onnx/onnx_registry_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ def forward(self, input_x, input_y):
# NOTE: All attributes must be annotated with type hints.
@onnxscript.script(custom_aten)
def custom_aten_add(input_x, input_y, alpha: float = 1.0):
alpha = opset18.CastLike(alpha, input_y)
input_y = opset18.Mul(input_y, alpha)
return opset18.Add(input_x, input_y)

Expand Down Expand Up @@ -130,9 +129,9 @@ def custom_aten_add(input_x, input_y, alpha: float = 1.0):
# graph node name is the function name
assert onnx_program.model_proto.graph.node[0].op_type == "custom_aten_add"
# function node domain is empty because we use standard ONNX operators
assert onnx_program.model_proto.functions[0].node[3].domain == ""
assert onnx_program.model_proto.functions[0].node[2].domain == ""
# function node name is the standard ONNX operator name
assert onnx_program.model_proto.functions[0].node[3].op_type == "Add"
assert onnx_program.model_proto.functions[0].node[2].op_type == "Add"


######################################################################
Expand Down Expand Up @@ -231,33 +230,24 @@ def custom_aten_gelu(input_x, approximate: str = "none"):


######################################################################
# Let's inspect the model and verify the model uses :func:`custom_aten_gelu` instead of
# :class:`aten::gelu`. Note the graph has one graph nodes for
# ``custom_aten_gelu``, and inside ``custom_aten_gelu``, there is a function
# node for ``Gelu`` with namespace ``com.microsoft``.
# Let's inspect the model and verify the model uses op_type ``Gelu``
# from namespace ``com.microsoft``.
#
# Note that :func:`custom_aten_gelu` does not exist in the graph, because
# the funtions with less than 3 operators are inlined automatically.
#

# graph node domain is the custom domain we registered
assert onnx_program.model_proto.graph.node[0].domain == "com.microsoft"
# graph node name is the function name
assert onnx_program.model_proto.graph.node[0].op_type == "custom_aten_gelu"
# function node domain is the custom domain we registered
assert onnx_program.model_proto.functions[0].node[0].domain == "com.microsoft"
# function node name is the node name used in the function
assert onnx_program.model_proto.functions[0].node[0].op_type == "Gelu"
assert onnx_program.model_proto.graph.node[0].op_type == "Gelu"


######################################################################
# The following diagram shows ``custom_aten_gelu_model`` ONNX graph using Netron:
# The following diagram shows ``custom_aten_gelu_model`` ONNX graph using Netron,
# we can see the ``Gelu`` node from module ``com.microsoft`` used in the function:
#
# .. image:: /_static/img/onnx/custom_aten_gelu_model.png
# :width: 70%
# :align: center
#
# Inside the ``custom_aten_gelu`` function, we can see the ``Gelu`` node from module
# ``com.microsoft`` used in the function:
#
# .. image:: /_static/img/onnx/custom_aten_gelu_function.png
#
# That is all we need to do. As an additional step, we can use ONNX Runtime to run the model,
# and compare the results with PyTorch.
Expand Down

0 comments on commit 50cc591

Please sign in to comment.