Skip to content

Commit

Permalink
[Relay][OP] Add fast_erf implementation (apache#5241)
Browse files Browse the repository at this point in the history
* add fast erf

* doc

* lint

* fix

* fix indent
  • Loading branch information
icemelon authored and Trevor Morris committed Apr 16, 2020
1 parent c99aaa7 commit 3b1d0b6
Show file tree
Hide file tree
Showing 10 changed files with 124 additions and 6 deletions.
2 changes: 1 addition & 1 deletion include/tvm/target/generic_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ class GenericFunc : public ObjectRef {
*
* \code
* // Example code on how to call generic function
* void CallGeneirc(GenericFunc f) {
* void CallGeneric(GenericFunc f) {
* // call like normal functions by pass in arguments
* // return value is automatically converted back
* int rvalue = f(1, 2.0);
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/relay/op/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
register_injective_schedule("ndarray_size")
register_broadcast_schedule("fast_exp")
register_broadcast_schedule("fast_tanh")
register_broadcast_schedule("fast_erf")


# zeros
Expand Down Expand Up @@ -222,3 +223,4 @@ def elemwise_shape_func(attrs, inputs, _):
register_shape_func("tan", False, elemwise_shape_func)
register_shape_func("fast_exp", False, elemwise_shape_func)
register_shape_func("fast_tanh", False, elemwise_shape_func)
register_shape_func("fast_erf", False, elemwise_shape_func)
11 changes: 11 additions & 0 deletions src/relay/op/tensor/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,17 @@ RELAY_REGISTER_UNARY_OP("erf")
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::erf));


RELAY_REGISTER_UNARY_OP("fast_erf")
.describe(R"code(Returns the error function value for input array, computed element-wise.
.. math::
\fast_erf(x)
)code" TVM_ADD_FILELINE)
.set_support_level(1)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_UNARY_COMPUTE(topi::fast_erf));


RELAY_REGISTER_UNARY_OP("sqrt")
.describe(R"code(Returns the sqrt input array, computed element-wise.
Expand Down
4 changes: 4 additions & 0 deletions src/relay/transforms/fast_math.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,14 @@ class FastMathMutator : public ExprRewriter {
public:
FastMathMutator()
: exp_op_(Op::Get("exp")),
erf_op_(Op::Get("erf")),
tanh_op_(Op::Get("tanh")) {}

Expr Rewrite_(const CallNode* pre, const Expr& post) override {
if (pre->op == exp_op_) {
return FastExp(post.as<CallNode>()->args[0]);
} else if (pre->op == erf_op_) {
return FastErf(post.as<CallNode>()->args[0]);
} else if (pre->op == tanh_op_) {
return FastTanh(post.as<CallNode>()->args[0]);
}
Expand All @@ -51,6 +54,7 @@ class FastMathMutator : public ExprRewriter {
// operator equivalence checking so that the registry lookup overhead can be
// reduced.
const Op& exp_op_;
const Op& erf_op_;
const Op& tanh_op_;
};

Expand Down
5 changes: 5 additions & 0 deletions src/relay/transforms/pattern_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,11 @@ inline Expr FastExp(Expr e) {
return Call(op, {e});
}

inline Expr FastErf(Expr e) {
static const Op& op = Op::Get("fast_erf");
return Call(op, {e});
}

inline Expr FastTanh(Expr e) {
static const Op& op = Op::Get("fast_tanh");
return Call(op, {e});
Expand Down
3 changes: 3 additions & 0 deletions tests/python/relay/test_op_fast_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
# specific language governing permissions and limitations
# under the License.
import numpy as np
import scipy
from scipy import special
import tvm
import tvm.relay as relay
import topi
Expand Down Expand Up @@ -52,6 +54,7 @@ def test_apply(relay_op, name, f_numpy, low, high, step, dtype="float32"):
rtol=1e-5, atol=1e-5)

test_apply(relay.exp, "fast_exp", np.exp, low=-88, high=88, step=0.01)
test_apply(relay.erf, "fast_erf", scipy.special.erf, low=-10, high=10, step=0.01)
test_apply(relay.tanh, "fast_tanh", np.tanh, low=-10, high=10, step=0.01)


Expand Down
73 changes: 72 additions & 1 deletion topi/include/topi/elemwise.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <tvm/tir/expr.h>
#include <tvm/tir/ir_pass.h>
#include <topi/tags.h>
#include <algorithm>
#include <string>
#include "broadcast.h"

Expand Down Expand Up @@ -63,7 +64,7 @@ TOPI_DECLARE_UNARY_OP(tanh);
TOPI_DECLARE_UNARY_OP(isfinite);
TOPI_DECLARE_UNARY_OP(isinf);

/*
/*!
* \brief Fast_tanh_float implementation from Eigen
* https://github.com/eigenteam/eigen-git-mirror/blob/master/Eigen/src/Core/MathFunctionsImpl.h#L26
*/
Expand Down Expand Up @@ -461,5 +462,75 @@ inline Tensor fast_exp(const Tensor& x,
}
}

/*!
* \brief Fast_tanh_float implementation from Eigen
* https://github.com/eigenteam/eigen-git-mirror/blob/master/unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsImpl.h#L290
*/
inline Tensor fast_erf_float32(const Tensor& data,
std::string name,
std::string tag) {
auto plus_4 = make_const(DataType::Float(32), 4.f);
auto minus_4 = make_const(DataType::Float(32), -4.f);

// The monomial coefficients of the numerator polynomial (odd).
auto alpha_1 = make_const(DataType::Float(32), -1.60960333262415e-02f);
auto alpha_3 = make_const(DataType::Float(32), -2.95459980854025e-03f);
auto alpha_5 = make_const(DataType::Float(32), -7.34990630326855e-04f);
auto alpha_7 = make_const(DataType::Float(32), -5.69250639462346e-05f);
auto alpha_9 = make_const(DataType::Float(32), -2.10102402082508e-06f);
auto alpha_11 = make_const(DataType::Float(32), 2.77068142495902e-08f);
auto alpha_13 = make_const(DataType::Float(32), -2.72614225801306e-10f);

// The monomial coefficients of the denominator polynomial (even).
auto beta_0 = make_const(DataType::Float(32), -1.42647390514189e-02f);
auto beta_2 = make_const(DataType::Float(32), -7.37332916720468e-03f);
auto beta_4 = make_const(DataType::Float(32), -1.68282697438203e-03f);
auto beta_6 = make_const(DataType::Float(32), -2.13374055278905e-04f);
auto beta_8 = make_const(DataType::Float(32), -1.45660718464996e-05f);

return compute(data->shape, [&](const Array<Var> &i) {
// clamp x
auto x = tvm::max(tvm::min(data(i), plus_4), minus_4);
auto x2 = x * x;

// Evaluate the numerator polynomial p.
auto p = x2 * alpha_13 + alpha_11;
p = x2 * p + alpha_9;
p = x2 * p + alpha_7;
p = x2 * p + alpha_5;
p = x2 * p + alpha_3;
p = x2 * p + alpha_1;
p = x * p;

// Evaluate the denominator polynomial p.
auto q = x2 * beta_8 + beta_6;
q = x2 * q + beta_4;
q = x2 * q + beta_2;
q = x2 * q + beta_0;

return p / q;
}, name, tag);
}

/*!
* \brief Fast erf implementation
*
* \param x The input tensor
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor whose op member is erf operation
*/
inline Tensor fast_erf(const Tensor& x,
std::string name = "T_fast_erf",
std::string tag = kElementWise) {
if (x->dtype == DataType::Float(32)) {
auto ret = fast_erf_float32(x, name, tag);
return ret;
} else {
return topi::erf(x);
}
}

} // namespace topi
#endif // TOPI_ELEMWISE_H_
16 changes: 16 additions & 0 deletions topi/python/topi/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,3 +534,19 @@ def fast_tanh(x):
The result.
"""
return cpp.fast_tanh(x, x.dtype, tag.ELEMWISE)


def fast_erf(x):
"""Take gauss error function of input x using fast_erf implementation.
Parameters
----------
x : tvm.te.Tensor
Input argument.
Returns
-------
y : tvm.te.Tensor
The result.
"""
return cpp.fast_erf(x, x.dtype, tag.ELEMWISE)
5 changes: 5 additions & 0 deletions topi/src/elemwise.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ TVM_REGISTER_GLOBAL("topi.erf")
*rv = erf(args[0]);
});

TVM_REGISTER_GLOBAL("topi.fast_erf")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = fast_erf(args[0]);
});

TVM_REGISTER_GLOBAL("topi.tan")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = tan(args[0]);
Expand Down
9 changes: 5 additions & 4 deletions topi/tests/python/test_topi_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
import numpy as np
import scipy
from scipy import special
import tvm
from tvm import te
import topi
Expand Down Expand Up @@ -238,11 +239,11 @@ def check_device(device):


test_apply(topi.fast_exp, "fast_exp", np.exp,
low=-88, high=88,
step = 0.01)
low=-88, high=88, step=0.01)
test_apply(topi.fast_erf, "fast_erf", scipy.special.erf,
low=-10, high=10, step=0.01)
test_apply(topi.fast_tanh, "fast_tanh", np.tanh,
low=-10, high=10,
step = 0.01)
low=-10, high=10, step=0.01)

if __name__ == "__main__":
test_util()
Expand Down

0 comments on commit 3b1d0b6

Please sign in to comment.