Skip to content

Commit

Permalink
[FRONTEND]onnx, mxnet, pytorch mathops added (apache#5561)
Browse files Browse the repository at this point in the history
  • Loading branch information
siju-samuel authored and Trevor Morris committed Jun 18, 2020
1 parent 6b7eaeb commit 0f9e55c
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 5 deletions.
13 changes: 10 additions & 3 deletions python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -1749,16 +1749,18 @@ def impl(inputs, input_types):
"floor",
"ceil",
"sigmoid",
"tanh",
"negative",
"reshape_like",
"zeros_like",
"ones_like",
"where",
"gather_nd",
"tan",
"cos",
"sin"
"cosh",
"sin",
"sinh",
"tan",
"tanh",
]

_convert_map = {
Expand All @@ -1774,7 +1776,12 @@ def impl(inputs, input_types):
"broadcast_maximum" : _rename(_op.maximum),
"broadcast_minimum" : _rename(_op.minimum),
"broadcast_power" : _rename(_op.power),
"arccos" : _rename(_op.acos),
"arcsin" : _rename(_op.asin),
"arctan" : _rename(_op.atan),
"arccosh" : _rename(_op.acosh),
"arcsinh" : _rename(_op.asinh),
"arctanh" : _rename(_op.atanh),
"broadcast_equal" : _mx_compare(_op.equal, _rename),
"broadcast_not_equal" : _mx_compare(_op.not_equal, _rename),
"broadcast_greater" : _mx_compare(_op.greater, _rename),
Expand Down
11 changes: 11 additions & 0 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1627,6 +1627,17 @@ def _get_convert_map(opset):
'Greater': Greater.get_converter(opset),
'Less': Less.get_converter(opset),
'Log': Renamer('log'),
'ACos': Renamer('acos'),
'ACosh': Renamer('acosh'),
'ASin': Renamer('asin'),
'ASinh': Renamer('asinh'),
'ATan': Renamer('atan'),
'ATanh': Renamer('atanh'),
'Cos': Renamer('cos'),
'Cosh': Renamer('cosh'),
'Sin': Renamer('sin'),
'Sinh': Renamer('sinh'),
'Tan': Renamer('tan'),
'Tanh': Renamer('tanh'),
'Pow': Renamer('power'),
'PRelu': Prelu.get_converter(opset),
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1705,6 +1705,8 @@ def _get_convert_map(prelude):
"aten::sinh" : _unary("sinh"),
"aten::tan" : _unary("tan"),
"aten::tanh" : _unary("tanh"),
"aten::acos" : _unary("acos"),
"aten::asin" : _unary("asin"),
"aten::atan" : _unary("atan"),
"aten::log" : _unary("log"),
"aten::log2" : _unary("log2"),
Expand Down
21 changes: 21 additions & 0 deletions tests/python/frontend/mxnet/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,26 @@ def test_forward_elemwise_ops():
op_res = intrp.evaluate()(a_np, b_np)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())


def test_forward_unary_ops():
for op in ["cos", "sin", "tan",
"cosh", "sinh", "tanh",
"arccos", "arcsin", "arctan",
"arccosh", "arcsinh", "arctanh"]:
shape = (1, 3, 4, 5)
dtype = 'float32'
a_np = np.random.uniform(size=shape).astype(dtype)
mx_sym = _mx_symbol(mx.sym, op, [mx.sym.var('a')])
ref_res = _mx_symbol(mx.nd, op, [mx.nd.array(a_np)])
shapes = {'a': shape}
mod, _ = relay.frontend.from_mxnet(mx_sym, shapes, dtype)
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target)
op_res = intrp.evaluate()(a_np)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-5, atol=1e-5)


def test_forward_scalar_ops():
for op in [operator.add, operator.sub, operator.mul, operator.truediv,
operator.pow, operator.lt, operator.le, operator.eq,
Expand Down Expand Up @@ -1113,6 +1133,7 @@ def verify(shape, blocksize=2):
test_forward_broadcast_to()
test_forward_logical_not()
test_forward_elemwise_ops()
test_forward_unary_ops()
test_forward_scalar_ops()
test_forward_slice_like()
test_forward_slice_axis()
Expand Down
11 changes: 11 additions & 0 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1598,6 +1598,17 @@ def verify_single_ops(op, x, out_np, rtol=1e-5, atol=1e-5):
verify_single_ops("Exp", x, np.exp(x))
verify_single_ops("Log", x, np.log(x))
verify_single_ops("Log", x, np.log(x))
verify_single_ops("ACos", x, np.arccos(x))
verify_single_ops("ACosh", x, np.arccosh(x))
verify_single_ops("ASin", x, np.arcsin(x))
verify_single_ops("ASinh", x, np.arcsinh(x))
verify_single_ops("ATan", x, np.arctan(x))
verify_single_ops("ATanh", x, np.arctanh(x))
verify_single_ops("Cos", x, np.cos(x))
verify_single_ops("Cosh", x, np.cosh(x))
verify_single_ops("Sin", x, np.sin(x))
verify_single_ops("Sinh", x, np.sinh(x))
verify_single_ops("Tan", x, np.tan(x))
verify_single_ops("Tanh", x, np.tanh(x))
verify_single_ops("Sigmoid", x, 1 / (1 + np.exp(-x)))
verify_single_ops("Softsign", x, x / (1 + np.abs(x)))
Expand Down
14 changes: 12 additions & 2 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1906,7 +1906,15 @@ class Tanh1(Module):
def forward(self, *args):
return torch.tanh(args[0])

class ATanh1(Module):
class Acos1(Module):
def forward(self, *args):
return torch.acos(args[0])

class Asin1(Module):
def forward(self, *args):
return torch.asin(args[0])

class Atan1(Module):
def forward(self, *args):
return torch.atan(args[0])

Expand Down Expand Up @@ -1967,7 +1975,9 @@ def forward(self, *args):
verify_model(Sinh1().float().eval(), input_data=input_data)
verify_model(Tan1().float().eval(), input_data=input_data)
verify_model(Tanh1().float().eval(), input_data=input_data)
verify_model(ATanh1().float().eval(), input_data=input_data)
verify_model(Acos1().float().eval(), input_data=input_data)
verify_model(Asin1().float().eval(), input_data=input_data)
verify_model(Atan1().float().eval(), input_data=input_data)
verify_model(Log1().float().eval(), input_data=input_data)
verify_model(Log2_1().float().eval(), input_data=input_data)
verify_model(Log10_1().float().eval(), input_data=input_data)
Expand Down

0 comments on commit 0f9e55c

Please sign in to comment.