Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 23 additions & 27 deletions third_party/nvfuser/csrc/lower_unroll.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,8 @@ bool isReductionInitExpr(const Expr* expr) {

} // namespace

void UnrollPass::registerReplace(
Expr* reference,
Expr* new_expr,
kir::Scope* scope) {
kir::ExprMutator::registerReplace(reference, new_expr, scope);
void UnrollPass::registerReplace(Expr* reference, Expr* new_expr) {
kir::ExprMutator::registerReplace(reference, new_expr);
GpuLower::current()->propagateExprInfo(reference, new_expr);
}

Expand Down Expand Up @@ -115,7 +112,7 @@ void UnrollPass::handle(Expr* expr) {
expr_with_predicate = ShiftPredicateInserter::insert(
expr, for_loops_, thread_pred, unswitched_loop_);
if (expr_with_predicate != expr) {
registerReplace(expr, expr_with_predicate, &for_loops_.back()->body());
registerReplace(expr, expr_with_predicate);
}
return;
}
Expand All @@ -137,7 +134,7 @@ void UnrollPass::handle(Expr* expr) {
: IrBuilder::create<kir::Predicate>(
PredicateType::Inline, expr, thread_pred);
expr_with_predicate = expr_with_predicate->withPredicate(pred);
registerReplace(expr, expr_with_predicate, &for_loops_.back()->body());
registerReplace(expr, expr_with_predicate);
return;
}

Expand All @@ -160,27 +157,21 @@ void UnrollPass::handle(Expr* expr) {

if (lower_utils::supportInlinePredicate(expr)) {
expr_with_predicate = expr_with_predicate->withPredicate(pred);
registerReplace(expr, expr_with_predicate, &for_loops_.back()->body());
registerReplace(expr, expr_with_predicate);
return;
}

// If we need a predicate, put expr inside an if then else
kir::IfThenElse* inline_ite = IrBuilder::create<kir::IfThenElse>(pred);
if (for_loops_.empty()) {
// Special handling for top level output expressions that still
// need predicates. One motivating example is a reduction op that
// reduces to a scalar (issue #491)
kir::ExprMutator::registerReplace(expr, inline_ite, nullptr);
} else {
kir::ExprMutator::registerReplace(
expr, inline_ite, &for_loops_.back()->body());
}
kir::ExprMutator::registerReplace(expr, inline_ite);
if (expr != expr_with_predicate) {
GpuLower::current()->propagateExprInfo(expr, expr_with_predicate);
}
inline_ite->thenBody().push_back(expr_with_predicate);
} else if (auto for_loop = dynamic_cast<kir::ForLoop*>(expr)) {
handle(for_loop);
} else if (auto ite = dynamic_cast<kir::IfThenElse*>(expr)) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm guessing you're adding something in the original PR that introduces IfThenElse before the unroll pass. What is it?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it is loop rotation (which happens right before the unroll pass):

for i in range(n):
  statement1(i)
  statement2(i)
  statement3(i)
  statement4(i)

transform to

if 0 < n:
  for i = 0:
    statement1(i)
    statement2(i)
for i ...:
  statement3(i)
  statement4(i)
  if i + 1 < n:
    statement1(i)
    statement2(i)

I am actually not materializing these conditions (because I think existing predicates should already cover all illegal access), so I am just generating:

for i = 0:
  statement1(i)
  statement2(i)
for i ...:
  statement3(i)
  statement4(i)
  if true:
    statement1(i)
    statement2(i)

But the if true is still necessary because I use it as a special container to mark which part of the for loop is rotated from the next iteration.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. Not about this PR, but another thing we should clean up is to make the pass dependencies more explicit. I suspect there's a pass that assumes there's no kir::IfThenElse in the incoming expr list, but I don't remember that's validated.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. Not about this PR, but another thing we should clean up is to make the pass dependencies more explicit. I suspect there's a pass that assumes there's no kir::IfThenElse in the incoming expr list, but I don't remember that's validated.

I think you are referring to the double buffer pass? Fortunately, there is a TORCH_INTERNAL_ASSERT on handle(kir::IfThenElse* ite).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, but it's not much about a specific pass. These assumptions on the lowering pass dependencies should be explicitly represented and enforced.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That sound like something like a pass manager? Passes should use a unified data structure and should have some metadata stored in that data structure so that the pass manager can parse and decide the order of passes?

kir::ExprMutator::handle(ite);
}
}

Expand All @@ -196,6 +187,8 @@ void UnrollPass::handle(kir::ForLoop* fl) {
// normal.
if (!is_unroll || !look_for_unroll_) {
for_loops_.push_back(fl);
scope_.push_back(&fl->body());
scope_exprs_.push_back(fl);

// Make copy of exprs because we replace them inplace in fl
const auto exprs_copy = fl->body().exprs();
Expand All @@ -208,6 +201,8 @@ void UnrollPass::handle(kir::ForLoop* fl) {
}

for_loops_.pop_back();
scope_.pop_back();
scope_exprs_.pop_back();
return;
}

Expand All @@ -217,38 +212,39 @@ void UnrollPass::handle(kir::ForLoop* fl) {

// Get the loop nest for the unrolled path
kir::ForLoop* unrolled_loop_nest = cloneLoopNest(fl);
unroll_ite->thenBody().push_back(unrolled_loop_nest);

// Thread predicates are not removed from the expressions. Visit
// each expression to attach kir::Predicate.
scope_.push_back(&unroll_ite->thenBody());
scope_exprs_.push_back(unroll_ite);
unswitched_loop_ = true;
look_for_unroll_ = false;
handle(unrolled_loop_nest);
unswitched_loop_ = false;
look_for_unroll_ = true;

unroll_ite->thenBody().push_back(unrolled_loop_nest);
scope_.pop_back();
scope_exprs_.pop_back();

// Loop nest for inlined path
kir::ForLoop* inlined_loop = cloneLoopNest(fl);

// Add inline predicates for inlined loop nest
scope_.push_back(&unroll_ite->elseBody());
scope_exprs_.push_back(unroll_ite);
look_for_unroll_ = false;
non_trivial_pred_found_ = false;
handle(inlined_loop);
look_for_unroll_ = true;
scope_.pop_back();
scope_exprs_.pop_back();
if (!non_trivial_pred_found_) {
kir::ExprMutator::registerReplace(
fl,
inlined_loop,
for_loops_.empty() ? nullptr : &for_loops_.back()->body());
kir::ExprMutator::registerReplace(fl, inlined_loop);
} else {
if (!canOmitElseClause(fl)) {
unroll_ite->elseBody().push_back(inlined_loop);
}
kir::ExprMutator::registerReplace(
fl,
unroll_ite,
for_loops_.empty() ? nullptr : &for_loops_.back()->body());
kir::ExprMutator::registerReplace(fl, unroll_ite);
}
}

Expand Down
4 changes: 2 additions & 2 deletions third_party/nvfuser/csrc/lower_unroll.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class TORCH_CUDA_CU_API UnrollPass : kir::ExprMutator {
static bool canOmitElseClause(kir::ForLoop* fl);

private:
void registerReplace(Expr* reference, Expr* new_expr, kir::Scope* scope);
void registerReplace(Expr* reference, Expr* new_expr);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we still need the other version?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it is needed. Running tests to verify.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are no longer needed. Removed.


// Generate the for Expr replacement map
UnrollPass(const std::vector<Expr*>& exprs);
Expand All @@ -68,7 +68,7 @@ class TORCH_CUDA_CU_API UnrollPass : kir::ExprMutator {
return expr_replacement_map_;
}

using OptOutDispatch::handle;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be using kir::ExprMutator::handle;, but just removing it results in a build error for me (using clang):

/raid/tmp/nmaruyama/debug1/third_party/nvfuser/csrc/lower_unroll.h:74:8: error: 'nvfuser::UnrollPass::handle' hides overloaded virtual function [-Werror,-Woverloaded-virtual]
  void handle(Expr* expr) final;
       ^
/raid/tmp/nmaruyama/debug1/third_party/nvfuser/csrc/kernel_ir_dispatch.h:36:16: note: hidden overloaded virtual function 'nvfuser::kir::IrVisitor::handle' declared here: type mismatch at 1st parameter ('nvfuser::kir::IfThenElse *' vs 'nvfuser::Expr *')
  virtual void handle(IfThenElse*) override;

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't make sense to me. But I will not argue with a compiler about which is correct, so I just added it back.

using kir::ExprMutator::handle;

void handle(kir::ForLoop* fl) final;

Expand Down