Skip to content

Commit

Permalink
[TIR] Improve Let/LetStmt support.
Browse files Browse the repository at this point in the history
Let/LetStmt are useful primitives to create variable bindings.
While let binding are harmful for simplification and integer analysis,
they are useful for other cases:

- C0: LetStmt is useful to represent a step that has side effect(e.g. call a PRNG)
- C1: Let expression can be used to create deep nested expression for complicated functions.

This PR improves the let support in the following ways:
- Enable vectorization support for let
- Change let simplification strategy to simplify the most trivial case
  while ignore more complicated cases(to avoid deep nest explosion)
- Enhance arith module to handle const bound and modular set for let.

The overall recommendation is to only use Let in the cases when necessary(C0, C1).
  • Loading branch information
tqchen committed Jun 28, 2020
1 parent e99e116 commit 9dadf68
Show file tree
Hide file tree
Showing 22 changed files with 341 additions and 135 deletions.
32 changes: 16 additions & 16 deletions include/tvm/arith/analyzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,17 +128,17 @@ class ConstIntBoundAnalyzer {
*
* \param var The variable of interest.
* \param info The bound information.
* \param override Whether do we allow override of existing information.
* \param allow_override Whether do we allow override of existing information.
*/
TVM_DLL void Update(const Var& var, const ConstIntBound& info, bool override = false);
TVM_DLL void Update(const Var& var, const ConstIntBound& info, bool allow_override = false);
/*!
* \brief Bind variable to a range.
*
* \param var The variable.
* \param range The range we bind to.
* \param override Whether we allow overriding an existing var's range.
* \param allow_override Whether we allow overriding an existing var's range.
*/
TVM_DLL void Bind(const Var& var, const Range& range, bool override = false);
TVM_DLL void Bind(const Var& var, const Range& range, bool allow_override = false);

private:
friend class Analyzer;
Expand Down Expand Up @@ -217,9 +217,9 @@ class ModularSetAnalyzer {
*
* \param var The variable of interest.
* \param info The bound information.
* \param override Whether do we allow override of existing information.
* \param allow_override Whether do we allow override of existing information.
*/
TVM_DLL void Update(const Var& var, const ModularSet& info, bool override = false);
TVM_DLL void Update(const Var& var, const ModularSet& info, bool allow_override = false);

private:
friend class Analyzer;
Expand Down Expand Up @@ -256,9 +256,9 @@ class RewriteSimplifier {
*
* \param var The variable of interest.
* \param new_expr
* \param override Whether do we allow override of existing information.
* \param allow_override Whether do we allow override of existing information.
*/
TVM_DLL void Update(const Var& var, const PrimExpr& new_expr, bool override = false);
TVM_DLL void Update(const Var& var, const PrimExpr& new_expr, bool allow_override = false);

std::function<void()> EnterConstraint(const PrimExpr& constraint);

Expand Down Expand Up @@ -290,9 +290,9 @@ class CanonicalSimplifier {
*
* \param var The variable of interest.
* \param new_expr
* \param override Whether do we allow override of existing information.
* \param allow_override Whether do we allow override of existing information.
*/
TVM_DLL void Update(const Var& var, const PrimExpr& new_expr, bool override = false);
TVM_DLL void Update(const Var& var, const PrimExpr& new_expr, bool allow_override = false);

private:
friend class Analyzer;
Expand Down Expand Up @@ -404,9 +404,9 @@ class TVM_DLL Analyzer {
*
* \param var The variable.
* \param expr The expression we bind to.
* \param override Whether we allow overriding an existing var's expression.
* \param allow_override Whether we allow overriding an existing var's expression.
*/
void Bind(const Var& var, const PrimExpr& expr, bool override = false);
void Bind(const Var& var, const PrimExpr& expr, bool allow_override = false);
/*!
* \brief Notify all the sub-analyzers that var
* is created and binded to a range.
Expand All @@ -415,16 +415,16 @@ class TVM_DLL Analyzer {
*
* \param var The variable.
* \param range The range we bind to.
* \param override Whether we allow overriding an existing var's expression.
* \param allow_override Whether we allow overriding an existing var's expression.
*/
void Bind(const Var& var, const Range& range, bool override = false);
void Bind(const Var& var, const Range& range, bool allow_override = false);
/*!
* \brief Bind all the vars in the Map
*
* \param variables The {variable -> range} map.
* \param override Whether we allow overriding an existing var's expression.
* \param allow_override Whether we allow overriding an existing var's expression.
*/
void Bind(const Map<Var, Range>& variables, bool override = false);
void Bind(const Map<Var, Range>& variables, bool allow_override = false);
/*!
* \brief Whether can we prove expr >= val.
Expand Down
26 changes: 22 additions & 4 deletions include/tvm/tir/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -671,11 +671,18 @@ inline bool is_one(const PrimExpr& x) { return is_const_int(x, 1); }
inline bool is_zero(const PrimExpr& x) { return is_const_int(x, 0); }

/*!
* \brief Check whether x is a constant.
* \brief Check whether x is an integer constant.
* \note This only return true for integer types.
* \return whether x is constant
*/
inline bool is_const(const PrimExpr& x);
inline bool is_const_int(const PrimExpr& x);

/*!
* \brief Check whether x is an integer/float constant.
* \note This only return true for integer types.
* \return whether x is constant
*/
inline bool is_const_number(const PrimExpr& x);

/*!
* \brief Left fold.
Expand All @@ -699,7 +706,7 @@ inline PrimExpr foldl(FReduce freduce, PrimExpr init_value, const Array<PrimExpr
TVM_DLL bool is_const_power_of_two_integer(const PrimExpr& x, int* shift);

// Implementation details after this
inline bool is_const(const PrimExpr& x) {
inline bool is_const_int(const PrimExpr& x) {
if (x.as<tir::IntImmNode>()) {
return true;
} else if (const auto* op = x.as<tir::BroadcastNode>()) {
Expand All @@ -711,6 +718,17 @@ inline bool is_const(const PrimExpr& x) {
return false;
}

inline bool is_const_number(const PrimExpr& x) {
if (x.as<tir::IntImmNode>()) {
return true;
} else if (x.as<tir::FloatImmNode>()) {
return true;
} else if (const auto* op = x.as<tir::BroadcastNode>()) {
return (op->value->IsInstance<tir::IntImmNode>() || op->value->IsInstance<tir::FloatImmNode>());
}
return false;
}

inline bool is_positive_const(const PrimExpr& a) {
if (const tir::IntImmNode* op = a.as<tir::IntImmNode>()) {
return op->value > 0;
Expand Down Expand Up @@ -742,7 +760,7 @@ inline bool is_const_int(const PrimExpr& x, int64_t value) {
inline bool is_no_op(const tir::Stmt& stmt) {
if (!stmt.defined()) return true;
if (const auto* op = stmt.as<tir::EvaluateNode>()) {
return is_const(op->value);
return is_const_int(op->value);
}
if (const auto* op = stmt.as<tir::SeqStmtNode>()) {
return op->seq.size() == 0;
Expand Down
24 changes: 12 additions & 12 deletions src/arith/analyzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,31 +35,31 @@ Analyzer::Analyzer()
canonical_simplify(this),
int_set(this) {}

void Analyzer::Bind(const Var& var, const PrimExpr& expr, bool override) {
void Analyzer::Bind(const Var& var, const PrimExpr& expr, bool allow_override) {
PrimExpr new_expr = expr;
new_expr = this->canonical_simplify(new_expr);
new_expr = this->rewrite_simplify(new_expr);

this->const_int_bound.Update(var, this->const_int_bound(new_expr), override);
this->modular_set.Update(var, this->modular_set(new_expr), override);
this->rewrite_simplify.Update(var, new_expr, override);
this->canonical_simplify.Update(var, new_expr, override);
this->const_int_bound.Update(var, this->const_int_bound(new_expr), allow_override);
this->modular_set.Update(var, this->modular_set(new_expr), allow_override);
this->rewrite_simplify.Update(var, new_expr, allow_override);
this->canonical_simplify.Update(var, new_expr, allow_override);
}

void Analyzer::Bind(const Var& var, const Range& range, bool override) {
void Analyzer::Bind(const Var& var, const Range& range, bool allow_override) {
CHECK(range.defined());
if (tir::is_one(range->extent)) {
this->Bind(var, range->min, override);
this->Bind(var, range->min, allow_override);
} else {
this->const_int_bound.Bind(var, range, override);
this->const_int_bound.Bind(var, range, allow_override);
}
// skip modular_set
// skip rewrite simplify
}

void Analyzer::Bind(const Map<Var, Range>& variables, bool override) {
void Analyzer::Bind(const Map<Var, Range>& variables, bool allow_override) {
for (const auto& iter : variables) {
this->Bind(iter.first, iter.second, override);
this->Bind(iter.first, iter.second, allow_override);
}
}

Expand Down Expand Up @@ -116,9 +116,9 @@ bool Analyzer::CanProve(const PrimExpr& expr) {
}

PrimExpr Analyzer::Simplify(const PrimExpr& expr) {
if (tir::is_const(expr)) return expr;
if (tir::is_const_int(expr)) return expr;
auto res = this->rewrite_simplify(expr);
if (tir::is_const(res)) return res;
if (tir::is_const_int(res)) return res;
res = this->canonical_simplify(res);
return res;
}
Expand Down
33 changes: 23 additions & 10 deletions src/arith/const_int_bound.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,17 +96,17 @@ class ConstIntBoundAnalyzer::Impl
BoundInfo(PrimExpr expr, Entry bound) : expr(expr), bound(bound) {}
};

void Bind(const Var& var, const Range& range, bool override) {
void Bind(const Var& var, const Range& range, bool allow_override) {
Entry a = VisitExpr(range->min);
Entry b = VisitExpr(range->extent);
Entry ret;
ret.min_value = a.min_value;
ret.max_value = InfAwareAdd(a.max_value, InfAwareAdd(b.max_value, -1));
Update(var, ret, override);
Update(var, ret, allow_override);
}

void Update(const Var& var, const Entry& info, bool override) {
if (!override) {
void Update(const Var& var, const Entry& info, bool allow_override) {
if (!allow_override) {
auto it = var_map_.find(var);
if (it != var_map_.end()) {
CHECK(it->second == info) << "Trying to update var \'" << var << "\'"
Expand All @@ -119,8 +119,21 @@ class ConstIntBoundAnalyzer::Impl
var_map_[var] = info;
}

void Update(const Var& var, const ConstIntBound& info, bool override) {
Update(var, MakeBound(info->min_value, info->max_value), override);
Entry VisitExpr_(const LetNode* op) final {
auto it = var_map_.find(op->var);
// if the var has not been binded, update the info.
if (it == var_map_.end()) {
var_map_[op->var] = this->VisitExpr(op->value);
Entry ret = VisitExpr(op->body);
var_map_.erase(op->var);
return ret;
} else {
return VisitExpr(op->body);
}
}

void Update(const Var& var, const ConstIntBound& info, bool allow_override) {
Update(var, MakeBound(info->min_value, info->max_value), allow_override);
}

// Override visitor behaviors
Expand Down Expand Up @@ -558,12 +571,12 @@ ConstIntBound ConstIntBoundAnalyzer::operator()(const PrimExpr& expr, BoundMapTy
return ConstIntBound(ret.min_value, ret.max_value);
}

void ConstIntBoundAnalyzer::Update(const Var& var, const ConstIntBound& info, bool override) {
impl_->Update(var, info, override);
void ConstIntBoundAnalyzer::Update(const Var& var, const ConstIntBound& info, bool allow_override) {
impl_->Update(var, info, allow_override);
}

void ConstIntBoundAnalyzer::Bind(const Var& var, const Range& range, bool override) {
impl_->Bind(var, range, override);
void ConstIntBoundAnalyzer::Bind(const Var& var, const Range& range, bool allow_override) {
impl_->Bind(var, range, allow_override);
}

std::function<void()> ConstIntBoundAnalyzer::EnterConstraint(const PrimExpr& constraint) {
Expand Down
21 changes: 17 additions & 4 deletions src/arith/modular_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ class ModularSetAnalyzer::Impl : public ExprFunctor<ModularSetAnalyzer::Entry(co
public:
explicit Impl(Analyzer* parent) : parent_(parent) {}

void Update(const Var& var, const ModularSet& info, bool override) {
if (!override) {
void Update(const Var& var, const ModularSet& info, bool allow_override) {
if (!allow_override) {
auto it = var_map_.find(var);
if (it != var_map_.end()) {
CHECK(it->second == info) << "Trying to update var \'" << var << "\'"
Expand Down Expand Up @@ -118,6 +118,19 @@ class ModularSetAnalyzer::Impl : public ExprFunctor<ModularSetAnalyzer::Entry(co
// Override visitor behaviors
Entry VisitExprDefault_(const Object* op) final { return Everything(); }

Entry VisitExpr_(const LetNode* op) final {
auto it = var_map_.find(op->var);
// if the var has not been binded, update the info.
if (it == var_map_.end()) {
var_map_[op->var] = this->VisitExpr(op->value);
Entry ret = VisitExpr(op->body);
var_map_.erase(op->var);
return ret;
} else {
return VisitExpr(op->body);
}
}

Entry VisitExpr_(const CastNode* op) final { return VisitExpr(op->value); }

Entry VisitExpr_(const IntImmNode* op) final { return Entry(0, op->value); }
Expand Down Expand Up @@ -315,8 +328,8 @@ ModularSet ModularSetAnalyzer::operator()(const PrimExpr& expr) {
return ModularSet(ret.coeff, ret.base);
}

void ModularSetAnalyzer::Update(const Var& var, const ModularSet& info, bool override) {
impl_->Update(var, info, override);
void ModularSetAnalyzer::Update(const Var& var, const ModularSet& info, bool allow_override) {
impl_->Update(var, info, allow_override);
}

std::function<void()> ModularSetAnalyzer::EnterConstraint(const PrimExpr& constraint) {
Expand Down
16 changes: 12 additions & 4 deletions src/arith/rewrite_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1519,7 +1519,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const CallNode* op) {
op = ret.as<CallNode>();
if (op == nullptr) return ret;

if (op->op.same_as(tir::builtin::likely()) && is_const(op->args[0])) {
if (op->op.same_as(tir::builtin::likely()) && is_const_int(op->args[0])) {
return op->args[0];
} else if (op->op.same_as(tir::builtin::shift_right())) {
if (op->args[0].as<IntImmNode>() && op->args[1].as<IntImmNode>()) {
Expand Down Expand Up @@ -1559,9 +1559,17 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const CastNode* op) {
return cast(op->dtype, op->value);
}

bool RewriteSimplifier::Impl::CanInlineLet(const LetNode* op) {
// Only inline trivial bindings to avoid deep expression explosion
// when we need let to construct complicated expressions.
if (is_const_number(op->value)) return true;
if (op->value.as<VarNode>()) return true;
return false;
}

PrimExpr RewriteSimplifier::Impl::VisitExpr_(const LetNode* op) {
PrimExpr value = this->VisitExpr(op->value);
if (!tir::HasSideEffect(value)) {
if (CanInlineLet(op)) {
// it is fine to discard the let binding
// because the value will always be inlined in the simplifier.
analyzer_->Bind(op->var, value);
Expand All @@ -1587,8 +1595,8 @@ PrimExpr RewriteSimplifier::operator()(const PrimExpr& expr) {
return res;
}

void RewriteSimplifier::Update(const Var& var, const PrimExpr& info, bool override) {
impl_->Update(var, info, override);
void RewriteSimplifier::Update(const Var& var, const PrimExpr& info, bool allow_override) {
impl_->Update(var, info, allow_override);
}

std::function<void()> RewriteSimplifier::EnterConstraint(const PrimExpr& constraint) {
Expand Down
7 changes: 7 additions & 0 deletions src/arith/rewrite_simplify.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,13 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer {
*/
CompareResult TryCompare(const PrimExpr& x, int64_t val);

/*!
* \brief Internal function to check whether or not to inline let.
* \param op The let expr.
* \return The inline decision.
*/
bool CanInlineLet(const LetNode* op);

private:
// Whether x >= val
bool CanProveGreaterEqual(const PrimExpr& x, int64_t val) {
Expand Down
4 changes: 2 additions & 2 deletions src/contrib/hybrid/codegen_hybrid.cc
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ void CodeGenHybrid::VisitStmt_(const ForNode* op) {

bool is_noop(const Stmt& stmt) {
if (!stmt.defined()) return true;
if (auto eval = stmt.as<EvaluateNode>()) return is_const(eval->value);
if (auto eval = stmt.as<EvaluateNode>()) return is_const_int(eval->value);
return false;
}

Expand Down Expand Up @@ -409,7 +409,7 @@ void CodeGenHybrid::VisitStmt_(const SeqStmtNode* op) {
}

void CodeGenHybrid::VisitStmt_(const EvaluateNode* op) {
if (is_const(op->value)) return;
if (is_const_int(op->value)) return;
std::string str = PrintExpr(op->value);
if (!str.empty()) stream << str << "\n";
}
Expand Down
Loading

0 comments on commit 9dadf68

Please sign in to comment.