Skip to content

Commit

Permalink
[TVM][LANG] Add eager simplification for operations with FloatImm (#2615
Browse files Browse the repository at this point in the history
)

* Add eager simplication for FloatImm

* fix

* fix lint

* Fix gcc warning

* fix

* Add test case
  • Loading branch information
icemelon authored Feb 19, 2019
1 parent 255c187 commit c59a78e
Show file tree
Hide file tree
Showing 6 changed files with 199 additions and 35 deletions.
32 changes: 28 additions & 4 deletions include/tvm/ir_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,34 @@ TVM_DLL Expr min(Expr source, Array<IterVar> axis);
*/
TVM_DLL Expr prod(Expr source, Array<IterVar> axis);

/*!
* \brief Calculate floor(x)
* \param x The input expression.
* \return The result expression.
*/
TVM_DLL Expr floor(Expr x);

/*!
* \brief Calculate ceil(x)
* \param x The input expression.
* \return The result expression.
*/
TVM_DLL Expr ceil(Expr x);

/*!
* \brief Calculate round(x)
* \param x The input expression.
* \return The result expression.
*/
TVM_DLL Expr round(Expr x);

/*!
* \brief Calculate trunc(x)
* \param x The input expression.
* \return The result expression.
*/
TVM_DLL Expr trunc(Expr x);

// Intrinsic operators
#define TVM_DECLARE_INTRIN_UNARY(OpName) \
inline Expr OpName(Expr x) { \
Expand All @@ -441,10 +469,6 @@ TVM_DECLARE_INTRIN_UNARY(tanh);
TVM_DECLARE_INTRIN_UNARY(sigmoid);
TVM_DECLARE_INTRIN_UNARY(sqrt);
TVM_DECLARE_INTRIN_UNARY(log);
TVM_DECLARE_INTRIN_UNARY(floor);
TVM_DECLARE_INTRIN_UNARY(ceil);
TVM_DECLARE_INTRIN_UNARY(round);
TVM_DECLARE_INTRIN_UNARY(trunc);
TVM_DECLARE_INTRIN_UNARY(popcount);


Expand Down
2 changes: 1 addition & 1 deletion python/tvm/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,4 +94,4 @@ def cast(src, dtype):
op : tvm.Expr
The result Expr of divide operaton.
"""
return _make.static_cast(dtype, src)
return _make._cast(dtype, src)
8 changes: 4 additions & 4 deletions python/tvm/intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def floor(x):
y : Expr
The result.
"""
return call_pure_intrin(x.dtype, "floor", x)
return _make.floor(x)


def ceil(x):
Expand All @@ -288,7 +288,7 @@ def ceil(x):
y : Expr
The result.
"""
return call_pure_intrin(x.dtype, "ceil", x)
return _make.ceil(x)


def trunc(x):
Expand All @@ -307,7 +307,7 @@ def trunc(x):
y : Expr
The result.
"""
return call_pure_intrin(x.dtype, "trunc", x)
return _make.trunc(x)


def abs(x):
Expand Down Expand Up @@ -339,7 +339,7 @@ def round(x):
y : Expr
The result.
"""
return call_pure_intrin(x.dtype, "round", x)
return _make.round(x)


def power(x, y):
Expand Down
25 changes: 25 additions & 0 deletions src/api/api_ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,31 @@ TVM_REGISTER_API("make.abs")
*ret = tvm::abs(args[0]);
});

TVM_REGISTER_API("make.floor")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = tvm::floor(args[0]);
});

TVM_REGISTER_API("make.ceil")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = tvm::ceil(args[0]);
});

TVM_REGISTER_API("make.round")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = tvm::round(args[0]);
});

TVM_REGISTER_API("make.trunc")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = tvm::trunc(args[0]);
});

TVM_REGISTER_API("make._cast")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = tvm::cast(args[0], args[1]);
});

TVM_REGISTER_API("make._range_by_min_extent")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = Range::make_by_min_extent(args[0], args[1]);
Expand Down
Loading

0 comments on commit c59a78e

Please sign in to comment.