diff --git a/include/tvm/expr_operator.h b/include/tvm/expr_operator.h index efbe37ef4d68..a576a28c6828 100644 --- a/include/tvm/expr_operator.h +++ b/include/tvm/expr_operator.h @@ -104,6 +104,16 @@ inline const uint64_t* as_const_uint(const Expr& x) { */ inline bool is_const_int(const Expr& x, int64_t value); +/*! + * \brief Check if the given expr is a const of any type equal to the given integer value. + * \param e The expression. + * \param value The value to compare to. + * \return Whether the expression is a const equal to the value. + * \tparam ValueType The value type + */ +template +inline bool is_const_value(const Expr& e, ValueType value); + /*! * \brief Check whether stmt is nop. * \param stmt The input statement @@ -551,18 +561,31 @@ inline bool is_negative_const(const Expr& a) { } } +template +inline bool is_const_value(const Expr& e, ValueType value) { + static_assert(std::is_integral::value, + "Comparison to non-integer values is forbidden."); + // This implementation was copy-pasted from HalideIR + if (const ir::IntImm* i = e.as()) { + return i->value == value; + } else if (const ir::UIntImm* i = e.as()) { + return (value >= 0) && (i->value == static_cast(value)); + } else if (const ir::FloatImm* i = e.as()) { + return i->value == value; + } else if (const ir::Cast* c = e.as()) { + return is_const_value(c->value, value); + } else if (const ir::Broadcast* b = e.as()) { + return is_const_value(b->value, value); + } else { + return false; + } +} + inline bool is_const_int(const Expr& x, int64_t value) { - if (const auto* op = x.as()) { - return op->value == value; - } else if (const auto* op = x.as()) { - return op->value == static_cast(value); + if (x.as() || x.as()) { + return is_const_value(x, value); } else if (const auto* op = x.as()) { - const Expr& val = op->value; - if (const auto* opv = val.as()) { - return opv->value == value; - } else if (const auto* opv = val.as()) { - return opv->value == static_cast(value); - } + return !op->value.as() && is_const_int(op->value, value); } return false; }