Skip to content

Commit

Permalink
[ARITH] Explicitly state truncdiv/mod in pattern matching. (apache#3986)
Browse files Browse the repository at this point in the history
* [ARITH] Explicitly state truncdiv/mod in pattern matching.

* Fix the dependent cpp test
  • Loading branch information
tqchen authored and wweic committed Oct 1, 2019
1 parent c753860 commit aead397
Show file tree
Hide file tree
Showing 9 changed files with 350 additions and 243 deletions.
34 changes: 34 additions & 0 deletions include/tvm/expr_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,20 @@ TVM_DLL Expr operator||(Expr a, Expr b);
* \note This operator does eager constant folding.
*/
TVM_DLL Expr operator!(Expr a);
/*!
* \brief compute division in C semantics.
*
* a / b as in C/C++.
*
* When operands are integers, it directly corresponds to truncdiv.
*
* \param a left operand
* \param b right operand
* \return The result expression.
* \note this function does eager constant folding for
* index types(int32, int64) when possible.
*/
TVM_DLL Expr div(Expr a, Expr b);
/*!
* \brief compute trunc(a / b)
*
Expand Down Expand Up @@ -640,6 +654,21 @@ inline Expr make_zero(Type t) {
return make_const(t, 0);
}

/*!
* \brief Helper function to raise a compiler error about division ambiguity.
* \note The call to this function will always results in a compiler error.
* \tparam TA Any class type.
*/
template<typename TA>
inline void DivAmbiguityError(const TA& a) {
constexpr bool div_ambiguity = !std::is_class<TA>::value;
static_assert(div_ambiguity,
"TVM supports multiple types of integer divisions, "
"please call div, floordiv/floormod or truncdiv/truncmod directly "
"to avoid ambiguity in the code. "
"Checkout these functions in expr_operator.h.");
}

// additional const expression overloading
#define TVM_DEFINE_ASSIGN_OP_OVERLOAD(Name, OpFunc) \
inline Expr Name(Expr& a, Expr b) { \
Expand Down Expand Up @@ -688,12 +717,17 @@ TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator*);
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator/);
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(max);
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(min);
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(div);
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator>); // NOLINT(*)
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator>=);
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator<); // NOLINT(*)
TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator<=);
// integer related ops
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator%);
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(truncmod);
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(floordiv);
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(floormod);
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(truncdiv);
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator>>); // NOLINT(*)
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator<<); // NOLINT(*)
TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator&);
Expand Down
4 changes: 2 additions & 2 deletions src/arithmetic/canonical_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ enum DivMode {

inline Expr ModImpl(Expr a, Expr b, DivMode mode) {
if (mode == kTruncDiv) {
return a % b;
return truncmod(a, b);
} else {
CHECK_EQ(mode, kFloorDiv);
return floormod(a, b);
Expand All @@ -76,7 +76,7 @@ inline Expr ModImpl(Expr a, Expr b, DivMode mode) {

inline Expr DivImpl(Expr a, Expr b, DivMode mode) {
if (mode == kTruncDiv) {
return a / b;
return truncdiv(a, b);
} else {
CHECK_EQ(mode, kFloorDiv);
return floordiv(a, b);
Expand Down
20 changes: 20 additions & 0 deletions src/arithmetic/int_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,26 @@ inline bool WillOverflow<ir::Mod>(int64_t x,
return y == 0;
}

/*!
* \brief Peform trunc division of two integers.
* \param x The left operand.
* \param y The right operand.
* \return the result.
*/
inline int64_t truncdiv(int64_t x, int64_t y) {
return x / y;
}

/*!
* \brief Compute the truncdiv remainder of two integers.
* \param x The left operand.
* \param y The right operand.
* \return the result.
*/
inline int64_t truncmod(int64_t x, int64_t y) {
return x % y;
}

/*!
* \brief Peform floor division of two integers.
* \param x The left operand.
Expand Down
4 changes: 2 additions & 2 deletions src/arithmetic/modular_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
*/

/*!
* Copyright (c) 2019 by Contributors
* \file modular_set.cc
* \brief Modular set analysis
*/
Expand Down Expand Up @@ -111,7 +110,8 @@ class ModularSetAnalyzer::Impl :
PVar<Var> var;
PVar<Integer> coeff, base;
// pattern match interesting constraints
if (((var % coeff) == base).Match(constraint)) {
if ((truncmod(var, coeff) == base).Match(constraint) ||
(floormod(var, coeff) == base).Match(constraint)) {
Entry entry(coeff.Eval()->value, base.Eval()->value);
return UpdateByIntersect(var.Eval(), entry);
}
Expand Down
16 changes: 13 additions & 3 deletions src/arithmetic/pattern_match.h
Original file line number Diff line number Diff line change
Expand Up @@ -300,31 +300,41 @@ class PConstWithTypeLike :
};


#define TVM_PATTERN_BINARY_OP(FuncName, NodeName) \
#define TVM_PATTERN_BINARY_OP_EX(FuncName, NodeName, CheckStep) \
template<typename TA, typename TB> \
inline PBinaryExpr<NodeName, TA, TB> \
FuncName(const Pattern<TA>& a, const Pattern<TB>& b) { \
CheckStep; \
return PBinaryExpr<NodeName, TA, TB>(a.derived(), b.derived()); \
} \
template<typename TA> \
inline PBinaryExpr<NodeName, TA, PConstWithTypeLike<TA> > \
FuncName(const Pattern<TA>& a, int64_t b) { \
CheckStep; \
return FuncName(a, PConstWithTypeLike<TA>(a.derived(), b)); \
} \
template<typename TA> \
inline PBinaryExpr<NodeName, PConstWithTypeLike<TA>, TA> \
FuncName(int64_t b, const Pattern<TA>& a) { \
CheckStep; \
return FuncName(PConstWithTypeLike<TA>(a.derived(), b), a); \
}

#define TVM_PATTERN_BINARY_OP(FuncName, NodeName) \
TVM_PATTERN_BINARY_OP_EX(FuncName, NodeName, )


// raise ambiguity error for operator overload of / and %
TVM_PATTERN_BINARY_OP_EX(operator/, ir::Div, DivAmbiguityError(a));
TVM_PATTERN_BINARY_OP_EX(operator%, ir::Mod, DivAmbiguityError(a));

// arithmetic expressions
TVM_PATTERN_BINARY_OP(operator+, ir::Add);
TVM_PATTERN_BINARY_OP(operator-, ir::Sub);
TVM_PATTERN_BINARY_OP(operator*, ir::Mul);
TVM_PATTERN_BINARY_OP(operator/, ir::Div);
TVM_PATTERN_BINARY_OP(operator%, ir::Mod);
TVM_PATTERN_BINARY_OP(min, ir::Min);
TVM_PATTERN_BINARY_OP(max, ir::Max);
TVM_PATTERN_BINARY_OP(div, ir::Div);
TVM_PATTERN_BINARY_OP(truncdiv, ir::Div);
TVM_PATTERN_BINARY_OP(truncmod, ir::Mod);
TVM_PATTERN_BINARY_OP(floordiv, ir::FloorDiv);
Expand Down
Loading

0 comments on commit aead397

Please sign in to comment.