Skip to content

Commit

Permalink
Cache object refs in loop partitioner instead of object pointers (#6004)
Browse files Browse the repository at this point in the history
* Cache object refs in loop partitioner instead of object pointers

Loop partitioner modifies the IR, which can cause TIR objects to
become dead and be destroyed. To avoid working on junk data cache
object references instead of object pointers.

* Fix format/lint errors
  • Loading branch information
Krzysztof Parzyszek authored Jul 8, 2020
1 parent 1cd56da commit 2875e4c
Showing 1 changed file with 41 additions and 30 deletions.
71 changes: 41 additions & 30 deletions src/tir/transforms/loop_partition.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,18 +58,27 @@ using arith::DeduceBound;
using arith::Intersect;
using arith::IntSet;

using PartitionKey = std::pair<const Object*, bool>;
using PartitionKey = std::pair<PrimExpr, bool>;
struct PartitionKeyHash {
std::size_t operator()(PartitionKey const& k) const noexcept {
std::size_t h1 = std::hash<const Object*>{}(k.first);
std::size_t h1 = ObjectPtrHash{}(k.first); // NOLINT(whitespace/braces)
std::size_t h2 = std::hash<bool>{}(k.second);
return h1 ^ h2;
}
};

struct PartitionKeyEqual {
bool operator()(const PartitionKey& k1, const PartitionKey& k2) const {
// NOLINTNEXTLINE(whitespace/braces)
return k1.second == k2.second && ObjectPtrEqual{}(k1.first, k2.first);
}
};

// Each mapping (cond, cond_value) -> interval represents the fact that
// condition cond is proven to have value cond_value (true or false) in interval.
using Partition = std::unordered_map<PartitionKey, IntSet, PartitionKeyHash>;
using Partition = std::unordered_map<PartitionKey, IntSet, PartitionKeyHash, PartitionKeyEqual>;

using ExpressionSet = std::unordered_set<PrimExpr, ObjectPtrHash, ObjectPtrEqual>;

bool ExprUseVars(PrimExpr expr, const std::unordered_set<const VarNode*>& vars) {
bool success = false;
Expand Down Expand Up @@ -101,7 +110,7 @@ class CandidateSelector final : public StmtExprVisitor {
record_.insert({var, false});
StmtExprVisitor::VisitStmt_(op);
if (record_.at(var) && !no_split_) {
candidates.insert(op);
candidates.insert(GetRef<Stmt>(op));
}
record_.erase(var);
} else {
Expand All @@ -119,7 +128,7 @@ class CandidateSelector final : public StmtExprVisitor {
record_.insert({var.get(), false});
StmtExprVisitor::VisitStmt_(op);
if (record_.at(var.get()) && !no_split_) {
candidates.insert(op);
candidates.insert(GetRef<Stmt>(op));
}
record_.erase(var.get());
return;
Expand Down Expand Up @@ -160,7 +169,7 @@ class CandidateSelector final : public StmtExprVisitor {
}
}

std::unordered_set<const Object*> candidates;
std::unordered_set<Stmt, ObjectPtrHash, ObjectPtrEqual> candidates;

private:
bool in_likely_{false};
Expand Down Expand Up @@ -224,14 +233,14 @@ class PartitionFinder : public StmtExprVisitor {
IntSet interval = DeduceBound(current_var_, cond, hint_map_, relax_map_);
if (!interval.IsNothing()) {
// cond is true within interval
partitions[{cond.get(), true}] = interval;
partitions[{cond, true}] = interval;
}
PrimExpr inverse_cond = InverseCond(cond);
if (inverse_cond.defined()) {
IntSet interval = DeduceBound(current_var_, inverse_cond, hint_map_, relax_map_);
if (!interval.IsNothing()) {
// cond is false within interval
partitions[{cond.get(), false}] = interval;
partitions[{cond, false}] = interval;
}
}
}
Expand Down Expand Up @@ -276,25 +285,25 @@ class PartitionFinder : public StmtExprVisitor {
// Replace the set of conditions given by ps with cond_value (true or false)
class ConditionEliminator : public StmtExprMutator {
public:
explicit ConditionEliminator(const std::unordered_set<const Object*>& ps, bool cond_value = true)
explicit ConditionEliminator(const ExpressionSet& ps, bool cond_value = true)
: ps_(ps), cond_value_(cond_value) {}

PrimExpr VisitExpr(const PrimExpr& e) final {
if (ps_.find(e.get()) != ps_.end()) {
if (ps_.find(e) != ps_.end()) {
return VisitExpr(cond_value_ ? const_true() : const_false());
}
return StmtExprMutator::VisitExpr(e);
}

private:
std::unordered_set<const Object*> ps_;
ExpressionSet ps_;
bool cond_value_;
};

// Insert the partition branch at the innermost thread scope
class ThreadPartitionInserter : public StmtMutator {
public:
explicit ThreadPartitionInserter(const std::unordered_set<const Object*>& ps, PrimExpr cond)
explicit ThreadPartitionInserter(const ExpressionSet& ps, PrimExpr cond)
: ps_(ps), cond_(cond), innermost_thread_scope_(false) {}

Stmt VisitStmt_(const AttrStmtNode* op) final {
Expand All @@ -316,7 +325,7 @@ class ThreadPartitionInserter : public StmtMutator {
}

private:
const std::unordered_set<const Object*>& ps_;
const ExpressionSet& ps_;
PrimExpr cond_;
bool innermost_thread_scope_;
};
Expand All @@ -334,9 +343,9 @@ class LoopPartitioner : public StmtMutator {
}

Stmt VisitStmt_(const ForNode* op) final {
if (selector.candidates.count(op)) {
Stmt s = TryPartition(op, GetRef<Stmt>(op), op->loop_var, op->min, op->min + op->extent - 1,
op->body, false);
auto fs = GetRef<Stmt>(op);
if (selector.candidates.count(fs)) {
Stmt s = TryPartition(fs, op->loop_var, op->min, op->min + op->extent - 1, op->body, false);
if (s.defined()) return s;
}

Expand All @@ -356,8 +365,9 @@ class LoopPartitioner : public StmtMutator {
const IterVarNode* iv = op->node.as<IterVarNode>();
CHECK(iv);
Var var = iv->var;
if (selector.candidates.count(op)) {
Stmt s = TryPartition(op, GetRef<Stmt>(op), var, 0, op->value - 1, op->body, true);
auto as = GetRef<Stmt>(op);
if (selector.candidates.count(as)) {
Stmt s = TryPartition(as, var, 0, op->value - 1, op->body, true);
if (s.defined()) return s;
}

Expand All @@ -378,11 +388,12 @@ class LoopPartitioner : public StmtMutator {
}

private:
Stmt TryPartition(const Object* op, const Stmt& stmt, Var var, PrimExpr min, PrimExpr max,
Stmt body, bool partition_thread_scope);
Stmt TryPartition(const Stmt& stmt, Var var, PrimExpr min, PrimExpr max, Stmt body,
bool partition_thread_scope);

std::pair<IntSet, std::unordered_set<const Object*>> GetIntervalAndCondset(
const Partition& partitions, const arith::IntervalSet& for_interval, bool cond_value);
std::pair<IntSet, ExpressionSet> GetIntervalAndCondset(const Partition& partitions,
const arith::IntervalSet& for_interval,
bool cond_value);

inline Stmt MakeFor(const Object* op, PrimExpr extent, Stmt body);

Expand All @@ -395,10 +406,10 @@ class LoopPartitioner : public StmtMutator {

// Returns an interval (in the first component) in which all the conditions
// given in the second component provably have value given by cond_value
std::pair<IntSet, std::unordered_set<const Object*>> LoopPartitioner::GetIntervalAndCondset(
std::pair<IntSet, ExpressionSet> LoopPartitioner::GetIntervalAndCondset(
const Partition& partitions, const arith::IntervalSet& for_interval, bool cond_value) {
Array<IntSet> sets;
std::unordered_set<const Object*> cond_set;
ExpressionSet cond_set;

for (const auto& kv : partitions) {
if (kv.first.second == cond_value) {
Expand Down Expand Up @@ -460,8 +471,8 @@ std::pair<IntSet, std::unordered_set<const Object*>> LoopPartitioner::GetInterva
* which will eventually be simplified to empty code. And because only one loop was generated
* from loop 2 we stop recursing.
*/
Stmt LoopPartitioner::TryPartition(const Object* node, const Stmt& stmt, Var var, PrimExpr min,
PrimExpr max, Stmt body, bool partition_thread_scope) {
Stmt LoopPartitioner::TryPartition(const Stmt& stmt, Var var, PrimExpr min, PrimExpr max, Stmt body,
bool partition_thread_scope) {
using namespace arith;
// include hint of var.
hint_map_.insert({var.get(), IntSet::Interval(min, max)});
Expand All @@ -475,7 +486,7 @@ Stmt LoopPartitioner::TryPartition(const Object* node, const Stmt& stmt, Var var
arith::IntervalSet for_interval(min, max);
bool cond_value;
IntSet middle_interval;
std::unordered_set<const Object*> cond_set;
ExpressionSet cond_set;
// find an interval in which all conditions on var are true
std::tie(middle_interval, cond_set) =
GetIntervalAndCondset(finder.partitions, for_interval, true);
Expand Down Expand Up @@ -516,7 +527,7 @@ Stmt LoopPartitioner::TryPartition(const Object* node, const Stmt& stmt, Var var
}
if (!partition_thread_scope) {
Stmt pre_body = Substitute(body, {{Var{var}, var + min}});
pre_stmt = MakeFor(node, body_begin - min, pre_body);
pre_stmt = MakeFor(stmt.get(), body_begin - min, pre_body);
}
}
} else {
Expand All @@ -541,7 +552,7 @@ Stmt LoopPartitioner::TryPartition(const Object* node, const Stmt& stmt, Var var
}
if (!partition_thread_scope) {
Stmt post_body = Substitute(body, {{Var{var}, var + post_doubt_begin}});
post_stmt = MakeFor(node, max - post_doubt_begin + 1, post_body);
post_stmt = MakeFor(stmt.get(), max - post_doubt_begin + 1, post_body);
}
}
} else {
Expand All @@ -557,7 +568,7 @@ Stmt LoopPartitioner::TryPartition(const Object* node, const Stmt& stmt, Var var
// [body_begin, post_doubt_begin)
Stmt simplified_body = ConditionEliminator(cond_set, cond_value)(body);
Stmt new_body = Substitute(simplified_body, {{Var{var}, var + body_begin}});
mid_stmt = MakeFor(node, post_doubt_begin - body_begin, new_body);
mid_stmt = MakeFor(stmt.get(), post_doubt_begin - body_begin, new_body);

// Recurse for each non-empty subrange only if there are at least
// two non-empty subranges
Expand Down

0 comments on commit 2875e4c

Please sign in to comment.