From a3565ee05c963df7fc939c8df3d4f6a616970409 Mon Sep 17 00:00:00 2001 From: erweiw Date: Fri, 19 Jul 2024 16:32:53 -0600 Subject: [PATCH 1/2] Fixup a bug caused by 'hasOneUse' of iv not being enforced in affine.apply folding --- mlir/lib/Transform/AIRDependencyScheduleOpt.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mlir/lib/Transform/AIRDependencyScheduleOpt.cpp b/mlir/lib/Transform/AIRDependencyScheduleOpt.cpp index 343961a00..084a98bb7 100644 --- a/mlir/lib/Transform/AIRDependencyScheduleOpt.cpp +++ b/mlir/lib/Transform/AIRDependencyScheduleOpt.cpp @@ -1685,6 +1685,8 @@ struct CanonicalizeAffineApplyOnLoopInductionVar return failure(); if (!ivArg.getOwner()) return failure(); + if (!val.hasOneUse()) + return failure(); if (apply.getResult().use_empty()) return failure(); if (auto exec_apply = dyn_cast(apply->getParentOp())) From d22dce4cf7035b3501506cee4dfaf7283c14a540 Mon Sep 17 00:00:00 2001 From: erweiw Date: Fri, 19 Jul 2024 17:10:10 -0600 Subject: [PATCH 2/2] Improve dep tracing to take into consideration sync for + async affine.apply --- .../Transform/AIRDependencyScheduleOpt.cpp | 30 +++++++++++++++++-- 1 file changed, 27 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Transform/AIRDependencyScheduleOpt.cpp b/mlir/lib/Transform/AIRDependencyScheduleOpt.cpp index 084a98bb7..e27a88f2f 100644 --- a/mlir/lib/Transform/AIRDependencyScheduleOpt.cpp +++ b/mlir/lib/Transform/AIRDependencyScheduleOpt.cpp @@ -1715,7 +1715,15 @@ struct CanonicalizeAffineApplyOnLoopInductionVar if (auto exec = dyn_cast(apply->getParentOp())) { rewriter.setInsertionPoint(exec); exec.getResult(1).replaceAllUsesWith(sfo.getInductionVar()); - exec.getAsyncToken().replaceAllUsesWith(sfo.getRegionIterArgs()[0]); + if (sfo.getNumRegionIterArgs()) + exec.getAsyncToken().replaceAllUsesWith(sfo.getRegionIterArgs()[0]); + else if (exec.getAsyncDependencies().size() == 1) + exec.getAsyncToken().replaceAllUsesWith( + exec.getAsyncDependencies()[0]); + else { + exec->emitOpError("failed to reconstruct dependency after its erase"); + return failure(); + } rewriter.eraseOp(exec); } else { rewriter.setInsertionPoint(apply); @@ -1812,7 +1820,15 @@ struct CanonicalizeArithMuliOpOnLoopInductionVar if (auto exec = dyn_cast(op->getParentOp())) { rewriter.setInsertionPoint(exec); exec.getResult(1).replaceAllUsesWith(sfo.getInductionVar()); - exec.getAsyncToken().replaceAllUsesWith(sfo.getRegionIterArgs()[0]); + if (sfo.getNumRegionIterArgs()) + exec.getAsyncToken().replaceAllUsesWith(sfo.getRegionIterArgs()[0]); + else if (exec.getAsyncDependencies().size() == 1) + exec.getAsyncToken().replaceAllUsesWith( + exec.getAsyncDependencies()[0]); + else { + exec->emitOpError("failed to reconstruct dependency after its erase"); + return failure(); + } rewriter.eraseOp(exec); } else { rewriter.setInsertionPoint(op); @@ -1901,7 +1917,15 @@ struct CanonicalizeArithAddiOpOnLoopInductionVar if (auto exec = dyn_cast(op->getParentOp())) { rewriter.setInsertionPoint(exec); exec.getResult(1).replaceAllUsesWith(sfo.getInductionVar()); - exec.getAsyncToken().replaceAllUsesWith(sfo.getRegionIterArgs()[0]); + if (sfo.getNumRegionIterArgs()) + exec.getAsyncToken().replaceAllUsesWith(sfo.getRegionIterArgs()[0]); + else if (exec.getAsyncDependencies().size() == 1) + exec.getAsyncToken().replaceAllUsesWith( + exec.getAsyncDependencies()[0]); + else { + exec->emitOpError("failed to reconstruct dependency after its erase"); + return failure(); + } rewriter.eraseOp(exec); } else { rewriter.setInsertionPoint(op);