Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR] Improve Let/LetStmt support. #5949

Merged
merged 1 commit into from
Jun 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 16 additions & 16 deletions include/tvm/arith/analyzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,17 +128,17 @@ class ConstIntBoundAnalyzer {
*
* \param var The variable of interest.
* \param info The bound information.
* \param override Whether do we allow override of existing information.
* \param allow_override Whether do we allow override of existing information.
*/
TVM_DLL void Update(const Var& var, const ConstIntBound& info, bool override = false);
TVM_DLL void Update(const Var& var, const ConstIntBound& info, bool allow_override = false);
/*!
* \brief Bind variable to a range.
*
* \param var The variable.
* \param range The range we bind to.
* \param override Whether we allow overriding an existing var's range.
* \param allow_override Whether we allow overriding an existing var's range.
*/
TVM_DLL void Bind(const Var& var, const Range& range, bool override = false);
TVM_DLL void Bind(const Var& var, const Range& range, bool allow_override = false);

private:
friend class Analyzer;
Expand Down Expand Up @@ -217,9 +217,9 @@ class ModularSetAnalyzer {
*
* \param var The variable of interest.
* \param info The bound information.
* \param override Whether do we allow override of existing information.
* \param allow_override Whether do we allow override of existing information.
*/
TVM_DLL void Update(const Var& var, const ModularSet& info, bool override = false);
TVM_DLL void Update(const Var& var, const ModularSet& info, bool allow_override = false);

private:
friend class Analyzer;
Expand Down Expand Up @@ -256,9 +256,9 @@ class RewriteSimplifier {
*
* \param var The variable of interest.
* \param new_expr
* \param override Whether do we allow override of existing information.
* \param allow_override Whether do we allow override of existing information.
*/
TVM_DLL void Update(const Var& var, const PrimExpr& new_expr, bool override = false);
TVM_DLL void Update(const Var& var, const PrimExpr& new_expr, bool allow_override = false);

std::function<void()> EnterConstraint(const PrimExpr& constraint);

Expand Down Expand Up @@ -290,9 +290,9 @@ class CanonicalSimplifier {
*
* \param var The variable of interest.
* \param new_expr
* \param override Whether do we allow override of existing information.
* \param allow_override Whether do we allow override of existing information.
*/
TVM_DLL void Update(const Var& var, const PrimExpr& new_expr, bool override = false);
TVM_DLL void Update(const Var& var, const PrimExpr& new_expr, bool allow_override = false);

private:
friend class Analyzer;
Expand Down Expand Up @@ -404,9 +404,9 @@ class TVM_DLL Analyzer {
*
* \param var The variable.
* \param expr The expression we bind to.
* \param override Whether we allow overriding an existing var's expression.
* \param allow_override Whether we allow overriding an existing var's expression.
*/
void Bind(const Var& var, const PrimExpr& expr, bool override = false);
void Bind(const Var& var, const PrimExpr& expr, bool allow_override = false);
/*!
* \brief Notify all the sub-analyzers that var
* is created and binded to a range.
Expand All @@ -415,16 +415,16 @@ class TVM_DLL Analyzer {
*
* \param var The variable.
* \param range The range we bind to.
* \param override Whether we allow overriding an existing var's expression.
* \param allow_override Whether we allow overriding an existing var's expression.
*/
void Bind(const Var& var, const Range& range, bool override = false);
void Bind(const Var& var, const Range& range, bool allow_override = false);
/*!
* \brief Bind all the vars in the Map
*
* \param variables The {variable -> range} map.
* \param override Whether we allow overriding an existing var's expression.
* \param allow_override Whether we allow overriding an existing var's expression.
*/
void Bind(const Map<Var, Range>& variables, bool override = false);
void Bind(const Map<Var, Range>& variables, bool allow_override = false);
/*!
* \brief Whether can we prove expr >= val.

Expand Down
26 changes: 22 additions & 4 deletions include/tvm/tir/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -671,11 +671,18 @@ inline bool is_one(const PrimExpr& x) { return is_const_int(x, 1); }
inline bool is_zero(const PrimExpr& x) { return is_const_int(x, 0); }

/*!
* \brief Check whether x is a constant.
* \brief Check whether x is an integer constant.
* \note This only return true for integer types.
* \return whether x is constant
*/
inline bool is_const(const PrimExpr& x);
inline bool is_const_int(const PrimExpr& x);

/*!
* \brief Check whether x is an integer/float constant.
* \note This only return true for integer types.
* \return whether x is constant
*/
inline bool is_const_number(const PrimExpr& x);

/*!
* \brief Left fold.
Expand All @@ -699,7 +706,7 @@ inline PrimExpr foldl(FReduce freduce, PrimExpr init_value, const Array<PrimExpr
TVM_DLL bool is_const_power_of_two_integer(const PrimExpr& x, int* shift);

// Implementation details after this
inline bool is_const(const PrimExpr& x) {
inline bool is_const_int(const PrimExpr& x) {
if (x.as<tir::IntImmNode>()) {
return true;
} else if (const auto* op = x.as<tir::BroadcastNode>()) {
Expand All @@ -711,6 +718,17 @@ inline bool is_const(const PrimExpr& x) {
return false;
}

inline bool is_const_number(const PrimExpr& x) {
if (x.as<tir::IntImmNode>()) {
return true;
} else if (x.as<tir::FloatImmNode>()) {
return true;
} else if (const auto* op = x.as<tir::BroadcastNode>()) {
return (op->value->IsInstance<tir::IntImmNode>() || op->value->IsInstance<tir::FloatImmNode>());
}
return false;
}

inline bool is_positive_const(const PrimExpr& a) {
if (const tir::IntImmNode* op = a.as<tir::IntImmNode>()) {
return op->value > 0;
Expand Down Expand Up @@ -742,7 +760,7 @@ inline bool is_const_int(const PrimExpr& x, int64_t value) {
inline bool is_no_op(const tir::Stmt& stmt) {
if (!stmt.defined()) return true;
if (const auto* op = stmt.as<tir::EvaluateNode>()) {
return is_const(op->value);
return is_const_int(op->value);
}
if (const auto* op = stmt.as<tir::SeqStmtNode>()) {
return op->seq.size() == 0;
Expand Down
24 changes: 12 additions & 12 deletions src/arith/analyzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,31 +35,31 @@ Analyzer::Analyzer()
canonical_simplify(this),
int_set(this) {}

void Analyzer::Bind(const Var& var, const PrimExpr& expr, bool override) {
void Analyzer::Bind(const Var& var, const PrimExpr& expr, bool allow_override) {
PrimExpr new_expr = expr;
new_expr = this->canonical_simplify(new_expr);
new_expr = this->rewrite_simplify(new_expr);

this->const_int_bound.Update(var, this->const_int_bound(new_expr), override);
this->modular_set.Update(var, this->modular_set(new_expr), override);
this->rewrite_simplify.Update(var, new_expr, override);
this->canonical_simplify.Update(var, new_expr, override);
this->const_int_bound.Update(var, this->const_int_bound(new_expr), allow_override);
this->modular_set.Update(var, this->modular_set(new_expr), allow_override);
this->rewrite_simplify.Update(var, new_expr, allow_override);
this->canonical_simplify.Update(var, new_expr, allow_override);
}

void Analyzer::Bind(const Var& var, const Range& range, bool override) {
void Analyzer::Bind(const Var& var, const Range& range, bool allow_override) {
CHECK(range.defined());
if (tir::is_one(range->extent)) {
this->Bind(var, range->min, override);
this->Bind(var, range->min, allow_override);
} else {
this->const_int_bound.Bind(var, range, override);
this->const_int_bound.Bind(var, range, allow_override);
}
// skip modular_set
// skip rewrite simplify
}

void Analyzer::Bind(const Map<Var, Range>& variables, bool override) {
void Analyzer::Bind(const Map<Var, Range>& variables, bool allow_override) {
for (const auto& iter : variables) {
this->Bind(iter.first, iter.second, override);
this->Bind(iter.first, iter.second, allow_override);
}
}

Expand Down Expand Up @@ -116,9 +116,9 @@ bool Analyzer::CanProve(const PrimExpr& expr) {
}

PrimExpr Analyzer::Simplify(const PrimExpr& expr) {
if (tir::is_const(expr)) return expr;
if (tir::is_const_int(expr)) return expr;
auto res = this->rewrite_simplify(expr);
if (tir::is_const(res)) return res;
if (tir::is_const_int(res)) return res;
res = this->canonical_simplify(res);
return res;
}
Expand Down
33 changes: 23 additions & 10 deletions src/arith/const_int_bound.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,17 +96,17 @@ class ConstIntBoundAnalyzer::Impl
BoundInfo(PrimExpr expr, Entry bound) : expr(expr), bound(bound) {}
};

void Bind(const Var& var, const Range& range, bool override) {
void Bind(const Var& var, const Range& range, bool allow_override) {
Entry a = VisitExpr(range->min);
Entry b = VisitExpr(range->extent);
Entry ret;
ret.min_value = a.min_value;
ret.max_value = InfAwareAdd(a.max_value, InfAwareAdd(b.max_value, -1));
Update(var, ret, override);
Update(var, ret, allow_override);
}

void Update(const Var& var, const Entry& info, bool override) {
if (!override) {
void Update(const Var& var, const Entry& info, bool allow_override) {
if (!allow_override) {
auto it = var_map_.find(var);
if (it != var_map_.end()) {
CHECK(it->second == info) << "Trying to update var \'" << var << "\'"
Expand All @@ -119,8 +119,21 @@ class ConstIntBoundAnalyzer::Impl
var_map_[var] = info;
}

void Update(const Var& var, const ConstIntBound& info, bool override) {
Update(var, MakeBound(info->min_value, info->max_value), override);
Entry VisitExpr_(const LetNode* op) final {
auto it = var_map_.find(op->var);
// if the var has not been binded, update the info.
if (it == var_map_.end()) {
var_map_[op->var] = this->VisitExpr(op->value);
Entry ret = VisitExpr(op->body);
var_map_.erase(op->var);
return ret;
} else {
return VisitExpr(op->body);
}
}

void Update(const Var& var, const ConstIntBound& info, bool allow_override) {
Update(var, MakeBound(info->min_value, info->max_value), allow_override);
}

// Override visitor behaviors
Expand Down Expand Up @@ -558,12 +571,12 @@ ConstIntBound ConstIntBoundAnalyzer::operator()(const PrimExpr& expr, BoundMapTy
return ConstIntBound(ret.min_value, ret.max_value);
}

void ConstIntBoundAnalyzer::Update(const Var& var, const ConstIntBound& info, bool override) {
impl_->Update(var, info, override);
void ConstIntBoundAnalyzer::Update(const Var& var, const ConstIntBound& info, bool allow_override) {
impl_->Update(var, info, allow_override);
}

void ConstIntBoundAnalyzer::Bind(const Var& var, const Range& range, bool override) {
impl_->Bind(var, range, override);
void ConstIntBoundAnalyzer::Bind(const Var& var, const Range& range, bool allow_override) {
impl_->Bind(var, range, allow_override);
}

std::function<void()> ConstIntBoundAnalyzer::EnterConstraint(const PrimExpr& constraint) {
Expand Down
21 changes: 17 additions & 4 deletions src/arith/modular_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ class ModularSetAnalyzer::Impl : public ExprFunctor<ModularSetAnalyzer::Entry(co
public:
explicit Impl(Analyzer* parent) : parent_(parent) {}

void Update(const Var& var, const ModularSet& info, bool override) {
if (!override) {
void Update(const Var& var, const ModularSet& info, bool allow_override) {
if (!allow_override) {
auto it = var_map_.find(var);
if (it != var_map_.end()) {
CHECK(it->second == info) << "Trying to update var \'" << var << "\'"
Expand Down Expand Up @@ -118,6 +118,19 @@ class ModularSetAnalyzer::Impl : public ExprFunctor<ModularSetAnalyzer::Entry(co
// Override visitor behaviors
Entry VisitExprDefault_(const Object* op) final { return Everything(); }

Entry VisitExpr_(const LetNode* op) final {
auto it = var_map_.find(op->var);
// if the var has not been binded, update the info.
if (it == var_map_.end()) {
var_map_[op->var] = this->VisitExpr(op->value);
Entry ret = VisitExpr(op->body);
var_map_.erase(op->var);
return ret;
} else {
return VisitExpr(op->body);
}
}

Entry VisitExpr_(const CastNode* op) final { return VisitExpr(op->value); }

Entry VisitExpr_(const IntImmNode* op) final { return Entry(0, op->value); }
Expand Down Expand Up @@ -315,8 +328,8 @@ ModularSet ModularSetAnalyzer::operator()(const PrimExpr& expr) {
return ModularSet(ret.coeff, ret.base);
}

void ModularSetAnalyzer::Update(const Var& var, const ModularSet& info, bool override) {
impl_->Update(var, info, override);
void ModularSetAnalyzer::Update(const Var& var, const ModularSet& info, bool allow_override) {
impl_->Update(var, info, allow_override);
}

std::function<void()> ModularSetAnalyzer::EnterConstraint(const PrimExpr& constraint) {
Expand Down
16 changes: 12 additions & 4 deletions src/arith/rewrite_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1519,7 +1519,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const CallNode* op) {
op = ret.as<CallNode>();
if (op == nullptr) return ret;

if (op->op.same_as(tir::builtin::likely()) && is_const(op->args[0])) {
if (op->op.same_as(tir::builtin::likely()) && is_const_int(op->args[0])) {
return op->args[0];
} else if (op->op.same_as(tir::builtin::shift_right())) {
if (op->args[0].as<IntImmNode>() && op->args[1].as<IntImmNode>()) {
Expand Down Expand Up @@ -1559,9 +1559,17 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const CastNode* op) {
return cast(op->dtype, op->value);
}

bool RewriteSimplifier::Impl::CanInlineLet(const LetNode* op) {
// Only inline trivial bindings to avoid deep expression explosion
// when we need let to construct complicated expressions.
if (is_const_number(op->value)) return true;
if (op->value.as<VarNode>()) return true;
return false;
}

PrimExpr RewriteSimplifier::Impl::VisitExpr_(const LetNode* op) {
PrimExpr value = this->VisitExpr(op->value);
if (!tir::HasSideEffect(value)) {
if (CanInlineLet(op)) {
// it is fine to discard the let binding
// because the value will always be inlined in the simplifier.
analyzer_->Bind(op->var, value);
Expand All @@ -1587,8 +1595,8 @@ PrimExpr RewriteSimplifier::operator()(const PrimExpr& expr) {
return res;
}

void RewriteSimplifier::Update(const Var& var, const PrimExpr& info, bool override) {
impl_->Update(var, info, override);
void RewriteSimplifier::Update(const Var& var, const PrimExpr& info, bool allow_override) {
impl_->Update(var, info, allow_override);
}

std::function<void()> RewriteSimplifier::EnterConstraint(const PrimExpr& constraint) {
Expand Down
7 changes: 7 additions & 0 deletions src/arith/rewrite_simplify.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,13 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer {
*/
CompareResult TryCompare(const PrimExpr& x, int64_t val);

/*!
* \brief Internal function to check whether or not to inline let.
* \param op The let expr.
* \return The inline decision.
*/
bool CanInlineLet(const LetNode* op);

private:
// Whether x >= val
bool CanProveGreaterEqual(const PrimExpr& x, int64_t val) {
Expand Down
4 changes: 2 additions & 2 deletions src/contrib/hybrid/codegen_hybrid.cc
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ void CodeGenHybrid::VisitStmt_(const ForNode* op) {

bool is_noop(const Stmt& stmt) {
if (!stmt.defined()) return true;
if (auto eval = stmt.as<EvaluateNode>()) return is_const(eval->value);
if (auto eval = stmt.as<EvaluateNode>()) return is_const_int(eval->value);
return false;
}

Expand Down Expand Up @@ -409,7 +409,7 @@ void CodeGenHybrid::VisitStmt_(const SeqStmtNode* op) {
}

void CodeGenHybrid::VisitStmt_(const EvaluateNode* op) {
if (is_const(op->value)) return;
if (is_const_int(op->value)) return;
std::string str = PrintExpr(op->value);
if (!str.empty()) stream << str << "\n";
}
Expand Down
Loading