Skip to content

Commit

Permalink
[Frontend][ONNX] add onnx Mish operator (#15415)
Browse files Browse the repository at this point in the history
* added mish operator to onnx frontend

* linter reformat

* fixed lint issues as linter failed on the CI

* added test for mish operator

* added test for mish operator

* pytest skip since ONNX Runtime in CI does not support domain version 18

* linter format
  • Loading branch information
Aarsh2001 authored Jul 30, 2023
1 parent 2d76c97 commit 0556653
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 0 deletions.
19 changes: 19 additions & 0 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1161,6 +1161,24 @@ def _impl_v1(cls, inputs, attr, params):
return Gelu._impl_v1([inp], attr, params)


class Mish(OnnxOpConverter):
"""Operator converter for Mish from Microsoft onnxruntime contrib opset.
mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + e^{x}))
"""

@classmethod
def _impl_v18(cls, inputs, attr, params):
x = inputs[0]
# Declare const
const_dtype = infer_type(x).checked_type.dtype
one = _expr.const(1.0, dtype=const_dtype)

# Compute Mish
term1 = _op.log(one + _op.exp(x))
return _op.multiply(x, _op.tanh(term1))


class LayerNormalization(OnnxOpConverter):
"""Operator converter for LayerNormalization from Microsoft onnxruntime contrib opset."""

Expand Down Expand Up @@ -6536,6 +6554,7 @@ def _get_convert_map(opset):
"Gelu": Gelu.get_converter(opset),
"FastGelu": FastGelu.get_converter(opset),
"BiasGelu": BiasGelu.get_converter(opset),
"Mish": Mish.get_converter(opset),
"LayerNormalization": LayerNormalization.get_converter(opset),
# TODO: We need a better way to handle different domains, in case
# of name collisions. EmbedLayerNormalization, SkipLayerNormalization, and Attention
Expand Down
9 changes: 9 additions & 0 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -2489,6 +2489,15 @@ def selu_x(x, alpha, gamma):
)


@pytest.mark.skip("Currently ONNX Runtime in CI does not support domain version of 18")
@tvm.testing.parametrize_targets
def test_mish(target, dev):
def mish_x(x):
return x * np.tanh(np.log1p(np.exp(x)))

_test_onnx_op_elementwise(target, dev, (2, 4, 5, 6), mish_x, {}, "float64", "Mish", {})


@tvm.testing.parametrize_targets
def test_prelu(target, dev):
"""test_prelu"""
Expand Down

0 comments on commit 0556653

Please sign in to comment.