diff --git a/_static/img/onnx/custom_aten_add_function.png b/_static/img/onnx/custom_aten_add_function.png index d9f927ce707..8ef05a747a0 100644 Binary files a/_static/img/onnx/custom_aten_add_function.png and b/_static/img/onnx/custom_aten_add_function.png differ diff --git a/_static/img/onnx/custom_aten_gelu_function.png b/_static/img/onnx/custom_aten_gelu_function.png deleted file mode 100644 index 5cb573e7dcb..00000000000 Binary files a/_static/img/onnx/custom_aten_gelu_function.png and /dev/null differ diff --git a/_static/img/onnx/custom_aten_gelu_model.png b/_static/img/onnx/custom_aten_gelu_model.png index 6bc46337b48..5b326690eb7 100644 Binary files a/_static/img/onnx/custom_aten_gelu_model.png and b/_static/img/onnx/custom_aten_gelu_model.png differ diff --git a/beginner_source/onnx/onnx_registry_tutorial.py b/beginner_source/onnx/onnx_registry_tutorial.py index dfb54d60974..6063b8ac356 100644 --- a/beginner_source/onnx/onnx_registry_tutorial.py +++ b/beginner_source/onnx/onnx_registry_tutorial.py @@ -99,7 +99,6 @@ def forward(self, input_x, input_y): # NOTE: All attributes must be annotated with type hints. @onnxscript.script(custom_aten) def custom_aten_add(input_x, input_y, alpha: float = 1.0): - alpha = opset18.CastLike(alpha, input_y) input_y = opset18.Mul(input_y, alpha) return opset18.Add(input_x, input_y) @@ -130,9 +129,9 @@ def custom_aten_add(input_x, input_y, alpha: float = 1.0): # graph node name is the function name assert onnx_program.model_proto.graph.node[0].op_type == "custom_aten_add" # function node domain is empty because we use standard ONNX operators -assert onnx_program.model_proto.functions[0].node[3].domain == "" +assert onnx_program.model_proto.functions[0].node[2].domain == "" # function node name is the standard ONNX operator name -assert onnx_program.model_proto.functions[0].node[3].op_type == "Add" +assert onnx_program.model_proto.functions[0].node[2].op_type == "Add" ###################################################################### @@ -231,33 +230,24 @@ def custom_aten_gelu(input_x, approximate: str = "none"): ###################################################################### -# Let's inspect the model and verify the model uses :func:`custom_aten_gelu` instead of -# :class:`aten::gelu`. Note the graph has one graph nodes for -# ``custom_aten_gelu``, and inside ``custom_aten_gelu``, there is a function -# node for ``Gelu`` with namespace ``com.microsoft``. +# Let's inspect the model and verify the model uses op_type ``Gelu`` +# from namespace ``com.microsoft``. +# +# Note that :func:`custom_aten_gelu` does not exist in the graph, because +# the funtions with less than 3 operators are inlined automatically. # # graph node domain is the custom domain we registered assert onnx_program.model_proto.graph.node[0].domain == "com.microsoft" # graph node name is the function name -assert onnx_program.model_proto.graph.node[0].op_type == "custom_aten_gelu" -# function node domain is the custom domain we registered -assert onnx_program.model_proto.functions[0].node[0].domain == "com.microsoft" -# function node name is the node name used in the function -assert onnx_program.model_proto.functions[0].node[0].op_type == "Gelu" +assert onnx_program.model_proto.graph.node[0].op_type == "Gelu" ###################################################################### -# The following diagram shows ``custom_aten_gelu_model`` ONNX graph using Netron: +# The following diagram shows ``custom_aten_gelu_model`` ONNX graph using Netron, +# we can see the ``Gelu`` node from module ``com.microsoft`` used in the function: # # .. image:: /_static/img/onnx/custom_aten_gelu_model.png -# :width: 70% -# :align: center -# -# Inside the ``custom_aten_gelu`` function, we can see the ``Gelu`` node from module -# ``com.microsoft`` used in the function: -# -# .. image:: /_static/img/onnx/custom_aten_gelu_function.png # # That is all we need to do. As an additional step, we can use ONNX Runtime to run the model, # and compare the results with PyTorch.