Skip to content

Commit

Permalink
[TIR][Arith] Prove conditionals by transitively applying knowns (#12863)
Browse files Browse the repository at this point in the history
This commit adds a new sub-analyzer, `TransitiveComparisonAnalyzer`,
which attempts to apply multiple known comparisons to prove an
unknown.  For example, `a <= b` and `b <= c` imply that `a <= c`.
These simplifications are necessary for simplifying conditionals
resulting from padded layout
transformations (#12261).

While some of these conditions may be proven using
`ConstIntBoundAnalyzer` or `IntSetAnalyzer`, each has some
limitations.  `ConstIntBoundAnalyzer` can only compare against a
constant, `IntSetAnalyzer` internally calls `RewriteSimplifier` which
can result in infinite recursion, and neither can handle not-equal
conditions because it would require tracking multiple intervals per
expression.  Therefore, introducing a new sub-analyzer for these
simplifications.

* Change mutable reference to mutable pointer

* Remove nullptr default on Impl unique_ptr

In g++ 7, defining a default constructor attempts to define the
destructor, which fails because `Impl` is an incomplete type.  As far
as I should tell, the destructor should only be defined at the point
where `~TransitiveComparisonAnalyzer` is defined, at which point
`Impl` has a full definition.  This issue does not occur in g++ 10.

* Require opt-in for CPU-intensive simplifications

* Document the intent of using bitflags

* Rename "Feature" to "Extension"

* Use TVM_DLL on new public member functions

* Remove duplicate BaseBeforeAfter.transform definition

* Explicitly enable extension for unit tests that require it

* Fix accidentally duplicate test case

* Improve TryCompareFromLHS documentation

* Update wording to distinguish `knowns_` and `scoped_knowns_`

* Better documentation for Key enum

* Document the normalization of LT/GT

* Removed unused PrimExpr temp

* Call out modifications of the `compared_to_x` contents

* Pointed to `Comparison::Comparison` for normalization details

* Updated to clarify right/RHS.

* Rename TryCompareFromLHS to DFSFromLHS
  • Loading branch information
Lunderberg authored Oct 7, 2022
1 parent 7804a98 commit fc333f9
Show file tree
Hide file tree
Showing 8 changed files with 1,172 additions and 28 deletions.
114 changes: 113 additions & 1 deletion include/tvm/arith/analyzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,41 @@ class RewriteSimplifier {
*
* \return an exit function that must be called to cleanup the constraint can be nullptr.
*/
std::function<void()> EnterConstraint(const PrimExpr& constraint);
TVM_DLL std::function<void()> EnterConstraint(const PrimExpr& constraint);

/*! \brief Flags to enable more computationally-intensive simplifications
*
* These simplifications may be required for specific schedules, but
* would impose too high a compile-time cost to enable by default.
* They can be enabled on an as-needed basis by calling
* `RewriteSimplifier::SetEnabledExtensions` prior to using
* `RewriteSimplifier::operator()`.
*
* Flags are defined as powers of two to allow future expansion. To
* enable multiple extensions, a user should pass a bitwise OR of the
* flags for each desired extension.
*/
enum Extension {
// No extensions enabled
kNone = 0,

/* When simplifying an inequality, attempt to use scope-based knowns.
*
* Example:
* if_then_else(i<j && j<k, i<k, false) => if_then_else(i<j && j<k, true, false)
*/
kTransitivelyProveInequalities = (1 << 0),
};

/*! \brief Enable an optional extension or extensions
*
* \param flags A bitwise OR of all optional extensions that should
* be enabled.
*/
TVM_DLL void SetEnabledExtensions(Extension flags);

/*! \brief Return the currently enabled extensions */
TVM_DLL Extension GetEnabledExtensions() const;

private:
friend class Analyzer;
Expand Down Expand Up @@ -317,6 +351,82 @@ class CanonicalSimplifier {
Impl* impl_;
};

/*! \brief Structure for representing result of known
*
* Values are assigned to allow these flags to be used in bitwise
* operations.
*/
enum class CompareResult : int {
kInconsistent = 0,
kEQ = 1,
kLT = 2,
kLE = 3,
kGT = 4,
kGE = 5,
kNE = 6,
kUnknown = 7
};

inline constexpr CompareResult operator&(CompareResult lhs, CompareResult rhs) {
return CompareResult(static_cast<int>(lhs) & static_cast<int>(rhs));
}
inline constexpr CompareResult operator|(CompareResult lhs, CompareResult rhs) {
return CompareResult(static_cast<int>(lhs) | static_cast<int>(rhs));
}

/*!
* \brief Using previously specified knowns, compare the expressions provided
*
* Given known expressions [(a OP b), (b OP c), ..., (y OP z)], search
* for a known result for `(a OP z)`.
*/
class TransitiveComparisonAnalyzer {
public:
/* \brief Using previously specified knowns, compare the expressions provided
*
* \param lhs The left-hand side of the comparison
*
* \param rhs The right-hand side of the comparison
*
* \return The most specific result that can be proven about the
* comparison. If nothing can be proven, returns kUnknown.
*/
TVM_DLL CompareResult TryCompare(const PrimExpr& lhs, const PrimExpr& rhs);

/*! \brief Bind a variable as being equal to a known expression
*
* \param var The variable of interest.
* \param expr The bound expression
* \param allow_override Whether to allow override of existing information.
*/
TVM_DLL void Bind(const Var& var, const PrimExpr& expr, bool allow_override = false);

/*! \brief Bind a variable as being within a specified range
*
* \param var The variable of interest.
* \param range The known range
* \param allow_override Whether to allow override of existing information.
*/
TVM_DLL void Bind(const Var& var, const Range& range, bool allow_override = false);

/*!
* \brief Update the internal state to enter constraint.
* \param constraint A constraint expression.
*
* \return an exit function that must be called to cleanup the constraint can be nullptr.
*/
TVM_DLL std::function<void()> EnterConstraint(const PrimExpr& constraint);

private:
friend class Analyzer;
friend class ConstraintContext;
TransitiveComparisonAnalyzer();
TVM_DLL ~TransitiveComparisonAnalyzer();
class Impl;
/*! \brief Internal impl */
std::unique_ptr<Impl> impl_;
};

/*!
* \brief Constraint context.
*
Expand Down Expand Up @@ -437,6 +547,8 @@ class TVM_DLL Analyzer {
CanonicalSimplifier canonical_simplify;
/*! \brief sub-analyzer: int set */
IntSetAnalyzer int_set;
/*! \brief sub-analyzer transitive comparisons */
TransitiveComparisonAnalyzer transitive_comparisons;
/*! \brief constructor */
Analyzer();
/*!
Expand Down
3 changes: 3 additions & 0 deletions src/arith/analyzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ void Analyzer::Bind(const Var& var, const PrimExpr& expr, bool allow_override) {
this->rewrite_simplify.Update(var, new_expr, allow_override);
this->canonical_simplify.Update(var, new_expr, allow_override);
this->int_set.Update(var, this->int_set(new_expr), allow_override);
this->transitive_comparisons.Bind(var, expr, allow_override);
}

void Analyzer::Bind(const Var& var, const Range& range, bool allow_override) {
Expand All @@ -54,6 +55,7 @@ void Analyzer::Bind(const Var& var, const Range& range, bool allow_override) {
} else {
this->const_int_bound.Bind(var, range, allow_override);
this->int_set.Bind(var, range, allow_override);
this->transitive_comparisons.Bind(var, range, allow_override);
}
// skip modular_set
// skip rewrite simplify
Expand All @@ -72,6 +74,7 @@ void ConstraintContext::EnterWithScope() {
recovery_functions_.push_back(analyzer_->modular_set.EnterConstraint(constraint_));
recovery_functions_.push_back(analyzer_->rewrite_simplify.EnterConstraint(constraint_));
recovery_functions_.push_back(analyzer_->int_set.EnterConstraint(constraint_));
recovery_functions_.push_back(analyzer_->transitive_comparisons.EnterConstraint(constraint_));
}

void ConstraintContext::ExitWithScope() {
Expand Down
10 changes: 6 additions & 4 deletions src/arith/canonical_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -891,7 +891,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const DivNode* op) {
lhs.CopyOnWrite()->AddToSelf(pconst->value / cval);
} else {
// if 0 <= extra < cval, it means the extra can be eliminated.
if (TryCompare(temp, cval) != kLT) {
if (TryCompare(temp, cval) != CompareResult::kLT) {
lhs.CopyOnWrite()->AddToSelf(SplitDivConst(ToSplitExpr(temp), cval, kTruncDiv), 1);
}
}
Expand Down Expand Up @@ -945,7 +945,8 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const FloorDivNode* op) {
lhs.CopyOnWrite()->AddToSelf(floordiv(pconst->value, cval));
} else {
// if 0 <= extra < cval, it means the extra can be eliminated.
if (!(TryCompare(temp, cval) == kLT && analyzer_->CanProveGreaterEqual(temp, 0))) {
if (!(TryCompare(temp, cval) == CompareResult::kLT &&
analyzer_->CanProveGreaterEqual(temp, 0))) {
lhs.CopyOnWrite()->AddToSelf(SplitDivConst(ToSplitExpr(temp), cval, kFloorDiv), 1);
}
}
Expand Down Expand Up @@ -1052,7 +1053,7 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const ModNode* op) {
return truncmod(temp, c1.Eval());
} else {
// If temp < cval && temp >=0 then can remove the mod.
if (TryCompare(temp, cval) == kLT) {
if (TryCompare(temp, cval) == CompareResult::kLT) {
return temp;
} else {
// contonue to use logic below.
Expand Down Expand Up @@ -1113,7 +1114,8 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const FloorModNode* op) {
return floormod(temp, c1.Eval());
} else {
// If temp < cval && temp >=0 then can remove the mod.
if (TryCompare(temp, cval) == kLT && analyzer_->CanProveGreaterEqual(temp, 0)) {
if (TryCompare(temp, cval) == CompareResult::kLT &&
analyzer_->CanProveGreaterEqual(temp, 0)) {
return temp;
} else {
// contonue to use logic below.
Expand Down
79 changes: 61 additions & 18 deletions src/arith/rewrite_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,42 +71,70 @@ using namespace tir;
// handled by CanonicalSimplifier.
//

CompareResult RewriteSimplifier::Impl::TryCompare(const PrimExpr& x, const PrimExpr& y) {
CompareResult output = CompareResult::kUnknown;

auto is_finished = [&output]() {
return output == CompareResult::kEQ || output == CompareResult::kLT ||
output == CompareResult::kGT;
};

output = CompareResult(output & TryCompareUsingConstIntBounds(x, y));

if (is_finished()) return output;

if (enabled_extensions_ & kTransitivelyProveInequalities) {
output = CompareResult(output & TryCompareUsingKnownInequalities(x, y));
}

return output;
}

CompareResult RewriteSimplifier::Impl::TryCompareUsingConstIntBounds(const PrimExpr& x,
const PrimExpr y) {
return TryCompare(x - y, 0);
}

CompareResult RewriteSimplifier::Impl::TryCompareUsingKnownInequalities(const PrimExpr& x,
const PrimExpr& y) {
return analyzer_->transitive_comparisons.TryCompare(x, y);
}

// try to prove x equals val
RewriteSimplifier::Impl::CompareResult RewriteSimplifier::Impl::TryCompare(const PrimExpr& x,
int64_t val) {
CompareResult RewriteSimplifier::Impl::TryCompare(const PrimExpr& x, int64_t val) {
PrimExpr diff = this->VisitExpr(x);
if (const auto* ptr = diff.as<IntImmNode>()) {
if (ptr->value == val) {
return kEQ;
return CompareResult::kEQ;
} else if (ptr->value > val) {
return kGT;
return CompareResult::kGT;
} else if (ptr->value < val) {
return kLT;
return CompareResult::kLT;
}
}
ConstIntBound dbound = analyzer_->const_int_bound(diff);
if (dbound->min_value == val && dbound->max_value == val) {
return kEQ;
return CompareResult::kEQ;
}
if (dbound->min_value > val) {
return kGT;
return CompareResult::kGT;
}
if (dbound->max_value < val) {
return kLT;
return CompareResult::kLT;
}
if (dbound->min_value >= val) {
return kGE;
return CompareResult::kGE;
}
if (dbound->max_value <= val) {
return kLE;
return CompareResult::kLE;
}
if (val == 0) {
ModularSet dmod = analyzer_->modular_set(diff);
if (dmod->base != 0) {
return kNE;
return CompareResult::kNE;
}
}
return kUnknown;
return CompareResult::kUnknown;
}

void RewriteSimplifier::Impl::Update(const Var& var, const PrimExpr& info, bool can_override) {
Expand Down Expand Up @@ -254,6 +282,12 @@ std::function<void()> RewriteSimplifier::Impl::EnterConstraint(const PrimExpr& c
return frecover;
}

void RewriteSimplifier::Impl::SetEnabledExtensions(Extension flags) { enabled_extensions_ = flags; }

RewriteSimplifier::Extension RewriteSimplifier::Impl::GetEnabledExtensions() const {
return enabled_extensions_;
}

PrimExpr RewriteSimplifier::Impl::VisitExpr_(const SubNode* op) {
PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<SubNode>();
Expand Down Expand Up @@ -1333,10 +1367,11 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const EQNode* op) {
}

if (IsIndexType(op->a.dtype())) {
CompareResult result = TryCompare(op->a - op->b, 0);
if (result == kEQ) {
CompareResult result = TryCompare(op->a, op->b);
if (result == CompareResult::kEQ) {
return make_const(op->dtype, true);
} else if (result == kNE || result == kGT || result == kLT) {
} else if (result == CompareResult::kNE || result == CompareResult::kGT ||
result == CompareResult::kLT) {
return make_const(op->dtype, false);
}
TVM_TRY_REWRITE(x - c1 == 0, x == c1);
Expand Down Expand Up @@ -1382,11 +1417,12 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const LTNode* op) {
}

if (IsIndexType(op->a.dtype())) {
CompareResult result = TryCompare(op->a - op->b, 0);
if (result == kLT) {
CompareResult result = TryCompare(op->a, op->b);
if (result == CompareResult::kLT) {
return make_const(op->dtype, true);
}
if (result == kEQ || result == kGT || result == kGE) {
if (result == CompareResult::kEQ || result == CompareResult::kGT ||
result == CompareResult::kGE) {
return make_const(op->dtype, false);
}

Expand Down Expand Up @@ -1742,6 +1778,13 @@ std::function<void()> RewriteSimplifier::EnterConstraint(const PrimExpr& constra
return impl_->EnterConstraint(constraint);
}

void RewriteSimplifier::SetEnabledExtensions(Extension flags) {
impl_->SetEnabledExtensions(flags);
}
RewriteSimplifier::Extension RewriteSimplifier::GetEnabledExtensions() const {
return impl_->GetEnabledExtensions();
}

RewriteSimplifier::RewriteSimplifier(Analyzer* parent) : impl_(new Impl(parent)) {}

RewriteSimplifier::~RewriteSimplifier() { delete impl_; }
Expand Down
Loading

1 comment on commit fc333f9

@MrJungle1
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello, I found that I updated the tvm version from 0.10.0 to 0.11.1, and the time of auto_schedule tune became longer. It is located that this commit caused it. Have you encountered this problem?

Please sign in to comment.