Skip to content

Commit

Permalink
[3] Review comment handled
Browse files Browse the repository at this point in the history
  • Loading branch information
ANSHUMAN TRIPATHY authored and ANSHUMAN TRIPATHY committed Jul 25, 2020
1 parent 73e2d02 commit cc36ae7
Showing 1 changed file with 68 additions and 78 deletions.
146 changes: 68 additions & 78 deletions src/tir/transforms/hoist_if_then_else.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,15 +101,16 @@ class HoistCandidateSelector final : public StmtExprVisitor {
HoistCandidateSelector() { InitRecorder(); }

void VisitStmt_(const ForNode* op) final {
// Check if it is first for loop, then start the recorder
if (!RecordingComplete()) {
StartOrAddRecord(op);
StmtExprVisitor::VisitStmt_(op);
RemoveRecord(op);
// If already recording complete,
// then stop tracing
if (RecordingComplete()) {
return;
}

// Check if it is first for loop, then start the recorder
StartOrAddRecord(op);
StmtExprVisitor::VisitStmt_(op);
RemoveRecord(op);
}

void VisitStmt_(const SeqStmtNode* op) final {
Expand All @@ -135,66 +136,65 @@ class HoistCandidateSelector final : public StmtExprVisitor {
}

void VisitStmt_(const IfThenElseNode* op) final {
if (IsRecordingOn()) {
is_if_cond = true;
StmtExprVisitor::VisitExpr(op->condition);
is_if_cond = false;

if (CheckValidIf()) {
// Check corresponding for loop
bool match_found = false;
size_t match_for_loop_pos = 0;
for (auto var : if_var_list_) {
for (size_t i = 0; i < ordered_for_list_.size() - 1; ++i) {
if (ordered_for_list_[i] == var_for_map_[var]) {
if (match_for_loop_pos < i) {
match_for_loop_pos = i;
}
match_found = true;
break;
if (!IsRecordingOn()) {
StmtExprVisitor::VisitStmt_(op);
return;
}

is_if_cond_ = true;
StmtExprVisitor::VisitExpr(op->condition);
is_if_cond_ = false;

if (CheckValidIf()) {
// Check corresponding for loop
bool match_found = false;
size_t match_for_loop_pos = 0;
for (auto var : if_var_list_) {
for (size_t i = 0; i < ordered_for_list_.size() - 1; ++i) {
if (ordered_for_list_[i] == var_for_map_[var]) {
if (match_for_loop_pos < i) {
match_for_loop_pos = i;
}
match_found = true;
break;
}
}
// If none of the for loop has the matching loop variable as if condition,
// then the if node need to be hoisted on top of all, provided no parent loop exists.
int target_for_pos = match_found ? match_for_loop_pos + 1 : 0;

// Check if target for loop is not the parent of current if node
if (!IsParentForLoop(target_for_pos)) {
StopAndAddRecord(ordered_for_list_[target_for_pos], op);
}
}
// If none of the for loop has the matching loop variable as if condition,
// then the if node need to be hoisted on top of all, provided no parent loop exists.
int target_for_pos = match_found ? match_for_loop_pos + 1 : 0;

// Check if target for loop is not the parent of current if node
if (!IsParentForLoop(target_for_pos)) {
StopAndAddRecord(ordered_for_list_[target_for_pos], op);
if_var_list_.clear();
return;
}
}
if_var_list_.clear();
StmtExprVisitor::VisitStmt_(op);
StopRecording();
return;
}

StmtExprVisitor::VisitStmt_(op);
}

void VisitExpr_(const VarNode* op) final {
if (is_if_cond) {
if (is_if_cond_) {
if_var_list_.emplace_back(op);
}
}

HoistForIfTuple hoist_for_if_recorder;

void ResetRecorder() {
if (is_recorder_on) {
if (is_recorder_on_) {
CHECK_GT(ordered_for_list_.size(), 0);
is_recorder_on = false;
is_recorder_on_ = false;
}
ordered_for_list_.clear();
var_for_map_.clear();
hoist_for_if_recorder = std::make_tuple(false, nullptr, nullptr);
}

bool RecordingComplete() {
if (std::get<0>(hoist_for_if_recorder)) return true;
return false;
}
bool RecordingComplete() { return std::get<0>(hoist_for_if_recorder); }

const ForNode* GetTargetForNode() { return std::get<1>(hoist_for_if_recorder); }

Expand All @@ -204,13 +204,7 @@ class HoistCandidateSelector final : public StmtExprVisitor {
bool CheckValidIf() {
// If no if var list is present, then all the condition vars are possibly from AttrStmt, so stop
// hoisting
if (if_var_list_.size() == 0) {
return false;
}
if (CheckAttrVar()) {
return false;
}
return true;
return ((!if_var_list_.empty()) && (!CheckAttrVar()));
}

bool IsParentForLoop(int loop_pos) {
Expand All @@ -234,12 +228,12 @@ class HoistCandidateSelector final : public StmtExprVisitor {

void InitRecorder() { hoist_for_if_recorder = std::make_tuple(false, nullptr, nullptr); }

void StopRecording() { is_recorder_on = false; }
void StopRecording() { is_recorder_on_ = false; }

bool IsRecordingOn() { return is_recorder_on; }
bool IsRecordingOn() { return is_recorder_on_; }

void StartOrAddRecord(const ForNode* op) {
is_recorder_on = true;
is_recorder_on_ = true;
if (!var_for_map_.count(op->loop_var.get())) {
var_for_map_.insert({op->loop_var.get(), op});
}
Expand Down Expand Up @@ -283,33 +277,30 @@ class HoistCandidateSelector final : public StmtExprVisitor {
}

std::vector<const ForNode*> ordered_for_list_;

std::vector<const VarNode*> if_var_list_;

std::unordered_set<const VarNode*> attr_var_list_;

VarForMap var_for_map_;

bool is_if_cond{false};
bool is_recorder_on{false};
bool is_if_cond_{false};
bool is_recorder_on_{false};
};

class IfThenElseHoister : public StmtMutator {
public:
IfThenElseHoister() : hoist_selector(HoistCandidateSelector()) {}
IfThenElseHoister() : hoist_selector_(HoistCandidateSelector()) {}

Stmt VisitAndMutate(Stmt stmt) {
hoist_selector(stmt);
hoist_selector_(stmt);
Stmt stmt_copy = std::move(stmt);

while (hoist_selector.RecordingComplete()) {
target_for = hoist_selector.GetTargetForNode();
target_if = hoist_selector.GetTargetIfNode();
while (hoist_selector_.RecordingComplete()) {
target_for_ = hoist_selector_.GetTargetForNode();
target_if_ = hoist_selector_.GetTargetIfNode();

stmt_copy = operator()(stmt_copy);

hoist_selector.ResetRecorder();
hoist_selector(stmt_copy);
hoist_selector_.ResetRecorder();
hoist_selector_(stmt_copy);
}

// Support SSA Form
Expand All @@ -318,24 +309,24 @@ class IfThenElseHoister : public StmtMutator {
}

Stmt VisitStmt_(const ForNode* op) final {
if ((!is_updating) && (target_for == op)) {
is_updating = true;
is_then_case = true;
if ((!is_updating_) && (target_for_ == op)) {
is_updating_ = true;
is_then_case_ = true;
Stmt then_case = StmtMutator::VisitStmt_(op);
is_then_case = false;
is_then_case_ = false;
Stmt else_case = Stmt();
if (target_if->else_case.defined()) {
if (target_if_->else_case.defined()) {
else_case = StmtMutator::VisitStmt_(op);
}
is_updating = false;
return IfThenElse(target_if->condition, then_case, else_case);
is_updating_ = false;
return IfThenElse(target_if_->condition, then_case, else_case);
}
return StmtMutator::VisitStmt_(op);
}

Stmt VisitStmt_(const IfThenElseNode* op) final {
if (is_updating && (target_if == op)) {
if (is_then_case) {
if (is_updating_ && (target_if_ == op)) {
if (is_then_case_) {
return StmtMutator::VisitStmt(op->then_case);
} else if (op->else_case.defined()) {
return StmtMutator::VisitStmt(op->else_case);
Expand All @@ -344,13 +335,12 @@ class IfThenElseHoister : public StmtMutator {
return StmtMutator::VisitStmt_(op);
}

const ForNode* target_for;
const IfThenElseNode* target_if;

private:
bool is_updating{false};
bool is_then_case{false};
HoistCandidateSelector hoist_selector;
bool is_updating_{false};
bool is_then_case_{false};
HoistCandidateSelector hoist_selector_;
const ForNode* target_for_;
const IfThenElseNode* target_if_;
};

Stmt HoistIfThenElse(Stmt stmt) { return IfThenElseHoister().VisitAndMutate(stmt); }
Expand Down

0 comments on commit cc36ae7

Please sign in to comment.