diff --git a/third_party/nvfuser/csrc/lower_unroll.cpp b/third_party/nvfuser/csrc/lower_unroll.cpp index 09df8d02ed91..516aec4ed2e9 100644 --- a/third_party/nvfuser/csrc/lower_unroll.cpp +++ b/third_party/nvfuser/csrc/lower_unroll.cpp @@ -51,11 +51,8 @@ bool isReductionInitExpr(const Expr* expr) { } // namespace -void UnrollPass::registerReplace( - Expr* reference, - Expr* new_expr, - kir::Scope* scope) { - kir::ExprMutator::registerReplace(reference, new_expr, scope); +void UnrollPass::registerReplace(Expr* reference, Expr* new_expr) { + kir::ExprMutator::registerReplace(reference, new_expr); GpuLower::current()->propagateExprInfo(reference, new_expr); } @@ -115,7 +112,7 @@ void UnrollPass::handle(Expr* expr) { expr_with_predicate = ShiftPredicateInserter::insert( expr, for_loops_, thread_pred, unswitched_loop_); if (expr_with_predicate != expr) { - registerReplace(expr, expr_with_predicate, &for_loops_.back()->body()); + registerReplace(expr, expr_with_predicate); } return; } @@ -137,7 +134,7 @@ void UnrollPass::handle(Expr* expr) { : IrBuilder::create( PredicateType::Inline, expr, thread_pred); expr_with_predicate = expr_with_predicate->withPredicate(pred); - registerReplace(expr, expr_with_predicate, &for_loops_.back()->body()); + registerReplace(expr, expr_with_predicate); return; } @@ -160,27 +157,21 @@ void UnrollPass::handle(Expr* expr) { if (lower_utils::supportInlinePredicate(expr)) { expr_with_predicate = expr_with_predicate->withPredicate(pred); - registerReplace(expr, expr_with_predicate, &for_loops_.back()->body()); + registerReplace(expr, expr_with_predicate); return; } // If we need a predicate, put expr inside an if then else kir::IfThenElse* inline_ite = IrBuilder::create(pred); - if (for_loops_.empty()) { - // Special handling for top level output expressions that still - // need predicates. One motivating example is a reduction op that - // reduces to a scalar (issue #491) - kir::ExprMutator::registerReplace(expr, inline_ite, nullptr); - } else { - kir::ExprMutator::registerReplace( - expr, inline_ite, &for_loops_.back()->body()); - } + kir::ExprMutator::registerReplace(expr, inline_ite); if (expr != expr_with_predicate) { GpuLower::current()->propagateExprInfo(expr, expr_with_predicate); } inline_ite->thenBody().push_back(expr_with_predicate); } else if (auto for_loop = dynamic_cast(expr)) { handle(for_loop); + } else if (auto ite = dynamic_cast(expr)) { + kir::ExprMutator::handle(ite); } } @@ -196,6 +187,8 @@ void UnrollPass::handle(kir::ForLoop* fl) { // normal. if (!is_unroll || !look_for_unroll_) { for_loops_.push_back(fl); + scope_.push_back(&fl->body()); + scope_exprs_.push_back(fl); // Make copy of exprs because we replace them inplace in fl const auto exprs_copy = fl->body().exprs(); @@ -208,6 +201,8 @@ void UnrollPass::handle(kir::ForLoop* fl) { } for_loops_.pop_back(); + scope_.pop_back(); + scope_exprs_.pop_back(); return; } @@ -217,38 +212,39 @@ void UnrollPass::handle(kir::ForLoop* fl) { // Get the loop nest for the unrolled path kir::ForLoop* unrolled_loop_nest = cloneLoopNest(fl); + unroll_ite->thenBody().push_back(unrolled_loop_nest); // Thread predicates are not removed from the expressions. Visit // each expression to attach kir::Predicate. + scope_.push_back(&unroll_ite->thenBody()); + scope_exprs_.push_back(unroll_ite); unswitched_loop_ = true; look_for_unroll_ = false; handle(unrolled_loop_nest); unswitched_loop_ = false; look_for_unroll_ = true; - - unroll_ite->thenBody().push_back(unrolled_loop_nest); + scope_.pop_back(); + scope_exprs_.pop_back(); // Loop nest for inlined path kir::ForLoop* inlined_loop = cloneLoopNest(fl); // Add inline predicates for inlined loop nest + scope_.push_back(&unroll_ite->elseBody()); + scope_exprs_.push_back(unroll_ite); look_for_unroll_ = false; non_trivial_pred_found_ = false; handle(inlined_loop); look_for_unroll_ = true; + scope_.pop_back(); + scope_exprs_.pop_back(); if (!non_trivial_pred_found_) { - kir::ExprMutator::registerReplace( - fl, - inlined_loop, - for_loops_.empty() ? nullptr : &for_loops_.back()->body()); + kir::ExprMutator::registerReplace(fl, inlined_loop); } else { if (!canOmitElseClause(fl)) { unroll_ite->elseBody().push_back(inlined_loop); } - kir::ExprMutator::registerReplace( - fl, - unroll_ite, - for_loops_.empty() ? nullptr : &for_loops_.back()->body()); + kir::ExprMutator::registerReplace(fl, unroll_ite); } } diff --git a/third_party/nvfuser/csrc/lower_unroll.h b/third_party/nvfuser/csrc/lower_unroll.h index ffcfe9e95629..d6e38f86d681 100644 --- a/third_party/nvfuser/csrc/lower_unroll.h +++ b/third_party/nvfuser/csrc/lower_unroll.h @@ -59,7 +59,7 @@ class TORCH_CUDA_CU_API UnrollPass : kir::ExprMutator { static bool canOmitElseClause(kir::ForLoop* fl); private: - void registerReplace(Expr* reference, Expr* new_expr, kir::Scope* scope); + void registerReplace(Expr* reference, Expr* new_expr); // Generate the for Expr replacement map UnrollPass(const std::vector& exprs); @@ -68,7 +68,7 @@ class TORCH_CUDA_CU_API UnrollPass : kir::ExprMutator { return expr_replacement_map_; } - using OptOutDispatch::handle; + using kir::ExprMutator::handle; void handle(kir::ForLoop* fl) final;