From 804208ce5f16724483f38b08eafe51908ebdbde1 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 3 Aug 2022 10:39:17 -0500 Subject: [PATCH] [Arith][TIR] IntSetAnalyzer, delay intersection of IntSet until use Follow-up from https://github.com/apache/tvm/pull/11970, to improve performance. In the initial implementation, the `analyzer->int_set` would compute the intersection of all scope-based constraints when entering the scope, even if they weren't actually used. This commit delays the call to `Intersect` until required, following the same behavior as `ConstIntBound`. --- src/arith/int_set.cc | 126 ++++++++++++++++++------------------------- 1 file changed, 51 insertions(+), 75 deletions(-) diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index 18aaa8875b37f..eddadf39ee47b 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -362,8 +362,13 @@ using namespace tir; // We might use better set analysis in the future to replace the intervalset. class IntervalSetEvaluator : public ExprFunctor { public: - IntervalSetEvaluator(Analyzer* analyzer, const Map& dom_map, bool eval_vec = false) - : analyzer_(analyzer), dom_map_(dom_map), eval_vec_(eval_vec) {} + IntervalSetEvaluator(Analyzer* analyzer, const Map& dom_map, + const std::vector>& dom_constraints = {}, + bool eval_vec = false) + : analyzer_(analyzer), + dom_map_(dom_map), + dom_constraints_(dom_constraints), + eval_vec_(eval_vec) {} IntervalSet Eval(const PrimExpr& val) { return this->VisitExpr(val); } // evaluate and relax the set @@ -383,18 +388,38 @@ class IntervalSetEvaluator : public ExprFunctor { IntervalSet VisitExpr_(const VarNode* op) final { Var var = GetRef(op); + + Array values; + for (const auto& constraint : dom_constraints_) { + if (var.same_as(constraint.first)) { + values.push_back(constraint.second); + } + } + auto it = dom_map_.find(var); if (it != dom_map_.end()) { - IntervalSet res = ToIntervalSet((*it).second); - if (res->min_value.same_as(var) && res->max_value.same_as(var)) { - return res; - } - // recursively evaluate mapped result - // in case the domain contains variables to be relaxed. - return Eval(res); - } else { + values.push_back((*it).second); + } + + if (values.empty()) { return IntervalSet::SinglePoint(var); } + + IntSet intersection = [&]() { + if (values.size() == 1) { + return values.front(); + } else { + return Intersect(values); + } + }(); + + IntervalSet res = ToIntervalSet(intersection); + if (res->min_value.same_as(var) && res->max_value.same_as(var)) { + return res; + } + // recursively evaluate mapped result + // in case the domain contains variables to be relaxed. + return Eval(res); } IntervalSet VisitExpr_(const AddNode* op) final { return VisitBinaryExpr_(op); } @@ -517,6 +542,7 @@ class IntervalSetEvaluator : public ExprFunctor { // analyzer Analyzer* analyzer_; const Map& dom_map_; + const std::vector>& dom_constraints_; bool eval_vec_{false}; }; @@ -525,11 +551,11 @@ class IntSetAnalyzer::Impl { explicit Impl(Analyzer* analyzer) : analyzer_(analyzer) {} IntSet Eval(const PrimExpr& expr, const Map& dom_map) const { - return IntervalSetEvaluator(analyzer_, dom_map).Eval(expr); + return IntervalSetEvaluator(analyzer_, dom_map, {}).Eval(expr); } IntSet Eval(const PrimExpr& expr) const { - return IntervalSetEvaluator(analyzer_, GetCurrentBounds(), true).Eval(expr); + return IntervalSetEvaluator(analyzer_, dom_map_, constraints_, true).Eval(expr); } void Bind(const Var& var, const Range& range, bool allow_override) { @@ -543,10 +569,6 @@ class IntSetAnalyzer::Impl { std::function SuppressConstraints(); private: - // Get the current variable bounds, including both global bounds and - // scope-dependent bounds. - Map GetCurrentBounds() const; - // Utility function to split a boolean condition into the domain // bounds implied by that condition. static std::vector> DetectBoundInfo(const PrimExpr& cond); @@ -558,9 +580,11 @@ class IntSetAnalyzer::Impl { // ranges) Map dom_map_; - // Map of variables to implicit scope-dependent bounds (e.g. inside - // the body of an if-statement) - Map constraints_; + // List of implicit scope-dependent bounds (e.g. inside the body of + // an if-statement). Maintained as a list of constraints, rather + // than as a `Map`, to avoid computing an Intersection + // until required. + std::vector> constraints_; // Whether scope-based analysis should be temporarily disabled bool use_scoped_constraints_{true}; @@ -608,29 +632,6 @@ void IntSetAnalyzer::Impl::Bind(const Var& var, const PrimExpr& expr, bool can_o Update(var, Eval(expr), can_override); } -Map IntSetAnalyzer::Impl::GetCurrentBounds() const { - // If either constraints_ or dom_map_ is empty, return the other to - // avoid constructing a new map. - if (constraints_.empty() || !use_scoped_constraints_) { - return dom_map_; - } else if (dom_map_.empty()) { - return constraints_; - } - - // If neither is empty, construct a merged domain map with - // information from both sources. - Map merged = dom_map_; - for (const auto& pair : constraints_) { - auto it = merged.find(pair.first); - if (it == merged.end()) { - merged.Set(pair.first, pair.second); - } else { - merged.Set(pair.first, Intersect({pair.second, (*it).second})); - } - } - return merged; -} - std::vector> IntSetAnalyzer::Impl::DetectBoundInfo( const PrimExpr& constraint) { PVar x; @@ -672,41 +673,16 @@ std::function IntSetAnalyzer::EnterConstraint(const PrimExpr& constraint std::function IntSetAnalyzer::SuppressConstraints() { return impl_->SuppressConstraints(); } std::function IntSetAnalyzer::Impl::EnterConstraint(const PrimExpr& constraint) { - Map cached_values; - auto bounds = DetectBoundInfo(constraint); if (bounds.size() == 0) return nullptr; - // Collect the current values of each var that is changes by this - // constraint. - for (const auto& pair : bounds) { - auto it = constraints_.find(pair.first); - if (it == constraints_.end()) { - cached_values.Set(pair.first, IntSet()); - } else { - cached_values.Set(pair.first, (*it).second); - } - } - - // Update all constraints - for (const auto& pair : bounds) { - auto it = constraints_.find(pair.first); - if (it == constraints_.end()) { - constraints_.Set(pair.first, pair.second); - } else { - constraints_.Set(pair.first, Intersect({pair.second, (*it).second})); - } - } - - auto frecover = [cached_values, this]() { - for (const auto& it : cached_values) { - if (it.second.defined()) { - constraints_.Set(it.first, it.second); - } else { - constraints_.erase(it.first); - } - } + size_t old_size = constraints_.size(); + constraints_.insert(constraints_.end(), bounds.begin(), bounds.end()); + size_t new_size = constraints_.size(); + auto frecover = [old_size, new_size, this]() { + ICHECK_EQ(constraints_.size(), new_size); + constraints_.resize(old_size); }; return frecover; } @@ -975,13 +951,13 @@ Map ConvertDomMap(const std::unordered_map& IntSet EvalSet(PrimExpr e, const Map& dom_map) { Analyzer ana; - return IntervalSetEvaluator(&ana, dom_map, false).Eval(e); + return IntervalSetEvaluator(&ana, dom_map, {}, false).Eval(e); } IntSet IntSet::Vector(PrimExpr x) { Analyzer ana; Map dmap; - return IntervalSetEvaluator(&ana, dmap, true).Eval(x); + return IntervalSetEvaluator(&ana, dmap, {}, true).Eval(x); } IntSet EvalSet(PrimExpr e, const Map& dom_map) {