Skip to content

Commit 68ce62e

Browse files
authored
Lower unroll cleanup, make it support IfThenElse (#2496)
1 parent 167718b commit 68ce62e

File tree

2 files changed

+25
-29
lines changed

2 files changed

+25
-29
lines changed

third_party/nvfuser/csrc/lower_unroll.cpp

Lines changed: 23 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,8 @@ bool isReductionInitExpr(const Expr* expr) {
5151

5252
} // namespace
5353

54-
void UnrollPass::registerReplace(
55-
Expr* reference,
56-
Expr* new_expr,
57-
kir::Scope* scope) {
58-
kir::ExprMutator::registerReplace(reference, new_expr, scope);
54+
void UnrollPass::registerReplace(Expr* reference, Expr* new_expr) {
55+
kir::ExprMutator::registerReplace(reference, new_expr);
5956
GpuLower::current()->propagateExprInfo(reference, new_expr);
6057
}
6158

@@ -115,7 +112,7 @@ void UnrollPass::handle(Expr* expr) {
115112
expr_with_predicate = ShiftPredicateInserter::insert(
116113
expr, for_loops_, thread_pred, unswitched_loop_);
117114
if (expr_with_predicate != expr) {
118-
registerReplace(expr, expr_with_predicate, &for_loops_.back()->body());
115+
registerReplace(expr, expr_with_predicate);
119116
}
120117
return;
121118
}
@@ -137,7 +134,7 @@ void UnrollPass::handle(Expr* expr) {
137134
: IrBuilder::create<kir::Predicate>(
138135
PredicateType::Inline, expr, thread_pred);
139136
expr_with_predicate = expr_with_predicate->withPredicate(pred);
140-
registerReplace(expr, expr_with_predicate, &for_loops_.back()->body());
137+
registerReplace(expr, expr_with_predicate);
141138
return;
142139
}
143140

@@ -160,27 +157,21 @@ void UnrollPass::handle(Expr* expr) {
160157

161158
if (lower_utils::supportInlinePredicate(expr)) {
162159
expr_with_predicate = expr_with_predicate->withPredicate(pred);
163-
registerReplace(expr, expr_with_predicate, &for_loops_.back()->body());
160+
registerReplace(expr, expr_with_predicate);
164161
return;
165162
}
166163

167164
// If we need a predicate, put expr inside an if then else
168165
kir::IfThenElse* inline_ite = IrBuilder::create<kir::IfThenElse>(pred);
169-
if (for_loops_.empty()) {
170-
// Special handling for top level output expressions that still
171-
// need predicates. One motivating example is a reduction op that
172-
// reduces to a scalar (issue #491)
173-
kir::ExprMutator::registerReplace(expr, inline_ite, nullptr);
174-
} else {
175-
kir::ExprMutator::registerReplace(
176-
expr, inline_ite, &for_loops_.back()->body());
177-
}
166+
kir::ExprMutator::registerReplace(expr, inline_ite);
178167
if (expr != expr_with_predicate) {
179168
GpuLower::current()->propagateExprInfo(expr, expr_with_predicate);
180169
}
181170
inline_ite->thenBody().push_back(expr_with_predicate);
182171
} else if (auto for_loop = dynamic_cast<kir::ForLoop*>(expr)) {
183172
handle(for_loop);
173+
} else if (auto ite = dynamic_cast<kir::IfThenElse*>(expr)) {
174+
kir::ExprMutator::handle(ite);
184175
}
185176
}
186177

@@ -196,6 +187,8 @@ void UnrollPass::handle(kir::ForLoop* fl) {
196187
// normal.
197188
if (!is_unroll || !look_for_unroll_) {
198189
for_loops_.push_back(fl);
190+
scope_.push_back(&fl->body());
191+
scope_exprs_.push_back(fl);
199192

200193
// Make copy of exprs because we replace them inplace in fl
201194
const auto exprs_copy = fl->body().exprs();
@@ -208,6 +201,8 @@ void UnrollPass::handle(kir::ForLoop* fl) {
208201
}
209202

210203
for_loops_.pop_back();
204+
scope_.pop_back();
205+
scope_exprs_.pop_back();
211206
return;
212207
}
213208

@@ -217,38 +212,39 @@ void UnrollPass::handle(kir::ForLoop* fl) {
217212

218213
// Get the loop nest for the unrolled path
219214
kir::ForLoop* unrolled_loop_nest = cloneLoopNest(fl);
215+
unroll_ite->thenBody().push_back(unrolled_loop_nest);
220216

221217
// Thread predicates are not removed from the expressions. Visit
222218
// each expression to attach kir::Predicate.
219+
scope_.push_back(&unroll_ite->thenBody());
220+
scope_exprs_.push_back(unroll_ite);
223221
unswitched_loop_ = true;
224222
look_for_unroll_ = false;
225223
handle(unrolled_loop_nest);
226224
unswitched_loop_ = false;
227225
look_for_unroll_ = true;
228-
229-
unroll_ite->thenBody().push_back(unrolled_loop_nest);
226+
scope_.pop_back();
227+
scope_exprs_.pop_back();
230228

231229
// Loop nest for inlined path
232230
kir::ForLoop* inlined_loop = cloneLoopNest(fl);
233231

234232
// Add inline predicates for inlined loop nest
233+
scope_.push_back(&unroll_ite->elseBody());
234+
scope_exprs_.push_back(unroll_ite);
235235
look_for_unroll_ = false;
236236
non_trivial_pred_found_ = false;
237237
handle(inlined_loop);
238238
look_for_unroll_ = true;
239+
scope_.pop_back();
240+
scope_exprs_.pop_back();
239241
if (!non_trivial_pred_found_) {
240-
kir::ExprMutator::registerReplace(
241-
fl,
242-
inlined_loop,
243-
for_loops_.empty() ? nullptr : &for_loops_.back()->body());
242+
kir::ExprMutator::registerReplace(fl, inlined_loop);
244243
} else {
245244
if (!canOmitElseClause(fl)) {
246245
unroll_ite->elseBody().push_back(inlined_loop);
247246
}
248-
kir::ExprMutator::registerReplace(
249-
fl,
250-
unroll_ite,
251-
for_loops_.empty() ? nullptr : &for_loops_.back()->body());
247+
kir::ExprMutator::registerReplace(fl, unroll_ite);
252248
}
253249
}
254250

third_party/nvfuser/csrc/lower_unroll.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ class TORCH_CUDA_CU_API UnrollPass : kir::ExprMutator {
5959
static bool canOmitElseClause(kir::ForLoop* fl);
6060

6161
private:
62-
void registerReplace(Expr* reference, Expr* new_expr, kir::Scope* scope);
62+
void registerReplace(Expr* reference, Expr* new_expr);
6363

6464
// Generate the for Expr replacement map
6565
UnrollPass(const std::vector<Expr*>& exprs);
@@ -68,7 +68,7 @@ class TORCH_CUDA_CU_API UnrollPass : kir::ExprMutator {
6868
return expr_replacement_map_;
6969
}
7070

71-
using OptOutDispatch::handle;
71+
using kir::ExprMutator::handle;
7272

7373
void handle(kir::ForLoop* fl) final;
7474

0 commit comments

Comments
 (0)