@@ -51,11 +51,8 @@ bool isReductionInitExpr(const Expr* expr) {
51
51
52
52
} // namespace
53
53
54
- void UnrollPass::registerReplace (
55
- Expr* reference,
56
- Expr* new_expr,
57
- kir::Scope* scope) {
58
- kir::ExprMutator::registerReplace (reference, new_expr, scope);
54
+ void UnrollPass::registerReplace (Expr* reference, Expr* new_expr) {
55
+ kir::ExprMutator::registerReplace (reference, new_expr);
59
56
GpuLower::current ()->propagateExprInfo (reference, new_expr);
60
57
}
61
58
@@ -115,7 +112,7 @@ void UnrollPass::handle(Expr* expr) {
115
112
expr_with_predicate = ShiftPredicateInserter::insert (
116
113
expr, for_loops_, thread_pred, unswitched_loop_);
117
114
if (expr_with_predicate != expr) {
118
- registerReplace (expr, expr_with_predicate, &for_loops_. back ()-> body () );
115
+ registerReplace (expr, expr_with_predicate);
119
116
}
120
117
return ;
121
118
}
@@ -137,7 +134,7 @@ void UnrollPass::handle(Expr* expr) {
137
134
: IrBuilder::create<kir::Predicate>(
138
135
PredicateType::Inline, expr, thread_pred);
139
136
expr_with_predicate = expr_with_predicate->withPredicate (pred);
140
- registerReplace (expr, expr_with_predicate, &for_loops_. back ()-> body () );
137
+ registerReplace (expr, expr_with_predicate);
141
138
return ;
142
139
}
143
140
@@ -160,27 +157,21 @@ void UnrollPass::handle(Expr* expr) {
160
157
161
158
if (lower_utils::supportInlinePredicate (expr)) {
162
159
expr_with_predicate = expr_with_predicate->withPredicate (pred);
163
- registerReplace (expr, expr_with_predicate, &for_loops_. back ()-> body () );
160
+ registerReplace (expr, expr_with_predicate);
164
161
return ;
165
162
}
166
163
167
164
// If we need a predicate, put expr inside an if then else
168
165
kir::IfThenElse* inline_ite = IrBuilder::create<kir::IfThenElse>(pred);
169
- if (for_loops_.empty ()) {
170
- // Special handling for top level output expressions that still
171
- // need predicates. One motivating example is a reduction op that
172
- // reduces to a scalar (issue #491)
173
- kir::ExprMutator::registerReplace (expr, inline_ite, nullptr );
174
- } else {
175
- kir::ExprMutator::registerReplace (
176
- expr, inline_ite, &for_loops_.back ()->body ());
177
- }
166
+ kir::ExprMutator::registerReplace (expr, inline_ite);
178
167
if (expr != expr_with_predicate) {
179
168
GpuLower::current ()->propagateExprInfo (expr, expr_with_predicate);
180
169
}
181
170
inline_ite->thenBody ().push_back (expr_with_predicate);
182
171
} else if (auto for_loop = dynamic_cast <kir::ForLoop*>(expr)) {
183
172
handle (for_loop);
173
+ } else if (auto ite = dynamic_cast <kir::IfThenElse*>(expr)) {
174
+ kir::ExprMutator::handle (ite);
184
175
}
185
176
}
186
177
@@ -196,6 +187,8 @@ void UnrollPass::handle(kir::ForLoop* fl) {
196
187
// normal.
197
188
if (!is_unroll || !look_for_unroll_) {
198
189
for_loops_.push_back (fl);
190
+ scope_.push_back (&fl->body ());
191
+ scope_exprs_.push_back (fl);
199
192
200
193
// Make copy of exprs because we replace them inplace in fl
201
194
const auto exprs_copy = fl->body ().exprs ();
@@ -208,6 +201,8 @@ void UnrollPass::handle(kir::ForLoop* fl) {
208
201
}
209
202
210
203
for_loops_.pop_back ();
204
+ scope_.pop_back ();
205
+ scope_exprs_.pop_back ();
211
206
return ;
212
207
}
213
208
@@ -217,38 +212,39 @@ void UnrollPass::handle(kir::ForLoop* fl) {
217
212
218
213
// Get the loop nest for the unrolled path
219
214
kir::ForLoop* unrolled_loop_nest = cloneLoopNest (fl);
215
+ unroll_ite->thenBody ().push_back (unrolled_loop_nest);
220
216
221
217
// Thread predicates are not removed from the expressions. Visit
222
218
// each expression to attach kir::Predicate.
219
+ scope_.push_back (&unroll_ite->thenBody ());
220
+ scope_exprs_.push_back (unroll_ite);
223
221
unswitched_loop_ = true ;
224
222
look_for_unroll_ = false ;
225
223
handle (unrolled_loop_nest);
226
224
unswitched_loop_ = false ;
227
225
look_for_unroll_ = true ;
228
-
229
- unroll_ite-> thenBody (). push_back (unrolled_loop_nest );
226
+ scope_. pop_back ();
227
+ scope_exprs_. pop_back ( );
230
228
231
229
// Loop nest for inlined path
232
230
kir::ForLoop* inlined_loop = cloneLoopNest (fl);
233
231
234
232
// Add inline predicates for inlined loop nest
233
+ scope_.push_back (&unroll_ite->elseBody ());
234
+ scope_exprs_.push_back (unroll_ite);
235
235
look_for_unroll_ = false ;
236
236
non_trivial_pred_found_ = false ;
237
237
handle (inlined_loop);
238
238
look_for_unroll_ = true ;
239
+ scope_.pop_back ();
240
+ scope_exprs_.pop_back ();
239
241
if (!non_trivial_pred_found_) {
240
- kir::ExprMutator::registerReplace (
241
- fl,
242
- inlined_loop,
243
- for_loops_.empty () ? nullptr : &for_loops_.back ()->body ());
242
+ kir::ExprMutator::registerReplace (fl, inlined_loop);
244
243
} else {
245
244
if (!canOmitElseClause (fl)) {
246
245
unroll_ite->elseBody ().push_back (inlined_loop);
247
246
}
248
- kir::ExprMutator::registerReplace (
249
- fl,
250
- unroll_ite,
251
- for_loops_.empty () ? nullptr : &for_loops_.back ()->body ());
247
+ kir::ExprMutator::registerReplace (fl, unroll_ite);
252
248
}
253
249
}
254
250
0 commit comments