From 9dd0225a70c4d5dabd8112a1b0262005a81916be Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Sat, 29 Oct 2022 00:59:37 -0500 Subject: [PATCH] [TIR] Use Optional for IfThenElseNode::else_case (#13218) This parameter is nullable for cases where the else block isn't present. Previously, it was represented as a `Stmt` holding `nullptr`, because `IfThenElse` (https://github.com/apache/tvm/pull/3533) predates the `Optional` utility (https://github.com/apache/tvm/pull/5314). This commit updates to use `Optional` instead, and updates all usages of `else_case`. --- include/tvm/tir/stmt.h | 4 ++-- src/arith/ir_mutator_with_analyzer.cc | 12 +++++------- src/arith/ir_visitor_with_analyzer.cc | 4 ++-- src/contrib/hybrid/codegen_hybrid.cc | 4 ++-- src/printer/tir_text_printer.cc | 4 ++-- src/printer/tvmscript_printer.cc | 4 ++-- src/target/llvm/codegen_llvm.cc | 4 ++-- src/target/source/codegen_c.cc | 4 ++-- src/target/spirv/codegen_spirv.cc | 4 ++-- src/target/stackvm/codegen_stackvm.cc | 4 ++-- src/tir/analysis/block_access_region_detector.cc | 4 ++-- src/tir/analysis/estimate_flops.cc | 4 ++-- src/tir/ir/stmt.cc | 4 ++-- src/tir/ir/stmt_functor.cc | 10 +++++----- src/tir/transforms/common_subexpr_elim_tools.cc | 6 +++--- src/tir/transforms/compact_buffer_region.cc | 4 ++-- src/tir/transforms/coproc_sync.cc | 4 ++-- src/tir/transforms/inject_virtual_thread.cc | 6 +++--- src/tir/transforms/ir_utils.cc | 2 +- src/tir/transforms/lift_attr_scope.cc | 4 ++-- src/tir/transforms/profile_instrumentation.cc | 4 ++-- src/tir/transforms/remove_no_op.cc | 4 ++-- src/tir/transforms/simplify.cc | 4 ++-- src/tir/transforms/storage_access.cc | 4 ++-- src/tir/transforms/vectorize_loop.cc | 6 +++--- 25 files changed, 58 insertions(+), 60 deletions(-) diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index e16d773f02b3..e0e191b282e5 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -824,7 +824,7 @@ class IfThenElseNode : public StmtNode { /*! \brief The branch to be executed when condition is true. */ Stmt then_case; /*! \brief The branch to be executed when condition is false, can be null. */ - Stmt else_case; + Optional else_case; void VisitAttrs(AttrVisitor* v) { v->Visit("condition", &condition); @@ -854,7 +854,7 @@ class IfThenElseNode : public StmtNode { */ class IfThenElse : public Stmt { public: - TVM_DLL IfThenElse(PrimExpr condition, Stmt then_case, Stmt else_case = Stmt(), + TVM_DLL IfThenElse(PrimExpr condition, Stmt then_case, Optional else_case = NullOpt, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(IfThenElse, Stmt, IfThenElseNode); diff --git a/src/arith/ir_mutator_with_analyzer.cc b/src/arith/ir_mutator_with_analyzer.cc index 9cae3b7a6ac8..199f06191e4e 100644 --- a/src/arith/ir_mutator_with_analyzer.cc +++ b/src/arith/ir_mutator_with_analyzer.cc @@ -71,21 +71,19 @@ Stmt IRMutatorWithAnalyzer::VisitStmt_(const IfThenElseNode* op) { } } - Stmt then_case, else_case; + Stmt then_case; + Optional else_case; { With ctx(analyzer_, real_condition); then_case = this->VisitStmt(op->then_case); } - if (op->else_case.defined()) { + if (op->else_case) { With ctx(analyzer_, analyzer_->rewrite_simplify(Not(real_condition))); - else_case = this->VisitStmt(op->else_case); + else_case = this->VisitStmt(op->else_case.value()); } if (is_one(real_condition)) return then_case; if (is_zero(real_condition)) { - if (else_case.defined()) { - return else_case; - } - return Evaluate(0); + return else_case.value_or(Evaluate(0)); } if (condition.same_as(op->condition) && then_case.same_as(op->then_case) && diff --git a/src/arith/ir_visitor_with_analyzer.cc b/src/arith/ir_visitor_with_analyzer.cc index 75ae22ef9915..e7cf3ea7eadd 100644 --- a/src/arith/ir_visitor_with_analyzer.cc +++ b/src/arith/ir_visitor_with_analyzer.cc @@ -58,9 +58,9 @@ void IRVisitorWithAnalyzer::VisitStmt_(const IfThenElseNode* op) { With constraint(&analyzer_, real_condition); this->VisitStmt(op->then_case); } - if (op->else_case.defined()) { + if (op->else_case) { With constraint(&analyzer_, analyzer_.rewrite_simplify(Not(real_condition))); - this->VisitStmt(op->else_case); + this->VisitStmt(op->else_case.value()); } } diff --git a/src/contrib/hybrid/codegen_hybrid.cc b/src/contrib/hybrid/codegen_hybrid.cc index 79c9e567b459..687da61fa019 100644 --- a/src/contrib/hybrid/codegen_hybrid.cc +++ b/src/contrib/hybrid/codegen_hybrid.cc @@ -381,11 +381,11 @@ void CodeGenHybrid::VisitStmt_(const IfThenElseNode* op) { PrintStmt(op->then_case); indent_ -= tab_; - if (!is_noop(op->else_case)) { + if (op->else_case && !is_noop(op->else_case.value())) { PrintIndent(); stream << "else:\n"; indent_ += tab_; - PrintStmt(op->else_case); + PrintStmt(op->else_case.value()); indent_ -= tab_; } } diff --git a/src/printer/tir_text_printer.cc b/src/printer/tir_text_printer.cc index cdfc8fd318fd..e50559ac10ff 100644 --- a/src/printer/tir_text_printer.cc +++ b/src/printer/tir_text_printer.cc @@ -572,8 +572,8 @@ Doc TIRTextPrinter::VisitStmt_(const DeclBufferNode* op) { Doc TIRTextPrinter::VisitStmt_(const IfThenElseNode* op) { Doc doc; doc << "if " << Print(op->condition) << PrintBody(op->then_case); - if (!is_one(op->condition) && op->else_case.defined()) { - doc << " else" << PrintBody(op->else_case); + if (!is_one(op->condition) && op->else_case) { + doc << " else" << PrintBody(op->else_case.value()); } return doc; } diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc index 39eb245f3ad9..d8d5d89be0a4 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/printer/tvmscript_printer.cc @@ -1244,9 +1244,9 @@ Doc TVMScriptPrinter::VisitStmt_(const IfThenElseNode* op) { Doc doc; doc << "if " << Print(op->condition) << ":"; doc << Doc::Indent(4, Doc::NewLine() << PrintBody(op->then_case)); - if (!is_one(op->condition) && op->else_case.defined()) { + if (!is_one(op->condition) && op->else_case) { doc << Doc::NewLine(); - doc << "else:" << Doc::Indent(4, Doc::NewLine() << PrintBody(op->else_case)); + doc << "else:" << Doc::Indent(4, Doc::NewLine() << PrintBody(op->else_case.value())); } return doc; } diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index ea8a5ff5106a..87479ec74237 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -1759,14 +1759,14 @@ void CodeGenLLVM::VisitStmt_(const IfThenElseNode* op) { llvm::LLVMContext* ctx = llvm_target_->GetContext(); auto* then_block = llvm::BasicBlock::Create(*ctx, "if_then", function_); auto* end_block = llvm::BasicBlock::Create(*ctx, "if_end", function_); - if (op->else_case.defined()) { + if (op->else_case) { auto* else_block = llvm::BasicBlock::Create(*ctx, "if_else", function_); builder_->CreateCondBr(cond, then_block, else_block); builder_->SetInsertPoint(then_block); this->VisitStmt(op->then_case); builder_->CreateBr(end_block); builder_->SetInsertPoint(else_block); - this->VisitStmt(op->else_case); + this->VisitStmt(op->else_case.value()); builder_->CreateBr(end_block); } else { builder_->CreateCondBr(cond, then_block, end_block, md_very_likely_branch_); diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index b69f76914495..66c92181c126 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -936,11 +936,11 @@ void CodeGenC::VisitStmt_(const IfThenElseNode* op) { PrintStmt(op->then_case); this->EndScope(then_scope); - if (op->else_case.defined()) { + if (op->else_case) { PrintIndent(); stream << "} else {\n"; int else_scope = BeginScope(); - PrintStmt(op->else_case); + PrintStmt(op->else_case.value()); this->EndScope(else_scope); } PrintIndent(); diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index 4f875e955576..c291a478dd3f 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -628,7 +628,7 @@ void CodeGenSPIRV::VisitStmt_(const IfThenElseNode* op) { spirv::Value cond = MakeValue(op->condition); spirv::Label then_label = builder_->NewLabel(); spirv::Label merge_label = builder_->NewLabel(); - if (op->else_case.defined()) { + if (op->else_case) { spirv::Label else_label = builder_->NewLabel(); builder_->MakeInst(spv::OpSelectionMerge, merge_label, spv::SelectionControlMaskNone); builder_->MakeInst(spv::OpBranchConditional, cond, then_label, else_label); @@ -638,7 +638,7 @@ void CodeGenSPIRV::VisitStmt_(const IfThenElseNode* op) { builder_->MakeInst(spv::OpBranch, merge_label); // else block builder_->StartLabel(else_label); - this->VisitStmt(op->else_case); + this->VisitStmt(op->else_case.value()); builder_->MakeInst(spv::OpBranch, merge_label); } else { builder_->MakeInst(spv::OpSelectionMerge, merge_label, spv::SelectionControlMaskNone); diff --git a/src/target/stackvm/codegen_stackvm.cc b/src/target/stackvm/codegen_stackvm.cc index 80a5c4bfde6a..eac9ad849419 100644 --- a/src/target/stackvm/codegen_stackvm.cc +++ b/src/target/stackvm/codegen_stackvm.cc @@ -475,13 +475,13 @@ void CodeGenStackVM::VisitStmt_(const IfThenElseNode* op) { int64_t else_jump = this->PushOp(StackVM::RJUMP_IF_FALSE, 0); this->PushOp(StackVM::POP); this->Push(op->then_case); - if (op->else_case.defined()) { + if (op->else_case) { int64_t label_then_jump = this->GetPC(); int64_t then_jump = this->PushOp(StackVM::RJUMP, 0); int64_t else_begin = this->GetPC(); this->SetOperand(else_jump, else_begin - label_ejump); this->PushOp(StackVM::POP); - this->Push(op->else_case); + this->Push(op->else_case.value()); int64_t if_end = this->GetPC(); this->SetOperand(then_jump, if_end - label_then_jump); } else { diff --git a/src/tir/analysis/block_access_region_detector.cc b/src/tir/analysis/block_access_region_detector.cc index c65a422ed3d0..e9bff1b6fdee 100644 --- a/src/tir/analysis/block_access_region_detector.cc +++ b/src/tir/analysis/block_access_region_detector.cc @@ -173,10 +173,10 @@ void BlockReadWriteDetector::VisitStmt_(const IfThenElseNode* op) { With ctx(op->condition, &dom_map_, &hint_map_, true); StmtExprVisitor::VisitStmt(op->then_case); } - if (op->else_case.defined()) { + if (op->else_case) { // Visit else branch With ctx(op->condition, &dom_map_, &hint_map_, false); - StmtExprVisitor::VisitStmt(op->else_case); + StmtExprVisitor::VisitStmt(op->else_case.value()); } } diff --git a/src/tir/analysis/estimate_flops.cc b/src/tir/analysis/estimate_flops.cc index d8faf9bd1362..d158a001b2d8 100644 --- a/src/tir/analysis/estimate_flops.cc +++ b/src/tir/analysis/estimate_flops.cc @@ -148,8 +148,8 @@ class FlopEstimator : private ExprFunctor, TResult VisitStmt_(const IfThenElseNode* branch) override { TResult cond = VisitExpr(branch->condition); - if (branch->else_case.defined()) { - cond += VisitStmt(branch->then_case).MaxWith(VisitStmt(branch->else_case)); + if (branch->else_case) { + cond += VisitStmt(branch->then_case).MaxWith(VisitStmt(branch->else_case.value())); } else { cond += VisitStmt(branch->then_case); } diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 8f2a7b4ffe5b..a8d8936c905a 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -641,7 +641,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); // IfThenElse -IfThenElse::IfThenElse(PrimExpr condition, Stmt then_case, Stmt else_case, Span span) { +IfThenElse::IfThenElse(PrimExpr condition, Stmt then_case, Optional else_case, Span span) { ICHECK(condition.defined()); ICHECK(then_case.defined()); // else_case may be null. @@ -670,7 +670,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->Print(op->then_case); p->indent -= 2; - if (!op->else_case.defined()) { + if (!op->else_case) { break; } diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc index 59630d34c38e..e445432e5b6f 100644 --- a/src/tir/ir/stmt_functor.cc +++ b/src/tir/ir/stmt_functor.cc @@ -86,8 +86,8 @@ void StmtVisitor::VisitStmt_(const BufferRealizeNode* op) { void StmtVisitor::VisitStmt_(const IfThenElseNode* op) { this->VisitExpr(op->condition); this->VisitStmt(op->then_case); - if (op->else_case.defined()) { - this->VisitStmt(op->else_case); + if (op->else_case) { + this->VisitStmt(op->else_case.value()); } } @@ -352,9 +352,9 @@ Stmt StmtMutator::VisitStmt_(const DeclBufferNode* op) { Stmt StmtMutator::VisitStmt_(const IfThenElseNode* op) { PrimExpr condition = this->VisitExpr(op->condition); Stmt then_case = this->VisitStmt(op->then_case); - Stmt else_case; - if (op->else_case.defined()) { - else_case = this->VisitStmt(op->else_case); + Optional else_case = NullOpt; + if (op->else_case) { + else_case = this->VisitStmt(op->else_case.value()); } if (condition.same_as(op->condition) && then_case.same_as(op->then_case) && else_case.same_as(op->else_case)) { diff --git a/src/tir/transforms/common_subexpr_elim_tools.cc b/src/tir/transforms/common_subexpr_elim_tools.cc index 39d7a750a99c..130004c51cd8 100644 --- a/src/tir/transforms/common_subexpr_elim_tools.cc +++ b/src/tir/transforms/common_subexpr_elim_tools.cc @@ -434,9 +434,9 @@ void ComputationsDoneBy::VisitStmt_(const IfThenElseNode* op) { table_of_computations_.clear(); ComputationTable computations_done_by_else; - if (op->else_case.defined()) { - // And finally calls the VisitStmt() method on the `then_case` child - VisitStmt(op->else_case); + if (op->else_case) { + // And finally calls the VisitStmt() method on the `else_case` child + VisitStmt(op->else_case.value()); computations_done_by_else = table_of_computations_; table_of_computations_.clear(); } diff --git a/src/tir/transforms/compact_buffer_region.cc b/src/tir/transforms/compact_buffer_region.cc index 249b8cca77b0..b517150ce9f4 100644 --- a/src/tir/transforms/compact_buffer_region.cc +++ b/src/tir/transforms/compact_buffer_region.cc @@ -184,10 +184,10 @@ class BufferAccessRegionCollector : public StmtExprVisitor { With ctx(op->condition, &dom_map_, &hint_map_, true); StmtExprVisitor::VisitStmt(op->then_case); } - if (op->else_case.defined()) { + if (op->else_case) { // Visit else branch With ctx(op->condition, &dom_map_, &hint_map_, false); - StmtExprVisitor::VisitStmt(op->else_case); + StmtExprVisitor::VisitStmt(op->else_case.value()); } } diff --git a/src/tir/transforms/coproc_sync.cc b/src/tir/transforms/coproc_sync.cc index 1b1cabeadb71..69913f4bd604 100644 --- a/src/tir/transforms/coproc_sync.cc +++ b/src/tir/transforms/coproc_sync.cc @@ -417,8 +417,8 @@ class CoProcInstDepDetector : public StmtVisitor { first_state_.clear(); last_state_.clear(); } - if (op->else_case.defined()) { - this->VisitStmt(op->else_case); + if (op->else_case) { + this->VisitStmt(op->else_case.value()); if (last_state_.node != nullptr) { curr_state.node = op; MatchFixEnterPop(first_state_); diff --git a/src/tir/transforms/inject_virtual_thread.cc b/src/tir/transforms/inject_virtual_thread.cc index f49b6b2ace8e..a1ebdcef9855 100644 --- a/src/tir/transforms/inject_virtual_thread.cc +++ b/src/tir/transforms/inject_virtual_thread.cc @@ -360,11 +360,11 @@ class VTInjector : public arith::IRMutatorWithAnalyzer { visit_touched_var_ = false; ICHECK_EQ(max_loop_depth_, 0); Stmt then_case = this->VisitStmt(op->then_case); - Stmt else_case; - if (op->else_case.defined()) { + Optional else_case = NullOpt; + if (op->else_case) { int temp = max_loop_depth_; max_loop_depth_ = 0; - else_case = this->VisitStmt(op->else_case); + else_case = this->VisitStmt(op->else_case.value()); max_loop_depth_ = std::max(temp, max_loop_depth_); } if (condition.same_as(op->condition) && then_case.same_as(op->then_case) && diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc index b7e3e01f7506..6893aecc4d00 100644 --- a/src/tir/transforms/ir_utils.cc +++ b/src/tir/transforms/ir_utils.cc @@ -56,7 +56,7 @@ Stmt MergeNest(const std::vector& nest, Stmt body) { } else if (const auto* ite = s.as()) { auto n = make_object(*ite); ICHECK(is_no_op(n->then_case)); - ICHECK(!n->else_case.defined()); + ICHECK(!n->else_case); n->then_case = body; body = Stmt(n); } else if (const auto* seq = s.as()) { diff --git a/src/tir/transforms/lift_attr_scope.cc b/src/tir/transforms/lift_attr_scope.cc index 40d152b3b3b6..272e16d40d97 100644 --- a/src/tir/transforms/lift_attr_scope.cc +++ b/src/tir/transforms/lift_attr_scope.cc @@ -122,7 +122,7 @@ class AttrScopeLifter : public StmtMutator { } Stmt VisitStmt_(const IfThenElseNode* op) final { - if (!op->else_case.defined()) { + if (!op->else_case) { return StmtMutator::VisitStmt_(op); } Stmt then_case = this->VisitStmt(op->then_case); @@ -130,7 +130,7 @@ class AttrScopeLifter : public StmtMutator { PrimExpr first_value; std::swap(first_node, attr_node_); std::swap(first_value, attr_value_); - Stmt else_case = this->VisitStmt(op->else_case); + Stmt else_case = this->VisitStmt(op->else_case.value()); if (attr_node_.defined() && attr_value_.defined() && first_node.defined() && first_value.defined() && attr_node_.same_as(first_node) && ValueSame(attr_value_, first_value)) { diff --git a/src/tir/transforms/profile_instrumentation.cc b/src/tir/transforms/profile_instrumentation.cc index 5f52fc6630bc..68d5b0a204d5 100644 --- a/src/tir/transforms/profile_instrumentation.cc +++ b/src/tir/transforms/profile_instrumentation.cc @@ -110,8 +110,8 @@ class LoopAnalyzer : public StmtExprVisitor { } else if (stmt->IsInstance()) { const IfThenElseNode* n = stmt.as(); unsigned height = TraverseLoop(n->then_case, parent_depth, has_parallel); - if (n->else_case.defined()) { - height = std::max(height, TraverseLoop(n->else_case, parent_depth, has_parallel)); + if (n->else_case) { + height = std::max(height, TraverseLoop(n->else_case.value(), parent_depth, has_parallel)); } return height; } else if (stmt->IsInstance()) { diff --git a/src/tir/transforms/remove_no_op.cc b/src/tir/transforms/remove_no_op.cc index 8728817aad57..41250408a7f2 100644 --- a/src/tir/transforms/remove_no_op.cc +++ b/src/tir/transforms/remove_no_op.cc @@ -69,8 +69,8 @@ class NoOpRemover : public StmtMutator { Stmt VisitStmt_(const IfThenElseNode* op) final { Stmt stmt = StmtMutator::VisitStmt_(op); op = stmt.as(); - if (op->else_case.defined()) { - if (is_no_op(op->else_case)) { + if (op->else_case) { + if (is_no_op(op->else_case.value())) { if (is_no_op(op->then_case)) { return MakeEvaluate(op->condition); } else { diff --git a/src/tir/transforms/simplify.cc b/src/tir/transforms/simplify.cc index b6e3581aa614..1dbf9e688027 100644 --- a/src/tir/transforms/simplify.cc +++ b/src/tir/transforms/simplify.cc @@ -139,8 +139,8 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { if (const int64_t* as_int = as_const_int(cond)) { if (*as_int) { return this->VisitStmt(op->then_case); - } else if (op->else_case.defined()) { - return this->VisitStmt(op->else_case); + } else if (op->else_case) { + return this->VisitStmt(op->else_case.value()); } else { return Evaluate(0); } diff --git a/src/tir/transforms/storage_access.cc b/src/tir/transforms/storage_access.cc index 4f19f708880c..8729ab1ed296 100644 --- a/src/tir/transforms/storage_access.cc +++ b/src/tir/transforms/storage_access.cc @@ -187,9 +187,9 @@ void StorageAccessVisitor::VisitStmt_(const IfThenElseNode* op) { s.stmt = op; s.access = Summarize(std::move(scope_.back()), nullptr); scope_.pop_back(); - if (op->else_case.defined()) { + if (op->else_case) { scope_.push_back(std::vector()); - this->VisitStmt(op->else_case); + this->VisitStmt(op->else_case.value()); auto v = Summarize(std::move(scope_.back()), nullptr); scope_.pop_back(); s.access.insert(s.access.end(), v.begin(), v.end()); diff --git a/src/tir/transforms/vectorize_loop.cc b/src/tir/transforms/vectorize_loop.cc index 3cc17847e69b..8efed83ccdf1 100644 --- a/src/tir/transforms/vectorize_loop.cc +++ b/src/tir/transforms/vectorize_loop.cc @@ -490,9 +490,9 @@ class Vectorizer : public StmtMutator, public ExprFunctor(op)); } Stmt then_case = this->VisitStmt(op->then_case); - Stmt else_case; - if (op->else_case.defined()) { - else_case = this->VisitStmt(op->else_case); + Optional else_case = NullOpt; + if (op->else_case) { + else_case = this->VisitStmt(op->else_case.value()); } if (condition.same_as(op->condition) && then_case.same_as(op->then_case) && else_case.same_as(op->else_case)) {