From c59a78e5d259c00837fba39aaced6ceed98311b2 Mon Sep 17 00:00:00 2001 From: Haichen Shen Date: Mon, 18 Feb 2019 22:38:46 -0800 Subject: [PATCH] [TVM][LANG] Add eager simplification for operations with FloatImm (#2615) * Add eager simplication for FloatImm * fix * fix lint * Fix gcc warning * fix * Add test case --- include/tvm/ir_operator.h | 32 +++- python/tvm/generic.py | 2 +- python/tvm/intrin.py | 8 +- src/api/api_ir.cc | 25 ++++ src/lang/ir_operator.cc | 150 +++++++++++++++---- tests/python/unittest/test_arith_simplify.py | 17 +++ 6 files changed, 199 insertions(+), 35 deletions(-) diff --git a/include/tvm/ir_operator.h b/include/tvm/ir_operator.h index af5b23ed6552..c2cdc5e7a923 100644 --- a/include/tvm/ir_operator.h +++ b/include/tvm/ir_operator.h @@ -430,6 +430,34 @@ TVM_DLL Expr min(Expr source, Array axis); */ TVM_DLL Expr prod(Expr source, Array axis); +/*! + * \brief Calculate floor(x) + * \param x The input expression. + * \return The result expression. + */ +TVM_DLL Expr floor(Expr x); + +/*! + * \brief Calculate ceil(x) + * \param x The input expression. + * \return The result expression. + */ +TVM_DLL Expr ceil(Expr x); + +/*! + * \brief Calculate round(x) + * \param x The input expression. + * \return The result expression. + */ +TVM_DLL Expr round(Expr x); + +/*! + * \brief Calculate trunc(x) + * \param x The input expression. + * \return The result expression. + */ +TVM_DLL Expr trunc(Expr x); + // Intrinsic operators #define TVM_DECLARE_INTRIN_UNARY(OpName) \ inline Expr OpName(Expr x) { \ @@ -441,10 +469,6 @@ TVM_DECLARE_INTRIN_UNARY(tanh); TVM_DECLARE_INTRIN_UNARY(sigmoid); TVM_DECLARE_INTRIN_UNARY(sqrt); TVM_DECLARE_INTRIN_UNARY(log); -TVM_DECLARE_INTRIN_UNARY(floor); -TVM_DECLARE_INTRIN_UNARY(ceil); -TVM_DECLARE_INTRIN_UNARY(round); -TVM_DECLARE_INTRIN_UNARY(trunc); TVM_DECLARE_INTRIN_UNARY(popcount); diff --git a/python/tvm/generic.py b/python/tvm/generic.py index ab1a80d3f612..fb96ff0131ba 100644 --- a/python/tvm/generic.py +++ b/python/tvm/generic.py @@ -94,4 +94,4 @@ def cast(src, dtype): op : tvm.Expr The result Expr of divide operaton. """ - return _make.static_cast(dtype, src) + return _make._cast(dtype, src) diff --git a/python/tvm/intrin.py b/python/tvm/intrin.py index 320f838cd975..bb15c314ff23 100644 --- a/python/tvm/intrin.py +++ b/python/tvm/intrin.py @@ -272,7 +272,7 @@ def floor(x): y : Expr The result. """ - return call_pure_intrin(x.dtype, "floor", x) + return _make.floor(x) def ceil(x): @@ -288,7 +288,7 @@ def ceil(x): y : Expr The result. """ - return call_pure_intrin(x.dtype, "ceil", x) + return _make.ceil(x) def trunc(x): @@ -307,7 +307,7 @@ def trunc(x): y : Expr The result. """ - return call_pure_intrin(x.dtype, "trunc", x) + return _make.trunc(x) def abs(x): @@ -339,7 +339,7 @@ def round(x): y : Expr The result. """ - return call_pure_intrin(x.dtype, "round", x) + return _make.round(x) def power(x, y): diff --git a/src/api/api_ir.cc b/src/api/api_ir.cc index 1040f6ce6f66..fa2d52e9fe85 100644 --- a/src/api/api_ir.cc +++ b/src/api/api_ir.cc @@ -22,6 +22,31 @@ TVM_REGISTER_API("make.abs") *ret = tvm::abs(args[0]); }); +TVM_REGISTER_API("make.floor") +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = tvm::floor(args[0]); + }); + +TVM_REGISTER_API("make.ceil") +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = tvm::ceil(args[0]); + }); + +TVM_REGISTER_API("make.round") +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = tvm::round(args[0]); + }); + +TVM_REGISTER_API("make.trunc") +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = tvm::trunc(args[0]); + }); + +TVM_REGISTER_API("make._cast") +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = tvm::cast(args[0], args[1]); + }); + TVM_REGISTER_API("make._range_by_min_extent") .set_body([](TVMArgs args, TVMRetValue *ret) { *ret = Range::make_by_min_extent(args[0], args[1]); diff --git a/src/lang/ir_operator.cc b/src/lang/ir_operator.cc index 27053f43d81f..beceb094c620 100644 --- a/src/lang/ir_operator.cc +++ b/src/lang/ir_operator.cc @@ -5,6 +5,7 @@ #include #include #include +#include namespace tvm { @@ -49,17 +50,17 @@ void BinaryOpMatchTypes(Expr& lhs, Expr& rhs) { // NOLINT(*) // and also help user to find potential type conversion problems. if (!lhs.type().is_float() && rhs.type().is_float()) { // int->float - lhs = ir::Cast::make(rhs.type(), lhs); + lhs = cast(rhs.type(), lhs); } else if (lhs.type().is_float() && !rhs.type().is_float()) { // int->float - rhs = ir::Cast::make(lhs.type(), rhs); + rhs = cast(lhs.type(), rhs); } else if ((lhs.type().is_int() && rhs.type().is_int()) || (lhs.type().is_uint() && rhs.type().is_uint())) { // promote int to higher bits if (lhs.type().bits() < rhs.type().bits()) { - lhs = ir::Cast::make(rhs.type(), lhs); + lhs = cast(rhs.type(), lhs); } else { - rhs = ir::Cast::make(lhs.type(), rhs); + rhs = cast(lhs.type(), rhs); } } else if ((lhs.type().is_int() && rhs.type().is_uint()) || (lhs.type().is_uint() && rhs.type().is_int())) { @@ -98,11 +99,14 @@ bool is_const_power_of_two_integer(const Expr& x, int* shift) { Expr cast(const Type& t, Expr value) { using ir::IntImm; + using ir::FloatImm; if (value.type() == t) return value; // const fold IntImm as they are used in index computations if (t.lanes() == 1) { if (const IntImm* op = value.as()) { return make_const(t, op->value); + } else if (const FloatImm* op = value.as()) { + return make_const(t, op->value); } return ir::Cast::make(t, value); } else { @@ -112,6 +116,8 @@ Expr cast(const Type& t, Expr value) { if (value.type() != vtype) { if (const IntImm* op = value.as()) { value = make_const(vtype, op->value); + } else if (const FloatImm* op = value.as()) { + value = make_const(vtype, op->value); } else { value = ir::Cast::make(vtype, value); } @@ -129,7 +135,7 @@ Expr reinterpret(const Type& t, Expr value) { return ir::Call::make(t, ir::Call::reinterpret, { value }, ir::Call::PureIntrinsic); } -#define TVM_CONST_PROPAGATION(BODY) \ +#define TVM_INDEX_CONST_PROPAGATION(BODY) \ using ir::IntImm; \ using ir::UIntImm; \ const IntImm* pa = a.as(); \ @@ -141,37 +147,60 @@ Expr reinterpret(const Type& t, Expr value) { } \ BinaryOpMatchTypes(a, b); +#define TVM_ARITH_CONST_PROPAGATION(BODY) \ + using ir::IntImm; \ + using ir::UIntImm; \ + using ir::FloatImm; \ + BinaryOpMatchTypes(a, b); \ + const IntImm* pa = a.as(); \ + const IntImm* pb = b.as(); \ + const FloatImm* fa = a.as(); \ + const FloatImm* fb = b.as(); \ + BODY; + Expr operator+(Expr a, Expr b) { - TVM_CONST_PROPAGATION({ + TVM_ARITH_CONST_PROPAGATION({ + const Type& ta = a.type(); + const Type& tb = b.type(); Type rtype = ta.bits() >= tb.bits() ? ta : tb; if (pa && pb) return IntImm::make(rtype, pa->value + pb->value); if (pa && pa->value == 0) return SimpleCast(rtype, b); if (pb && pb->value == 0) return SimpleCast(rtype, a); + if (fa && fb) return FloatImm::make(rtype, fa->value + fb->value); + if (fa && fa->value == 0) return SimpleCast(rtype, b); + if (fb && fb->value == 0) return SimpleCast(rtype, a); }); return ir::Add::make(a, b); } Expr operator-(Expr a) { using ir::IntImm; + using ir::FloatImm; const IntImm* pa = a.as(); - if (pa) { - return ir::IntImm::make(a.type(), -pa->value); - } + const FloatImm* fa = a.as(); + if (pa) return ir::IntImm::make(a.type(), -pa->value); + if (fa) return ir::FloatImm::make(a.type(), -fa->value); return make_zero(a.type()) - a; } Expr operator-(Expr a, Expr b) { - TVM_CONST_PROPAGATION({ + TVM_ARITH_CONST_PROPAGATION({ + const Type& ta = a.type(); + const Type& tb = b.type(); Type rtype = ta.bits() >= tb.bits() ? ta : tb; if (pa && pb) return IntImm::make(rtype, pa->value - pb->value); if (pb && pb->value == 0) return SimpleCast(rtype, a); + if (fa && fb) return FloatImm::make(rtype, fa->value - fb->value); + if (fb && fb->value == 0) return SimpleCast(rtype, a); }); return ir::Sub::make(a, b); } Expr operator*(Expr a, Expr b) { - TVM_CONST_PROPAGATION({ + TVM_ARITH_CONST_PROPAGATION({ + const Type& ta = a.type(); + const Type& tb = b.type(); Type rtype = ta.bits() >= tb.bits() ? ta : tb; if (pa && pb) return IntImm::make(rtype, pa->value * pb->value); if (pa) { @@ -182,12 +211,23 @@ Expr operator*(Expr a, Expr b) { if (pb->value == 1) return SimpleCast(rtype, a); if (pb->value == 0) return SimpleCast(rtype, b); } + if (fa && fb) return FloatImm::make(rtype, fa->value * fb->value); + if (fa) { + if (fa->value == 1) return SimpleCast(rtype, b); + if (fa->value == 0) return SimpleCast(rtype, a); + } + if (fb) { + if (fb->value == 1) return SimpleCast(rtype, a); + if (fb->value == 0) return SimpleCast(rtype, b); + } }); return ir::Mul::make(a, b); } Expr operator/(Expr a, Expr b) { - TVM_CONST_PROPAGATION({ + TVM_ARITH_CONST_PROPAGATION({ + const Type& ta = a.type(); + const Type& tb = b.type(); Type rtype = ta.bits() >= tb.bits() ? ta : tb; // due to division and mod can have different modes // only constant fold positive number where rule is fixed. @@ -201,12 +241,22 @@ Expr operator/(Expr a, Expr b) { if (pb->value == 1) return SimpleCast(rtype, a); CHECK_NE(pb->value, 0) << "Divide by zero"; } + if (fa && fb && fb->value != 0) { + return FloatImm::make(rtype, fa->value / fb->value); + } + if (fa && fa->value == 0) { + return SimpleCast(rtype, a); + } + if (fb) { + if (fb->value == 1) return SimpleCast(rtype, a); + CHECK_NE(fb->value, 0) << "Divide by zero"; + } }); return ir::Div::make(a, b); } Expr operator%(Expr a, Expr b) { - TVM_CONST_PROPAGATION({ + TVM_INDEX_CONST_PROPAGATION({ Type rtype = ta.bits() >= tb.bits() ? ta : tb; // due to division and mod can have different modes // only constant fold positive number where rule is fixed. @@ -225,17 +275,23 @@ Expr operator%(Expr a, Expr b) { } Expr min(Expr a, Expr b) { - TVM_CONST_PROPAGATION({ + TVM_ARITH_CONST_PROPAGATION({ + const Type& ta = a.type(); + const Type& tb = b.type(); Type rtype = ta.bits() >= tb.bits() ? ta : tb; if (pa && pb) return IntImm::make(rtype, std::min(pa->value, pb->value)); + if (fa && fb) return FloatImm::make(rtype, std::min(fa->value, fb->value)); }); return ir::Min::make(a, b); } Expr max(Expr a, Expr b) { - TVM_CONST_PROPAGATION({ + TVM_ARITH_CONST_PROPAGATION({ + const Type& ta = a.type(); + const Type& tb = b.type(); Type rtype = ta.bits() >= tb.bits() ? ta : tb; if (pa && pb) return IntImm::make(rtype, std::max(pa->value, pb->value)); + if (fa && fb) return FloatImm::make(rtype, std::max(fa->value, fb->value)); }); return ir::Max::make(a, b); } @@ -272,43 +328,49 @@ Expr likely(Expr cond) { } Expr operator>(Expr a, Expr b) { - TVM_CONST_PROPAGATION({ + TVM_ARITH_CONST_PROPAGATION({ if (pa && pb) return UIntImm::make(UInt(1), pa->value > pb->value); + if (fa && fb) return UIntImm::make(UInt(1), fa->value > fb->value); }); return ir::GT::make(a, b); } Expr operator>=(Expr a, Expr b) { - TVM_CONST_PROPAGATION({ + TVM_ARITH_CONST_PROPAGATION({ if (pa && pb) return UIntImm::make(UInt(1), pa->value >= pb->value); + if (fa && fb) return UIntImm::make(UInt(1), fa->value >= fb->value); }); return ir::GE::make(a, b); } Expr operator<(Expr a, Expr b) { - TVM_CONST_PROPAGATION({ + TVM_ARITH_CONST_PROPAGATION({ if (pa && pb) return UIntImm::make(UInt(1), pa->value < pb->value); + if (fa && fb) return UIntImm::make(UInt(1), fa->value < fb->value); }); return ir::LT::make(a, b); } Expr operator<=(Expr a, Expr b) { - TVM_CONST_PROPAGATION({ + TVM_ARITH_CONST_PROPAGATION({ if (pa && pb) return UIntImm::make(UInt(1), pa->value <= pb->value); + if (fa && fb) return UIntImm::make(UInt(1), fa->value <= fb->value); }); return ir::LE::make(a, b); } Expr operator==(Expr a, Expr b) { - TVM_CONST_PROPAGATION({ + TVM_ARITH_CONST_PROPAGATION({ if (pa && pb) return UIntImm::make(UInt(1), pa->value == pb->value); + if (fa && fb) return UIntImm::make(UInt(1), fa->value == fb->value); }); return ir::EQ::make(a, b); } Expr operator!=(Expr a, Expr b) { - TVM_CONST_PROPAGATION({ + TVM_ARITH_CONST_PROPAGATION({ if (pa && pb) return UIntImm::make(UInt(1), pa->value != pb->value); + if (fa && fb) return UIntImm::make(UInt(1), fa->value != fb->value); }); return ir::NE::make(a, b); } @@ -349,7 +411,7 @@ Expr operator!(Expr a) { } Expr operator>>(Expr a, Expr b) { - TVM_CONST_PROPAGATION({ + TVM_INDEX_CONST_PROPAGATION({ Type rtype = ta.bits() >= tb.bits() ? ta : tb; if (pa && pb) return IntImm::make(rtype, (pa->value >> pb->value)); if (pb) { @@ -360,7 +422,7 @@ Expr operator>>(Expr a, Expr b) { } Expr operator<<(Expr a, Expr b) { - TVM_CONST_PROPAGATION({ + TVM_INDEX_CONST_PROPAGATION({ Type rtype = ta.bits() >= tb.bits() ? ta : tb; if (pa && pb) return IntImm::make(rtype, (pa->value << pb->value)); if (pb) { @@ -371,7 +433,7 @@ Expr operator<<(Expr a, Expr b) { } Expr operator&(Expr a, Expr b) { - TVM_CONST_PROPAGATION({ + TVM_INDEX_CONST_PROPAGATION({ Type rtype = ta.bits() >= tb.bits() ? ta : tb; if (pa && pb) return IntImm::make(rtype, (pa->value & pb->value)); }); @@ -379,7 +441,7 @@ Expr operator&(Expr a, Expr b) { } Expr operator|(Expr a, Expr b) { - TVM_CONST_PROPAGATION({ + TVM_INDEX_CONST_PROPAGATION({ Type rtype = ta.bits() >= tb.bits() ? ta : tb; if (pa && pb) return IntImm::make(rtype, (pa->value | pb->value)); }); @@ -387,7 +449,7 @@ Expr operator|(Expr a, Expr b) { } Expr operator^(Expr a, Expr b) { - TVM_CONST_PROPAGATION({ + TVM_INDEX_CONST_PROPAGATION({ Type rtype = ta.bits() >= tb.bits() ? ta : tb; if (pa && pb) return IntImm::make(rtype, (pa->value ^ pb->value)); }); @@ -414,6 +476,11 @@ Expr abs(Expr x) { } return ir::Select::make(x >= make_zero(x.type()), x, -x); } else if (x.type().is_float()) { + using ir::FloatImm; + const FloatImm* fx = x.as(); + if (fx) { + return ir::FloatImm::make(x.type(), std::fabs(fx->value)); + } return ir::Call::make(x.type(), "fabs", {x}, ir::Call::PureIntrinsic); } else if (x.type().is_uint()) { return x; @@ -466,4 +533,35 @@ Expr fmod(Expr x, Expr y) { return ir::Call::make(x.type(), "fmod", { x, y }, ir::Call::PureIntrinsic); } +Expr floor(Expr x) { + using ir::FloatImm; + const FloatImm* fx = x.as(); + if (fx) return FloatImm::make(x.type(), std::floor(fx->value)); + return ir::Call::make(x.type(), "floor", {x}, ir::Call::PureIntrinsic); +} + +Expr ceil(Expr x) { + using ir::FloatImm; + const FloatImm* fx = x.as(); + if (fx) return FloatImm::make(x.type(), std::ceil(fx->value)); + return ir::Call::make(x.type(), "ceil", {x}, ir::Call::PureIntrinsic); +} + +Expr round(Expr x) { + using ir::FloatImm; + const FloatImm* fx = x.as(); + if (fx) return FloatImm::make(x.type(), std::nearbyint(fx->value)); + return ir::Call::make(x.type(), "round", {x}, ir::Call::PureIntrinsic); +} + +Expr trunc(Expr x) { + using ir::FloatImm; + const FloatImm* fx = x.as(); + if (fx) { + return FloatImm::make(x.type(), (fx->value < 0 ? std::ceil(fx->value) : + std::floor(fx->value))); + } + return ir::Call::make(x.type(), "trunc", {x}, ir::Call::PureIntrinsic); +} + } // namespace tvm diff --git a/tests/python/unittest/test_arith_simplify.py b/tests/python/unittest/test_arith_simplify.py index 4bd482a1e5db..71818708fbf6 100644 --- a/tests/python/unittest/test_arith_simplify.py +++ b/tests/python/unittest/test_arith_simplify.py @@ -100,6 +100,22 @@ def test_modular(): assert tvm.ir_pass.CanonicalSimplify(z1 - (ry + y)).value == 0 assert tvm.ir_pass.CanonicalSimplify(z2 - (rx + x)).value == 0 +def test_const_propagation(): + x1 = tvm.const(4, "int32") + x2 = x1 + 5 + assert isinstance(x2, tvm.expr.IntImm) and x2.value == 9 + x3 = x2 / 3 + assert isinstance(x3, tvm.expr.IntImm) and x3.value == 3 + x4 = x3 + 0.5 + assert isinstance(x4, tvm.expr.FloatImm) and x4.value == 3.5 + x5 = tvm.ceil(x4) + assert isinstance(x5, tvm.expr.FloatImm) and x5.value == 4 + x6 = x5.astype('int') + assert isinstance(x6, tvm.expr.IntImm) and x6.value == 4 + y = (tvm.round((tvm.const(6.5, 'float32') - 1) / 1.5) + 2).astype('int') + assert isinstance(y, tvm.expr.IntImm) and y.value == 6 + + if __name__ == "__main__": test_simplify_div() test_simplify_mod() @@ -107,3 +123,4 @@ def test_modular(): test_simplify() test_mul() test_simplify_minmax() + test_const_propagation()