Skip to content

Commit

Permalink
[Arith][TIR] IntSetAnalyzer, delay intersection of IntSet until use (#…
Browse files Browse the repository at this point in the history
…12821)

Follow-up from #11970, to improve
performance.  In the initial implementation, the `analyzer->int_set`
would compute the intersection of all scope-based constraints when
entering the scope, even if they weren't actually used.  This commit
delays the call to `Intersect` until required, following the same
behavior as `ConstIntBound`.
  • Loading branch information
Lunderberg authored Sep 19, 2022
1 parent 2af9b90 commit e30ac71
Showing 1 changed file with 52 additions and 74 deletions.
126 changes: 52 additions & 74 deletions src/arith/int_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -362,8 +362,13 @@ using namespace tir;
// We might use better set analysis in the future to replace the intervalset.
class IntervalSetEvaluator : public ExprFunctor<IntervalSet(const PrimExpr&)> {
public:
IntervalSetEvaluator(Analyzer* analyzer, const Map<Var, IntSet>& dom_map, bool eval_vec = false)
: analyzer_(analyzer), dom_map_(dom_map), eval_vec_(eval_vec) {}
IntervalSetEvaluator(Analyzer* analyzer, const Map<Var, IntSet>& dom_map,
const std::vector<std::pair<Var, IntSet>>* dom_constraints = nullptr,
bool eval_vec = false)
: analyzer_(analyzer),
dom_map_(dom_map),
dom_constraints_(dom_constraints),
eval_vec_(eval_vec) {}

IntervalSet Eval(const PrimExpr& val) { return this->VisitExpr(val); }
// evaluate and relax the set
Expand All @@ -383,18 +388,40 @@ class IntervalSetEvaluator : public ExprFunctor<IntervalSet(const PrimExpr&)> {

IntervalSet VisitExpr_(const VarNode* op) final {
Var var = GetRef<Var>(op);

Array<IntSet> values;
if (dom_constraints_) {
for (const auto& constraint : *dom_constraints_) {
if (var.same_as(constraint.first)) {
values.push_back(constraint.second);
}
}
}

auto it = dom_map_.find(var);
if (it != dom_map_.end()) {
IntervalSet res = ToIntervalSet((*it).second);
if (res->min_value.same_as(var) && res->max_value.same_as(var)) {
return res;
}
// recursively evaluate mapped result
// in case the domain contains variables to be relaxed.
return Eval(res);
} else {
values.push_back((*it).second);
}

if (values.empty()) {
return IntervalSet::SinglePoint(var);
}

IntSet intersection = [&]() {
if (values.size() == 1) {
return values.front();
} else {
return Intersect(values);
}
}();

IntervalSet res = ToIntervalSet(intersection);
if (res->min_value.same_as(var) && res->max_value.same_as(var)) {
return res;
}
// recursively evaluate mapped result
// in case the domain contains variables to be relaxed.
return Eval(res);
}

IntervalSet VisitExpr_(const AddNode* op) final { return VisitBinaryExpr_<Add>(op); }
Expand Down Expand Up @@ -517,6 +544,7 @@ class IntervalSetEvaluator : public ExprFunctor<IntervalSet(const PrimExpr&)> {
// analyzer
Analyzer* analyzer_;
const Map<Var, IntSet>& dom_map_;
const std::vector<std::pair<Var, IntSet>>* dom_constraints_;
bool eval_vec_{false};
};

Expand All @@ -529,7 +557,7 @@ class IntSetAnalyzer::Impl {
}

IntSet Eval(const PrimExpr& expr) const {
return IntervalSetEvaluator(analyzer_, GetCurrentBounds(), true).Eval(expr);
return IntervalSetEvaluator(analyzer_, dom_map_, &dom_constraints_, true).Eval(expr);
}

void Bind(const Var& var, const Range& range, bool allow_override) {
Expand All @@ -541,10 +569,6 @@ class IntSetAnalyzer::Impl {
std::function<void()> EnterConstraint(const PrimExpr& constraint);

private:
// Get the current variable bounds, including both global bounds and
// scope-dependent bounds.
Map<Var, IntSet> GetCurrentBounds() const;

// Utility function to split a boolean condition into the domain
// bounds implied by that condition.
static std::vector<std::pair<Var, IntSet>> DetectBoundInfo(const PrimExpr& cond);
Expand All @@ -556,9 +580,11 @@ class IntSetAnalyzer::Impl {
// ranges)
Map<Var, IntSet> dom_map_;

// Map of variables to implicit scope-dependent bounds (e.g. inside
// the body of an if-statement)
Map<Var, IntSet> constraints_;
// List of implicit scope-dependent bounds (e.g. inside the body of
// an if-statement). Maintained as a list of constraints, rather
// than as a `Map<Var,IntSet>`, to avoid computing an Intersection
// until required.
std::vector<std::pair<Var, IntSet>> dom_constraints_;
};

IntSetAnalyzer::IntSetAnalyzer(Analyzer* parent) : impl_(new Impl(parent)) {}
Expand Down Expand Up @@ -603,29 +629,6 @@ void IntSetAnalyzer::Impl::Bind(const Var& var, const PrimExpr& expr, bool can_o
Update(var, Eval(expr), can_override);
}

Map<Var, IntSet> IntSetAnalyzer::Impl::GetCurrentBounds() const {
// If either constraints_ or dom_map_ is empty, return the other to
// avoid constructing a new map.
if (constraints_.empty()) {
return dom_map_;
} else if (dom_map_.empty()) {
return constraints_;
}

// If neither is empty, construct a merged domain map with
// information from both sources.
Map<Var, IntSet> merged = dom_map_;
for (const auto& pair : constraints_) {
auto it = merged.find(pair.first);
if (it == merged.end()) {
merged.Set(pair.first, pair.second);
} else {
merged.Set(pair.first, Intersect({pair.second, (*it).second}));
}
}
return merged;
}

std::vector<std::pair<Var, IntSet>> IntSetAnalyzer::Impl::DetectBoundInfo(
const PrimExpr& constraint) {
PVar<Var> x;
Expand Down Expand Up @@ -665,41 +668,16 @@ std::function<void()> IntSetAnalyzer::EnterConstraint(const PrimExpr& constraint
}

std::function<void()> IntSetAnalyzer::Impl::EnterConstraint(const PrimExpr& constraint) {
Map<Var, IntSet> cached_values;

auto bounds = DetectBoundInfo(constraint);

if (bounds.size() == 0) return nullptr;

// Collect the current values of each var that is changes by this
// constraint.
for (const auto& pair : bounds) {
auto it = constraints_.find(pair.first);
if (it == constraints_.end()) {
cached_values.Set(pair.first, IntSet());
} else {
cached_values.Set(pair.first, (*it).second);
}
}

// Update all constraints
for (const auto& pair : bounds) {
auto it = constraints_.find(pair.first);
if (it == constraints_.end()) {
constraints_.Set(pair.first, pair.second);
} else {
constraints_.Set(pair.first, Intersect({pair.second, (*it).second}));
}
}

auto frecover = [cached_values, this]() {
for (const auto& it : cached_values) {
if (it.second.defined()) {
constraints_.Set(it.first, it.second);
} else {
constraints_.erase(it.first);
}
}
size_t old_size = dom_constraints_.size();
dom_constraints_.insert(dom_constraints_.end(), bounds.begin(), bounds.end());
size_t new_size = dom_constraints_.size();
auto frecover = [old_size, new_size, this]() {
ICHECK_EQ(dom_constraints_.size(), new_size);
dom_constraints_.resize(old_size);
};
return frecover;
}
Expand Down Expand Up @@ -960,13 +938,13 @@ Map<Var, IntSet> ConvertDomMap(const std::unordered_map<const VarNode*, IntSet>&

IntSet EvalSet(PrimExpr e, const Map<Var, IntSet>& dom_map) {
Analyzer ana;
return IntervalSetEvaluator(&ana, dom_map, false).Eval(e);
return IntervalSetEvaluator(&ana, dom_map, {}, false).Eval(e);
}

IntSet IntSet::Vector(PrimExpr x) {
Analyzer ana;
Map<Var, IntSet> dmap;
return IntervalSetEvaluator(&ana, dmap, true).Eval(x);
return IntervalSetEvaluator(&ana, dmap, {}, true).Eval(x);
}

IntSet EvalSet(PrimExpr e, const Map<IterVar, IntSet>& dom_map) {
Expand Down

0 comments on commit e30ac71

Please sign in to comment.