Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay][FastMath] Relay pass to use fast exp/tanh #4873

Merged
merged 3 commits into from
Mar 1, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,13 @@ TVM_DLL Pass PartialEval();
*/
TVM_DLL Pass SimplifyInference();

/*!
* \brief Replaces non linear activation functions with their fast but approximate counterparts.
*
* \return The Pass.
*/
TVM_DLL Pass FastMath();

/*!
* \brief Infer the type of an expression.
*
Expand Down
16 changes: 14 additions & 2 deletions python/tvm/relay/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ def build_config(opt_level=2,
"CanonicalizeCast": 3,
"EliminateCommonSubexpr": 3,
"CombineParallelConv2D": 4,
"CombineParallelDense": 4
"CombineParallelDense": 4,
"FastMath": 4
}

fallback_device : int, str, or tvmContext, optional
Expand Down Expand Up @@ -175,11 +176,22 @@ def SimplifyInference():
Returns
-------
ret: tvm.relay.Pass
The registered to perform operator simplification.
The registered pass to perform operator simplification.
"""
return _transform.SimplifyInference()


def FastMath():
""" Converts the expensive non linear functions to their fast but approximate counterparts.

Returns
-------
ret: tvm.relay.Pass
The registered pass to perform fast math operations.
"""
return _transform.FastMath()


def CanonicalizeOps():
"""Canonicalize special operators to basic operators.
This can simplify followed analysis, e.g. expanding bias_add to
Expand Down
3 changes: 3 additions & 0 deletions src/relay/backend/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,9 @@ class RelayBuildModule : public runtime::ModuleNode {
if (targets.size() == 1) {
pass_seqs.push_back(transform::AlterOpLayout());
}

// Fast math optimizations.
pass_seqs.push_back(transform::FastMath());
pass_seqs.push_back(transform::FoldConstant());
anijain2305 marked this conversation as resolved.
Show resolved Hide resolved

// Create a sequential pass and perform optimizations.
Expand Down
22 changes: 22 additions & 0 deletions src/relay/op/tensor/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,17 @@ RELAY_REGISTER_UNARY_OP("exp")
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::exp));


RELAY_REGISTER_UNARY_OP("fast_exp")
.describe(R"code(Returns the fast_exp input array, computed element-wise.

.. math::
\fast_exp(x)

)code" TVM_ADD_FILELINE)
.set_support_level(1)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::fast_exp));


RELAY_REGISTER_UNARY_OP("erf")
.describe(R"code(Returns the error function value for input array, computed element-wise.

Expand Down Expand Up @@ -250,6 +261,17 @@ RELAY_REGISTER_UNARY_OP("tanh")
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::tanh));


RELAY_REGISTER_UNARY_OP("fast_tanh")
.describe(R"code(Returns the fast_tanh of input array, computed element-wise.

.. math::
Y = sinh(X) / cosh(X)

)code" TVM_ADD_FILELINE)
.set_support_level(1)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::fast_tanh));


RELAY_REGISTER_UNARY_OP("negative")
.describe(R"code(Returns the numeric negative of input array, computed element-wise.

Expand Down
79 changes: 79 additions & 0 deletions src/relay/pass/fast_math.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file fast_math.cc
* \brief Replaces non linear activation functions with their fast but approximate counterparts.
*/
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/transform.h>
#include <tvm/relay/op.h>
#include "pattern_util.h"

namespace tvm {
namespace relay {

class FastMathMutator : public ExprMutator {
public:
FastMathMutator()
: exp_op_(Op::Get("exp")),
tanh_op_(Op::Get("tanh")) {}

Expr VisitExpr_(const CallNode* n) {
auto new_n = ExprMutator::VisitExpr_(n);
if (n->op == exp_op_) {
return FastExp(new_n.as<CallNode>()->args[0]);
} else if (n->op == tanh_op_) {
return FastTanh(new_n.as<CallNode>()->args[0]);
}
return new_n;
}

private:
// Cache the following ops. They will be used in the passes repeatedly for
// operator equivalence checking so that the registry lookup overhead can be
// reduced.
const Op& exp_op_;
const Op& tanh_op_;
};

Expr FastMath(const Expr& e) {
return FastMathMutator().Mutate(e);
}

namespace transform {

Pass FastMath() {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(FastMath(f));
};
return CreateFunctionPass(pass_func, 4, "FastMath",
{tir::StringImmNode::make("InferType")});
}

TVM_REGISTER_GLOBAL("relay._transform.FastMath")
.set_body_typed(FastMath);

} // namespace transform

} // namespace relay
} // namespace tvm
10 changes: 10 additions & 0 deletions src/relay/pass/pattern_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,16 @@ inline Expr Exp(Expr e) {
return CallNode::make(op, {e});
}

inline Expr FastExp(Expr e) {
static const Op& op = Op::Get("fast_exp");
return CallNode::make(op, {e});
}

inline Expr FastTanh(Expr e) {
static const Op& op = Op::Get("fast_tanh");
return CallNode::make(op, {e});
}

inline Expr Log(Expr e) {
static const Op& op = Op::Get("log");
return CallNode::make(op, {e});
Expand Down
52 changes: 52 additions & 0 deletions tests/python/relay/test_pass_fast_math.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
from tvm.ir import IRModule
from tvm import relay
from tvm.relay.transform import FastMath

def test_exp():
x = relay.var("x", shape=(1, 16, 16, 16), dtype="float32")
y = relay.exp(x)
func = relay.Function([x], y)
mod = tvm.IRModule.from_expr(func)

fast_mod = FastMath()(mod)
assert "fast_exp" in fast_mod.astext()

# Check that FastMath option works for relay.build.
with relay.build_config(opt_level=3, required_pass=['FastMath']):
fast_mod = relay.optimize(mod, target='llvm', params=None)
assert "fast_exp" in fast_mod[0].astext()

def test_tanh():
x = relay.var("x", shape=(1, 16, 16, 16), dtype="float32")
y = relay.tanh(x)
func = relay.Function([x], y)
mod = tvm.IRModule.from_expr(func)

fast_mod = FastMath()(mod)
assert "fast_tanh" in fast_mod.astext()

# Check that FastMath option works for relay.build.
with relay.build_config(opt_level=3, required_pass=['FastMath']):
fast_mod = relay.optimize(mod, target='llvm', params=None)
assert "fast_tanh" in fast_mod[0].astext()

if __name__ == "__main__":
test_exp()
test_tanh()
7 changes: 4 additions & 3 deletions topi/include/topi/elemwise.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ TOPI_DECLARE_UNARY_OP(cos);
TOPI_DECLARE_UNARY_OP(sin);
TOPI_DECLARE_UNARY_OP(atan);
TOPI_DECLARE_UNARY_OP(isnan);
TOPI_DECLARE_UNARY_OP(tanh);

/*
* \brief Fast_tanh_float implementation from Eigen
Expand Down Expand Up @@ -113,9 +114,9 @@ inline Tensor fast_tanh_float(const Tensor& in,
*
* \return A Tensor whose op member is tanh
*/
inline Tensor tanh(const Tensor& x,
std::string name = "T_tanh",
std::string tag = kElementWise) {
inline Tensor fast_tanh(const Tensor& x,
std::string name = "T_fast_tanh",
std::string tag = kElementWise) {
if (x->dtype == DataType::Float(32)) {
// invoke fast_tanh_float implementation
return fast_tanh_float(x, name, tag);
Expand Down
16 changes: 16 additions & 0 deletions topi/python/topi/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,3 +467,19 @@ def fast_exp(x):
The result.
"""
return cpp.fast_exp(x, x.dtype, tag.ELEMWISE)


def fast_tanh(x):
"""Take tanhonential of input x using fast_tanh implementation

Parameters
----------
x : tvm.Tensor
Input argument.

Returns
-------
y : tvm.Tensor
The result.
"""
return cpp.fast_tanh(x, x.dtype, tag.ELEMWISE)
5 changes: 4 additions & 1 deletion topi/src/topi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,10 @@ TVM_REGISTER_GLOBAL("topi.tanh")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = tanh(args[0]);
});

TVM_REGISTER_GLOBAL("topi.fast_tanh")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = fast_tanh(args[0]);
});
TVM_REGISTER_GLOBAL("topi.atan")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = atan(args[0]);
Expand Down