diff --git a/tests/filecheck/dialects/onnx/onnx_ops.mlir b/tests/filecheck/dialects/onnx/onnx_ops.mlir index b3419e10e9..b668984c5a 100644 --- a/tests/filecheck/dialects/onnx/onnx_ops.mlir +++ b/tests/filecheck/dialects/onnx/onnx_ops.mlir @@ -67,3 +67,7 @@ %res_max_pool_single_out = "onnx.MaxPoolSingleOut"(%t26) {onnx_node_name = "/MaxPoolSingleOut", "auto_pad" = "VALID", "ceil_mode" = 0 : i64, "kernel_shape" = [3 : i64, 3 : i64], "dilations" = [1 : i64, 1 : i64], "pads" = [0 : i64, 0 : i64, 0 : i64, 0 : i64], "storage_order" = 0 : i64, "strides" = [1 : i64, 1 : i64]}: (tensor<5x5x32x32xf32>) -> tensor<5x5x30x30xf32> //CHECK: %res_max_pool_single_out = onnx.MaxPoolSingleOut(%t26) {"onnx_node_name" = "/MaxPoolSingleOut", "auto_pad" = "VALID", "ceil_mode" = 0 : i64, "kernel_shape" = [3 : i64, 3 : i64], "dilations" = [1 : i64, 1 : i64], "pads" = [0 : i64, 0 : i64, 0 : i64, 0 : i64], "storage_order" = 0 : i64, "strides" = [1 : i64, 1 : i64]}: (tensor<5x5x32x32xf32>) -> tensor<5x5x30x30xf32> + +"onnx.EntryPoint"() {onnx_node_name = "/EntryPoint", "func" = @main_graph} : () -> () +//CHECK: "onnx.EntryPoint"() {"onnx_node_name" = "/EntryPoint", "func" = @main_graph} : () -> () + diff --git a/xdsl/dialects/onnx.py b/xdsl/dialects/onnx.py index 85c8108bd6..30072b850a 100644 --- a/xdsl/dialects/onnx.py +++ b/xdsl/dialects/onnx.py @@ -17,6 +17,7 @@ NoneType, SSAValue, StringAttr, + SymbolRefAttr, TensorType, ) from xdsl.ir import ( @@ -825,6 +826,24 @@ def verify_(self) -> None: ) +@irdl_op_definition +class EntryPoint(IRDLOperation): + """ + Indicate ONNX entry point + The "onnx.EntryPoint" function indicates the main entry point of ONNX model. + """ + + name = "onnx.EntryPoint" + func = attr_def(SymbolRefAttr) + + def __init__(self, func: Attribute): + super().__init__( + attributes={ + "func": func, + }, + ) + + ONNX = Dialect( "onnx", [ @@ -833,6 +852,7 @@ def verify_(self) -> None: Constant, Conv, Div, + EntryPoint, Gemm, MaxPoolSingleOut, Mul,