From c795a053af68a7afcd55752fcaa1daebc7bf9c04 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Thu, 13 Feb 2020 00:13:57 +0000 Subject: [PATCH 1/3] [Relay][FastMath] Relay pass to use fast exp/tanh --- include/tvm/relay/transform.h | 7 ++ python/tvm/relay/transform.py | 14 +++- src/relay/backend/build_module.cc | 3 + src/relay/op/tensor/unary.cc | 22 +++++++ src/relay/pass/fast_math.cc | 79 +++++++++++++++++++++++ src/relay/pass/pattern_util.h | 10 +++ tests/python/relay/test_pass_fast_math.py | 52 +++++++++++++++ topi/include/topi/elemwise.h | 7 +- topi/python/topi/math.py | 16 +++++ topi/src/topi.cc | 5 +- 10 files changed, 210 insertions(+), 5 deletions(-) create mode 100644 src/relay/pass/fast_math.cc create mode 100644 tests/python/relay/test_pass_fast_math.py diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index 8d886aa09ea2..28628006105b 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -163,6 +163,13 @@ TVM_DLL Pass PartialEval(); */ TVM_DLL Pass SimplifyInference(); +/*! + * \brief Replaces non linear activation functions with their fast but approximate counterparts. + * + * \return The Pass. + */ +TVM_DLL Pass FastMath(); + /*! * \brief Infer the type of an expression. * diff --git a/python/tvm/relay/transform.py b/python/tvm/relay/transform.py index 45535afc486c..71c011d9b4cc 100644 --- a/python/tvm/relay/transform.py +++ b/python/tvm/relay/transform.py @@ -57,7 +57,8 @@ def build_config(opt_level=2, "CanonicalizeCast": 3, "EliminateCommonSubexpr": 3, "CombineParallelConv2D": 4, - "CombineParallelDense": 4 + "CombineParallelDense": 4, + "FastMath": 4 } fallback_device : int, str, or tvmContext, optional @@ -180,6 +181,17 @@ def SimplifyInference(): return _transform.SimplifyInference() +def FastMath(): + """ Converts the expensive non linear functions to their fast but approximate counterparts. + + Returns + ------- + ret: tvm.relay.Pass + The registered to perform operator simplification. + """ + return _transform.FastMath() + + def CanonicalizeOps(): """Canonicalize special operators to basic operators. This can simplify followed analysis, e.g. expanding bias_add to diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index ff64d4a3acbb..0c0a8b8cbfa8 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -305,6 +305,9 @@ class RelayBuildModule : public runtime::ModuleNode { if (targets.size() == 1) { pass_seqs.push_back(transform::AlterOpLayout()); } + + // Fast math optimizations. + pass_seqs.push_back(transform::FastMath()); pass_seqs.push_back(transform::FoldConstant()); // Create a sequential pass and perform optimizations. diff --git a/src/relay/op/tensor/unary.cc b/src/relay/op/tensor/unary.cc index 2c7345865095..1169fa801398 100644 --- a/src/relay/op/tensor/unary.cc +++ b/src/relay/op/tensor/unary.cc @@ -95,6 +95,17 @@ RELAY_REGISTER_UNARY_OP("exp") .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::exp)); +RELAY_REGISTER_UNARY_OP("fast_exp") +.describe(R"code(Returns the fast_exp input array, computed element-wise. + +.. math:: + \fast_exp(x) + +)code" TVM_ADD_FILELINE) +.set_support_level(1) +.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::fast_exp)); + + RELAY_REGISTER_UNARY_OP("erf") .describe(R"code(Returns the error function value for input array, computed element-wise. @@ -250,6 +261,17 @@ RELAY_REGISTER_UNARY_OP("tanh") .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::tanh)); +RELAY_REGISTER_UNARY_OP("fast_tanh") +.describe(R"code(Returns the fast_tanh of input array, computed element-wise. + +.. math:: + Y = sinh(X) / cosh(X) + +)code" TVM_ADD_FILELINE) +.set_support_level(1) +.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::fast_tanh)); + + RELAY_REGISTER_UNARY_OP("negative") .describe(R"code(Returns the numeric negative of input array, computed element-wise. diff --git a/src/relay/pass/fast_math.cc b/src/relay/pass/fast_math.cc new file mode 100644 index 000000000000..898f760fdb50 --- /dev/null +++ b/src/relay/pass/fast_math.cc @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file fast_math.cc + * \brief Replaces non linear activation functions with their fast but approximate counterparts. + */ +#include +#include +#include +#include +#include +#include "pattern_util.h" + +namespace tvm { +namespace relay { + +class FastMathMutator : public ExprMutator { + public: + FastMathMutator() + : exp_op_(Op::Get("exp")), + tanh_op_(Op::Get("tanh")) {} + + Expr VisitExpr_(const CallNode* n) { + auto new_n = ExprMutator::VisitExpr_(n); + if (n->op == exp_op_) { + return FastExp(new_n.as()->args[0]); + } else if (n->op == tanh_op_) { + return FastTanh(new_n.as()->args[0]); + } + return new_n; + } + + private: + // Cache the following ops. They will be used in the passes repeatedly for + // operator equivalence checking so that the registry lookup overhead can be + // reduced. + const Op& exp_op_; + const Op& tanh_op_; +}; + +Expr FastMath(const Expr& e) { + return FastMathMutator().Mutate(e); +} + +namespace transform { + +Pass FastMath() { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { + return Downcast(FastMath(f)); + }; + return CreateFunctionPass(pass_func, 4, "FastMath", + {tir::StringImmNode::make("InferType")}); +} + +TVM_REGISTER_GLOBAL("relay._transform.FastMath") +.set_body_typed(FastMath); + +} // namespace transform + +} // namespace relay +} // namespace tvm diff --git a/src/relay/pass/pattern_util.h b/src/relay/pass/pattern_util.h index f7d8f9c4665e..85750f5e2601 100644 --- a/src/relay/pass/pattern_util.h +++ b/src/relay/pass/pattern_util.h @@ -316,6 +316,16 @@ inline Expr Exp(Expr e) { return CallNode::make(op, {e}); } +inline Expr FastExp(Expr e) { + static const Op& op = Op::Get("fast_exp"); + return CallNode::make(op, {e}); +} + +inline Expr FastTanh(Expr e) { + static const Op& op = Op::Get("fast_tanh"); + return CallNode::make(op, {e}); +} + inline Expr Log(Expr e) { static const Op& op = Op::Get("log"); return CallNode::make(op, {e}); diff --git a/tests/python/relay/test_pass_fast_math.py b/tests/python/relay/test_pass_fast_math.py new file mode 100644 index 000000000000..ba119072d293 --- /dev/null +++ b/tests/python/relay/test_pass_fast_math.py @@ -0,0 +1,52 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +from tvm.ir import IRModule +from tvm import relay +from tvm.relay.transform import FastMath + +def test_exp(): + x = relay.var("x", shape=(1, 16, 16, 16), dtype="float32") + y = relay.exp(x) + func = relay.Function([x], y) + mod = tvm.IRModule.from_expr(func) + + fast_mod = FastMath()(mod) + assert "fast_exp" in fast_mod.astext() + + # Check that opt level 4 triggers the transformation. + with relay.build_config(opt_level=4): + fast_mod = relay.optimize(mod, target='llvm', params=None) + assert "fast_exp" in fast_mod[0].astext() + +def test_tanh(): + x = relay.var("x", shape=(1, 16, 16, 16), dtype="float32") + y = relay.tanh(x) + func = relay.Function([x], y) + mod = tvm.IRModule.from_expr(func) + + fast_mod = FastMath()(mod) + assert "fast_tanh" in fast_mod.astext() + + # Check that opt level 4 triggers the transformation. + with relay.build_config(opt_level=4): + fast_mod = relay.optimize(mod, target='llvm', params=None) + assert "fast_tanh" in fast_mod[0].astext() + +if __name__ == "__main__": + test_exp() + test_tanh() diff --git a/topi/include/topi/elemwise.h b/topi/include/topi/elemwise.h index e35e3e424d6e..3c0822f2b00e 100644 --- a/topi/include/topi/elemwise.h +++ b/topi/include/topi/elemwise.h @@ -58,6 +58,7 @@ TOPI_DECLARE_UNARY_OP(cos); TOPI_DECLARE_UNARY_OP(sin); TOPI_DECLARE_UNARY_OP(atan); TOPI_DECLARE_UNARY_OP(isnan); +TOPI_DECLARE_UNARY_OP(tanh); /* * \brief Fast_tanh_float implementation from Eigen @@ -113,9 +114,9 @@ inline Tensor fast_tanh_float(const Tensor& in, * * \return A Tensor whose op member is tanh */ -inline Tensor tanh(const Tensor& x, - std::string name = "T_tanh", - std::string tag = kElementWise) { +inline Tensor fast_tanh(const Tensor& x, + std::string name = "T_fast_tanh", + std::string tag = kElementWise) { if (x->dtype == DataType::Float(32)) { // invoke fast_tanh_float implementation return fast_tanh_float(x, name, tag); diff --git a/topi/python/topi/math.py b/topi/python/topi/math.py index 5b6b9ab8da75..4a63c4535289 100644 --- a/topi/python/topi/math.py +++ b/topi/python/topi/math.py @@ -467,3 +467,19 @@ def fast_exp(x): The result. """ return cpp.fast_exp(x, x.dtype, tag.ELEMWISE) + + +def fast_tanh(x): + """Take tanhonential of input x using fast_tanh implementation + + Parameters + ---------- + x : tvm.Tensor + Input argument. + + Returns + ------- + y : tvm.Tensor + The result. + """ + return cpp.fast_tanh(x, x.dtype, tag.ELEMWISE) diff --git a/topi/src/topi.cc b/topi/src/topi.cc index 79e223c30975..75517b818f45 100644 --- a/topi/src/topi.cc +++ b/topi/src/topi.cc @@ -188,7 +188,10 @@ TVM_REGISTER_GLOBAL("topi.tanh") .set_body([](TVMArgs args, TVMRetValue *rv) { *rv = tanh(args[0]); }); - +TVM_REGISTER_GLOBAL("topi.fast_tanh") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = fast_tanh(args[0]); + }); TVM_REGISTER_GLOBAL("topi.atan") .set_body([](TVMArgs args, TVMRetValue *rv) { *rv = atan(args[0]); From 200e9acf2f32673ba864825d6413b89df7155a26 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Tue, 25 Feb 2020 20:21:33 +0000 Subject: [PATCH 2/3] Adding required_pass to the tests. --- tests/python/relay/test_pass_fast_math.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/relay/test_pass_fast_math.py b/tests/python/relay/test_pass_fast_math.py index ba119072d293..5053b8072c94 100644 --- a/tests/python/relay/test_pass_fast_math.py +++ b/tests/python/relay/test_pass_fast_math.py @@ -43,7 +43,7 @@ def test_tanh(): assert "fast_tanh" in fast_mod.astext() # Check that opt level 4 triggers the transformation. - with relay.build_config(opt_level=4): + with relay.build_config(opt_level=3, required_pass=['FastMath']): fast_mod = relay.optimize(mod, target='llvm', params=None) assert "fast_tanh" in fast_mod[0].astext() From 41f05ecbd9754b2976d092a57e187406e568f637 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Wed, 26 Feb 2020 22:26:24 +0000 Subject: [PATCH 3/3] FastMath test changes. --- python/tvm/relay/transform.py | 4 ++-- tests/python/relay/test_pass_fast_math.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/python/tvm/relay/transform.py b/python/tvm/relay/transform.py index 71c011d9b4cc..f773835d5c29 100644 --- a/python/tvm/relay/transform.py +++ b/python/tvm/relay/transform.py @@ -176,7 +176,7 @@ def SimplifyInference(): Returns ------- ret: tvm.relay.Pass - The registered to perform operator simplification. + The registered pass to perform operator simplification. """ return _transform.SimplifyInference() @@ -187,7 +187,7 @@ def FastMath(): Returns ------- ret: tvm.relay.Pass - The registered to perform operator simplification. + The registered pass to perform fast math operations. """ return _transform.FastMath() diff --git a/tests/python/relay/test_pass_fast_math.py b/tests/python/relay/test_pass_fast_math.py index 5053b8072c94..e75316f1e04b 100644 --- a/tests/python/relay/test_pass_fast_math.py +++ b/tests/python/relay/test_pass_fast_math.py @@ -28,8 +28,8 @@ def test_exp(): fast_mod = FastMath()(mod) assert "fast_exp" in fast_mod.astext() - # Check that opt level 4 triggers the transformation. - with relay.build_config(opt_level=4): + # Check that FastMath option works for relay.build. + with relay.build_config(opt_level=3, required_pass=['FastMath']): fast_mod = relay.optimize(mod, target='llvm', params=None) assert "fast_exp" in fast_mod[0].astext() @@ -42,7 +42,7 @@ def test_tanh(): fast_mod = FastMath()(mod) assert "fast_tanh" in fast_mod.astext() - # Check that opt level 4 triggers the transformation. + # Check that FastMath option works for relay.build. with relay.build_config(opt_level=3, required_pass=['FastMath']): fast_mod = relay.optimize(mod, target='llvm', params=None) assert "fast_tanh" in fast_mod[0].astext()