diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index 8d886aa09ea2..28628006105b 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -163,6 +163,13 @@ TVM_DLL Pass PartialEval(); */ TVM_DLL Pass SimplifyInference(); +/*! + * \brief Replaces non linear activation functions with their fast but approximate counterparts. + * + * \return The Pass. + */ +TVM_DLL Pass FastMath(); + /*! * \brief Infer the type of an expression. * diff --git a/python/tvm/relay/transform.py b/python/tvm/relay/transform.py index 45535afc486c..f773835d5c29 100644 --- a/python/tvm/relay/transform.py +++ b/python/tvm/relay/transform.py @@ -57,7 +57,8 @@ def build_config(opt_level=2, "CanonicalizeCast": 3, "EliminateCommonSubexpr": 3, "CombineParallelConv2D": 4, - "CombineParallelDense": 4 + "CombineParallelDense": 4, + "FastMath": 4 } fallback_device : int, str, or tvmContext, optional @@ -175,11 +176,22 @@ def SimplifyInference(): Returns ------- ret: tvm.relay.Pass - The registered to perform operator simplification. + The registered pass to perform operator simplification. """ return _transform.SimplifyInference() +def FastMath(): + """ Converts the expensive non linear functions to their fast but approximate counterparts. + + Returns + ------- + ret: tvm.relay.Pass + The registered pass to perform fast math operations. + """ + return _transform.FastMath() + + def CanonicalizeOps(): """Canonicalize special operators to basic operators. This can simplify followed analysis, e.g. expanding bias_add to diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index ff64d4a3acbb..0c0a8b8cbfa8 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -305,6 +305,9 @@ class RelayBuildModule : public runtime::ModuleNode { if (targets.size() == 1) { pass_seqs.push_back(transform::AlterOpLayout()); } + + // Fast math optimizations. + pass_seqs.push_back(transform::FastMath()); pass_seqs.push_back(transform::FoldConstant()); // Create a sequential pass and perform optimizations. diff --git a/src/relay/op/tensor/unary.cc b/src/relay/op/tensor/unary.cc index 2c7345865095..1169fa801398 100644 --- a/src/relay/op/tensor/unary.cc +++ b/src/relay/op/tensor/unary.cc @@ -95,6 +95,17 @@ RELAY_REGISTER_UNARY_OP("exp") .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::exp)); +RELAY_REGISTER_UNARY_OP("fast_exp") +.describe(R"code(Returns the fast_exp input array, computed element-wise. + +.. math:: + \fast_exp(x) + +)code" TVM_ADD_FILELINE) +.set_support_level(1) +.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::fast_exp)); + + RELAY_REGISTER_UNARY_OP("erf") .describe(R"code(Returns the error function value for input array, computed element-wise. @@ -250,6 +261,17 @@ RELAY_REGISTER_UNARY_OP("tanh") .set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::tanh)); +RELAY_REGISTER_UNARY_OP("fast_tanh") +.describe(R"code(Returns the fast_tanh of input array, computed element-wise. + +.. math:: + Y = sinh(X) / cosh(X) + +)code" TVM_ADD_FILELINE) +.set_support_level(1) +.set_attr("FTVMCompute", RELAY_UNARY_COMPUTE(topi::fast_tanh)); + + RELAY_REGISTER_UNARY_OP("negative") .describe(R"code(Returns the numeric negative of input array, computed element-wise. diff --git a/src/relay/pass/fast_math.cc b/src/relay/pass/fast_math.cc new file mode 100644 index 000000000000..898f760fdb50 --- /dev/null +++ b/src/relay/pass/fast_math.cc @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file fast_math.cc + * \brief Replaces non linear activation functions with their fast but approximate counterparts. + */ +#include +#include +#include +#include +#include +#include "pattern_util.h" + +namespace tvm { +namespace relay { + +class FastMathMutator : public ExprMutator { + public: + FastMathMutator() + : exp_op_(Op::Get("exp")), + tanh_op_(Op::Get("tanh")) {} + + Expr VisitExpr_(const CallNode* n) { + auto new_n = ExprMutator::VisitExpr_(n); + if (n->op == exp_op_) { + return FastExp(new_n.as()->args[0]); + } else if (n->op == tanh_op_) { + return FastTanh(new_n.as()->args[0]); + } + return new_n; + } + + private: + // Cache the following ops. They will be used in the passes repeatedly for + // operator equivalence checking so that the registry lookup overhead can be + // reduced. + const Op& exp_op_; + const Op& tanh_op_; +}; + +Expr FastMath(const Expr& e) { + return FastMathMutator().Mutate(e); +} + +namespace transform { + +Pass FastMath() { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { + return Downcast(FastMath(f)); + }; + return CreateFunctionPass(pass_func, 4, "FastMath", + {tir::StringImmNode::make("InferType")}); +} + +TVM_REGISTER_GLOBAL("relay._transform.FastMath") +.set_body_typed(FastMath); + +} // namespace transform + +} // namespace relay +} // namespace tvm diff --git a/src/relay/pass/pattern_util.h b/src/relay/pass/pattern_util.h index f7d8f9c4665e..85750f5e2601 100644 --- a/src/relay/pass/pattern_util.h +++ b/src/relay/pass/pattern_util.h @@ -316,6 +316,16 @@ inline Expr Exp(Expr e) { return CallNode::make(op, {e}); } +inline Expr FastExp(Expr e) { + static const Op& op = Op::Get("fast_exp"); + return CallNode::make(op, {e}); +} + +inline Expr FastTanh(Expr e) { + static const Op& op = Op::Get("fast_tanh"); + return CallNode::make(op, {e}); +} + inline Expr Log(Expr e) { static const Op& op = Op::Get("log"); return CallNode::make(op, {e}); diff --git a/tests/python/relay/test_pass_fast_math.py b/tests/python/relay/test_pass_fast_math.py new file mode 100644 index 000000000000..e75316f1e04b --- /dev/null +++ b/tests/python/relay/test_pass_fast_math.py @@ -0,0 +1,52 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +from tvm.ir import IRModule +from tvm import relay +from tvm.relay.transform import FastMath + +def test_exp(): + x = relay.var("x", shape=(1, 16, 16, 16), dtype="float32") + y = relay.exp(x) + func = relay.Function([x], y) + mod = tvm.IRModule.from_expr(func) + + fast_mod = FastMath()(mod) + assert "fast_exp" in fast_mod.astext() + + # Check that FastMath option works for relay.build. + with relay.build_config(opt_level=3, required_pass=['FastMath']): + fast_mod = relay.optimize(mod, target='llvm', params=None) + assert "fast_exp" in fast_mod[0].astext() + +def test_tanh(): + x = relay.var("x", shape=(1, 16, 16, 16), dtype="float32") + y = relay.tanh(x) + func = relay.Function([x], y) + mod = tvm.IRModule.from_expr(func) + + fast_mod = FastMath()(mod) + assert "fast_tanh" in fast_mod.astext() + + # Check that FastMath option works for relay.build. + with relay.build_config(opt_level=3, required_pass=['FastMath']): + fast_mod = relay.optimize(mod, target='llvm', params=None) + assert "fast_tanh" in fast_mod[0].astext() + +if __name__ == "__main__": + test_exp() + test_tanh() diff --git a/topi/include/topi/elemwise.h b/topi/include/topi/elemwise.h index e35e3e424d6e..3c0822f2b00e 100644 --- a/topi/include/topi/elemwise.h +++ b/topi/include/topi/elemwise.h @@ -58,6 +58,7 @@ TOPI_DECLARE_UNARY_OP(cos); TOPI_DECLARE_UNARY_OP(sin); TOPI_DECLARE_UNARY_OP(atan); TOPI_DECLARE_UNARY_OP(isnan); +TOPI_DECLARE_UNARY_OP(tanh); /* * \brief Fast_tanh_float implementation from Eigen @@ -113,9 +114,9 @@ inline Tensor fast_tanh_float(const Tensor& in, * * \return A Tensor whose op member is tanh */ -inline Tensor tanh(const Tensor& x, - std::string name = "T_tanh", - std::string tag = kElementWise) { +inline Tensor fast_tanh(const Tensor& x, + std::string name = "T_fast_tanh", + std::string tag = kElementWise) { if (x->dtype == DataType::Float(32)) { // invoke fast_tanh_float implementation return fast_tanh_float(x, name, tag); diff --git a/topi/python/topi/math.py b/topi/python/topi/math.py index 5b6b9ab8da75..4a63c4535289 100644 --- a/topi/python/topi/math.py +++ b/topi/python/topi/math.py @@ -467,3 +467,19 @@ def fast_exp(x): The result. """ return cpp.fast_exp(x, x.dtype, tag.ELEMWISE) + + +def fast_tanh(x): + """Take tanhonential of input x using fast_tanh implementation + + Parameters + ---------- + x : tvm.Tensor + Input argument. + + Returns + ------- + y : tvm.Tensor + The result. + """ + return cpp.fast_tanh(x, x.dtype, tag.ELEMWISE) diff --git a/topi/src/topi.cc b/topi/src/topi.cc index 79e223c30975..75517b818f45 100644 --- a/topi/src/topi.cc +++ b/topi/src/topi.cc @@ -188,7 +188,10 @@ TVM_REGISTER_GLOBAL("topi.tanh") .set_body([](TVMArgs args, TVMRetValue *rv) { *rv = tanh(args[0]); }); - +TVM_REGISTER_GLOBAL("topi.fast_tanh") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = fast_tanh(args[0]); + }); TVM_REGISTER_GLOBAL("topi.atan") .set_body([](TVMArgs args, TVMRetValue *rv) { *rv = atan(args[0]);