From d992468d80af816f0413fc43c2ee1c02f7fe19c3 Mon Sep 17 00:00:00 2001 From: Yao Wang Date: Thu, 5 Mar 2020 17:17:35 -0800 Subject: [PATCH] [topi][relay] add operation tan to TVM (#4938) * Add relay operation relay.op.tan. * Update tan implementation in TVM. * Update tests. * Add shape function for tan. * Add missing main test to python/frontend/tensorflow/test_forward. * Revert, back to sin/cos. * Revert "Revert, back to sin/cos." This reverts commit 4da5b503b921585ba9d80944b29136142b575c40. * Fix implementation of tan in cuda. Do not support tan for float16. Simplify topi/tests/python/test_topi_math. Add testing for tan with float32 and float64. Try again to implement tan as sin/cos in llvm. --- docs/frontend/tensorflow.rst | 1 + include/tvm/tir/op.h | 1 + python/tvm/relay/frontend/mxnet.py | 1 + python/tvm/relay/frontend/tensorflow.py | 1 + python/tvm/relay/frontend/tflite.py | 8 ++++++++ python/tvm/relay/op/_tensor.py | 2 ++ python/tvm/relay/op/_tensor_grad.py | 7 +++++++ python/tvm/relay/op/tensor.py | 15 +++++++++++++++ python/tvm/te/__init__.py | 2 +- python/tvm/tir/__init__.py | 2 +- python/tvm/tir/op.py | 16 ++++++++++++++++ src/relay/op/tensor/unary.cc | 11 +++++++++++ src/target/intrin_rule.cc | 3 +++ src/target/llvm/intrin_rule_llvm.cc | 14 ++++++++++++++ src/target/llvm/intrin_rule_nvptx.cc | 3 +++ src/target/llvm/intrin_rule_rocm.cc | 3 +++ src/target/source/intrin_rule_cuda.cc | 19 +++++++++++++++++++ src/tir/ir/expr.cc | 2 +- .../frontend/tensorflow/test_forward.py | 10 ++++++++++ tests/python/frontend/tflite/test_forward.py | 8 ++++++++ tests/python/relay/test_op_grad_level1.py | 1 + tests/python/relay/test_op_level1.py | 1 + tests/python/unittest/test_testing.py | 1 + topi/include/topi/elemwise.h | 1 + topi/python/topi/math.py | 17 +++++++++++++++++ topi/src/topi.cc | 5 +++++ topi/tests/python/test_topi_basic.py | 1 + topi/tests/python/test_topi_math.py | 2 ++ 28 files changed, 155 insertions(+), 3 deletions(-) diff --git a/docs/frontend/tensorflow.rst b/docs/frontend/tensorflow.rst index 87341ab6b7c6..8a54033a2b09 100644 --- a/docs/frontend/tensorflow.rst +++ b/docs/frontend/tensorflow.rst @@ -135,6 +135,7 @@ Supported Ops - ConcatV2 - Conv2D - Cos +- Tan - CropAndResize - DecodeJpeg - DepthwiseConv2dNative diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h index 5172b1496ad2..0a714d8fc8c8 100644 --- a/include/tvm/tir/op.h +++ b/include/tvm/tir/op.h @@ -515,6 +515,7 @@ TVM_DECLARE_INTRIN_UNARY(sqrt); TVM_DECLARE_INTRIN_UNARY(rsqrt); TVM_DECLARE_INTRIN_UNARY(log); TVM_DECLARE_INTRIN_UNARY(popcount); +TVM_DECLARE_INTRIN_UNARY(tan); TVM_DECLARE_INTRIN_UNARY(cos); TVM_DECLARE_INTRIN_UNARY(sin); TVM_DECLARE_INTRIN_UNARY(atan); diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index 0020a63be625..c2bfd751facb 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -1696,6 +1696,7 @@ def _get_bias_requantize_scale(_inputs, _data_scale, _kernel_scale): "ones_like", "where", "gather_nd", + "tan", "cos", "sin" ] diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 14d2418da710..24164a3d759a 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -1564,6 +1564,7 @@ def _impl(inputs, attr, params): 'LessEqual' : _broadcast('less_equal'), 'Log' : AttrCvt('log'), 'Log1p' : _log1p(), + 'Tan' : AttrCvt('tan'), 'Cos' : AttrCvt('cos'), 'Sin' : AttrCvt('sin'), 'LogicalAnd' : _logical('logical_and'), diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index bc51c9138d6b..c2ec4d43bfac 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -68,6 +68,7 @@ def __init__(self, model, subgraph, exp_tab): 'LOG': self.convert_log, 'SIN': self.convert_sin, 'COS': self.convert_cos, + 'TAN': self.convert_tan, 'SQRT': self.convert_sqrt, 'RSQRT': self.convert_rsqrt, 'NEG': self.convert_neg, @@ -657,6 +658,13 @@ def convert_sin(self, op): 'TFlite quantized SIN operator is not supported yet.') return self._convert_unary_elemwise(_op.sin, op) + def convert_tan(self, op): + """Convert TFLite TAN""" + if self.is_quantized(op): + raise tvm.error.OpNotImplemented( + 'TFlite quantized TAN operator is not supported yet.') + return self._convert_unary_elemwise(_op.tan, op) + def convert_cos(self, op): """Convert TFLite COS""" if self.is_quantized(op): diff --git a/python/tvm/relay/op/_tensor.py b/python/tvm/relay/op/_tensor.py index 9f0906b113ba..4480849890b5 100644 --- a/python/tvm/relay/op/_tensor.py +++ b/python/tvm/relay/op/_tensor.py @@ -27,6 +27,7 @@ register_broadcast_schedule("log") +register_broadcast_schedule("tan") register_broadcast_schedule("cos") register_broadcast_schedule("sin") register_broadcast_schedule("atan") @@ -214,3 +215,4 @@ def elemwise_shape_func(attrs, inputs, _): register_shape_func("sqrt", False, elemwise_shape_func) register_shape_func("negative", False, elemwise_shape_func) register_shape_func("exp", False, elemwise_shape_func) +register_shape_func("tan", False, elemwise_shape_func) diff --git a/python/tvm/relay/op/_tensor_grad.py b/python/tvm/relay/op/_tensor_grad.py index 944e51e636f5..33a193799288 100644 --- a/python/tvm/relay/op/_tensor_grad.py +++ b/python/tvm/relay/op/_tensor_grad.py @@ -61,6 +61,13 @@ def log_grad(orig, grad): return [grad * ones_like(x) / x] +@register_gradient("tan") +def tan_grad(orig, grad): + """Returns [grad / (cos^2(x))]""" + x = orig.args[0] + return [grad / (cos(x) * cos(x))] + + @register_gradient("cos") def cos_grad(orig, grad): """Returns [grad * (-sin(x))]""" diff --git a/python/tvm/relay/op/tensor.py b/python/tvm/relay/op/tensor.py index ada1f5e85cac..77969185c0a7 100644 --- a/python/tvm/relay/op/tensor.py +++ b/python/tvm/relay/op/tensor.py @@ -47,6 +47,21 @@ def log(data): """ return _make.log(data) +def tan(data): + """Compute elementwise tan of data. + + Parameters + ---------- + data : relay.Expr + The input data + + Returns + ------- + result : relay.Expr + The computed result. + """ + return _make.tan(data) + def cos(data): """Compute elementwise cos of data. diff --git a/python/tvm/te/__init__.py b/python/tvm/te/__init__.py index 065cf4e5dbdd..3a1a12462344 100644 --- a/python/tvm/te/__init__.py +++ b/python/tvm/te/__init__.py @@ -19,7 +19,7 @@ """ # expose all operators in tvm tir.op from tvm.tir import any, all, min_value, max_value, trace -from tvm.tir import exp, erf, tanh, sigmoid, log, cos, sin, atan, sqrt, rsqrt, floor, ceil +from tvm.tir import exp, erf, tanh, sigmoid, log, tan, cos, sin, atan, sqrt, rsqrt, floor, ceil from tvm.tir import trunc, abs, round, nearbyint, isnan, power, popcount, fmod, if_then_else from tvm.tir import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod from tvm.tir import comm_reducer, min, max, sum diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index a5c81ac77627..aa5871aa636a 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -33,7 +33,7 @@ from .op import call_packed, call_pure_intrin, call_intrin, call_pure_extern, call_extern from .op import call_llvm_intrin, all, any, min_value, max_value, trace -from .op import exp, erf, tanh, sigmoid, log, cos, sin, atan, sqrt, rsqrt, floor, ceil +from .op import exp, erf, tanh, sigmoid, log, tan, cos, sin, atan, sqrt, rsqrt, floor, ceil from .op import trunc, abs, round, nearbyint, isnan, power, popcount, fmod, if_then_else from .op import div, indexdiv, indexmod, truncdiv, truncmod, floordiv, floormod from .op import comm_reducer, min, max, sum diff --git a/python/tvm/tir/op.py b/python/tvm/tir/op.py index a8aef8f85495..c5b1a0a1cbb5 100644 --- a/python/tvm/tir/op.py +++ b/python/tvm/tir/op.py @@ -393,6 +393,22 @@ def log(x): """ return call_pure_intrin(x.dtype, "log", x) +def tan(x): + """Take tan of input x. + + Parameters + ---------- + x : PrimExpr + Input argument. + + Returns + ------- + y : PrimExpr + The result. + """ + return call_pure_intrin(x.dtype, "tan", x) + + def cos(x): """Take cos of input x. diff --git a/src/relay/op/tensor/unary.cc b/src/relay/op/tensor/unary.cc index 1169fa801398..b9d247361996 100644 --- a/src/relay/op/tensor/unary.cc +++ b/src/relay/op/tensor/unary.cc @@ -51,6 +51,17 @@ RELAY_REGISTER_UNARY_OP("log") .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::log)); +RELAY_REGISTER_UNARY_OP("tan") +.describe(R"code(Returns the tan of input array, computed element-wise. + +.. math:: + Y = tan(X) + +)code" TVM_ADD_FILELINE) +.set_support_level(1) +.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::tan)); + + RELAY_REGISTER_UNARY_OP("cos") .describe(R"code(Returns the cos of input array, computed element-wise. diff --git a/src/target/intrin_rule.cc b/src/target/intrin_rule.cc index 7e9ac71bb753..ade3bbdf9304 100644 --- a/src/target/intrin_rule.cc +++ b/src/target/intrin_rule.cc @@ -40,6 +40,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.log") TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.tanh") .set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.tan") +.set_body(DispatchExtern); + TVM_REGISTER_GLOBAL("tvm.intrin.rule.default.cos") .set_body(DispatchExtern); diff --git a/src/target/llvm/intrin_rule_llvm.cc b/src/target/llvm/intrin_rule_llvm.cc index 758b0afc22ff..6c5a9cd079af 100644 --- a/src/target/llvm/intrin_rule_llvm.cc +++ b/src/target/llvm/intrin_rule_llvm.cc @@ -91,6 +91,20 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.pow") TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.popcount") .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::ctpop, 1>); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.tan") +.set_body([](const TVMArgs& targs, TVMRetValue* rv) { + PrimExpr e = targs[0]; + const tir::CallNode* call = e.as(); + CHECK(call != nullptr); + const PrimExpr& x = call->args[0]; + PrimExpr sin_x = tir::CallNode::make( + x.dtype(), "sin", {x}, tir::CallNode::PureIntrinsic); + PrimExpr cos_x = tir::CallNode::make( + x.dtype(), "cos", {x}, tir::CallNode::PureIntrinsic); + PrimExpr tan_x = sin_x / cos_x; + *rv = tan_x; +}); + TVM_REGISTER_GLOBAL("tvm.intrin.rule.llvm.cos") .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::cos, 1>); diff --git a/src/target/llvm/intrin_rule_nvptx.cc b/src/target/llvm/intrin_rule_nvptx.cc index 6f7a89cebd97..6d41d132724b 100644 --- a/src/target/llvm/intrin_rule_nvptx.cc +++ b/src/target/llvm/intrin_rule_nvptx.cc @@ -81,6 +81,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.pow") TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.tanh") .set_body(DispatchExternLibDevice); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.tan") +.set_body(DispatchExternLibDevice); + TVM_REGISTER_GLOBAL("tvm.intrin.rule.nvptx.cos") .set_body(DispatchExternLibDevice); diff --git a/src/target/llvm/intrin_rule_rocm.cc b/src/target/llvm/intrin_rule_rocm.cc index 31b7bf19419b..4e6a661f298d 100644 --- a/src/target/llvm/intrin_rule_rocm.cc +++ b/src/target/llvm/intrin_rule_rocm.cc @@ -80,6 +80,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.pow") TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.tanh") .set_body(DispatchExternOCML); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.tan") +.set_body(DispatchExternOCML); + TVM_REGISTER_GLOBAL("tvm.intrin.rule.rocm.cos") .set_body(DispatchExternOCML); diff --git a/src/target/source/intrin_rule_cuda.cc b/src/target/source/intrin_rule_cuda.cc index aed6c86e965b..849e098d4ad2 100644 --- a/src/target/source/intrin_rule_cuda.cc +++ b/src/target/source/intrin_rule_cuda.cc @@ -54,6 +54,22 @@ struct CUDAFastMath : public CUDAMath { } }; +struct CUDAFastMathTan : public CUDAMath { + std::string operator()(DataType t, std::string name) const { + if (t.lanes() == 1 && t.is_float()) { + switch (t.bits()) { + case 64: return name; + // `__tanf` seems to produce some values too deviant from numpy tan version. + // So, let's use just `tanf` instead. + case 32: return name + 'f'; + case 16: LOG(FATAL) << "cuda tan unsupported for float16"; + default: return ""; + } + } + return ""; + } +}; + struct CUDAPopcount { std::string operator()(DataType t, std::string name) const { if (t.lanes() == 1 && t.is_uint()) { @@ -97,6 +113,9 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.erf") TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.log") .set_body(DispatchExtern); +TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tan") +.set_body(DispatchExtern); + TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.cos") .set_body(DispatchExtern); diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index 07cae5e2c746..7572f8d6bbb0 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -229,7 +229,7 @@ PrimExpr LetNode::make(Var var, PrimExpr value, PrimExpr body) { const char* CallNode::vectorizable_intrinsics[] = { "floor", "ceil", "sign", "trunc", "fabs", "round", "exp", "tanh", "sqrt", - "log", "sin", "cos", "pow", tir::CallNode::shift_left, tir::CallNode::shift_right, + "log", "sin", "cos", "pow", "tan", tir::CallNode::shift_left, tir::CallNode::shift_right, tir::CallNode::likely, tir::CallNode::popcount }; diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 42408b706111..31c548044846 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -2628,6 +2628,15 @@ def test_forward_cos(): compare_tf_with_tvm([np_data], ['in_data:0'], 'cos:0') +def test_forward_tan(): + """test operator tan """ + np_data = np.random.uniform(1, 100, size=(2, 3, 5)).astype(np.float32) + tf.reset_default_graph() + in_data = tf.placeholder(tf.float32, (2, 3, 5), name="in_data") + tf.tan(in_data, name="tan") + compare_tf_with_tvm([np_data], ['in_data:0'], 'tan:0') + + def test_forward_sin(): """test operator sin """ np_data = np.random.uniform(1, 100, size=(2, 3, 5)).astype(np.float32) @@ -3031,6 +3040,7 @@ def test_forward_add_n(): test_forward_sign() test_forward_log() test_forward_log1p() + test_forward_tan() test_forward_cos() test_forward_sin() test_forward_negative() diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 147839398d37..28216fcd674a 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -723,6 +723,13 @@ def _test_cos(data): """ One iteration of cos """ return _test_unary_elemwise(math_ops.cos, data) ####################################################################### +# Tan +# --- + +def _test_tan(data): + """ One iteration of tan """ + return _test_unary_elemwise(math_ops.tan, data) +####################################################################### # Sqrt # ---- @@ -772,6 +779,7 @@ def test_all_unary_elemwise(): if package_version.parse(tf.VERSION) >= package_version.parse('1.14.0'): _test_forward_unary_elemwise(_test_ceil) _test_forward_unary_elemwise(_test_cos) + _test_forward_unary_elemwise(_test_tan) ####################################################################### # Element-wise diff --git a/tests/python/relay/test_op_grad_level1.py b/tests/python/relay/test_op_grad_level1.py index 0eb1cec916b6..0579441166ae 100644 --- a/tests/python/relay/test_op_grad_level1.py +++ b/tests/python/relay/test_op_grad_level1.py @@ -64,6 +64,7 @@ def check_single_op(opfunc, ref): (relay.nn.relu, lambda x: np.where(x < 0, np.zeros_like(x), np.ones_like(x))), (tvm.relay.cos, lambda x: -1.0 * np.sin(x)), (tvm.relay.sin, lambda x: np.cos(x)), + (tvm.relay.tan, lambda x: 1.0 / (np.cos(x) ** 2)), (tvm.relay.atan, lambda x: 1 / (1 + np.power(x, 2.0)))]: check_single_op(opfunc, ref) diff --git a/tests/python/relay/test_op_level1.py b/tests/python/relay/test_op_level1.py index 0fa07499193a..c58ff0f1aaba 100644 --- a/tests/python/relay/test_op_level1.py +++ b/tests/python/relay/test_op_level1.py @@ -76,6 +76,7 @@ def check_single_op(opfunc, ref, dtype): (relay.nn.relu, relu), (tvm.relay.cos, np.cos), (tvm.relay.sin, np.sin), + (tvm.relay.tan, np.tan), (tvm.relay.atan, np.arctan)]: for dtype in ['float16', 'float32']: check_single_op(opfunc, ref, dtype) diff --git a/tests/python/unittest/test_testing.py b/tests/python/unittest/test_testing.py index ecf520d251f1..cfa13845e4e8 100644 --- a/tests/python/unittest/test_testing.py +++ b/tests/python/unittest/test_testing.py @@ -31,6 +31,7 @@ def test_check_numerical_grads(): lambda x: (np.sign(np.sin(1/x)), np.zeros_like(x)), lambda x: (x*np.sin(1/x), np.sin(1/x) - np.cos(1/x)/x), lambda x: (np.sin(1/x), - np.cos(1/x)/(x*x)), + lambda x: (np.tan(x), 1.0 / (np.cos(x) * np.cos(x))), ] # Avoid values too close to 0 since singularities of our functions are there diff --git a/topi/include/topi/elemwise.h b/topi/include/topi/elemwise.h index 3c0822f2b00e..26107ea78d27 100644 --- a/topi/include/topi/elemwise.h +++ b/topi/include/topi/elemwise.h @@ -55,6 +55,7 @@ TOPI_DECLARE_UNARY_OP(round); TOPI_DECLARE_UNARY_OP(trunc); TOPI_DECLARE_UNARY_OP(abs); TOPI_DECLARE_UNARY_OP(cos); +TOPI_DECLARE_UNARY_OP(tan); TOPI_DECLARE_UNARY_OP(sin); TOPI_DECLARE_UNARY_OP(atan); TOPI_DECLARE_UNARY_OP(isnan); diff --git a/topi/python/topi/math.py b/topi/python/topi/math.py index 4a63c4535289..3eda88ad8c11 100644 --- a/topi/python/topi/math.py +++ b/topi/python/topi/math.py @@ -109,6 +109,23 @@ def tanh(x): return te.compute(x.shape, lambda *i: te.tanh(x(*i))) +@tvm.te.tag_scope(tag=tag.ELEMWISE) +def tan(x): + """Take tan of input x. + + Parameters + ---------- + x : tvm.te.Tensor + Input argument. + + Returns + ------- + y : tvm.te.Tensor + The result. + """ + return te.compute(x.shape, lambda *i: te.tan(x(*i))) + + @tvm.te.tag_scope(tag=tag.ELEMWISE) def cos(x): """Take cos of input x. diff --git a/topi/src/topi.cc b/topi/src/topi.cc index 75517b818f45..add01c20a443 100644 --- a/topi/src/topi.cc +++ b/topi/src/topi.cc @@ -175,6 +175,11 @@ TVM_REGISTER_GLOBAL("topi.erf") *rv = erf(args[0]); }); +TVM_REGISTER_GLOBAL("topi.tan") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = tan(args[0]); + }); + TVM_REGISTER_GLOBAL("topi.cos") .set_body([](TVMArgs args, TVMRetValue *rv) { *rv = cos(args[0]); diff --git a/topi/tests/python/test_topi_basic.py b/topi/tests/python/test_topi_basic.py index 83f0469dc00f..13f1463da7ff 100644 --- a/topi/tests/python/test_topi_basic.py +++ b/topi/tests/python/test_topi_basic.py @@ -45,6 +45,7 @@ def test_apply(func, name): test_apply(topi.rsqrt, "rsqrt") test_apply(topi.sin, "sin") test_apply(topi.cos, "cos") + test_apply(topi.tan, "tan") test_apply(topi.atan, "atan") diff --git a/topi/tests/python/test_topi_math.py b/topi/tests/python/test_topi_math.py index 30a0f44aad57..3e58518ed4fe 100644 --- a/topi/tests/python/test_topi_math.py +++ b/topi/tests/python/test_topi_math.py @@ -127,6 +127,8 @@ def check_device(device): test_apply(topi.sqrt, "sqrt", np.sqrt, 0, 100) test_apply(topi.rsqrt, "rsqrt", lambda x: np.ones_like(x) / np.sqrt(x), 0, 100, skip_name_check=True) test_apply(topi.cos, "cos", np.cos, -2.0*np.pi, 2.0*np.pi) + test_apply(topi.tan, "tan", np.tan, -2.0*np.pi, 2.0*np.pi, dtype='float32') + test_apply(topi.tan, "tan", np.tan, -2.0*np.pi, 2.0*np.pi, dtype='float64') test_apply(topi.sin, "sin", np.sin, -2.0*np.pi, 2.0*np.pi) test_apply(topi.erf, "erf", scipy.special.erf, -.1, .1, dtype="float32") test_isnan(-100, 100)