Skip to content

Commit

Permalink
[TIR] Use Optional<Stmt> for IfThenElseNode::else_case (#13218)
Browse files Browse the repository at this point in the history
This parameter is nullable for cases where the else block isn't
present.  Previously, it was represented as a `Stmt` holding
`nullptr`, because
`IfThenElse` (#3533) predates the
`Optional` utility (#5314).  This
commit updates to use `Optional<Stmt>` instead, and updates all usages
of `else_case`.
  • Loading branch information
Lunderberg authored Oct 29, 2022
1 parent da76587 commit 9dd0225
Show file tree
Hide file tree
Showing 25 changed files with 58 additions and 60 deletions.
4 changes: 2 additions & 2 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -824,7 +824,7 @@ class IfThenElseNode : public StmtNode {
/*! \brief The branch to be executed when condition is true. */
Stmt then_case;
/*! \brief The branch to be executed when condition is false, can be null. */
Stmt else_case;
Optional<Stmt> else_case;

void VisitAttrs(AttrVisitor* v) {
v->Visit("condition", &condition);
Expand Down Expand Up @@ -854,7 +854,7 @@ class IfThenElseNode : public StmtNode {
*/
class IfThenElse : public Stmt {
public:
TVM_DLL IfThenElse(PrimExpr condition, Stmt then_case, Stmt else_case = Stmt(),
TVM_DLL IfThenElse(PrimExpr condition, Stmt then_case, Optional<Stmt> else_case = NullOpt,
Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(IfThenElse, Stmt, IfThenElseNode);
Expand Down
12 changes: 5 additions & 7 deletions src/arith/ir_mutator_with_analyzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,21 +71,19 @@ Stmt IRMutatorWithAnalyzer::VisitStmt_(const IfThenElseNode* op) {
}
}

Stmt then_case, else_case;
Stmt then_case;
Optional<Stmt> else_case;
{
With<ConstraintContext> ctx(analyzer_, real_condition);
then_case = this->VisitStmt(op->then_case);
}
if (op->else_case.defined()) {
if (op->else_case) {
With<ConstraintContext> ctx(analyzer_, analyzer_->rewrite_simplify(Not(real_condition)));
else_case = this->VisitStmt(op->else_case);
else_case = this->VisitStmt(op->else_case.value());
}
if (is_one(real_condition)) return then_case;
if (is_zero(real_condition)) {
if (else_case.defined()) {
return else_case;
}
return Evaluate(0);
return else_case.value_or(Evaluate(0));
}

if (condition.same_as(op->condition) && then_case.same_as(op->then_case) &&
Expand Down
4 changes: 2 additions & 2 deletions src/arith/ir_visitor_with_analyzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,9 @@ void IRVisitorWithAnalyzer::VisitStmt_(const IfThenElseNode* op) {
With<ConstraintContext> constraint(&analyzer_, real_condition);
this->VisitStmt(op->then_case);
}
if (op->else_case.defined()) {
if (op->else_case) {
With<ConstraintContext> constraint(&analyzer_, analyzer_.rewrite_simplify(Not(real_condition)));
this->VisitStmt(op->else_case);
this->VisitStmt(op->else_case.value());
}
}

Expand Down
4 changes: 2 additions & 2 deletions src/contrib/hybrid/codegen_hybrid.cc
Original file line number Diff line number Diff line change
Expand Up @@ -381,11 +381,11 @@ void CodeGenHybrid::VisitStmt_(const IfThenElseNode* op) {
PrintStmt(op->then_case);
indent_ -= tab_;

if (!is_noop(op->else_case)) {
if (op->else_case && !is_noop(op->else_case.value())) {
PrintIndent();
stream << "else:\n";
indent_ += tab_;
PrintStmt(op->else_case);
PrintStmt(op->else_case.value());
indent_ -= tab_;
}
}
Expand Down
4 changes: 2 additions & 2 deletions src/printer/tir_text_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -572,8 +572,8 @@ Doc TIRTextPrinter::VisitStmt_(const DeclBufferNode* op) {
Doc TIRTextPrinter::VisitStmt_(const IfThenElseNode* op) {
Doc doc;
doc << "if " << Print(op->condition) << PrintBody(op->then_case);
if (!is_one(op->condition) && op->else_case.defined()) {
doc << " else" << PrintBody(op->else_case);
if (!is_one(op->condition) && op->else_case) {
doc << " else" << PrintBody(op->else_case.value());
}
return doc;
}
Expand Down
4 changes: 2 additions & 2 deletions src/printer/tvmscript_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1244,9 +1244,9 @@ Doc TVMScriptPrinter::VisitStmt_(const IfThenElseNode* op) {
Doc doc;
doc << "if " << Print(op->condition) << ":";
doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->then_case));
if (!is_one(op->condition) && op->else_case.defined()) {
if (!is_one(op->condition) && op->else_case) {
doc << Doc::NewLine();
doc << "else:" << Doc::Indent(4, Doc::NewLine() << PrintBody(op->else_case));
doc << "else:" << Doc::Indent(4, Doc::NewLine() << PrintBody(op->else_case.value()));
}
return doc;
}
Expand Down
4 changes: 2 additions & 2 deletions src/target/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1759,14 +1759,14 @@ void CodeGenLLVM::VisitStmt_(const IfThenElseNode* op) {
llvm::LLVMContext* ctx = llvm_target_->GetContext();
auto* then_block = llvm::BasicBlock::Create(*ctx, "if_then", function_);
auto* end_block = llvm::BasicBlock::Create(*ctx, "if_end", function_);
if (op->else_case.defined()) {
if (op->else_case) {
auto* else_block = llvm::BasicBlock::Create(*ctx, "if_else", function_);
builder_->CreateCondBr(cond, then_block, else_block);
builder_->SetInsertPoint(then_block);
this->VisitStmt(op->then_case);
builder_->CreateBr(end_block);
builder_->SetInsertPoint(else_block);
this->VisitStmt(op->else_case);
this->VisitStmt(op->else_case.value());
builder_->CreateBr(end_block);
} else {
builder_->CreateCondBr(cond, then_block, end_block, md_very_likely_branch_);
Expand Down
4 changes: 2 additions & 2 deletions src/target/source/codegen_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -936,11 +936,11 @@ void CodeGenC::VisitStmt_(const IfThenElseNode* op) {
PrintStmt(op->then_case);
this->EndScope(then_scope);

if (op->else_case.defined()) {
if (op->else_case) {
PrintIndent();
stream << "} else {\n";
int else_scope = BeginScope();
PrintStmt(op->else_case);
PrintStmt(op->else_case.value());
this->EndScope(else_scope);
}
PrintIndent();
Expand Down
4 changes: 2 additions & 2 deletions src/target/spirv/codegen_spirv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -628,7 +628,7 @@ void CodeGenSPIRV::VisitStmt_(const IfThenElseNode* op) {
spirv::Value cond = MakeValue(op->condition);
spirv::Label then_label = builder_->NewLabel();
spirv::Label merge_label = builder_->NewLabel();
if (op->else_case.defined()) {
if (op->else_case) {
spirv::Label else_label = builder_->NewLabel();
builder_->MakeInst(spv::OpSelectionMerge, merge_label, spv::SelectionControlMaskNone);
builder_->MakeInst(spv::OpBranchConditional, cond, then_label, else_label);
Expand All @@ -638,7 +638,7 @@ void CodeGenSPIRV::VisitStmt_(const IfThenElseNode* op) {
builder_->MakeInst(spv::OpBranch, merge_label);
// else block
builder_->StartLabel(else_label);
this->VisitStmt(op->else_case);
this->VisitStmt(op->else_case.value());
builder_->MakeInst(spv::OpBranch, merge_label);
} else {
builder_->MakeInst(spv::OpSelectionMerge, merge_label, spv::SelectionControlMaskNone);
Expand Down
4 changes: 2 additions & 2 deletions src/target/stackvm/codegen_stackvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -475,13 +475,13 @@ void CodeGenStackVM::VisitStmt_(const IfThenElseNode* op) {
int64_t else_jump = this->PushOp(StackVM::RJUMP_IF_FALSE, 0);
this->PushOp(StackVM::POP);
this->Push(op->then_case);
if (op->else_case.defined()) {
if (op->else_case) {
int64_t label_then_jump = this->GetPC();
int64_t then_jump = this->PushOp(StackVM::RJUMP, 0);
int64_t else_begin = this->GetPC();
this->SetOperand(else_jump, else_begin - label_ejump);
this->PushOp(StackVM::POP);
this->Push(op->else_case);
this->Push(op->else_case.value());
int64_t if_end = this->GetPC();
this->SetOperand(then_jump, if_end - label_then_jump);
} else {
Expand Down
4 changes: 2 additions & 2 deletions src/tir/analysis/block_access_region_detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -173,10 +173,10 @@ void BlockReadWriteDetector::VisitStmt_(const IfThenElseNode* op) {
With<ConditionalBoundsContext> ctx(op->condition, &dom_map_, &hint_map_, true);
StmtExprVisitor::VisitStmt(op->then_case);
}
if (op->else_case.defined()) {
if (op->else_case) {
// Visit else branch
With<ConditionalBoundsContext> ctx(op->condition, &dom_map_, &hint_map_, false);
StmtExprVisitor::VisitStmt(op->else_case);
StmtExprVisitor::VisitStmt(op->else_case.value());
}
}

Expand Down
4 changes: 2 additions & 2 deletions src/tir/analysis/estimate_flops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,8 @@ class FlopEstimator : private ExprFunctor<TResult(const PrimExpr& n)>,

TResult VisitStmt_(const IfThenElseNode* branch) override {
TResult cond = VisitExpr(branch->condition);
if (branch->else_case.defined()) {
cond += VisitStmt(branch->then_case).MaxWith(VisitStmt(branch->else_case));
if (branch->else_case) {
cond += VisitStmt(branch->then_case).MaxWith(VisitStmt(branch->else_case.value()));
} else {
cond += VisitStmt(branch->then_case);
}
Expand Down
4 changes: 2 additions & 2 deletions src/tir/ir/stmt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -641,7 +641,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
});

// IfThenElse
IfThenElse::IfThenElse(PrimExpr condition, Stmt then_case, Stmt else_case, Span span) {
IfThenElse::IfThenElse(PrimExpr condition, Stmt then_case, Optional<Stmt> else_case, Span span) {
ICHECK(condition.defined());
ICHECK(then_case.defined());
// else_case may be null.
Expand Down Expand Up @@ -670,7 +670,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->Print(op->then_case);
p->indent -= 2;

if (!op->else_case.defined()) {
if (!op->else_case) {
break;
}

Expand Down
10 changes: 5 additions & 5 deletions src/tir/ir/stmt_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ void StmtVisitor::VisitStmt_(const BufferRealizeNode* op) {
void StmtVisitor::VisitStmt_(const IfThenElseNode* op) {
this->VisitExpr(op->condition);
this->VisitStmt(op->then_case);
if (op->else_case.defined()) {
this->VisitStmt(op->else_case);
if (op->else_case) {
this->VisitStmt(op->else_case.value());
}
}

Expand Down Expand Up @@ -352,9 +352,9 @@ Stmt StmtMutator::VisitStmt_(const DeclBufferNode* op) {
Stmt StmtMutator::VisitStmt_(const IfThenElseNode* op) {
PrimExpr condition = this->VisitExpr(op->condition);
Stmt then_case = this->VisitStmt(op->then_case);
Stmt else_case;
if (op->else_case.defined()) {
else_case = this->VisitStmt(op->else_case);
Optional<Stmt> else_case = NullOpt;
if (op->else_case) {
else_case = this->VisitStmt(op->else_case.value());
}
if (condition.same_as(op->condition) && then_case.same_as(op->then_case) &&
else_case.same_as(op->else_case)) {
Expand Down
6 changes: 3 additions & 3 deletions src/tir/transforms/common_subexpr_elim_tools.cc
Original file line number Diff line number Diff line change
Expand Up @@ -434,9 +434,9 @@ void ComputationsDoneBy::VisitStmt_(const IfThenElseNode* op) {
table_of_computations_.clear();

ComputationTable computations_done_by_else;
if (op->else_case.defined()) {
// And finally calls the VisitStmt() method on the `then_case` child
VisitStmt(op->else_case);
if (op->else_case) {
// And finally calls the VisitStmt() method on the `else_case` child
VisitStmt(op->else_case.value());
computations_done_by_else = table_of_computations_;
table_of_computations_.clear();
}
Expand Down
4 changes: 2 additions & 2 deletions src/tir/transforms/compact_buffer_region.cc
Original file line number Diff line number Diff line change
Expand Up @@ -184,10 +184,10 @@ class BufferAccessRegionCollector : public StmtExprVisitor {
With<ConditionalBoundsContext> ctx(op->condition, &dom_map_, &hint_map_, true);
StmtExprVisitor::VisitStmt(op->then_case);
}
if (op->else_case.defined()) {
if (op->else_case) {
// Visit else branch
With<ConditionalBoundsContext> ctx(op->condition, &dom_map_, &hint_map_, false);
StmtExprVisitor::VisitStmt(op->else_case);
StmtExprVisitor::VisitStmt(op->else_case.value());
}
}

Expand Down
4 changes: 2 additions & 2 deletions src/tir/transforms/coproc_sync.cc
Original file line number Diff line number Diff line change
Expand Up @@ -417,8 +417,8 @@ class CoProcInstDepDetector : public StmtVisitor {
first_state_.clear();
last_state_.clear();
}
if (op->else_case.defined()) {
this->VisitStmt(op->else_case);
if (op->else_case) {
this->VisitStmt(op->else_case.value());
if (last_state_.node != nullptr) {
curr_state.node = op;
MatchFixEnterPop(first_state_);
Expand Down
6 changes: 3 additions & 3 deletions src/tir/transforms/inject_virtual_thread.cc
Original file line number Diff line number Diff line change
Expand Up @@ -360,11 +360,11 @@ class VTInjector : public arith::IRMutatorWithAnalyzer {
visit_touched_var_ = false;
ICHECK_EQ(max_loop_depth_, 0);
Stmt then_case = this->VisitStmt(op->then_case);
Stmt else_case;
if (op->else_case.defined()) {
Optional<Stmt> else_case = NullOpt;
if (op->else_case) {
int temp = max_loop_depth_;
max_loop_depth_ = 0;
else_case = this->VisitStmt(op->else_case);
else_case = this->VisitStmt(op->else_case.value());
max_loop_depth_ = std::max(temp, max_loop_depth_);
}
if (condition.same_as(op->condition) && then_case.same_as(op->then_case) &&
Expand Down
2 changes: 1 addition & 1 deletion src/tir/transforms/ir_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ Stmt MergeNest(const std::vector<Stmt>& nest, Stmt body) {
} else if (const auto* ite = s.as<IfThenElseNode>()) {
auto n = make_object<IfThenElseNode>(*ite);
ICHECK(is_no_op(n->then_case));
ICHECK(!n->else_case.defined());
ICHECK(!n->else_case);
n->then_case = body;
body = Stmt(n);
} else if (const auto* seq = s.as<SeqStmtNode>()) {
Expand Down
4 changes: 2 additions & 2 deletions src/tir/transforms/lift_attr_scope.cc
Original file line number Diff line number Diff line change
Expand Up @@ -122,15 +122,15 @@ class AttrScopeLifter : public StmtMutator {
}

Stmt VisitStmt_(const IfThenElseNode* op) final {
if (!op->else_case.defined()) {
if (!op->else_case) {
return StmtMutator::VisitStmt_(op);
}
Stmt then_case = this->VisitStmt(op->then_case);
ObjectRef first_node;
PrimExpr first_value;
std::swap(first_node, attr_node_);
std::swap(first_value, attr_value_);
Stmt else_case = this->VisitStmt(op->else_case);
Stmt else_case = this->VisitStmt(op->else_case.value());
if (attr_node_.defined() && attr_value_.defined() && first_node.defined() &&
first_value.defined() && attr_node_.same_as(first_node) &&
ValueSame(attr_value_, first_value)) {
Expand Down
4 changes: 2 additions & 2 deletions src/tir/transforms/profile_instrumentation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,8 @@ class LoopAnalyzer : public StmtExprVisitor {
} else if (stmt->IsInstance<IfThenElseNode>()) {
const IfThenElseNode* n = stmt.as<IfThenElseNode>();
unsigned height = TraverseLoop(n->then_case, parent_depth, has_parallel);
if (n->else_case.defined()) {
height = std::max(height, TraverseLoop(n->else_case, parent_depth, has_parallel));
if (n->else_case) {
height = std::max(height, TraverseLoop(n->else_case.value(), parent_depth, has_parallel));
}
return height;
} else if (stmt->IsInstance<ForNode>()) {
Expand Down
4 changes: 2 additions & 2 deletions src/tir/transforms/remove_no_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ class NoOpRemover : public StmtMutator {
Stmt VisitStmt_(const IfThenElseNode* op) final {
Stmt stmt = StmtMutator::VisitStmt_(op);
op = stmt.as<IfThenElseNode>();
if (op->else_case.defined()) {
if (is_no_op(op->else_case)) {
if (op->else_case) {
if (is_no_op(op->else_case.value())) {
if (is_no_op(op->then_case)) {
return MakeEvaluate(op->condition);
} else {
Expand Down
4 changes: 2 additions & 2 deletions src/tir/transforms/simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,8 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
if (const int64_t* as_int = as_const_int(cond)) {
if (*as_int) {
return this->VisitStmt(op->then_case);
} else if (op->else_case.defined()) {
return this->VisitStmt(op->else_case);
} else if (op->else_case) {
return this->VisitStmt(op->else_case.value());
} else {
return Evaluate(0);
}
Expand Down
4 changes: 2 additions & 2 deletions src/tir/transforms/storage_access.cc
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,9 @@ void StorageAccessVisitor::VisitStmt_(const IfThenElseNode* op) {
s.stmt = op;
s.access = Summarize(std::move(scope_.back()), nullptr);
scope_.pop_back();
if (op->else_case.defined()) {
if (op->else_case) {
scope_.push_back(std::vector<StmtEntry>());
this->VisitStmt(op->else_case);
this->VisitStmt(op->else_case.value());
auto v = Summarize(std::move(scope_.back()), nullptr);
scope_.pop_back();
s.access.insert(s.access.end(), v.begin(), v.end());
Expand Down
6 changes: 3 additions & 3 deletions src/tir/transforms/vectorize_loop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -490,9 +490,9 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
return Scalarize(GetRef<Stmt>(op));
}
Stmt then_case = this->VisitStmt(op->then_case);
Stmt else_case;
if (op->else_case.defined()) {
else_case = this->VisitStmt(op->else_case);
Optional<Stmt> else_case = NullOpt;
if (op->else_case) {
else_case = this->VisitStmt(op->else_case.value());
}
if (condition.same_as(op->condition) && then_case.same_as(op->then_case) &&
else_case.same_as(op->else_case)) {
Expand Down

0 comments on commit 9dd0225

Please sign in to comment.