Skip to content

Commit

Permalink
[EXPR] is_const_value to check whether non-ints are consts
Browse files Browse the repository at this point in the history
  • Loading branch information
sgrechanik-h committed Aug 12, 2019
1 parent 4f12046 commit 7e1b346
Showing 1 changed file with 33 additions and 10 deletions.
43 changes: 33 additions & 10 deletions include/tvm/expr_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,16 @@ inline const uint64_t* as_const_uint(const Expr& x) {
*/
inline bool is_const_int(const Expr& x, int64_t value);

/*!
* \brief Check if the given expr is a const of any type equal to the given integer value.
* \param e The expression.
* \param value The value to compare to.
* \return Whether the expression is a const equal to the value.
* \tparam ValueType The value type
*/
template <typename ValueType>
inline bool is_const_value(const Expr& e, ValueType value);

/*!
* \brief Check whether stmt is nop.
* \param stmt The input statement
Expand Down Expand Up @@ -551,18 +561,31 @@ inline bool is_negative_const(const Expr& a) {
}
}

template <typename ValueType>
inline bool is_const_value(const Expr& e, ValueType value) {
static_assert(std::is_integral<ValueType>::value,
"Comparison to non-integer values is forbidden.");
// This implementation was copy-pasted from HalideIR
if (const ir::IntImm* i = e.as<ir::IntImm>()) {
return i->value == value;
} else if (const ir::UIntImm* i = e.as<ir::UIntImm>()) {
return (value >= 0) && (i->value == static_cast<uint64_t>(value));
} else if (const ir::FloatImm* i = e.as<ir::FloatImm>()) {
return i->value == value;
} else if (const ir::Cast* c = e.as<ir::Cast>()) {
return is_const_value(c->value, value);
} else if (const ir::Broadcast* b = e.as<ir::Broadcast>()) {
return is_const_value(b->value, value);
} else {
return false;
}
}

inline bool is_const_int(const Expr& x, int64_t value) {
if (const auto* op = x.as<ir::IntImm>()) {
return op->value == value;
} else if (const auto* op = x.as<ir::UIntImm>()) {
return op->value == static_cast<uint64_t>(value);
if (x.as<ir::IntImm>() || x.as<ir::UIntImm>()) {
return is_const_value(x, value);
} else if (const auto* op = x.as<ir::Broadcast>()) {
const Expr& val = op->value;
if (const auto* opv = val.as<ir::IntImm>()) {
return opv->value == value;
} else if (const auto* opv = val.as<ir::UIntImm>()) {
return opv->value == static_cast<uint64_t>(value);
}
return !op->value.as<ir::Broadcast>() && is_const_int(op->value, value);
}
return false;
}
Expand Down

0 comments on commit 7e1b346

Please sign in to comment.