diff --git a/mlir/lib/Transform/AIRDependencyScheduleOpt.cpp b/mlir/lib/Transform/AIRDependencyScheduleOpt.cpp index 343961a00..e27a88f2f 100644 --- a/mlir/lib/Transform/AIRDependencyScheduleOpt.cpp +++ b/mlir/lib/Transform/AIRDependencyScheduleOpt.cpp @@ -1685,6 +1685,8 @@ struct CanonicalizeAffineApplyOnLoopInductionVar return failure(); if (!ivArg.getOwner()) return failure(); + if (!val.hasOneUse()) + return failure(); if (apply.getResult().use_empty()) return failure(); if (auto exec_apply = dyn_cast(apply->getParentOp())) @@ -1713,7 +1715,15 @@ struct CanonicalizeAffineApplyOnLoopInductionVar if (auto exec = dyn_cast(apply->getParentOp())) { rewriter.setInsertionPoint(exec); exec.getResult(1).replaceAllUsesWith(sfo.getInductionVar()); - exec.getAsyncToken().replaceAllUsesWith(sfo.getRegionIterArgs()[0]); + if (sfo.getNumRegionIterArgs()) + exec.getAsyncToken().replaceAllUsesWith(sfo.getRegionIterArgs()[0]); + else if (exec.getAsyncDependencies().size() == 1) + exec.getAsyncToken().replaceAllUsesWith( + exec.getAsyncDependencies()[0]); + else { + exec->emitOpError("failed to reconstruct dependency after its erase"); + return failure(); + } rewriter.eraseOp(exec); } else { rewriter.setInsertionPoint(apply); @@ -1810,7 +1820,15 @@ struct CanonicalizeArithMuliOpOnLoopInductionVar if (auto exec = dyn_cast(op->getParentOp())) { rewriter.setInsertionPoint(exec); exec.getResult(1).replaceAllUsesWith(sfo.getInductionVar()); - exec.getAsyncToken().replaceAllUsesWith(sfo.getRegionIterArgs()[0]); + if (sfo.getNumRegionIterArgs()) + exec.getAsyncToken().replaceAllUsesWith(sfo.getRegionIterArgs()[0]); + else if (exec.getAsyncDependencies().size() == 1) + exec.getAsyncToken().replaceAllUsesWith( + exec.getAsyncDependencies()[0]); + else { + exec->emitOpError("failed to reconstruct dependency after its erase"); + return failure(); + } rewriter.eraseOp(exec); } else { rewriter.setInsertionPoint(op); @@ -1899,7 +1917,15 @@ struct CanonicalizeArithAddiOpOnLoopInductionVar if (auto exec = dyn_cast(op->getParentOp())) { rewriter.setInsertionPoint(exec); exec.getResult(1).replaceAllUsesWith(sfo.getInductionVar()); - exec.getAsyncToken().replaceAllUsesWith(sfo.getRegionIterArgs()[0]); + if (sfo.getNumRegionIterArgs()) + exec.getAsyncToken().replaceAllUsesWith(sfo.getRegionIterArgs()[0]); + else if (exec.getAsyncDependencies().size() == 1) + exec.getAsyncToken().replaceAllUsesWith( + exec.getAsyncDependencies()[0]); + else { + exec->emitOpError("failed to reconstruct dependency after its erase"); + return failure(); + } rewriter.eraseOp(exec); } else { rewriter.setInsertionPoint(op);