Skip to content

Commit

Permalink
dialects: (onnx) Add onnx.Sigmoid (#2479)
Browse files Browse the repository at this point in the history
Implementation of Sigmoid onnx operator.

@superlopuh @compor

---------

Co-authored-by: Sasha Lopoukhine <superlopuh@gmail.com>
  • Loading branch information
alecerio and superlopuh authored May 30, 2024
1 parent ae12eed commit ac40191
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 0 deletions.
9 changes: 9 additions & 0 deletions tests/filecheck/dialects/onnx/onnx_invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -560,3 +560,12 @@ builtin.module {
// CHECK: Operation does not verify: axes to squeeze must be between 0 and 2, axes: 3
%res_squeeze = "onnx.Squeeze"(%t0) {onnx_node_name = "/Squeeze", "axes" = 3 : i64} : (tensor<1x2x4xf32>) -> tensor<2x4xf32>
}

// -----

builtin.module {
%t0 = "test.op"() : () -> (tensor<3x4xf32>)

// CHECK: Operation does not verify: tensor input shape (3, 4) is not equal to tensor output shape (7, 3)
%res_sigmoid = "onnx.Sigmoid"(%t0) {onnx_node_name = "/Sigmoid"} : (tensor<3x4xf32>) -> tensor<7x3xf32>
}
3 changes: 3 additions & 0 deletions tests/filecheck/dialects/onnx/onnx_ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,6 @@

%res_squeeze = "onnx.Squeeze"(%t0) {onnx_node_name = "/Squeeze", "axes" = 0}: (tensor<1x2x6xf32>) -> tensor<2x6xf32>
// CHECK: %res_squeeze = onnx.Squeeze(%t0) {"onnx_node_name" = "/Squeeze", "axes" = 0 : i64} : (tensor<1x2x6xf32>) -> tensor<2x6xf32>

%res_sigmoid = "onnx.Sigmoid"(%t8) {onnx_node_name = "/Sigmoid"}: (tensor<3x4xf32>) -> tensor<3x4xf32>
// CHECK: %res_sigmoid = onnx.Sigmoid(%t8) {"onnx_node_name" = "/Sigmoid"} : (tensor<3x4xf32>) -> tensor<3x4xf32>
52 changes: 52 additions & 0 deletions xdsl/dialects/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1067,6 +1067,57 @@ def verify_(self) -> None:
)


@irdl_op_definition
class Sigmoid(IRDLOperation):
"""
Applies the sigmoid function element-wise to all elements of the input tensor.
The sigmoid function, denoted by sigma(x), is a common mathematical function used in machine learning and neural networks. It is defined as:
sigma(x) = 1 / (1 + e^-x)
where e is the base of the natural logarithm. The sigmoid function maps any real-valued number to the range of [0, 1].
The sigmoid function is used as an activation function.
Args:
- input_tensor (TensorType): The input tensor to which the sigmoid function will be applied.
Returns:
- output_tensor (TensorType): The output tensor after applying the sigmoid function element-wise to the input tensor.
"""

name = "onnx.Sigmoid"

T = Annotated[AnyFloat, ConstraintVar("T")]
input_tensor = operand_def(TensorType[T])
output_tensor = result_def(TensorType[T])

assembly_format = "`(` $input_tensor`)` attr-dict `:` `(` type($input_tensor) `)` `->` type($output_tensor) "

def __init__(
self,
input_tensor: SSAValue,
):
super().__init__(
operands=[input_tensor],
result_types=[input_tensor.type],
)

def verify_(self) -> None:
if not isinstance(
input_tensor_type := self.input_tensor.type, TensorType
) or not isinstance(output_tensor_type := self.output_tensor.type, TensorType):
assert (
False
), "onnx elementwise operation operands and result must be of type TensorType"

input_tensor_shape = input_tensor_type.get_shape()
output_tensor_shape = output_tensor_type.get_shape()

# check if input tensor and output tensor have the same shape
if input_tensor_shape != output_tensor_shape:
raise VerifyException(
f"tensor input shape {input_tensor_shape} is not equal to tensor output shape {output_tensor_shape}"
)


ONNX = Dialect(
"onnx",
[
Expand All @@ -1085,5 +1136,6 @@ def verify_(self) -> None:
Sub,
Transpose,
Squeeze,
Sigmoid,
],
)

0 comments on commit ac40191

Please sign in to comment.