From 621a61df5afd05637e0e781477b9154e3865253a Mon Sep 17 00:00:00 2001 From: Samuel Date: Tue, 12 May 2020 00:26:23 +0530 Subject: [PATCH] [FRONTEND]onnx, mxnet, pytorch mathops added (#5561) --- python/tvm/relay/frontend/mxnet.py | 13 +++++++++--- python/tvm/relay/frontend/onnx.py | 11 ++++++++++ python/tvm/relay/frontend/pytorch.py | 2 ++ tests/python/frontend/mxnet/test_forward.py | 21 +++++++++++++++++++ tests/python/frontend/onnx/test_forward.py | 11 ++++++++++ tests/python/frontend/pytorch/test_forward.py | 14 +++++++++++-- 6 files changed, 67 insertions(+), 5 deletions(-) diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index 7dbc7881f43c..4cb7a2a75bad 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -1749,16 +1749,18 @@ def impl(inputs, input_types): "floor", "ceil", "sigmoid", - "tanh", "negative", "reshape_like", "zeros_like", "ones_like", "where", "gather_nd", - "tan", "cos", - "sin" + "cosh", + "sin", + "sinh", + "tan", + "tanh", ] _convert_map = { @@ -1774,7 +1776,12 @@ def impl(inputs, input_types): "broadcast_maximum" : _rename(_op.maximum), "broadcast_minimum" : _rename(_op.minimum), "broadcast_power" : _rename(_op.power), + "arccos" : _rename(_op.acos), + "arcsin" : _rename(_op.asin), "arctan" : _rename(_op.atan), + "arccosh" : _rename(_op.acosh), + "arcsinh" : _rename(_op.asinh), + "arctanh" : _rename(_op.atanh), "broadcast_equal" : _mx_compare(_op.equal, _rename), "broadcast_not_equal" : _mx_compare(_op.not_equal, _rename), "broadcast_greater" : _mx_compare(_op.greater, _rename), diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 1a4aee0a0d6c..58ec4ee56a93 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1627,6 +1627,17 @@ def _get_convert_map(opset): 'Greater': Greater.get_converter(opset), 'Less': Less.get_converter(opset), 'Log': Renamer('log'), + 'ACos': Renamer('acos'), + 'ACosh': Renamer('acosh'), + 'ASin': Renamer('asin'), + 'ASinh': Renamer('asinh'), + 'ATan': Renamer('atan'), + 'ATanh': Renamer('atanh'), + 'Cos': Renamer('cos'), + 'Cosh': Renamer('cosh'), + 'Sin': Renamer('sin'), + 'Sinh': Renamer('sinh'), + 'Tan': Renamer('tan'), 'Tanh': Renamer('tanh'), 'Pow': Renamer('power'), 'PRelu': Prelu.get_converter(opset), diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 64f30f35b376..3af1051344ef 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1699,6 +1699,8 @@ def _get_convert_map(prelude): "aten::sinh" : _unary("sinh"), "aten::tan" : _unary("tan"), "aten::tanh" : _unary("tanh"), + "aten::acos" : _unary("acos"), + "aten::asin" : _unary("asin"), "aten::atan" : _unary("atan"), "aten::log" : _unary("log"), "aten::log2" : _unary("log2"), diff --git a/tests/python/frontend/mxnet/test_forward.py b/tests/python/frontend/mxnet/test_forward.py index 84c8acf20f3e..3fb8e30acb88 100644 --- a/tests/python/frontend/mxnet/test_forward.py +++ b/tests/python/frontend/mxnet/test_forward.py @@ -363,6 +363,26 @@ def test_forward_elemwise_ops(): op_res = intrp.evaluate()(a_np, b_np) tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy()) + +def test_forward_unary_ops(): + for op in ["cos", "sin", "tan", + "cosh", "sinh", "tanh", + "arccos", "arcsin", "arctan", + "arccosh", "arcsinh", "arctanh"]: + shape = (1, 3, 4, 5) + dtype = 'float32' + a_np = np.random.uniform(size=shape).astype(dtype) + mx_sym = _mx_symbol(mx.sym, op, [mx.sym.var('a')]) + ref_res = _mx_symbol(mx.nd, op, [mx.nd.array(a_np)]) + shapes = {'a': shape} + mod, _ = relay.frontend.from_mxnet(mx_sym, shapes, dtype) + for target, ctx in ctx_list(): + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) + op_res = intrp.evaluate()(a_np) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-5, atol=1e-5) + + def test_forward_scalar_ops(): for op in [operator.add, operator.sub, operator.mul, operator.truediv, operator.pow, operator.lt, operator.le, operator.eq, @@ -1113,6 +1133,7 @@ def verify(shape, blocksize=2): test_forward_broadcast_to() test_forward_logical_not() test_forward_elemwise_ops() + test_forward_unary_ops() test_forward_scalar_ops() test_forward_slice_like() test_forward_slice_axis() diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index a26c6137b32b..614041401026 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -1598,6 +1598,17 @@ def verify_single_ops(op, x, out_np, rtol=1e-5, atol=1e-5): verify_single_ops("Exp", x, np.exp(x)) verify_single_ops("Log", x, np.log(x)) verify_single_ops("Log", x, np.log(x)) + verify_single_ops("ACos", x, np.arccos(x)) + verify_single_ops("ACosh", x, np.arccosh(x)) + verify_single_ops("ASin", x, np.arcsin(x)) + verify_single_ops("ASinh", x, np.arcsinh(x)) + verify_single_ops("ATan", x, np.arctan(x)) + verify_single_ops("ATanh", x, np.arctanh(x)) + verify_single_ops("Cos", x, np.cos(x)) + verify_single_ops("Cosh", x, np.cosh(x)) + verify_single_ops("Sin", x, np.sin(x)) + verify_single_ops("Sinh", x, np.sinh(x)) + verify_single_ops("Tan", x, np.tan(x)) verify_single_ops("Tanh", x, np.tanh(x)) verify_single_ops("Sigmoid", x, 1 / (1 + np.exp(-x))) verify_single_ops("Softsign", x, x / (1 + np.abs(x))) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index a53f3540ef29..e1c276bb95f6 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -1895,7 +1895,15 @@ class Tanh1(Module): def forward(self, *args): return torch.tanh(args[0]) - class ATanh1(Module): + class Acos1(Module): + def forward(self, *args): + return torch.acos(args[0]) + + class Asin1(Module): + def forward(self, *args): + return torch.asin(args[0]) + + class Atan1(Module): def forward(self, *args): return torch.atan(args[0]) @@ -1956,7 +1964,9 @@ def forward(self, *args): verify_model(Sinh1().float().eval(), input_data=input_data) verify_model(Tan1().float().eval(), input_data=input_data) verify_model(Tanh1().float().eval(), input_data=input_data) - verify_model(ATanh1().float().eval(), input_data=input_data) + verify_model(Acos1().float().eval(), input_data=input_data) + verify_model(Asin1().float().eval(), input_data=input_data) + verify_model(Atan1().float().eval(), input_data=input_data) verify_model(Log1().float().eval(), input_data=input_data) verify_model(Log2_1().float().eval(), input_data=input_data) verify_model(Log10_1().float().eval(), input_data=input_data)