From edb8e02ddcb003349cce8cac33303d01b79907a7 Mon Sep 17 00:00:00 2001 From: erwei-xilinx Date: Wed, 14 Feb 2024 18:51:05 -0800 Subject: [PATCH] Add `air-loop-fusion` pass to fuse scf.for loops in air.segment (#426) * Fixup a missing condition when erasing async events * Perform memalloc hoisting before converting air.dma to air.channel ops; simplify logic for air.dma memory space demotion * Silence warnings * Update unit test to reflect a real gemm scenario * Add a pass which fuses scf.for loops in air.segment, and generates the candidate for loop structure for pingpong buffering --- .../air/Transform/AIRDependencyScheduleOpt.h | 2 + mlir/include/air/Transform/PassDetail.h | 1 + mlir/include/air/Transform/Passes.td | 8 + mlir/lib/Conversion/ConvertToAIRPass.cpp | 130 +++++-------- .../Transform/AIRDependencyScheduleOpt.cpp | 180 ++++++++++++++++++ mlir/lib/Util/Dependency.cpp | 13 +- .../dma_to_channel_nested_for_in_herd.mlir | 164 ++++------------ .../segment_loop_fusion.mlir | 109 +++++++++++ 8 files changed, 395 insertions(+), 212 deletions(-) create mode 100644 mlir/test/Transform/AIRDependencyScheduleOpt/segment_loop_fusion.mlir diff --git a/mlir/include/air/Transform/AIRDependencyScheduleOpt.h b/mlir/include/air/Transform/AIRDependencyScheduleOpt.h index 2f9162c08..2ab79ffe5 100644 --- a/mlir/include/air/Transform/AIRDependencyScheduleOpt.h +++ b/mlir/include/air/Transform/AIRDependencyScheduleOpt.h @@ -55,6 +55,8 @@ std::unique_ptr createAIRFuseChannels(); std::unique_ptr createAIRIsolateAsyncDmaLoopNests(); +std::unique_ptr createAIRSegmentLoopFusion(); + // Populate patterns for canonicalizing index operations on loop index // variables. At the moment, only affine.apply computations on induction // variables are canonicalized diff --git a/mlir/include/air/Transform/PassDetail.h b/mlir/include/air/Transform/PassDetail.h index 8473a0a06..5f62a7157 100644 --- a/mlir/include/air/Transform/PassDetail.h +++ b/mlir/include/air/Transform/PassDetail.h @@ -66,6 +66,7 @@ namespace air { #define GEN_PASS_DEF_AIRUNROLLCHANNELBYFACTORPATTERN #define GEN_PASS_DEF_AIRUNROLLLOOPFORPIPELININGPATTERN #define GEN_PASS_DEF_AFFINELOOPOPTPASS +#define GEN_PASS_DEF_AIRSEGMENTLOOPFUSION #include "air/Transform/Passes.h.inc" } // namespace air diff --git a/mlir/include/air/Transform/Passes.td b/mlir/include/air/Transform/Passes.td index 7aa6ec070..c726c4cf7 100644 --- a/mlir/include/air/Transform/Passes.td +++ b/mlir/include/air/Transform/Passes.td @@ -1034,6 +1034,14 @@ def AIRIsolateAsyncDmaLoopNests: Pass<"air-isolate-async-dma-loop-nests", "Modul }]; } +def AIRSegmentLoopFusion: Pass<"air-loop-fusion", "func::FuncOp"> { + let summary = "Hoist dma ops into perfectly nested loop"; + let constructor = "xilinx::air::createAIRSegmentLoopFusion()"; + let description = [{ + This pass performs loop fusion within air.segment op's region. + }]; +} + def AIRDependencyScheduleOpt: Pass<"air-dependency-schedule-opt", "ModuleOp"> { let summary = "Optimize scheduling based on air async dependency"; let constructor = "xilinx::air::createAIRDependencyScheduleOptPass()"; diff --git a/mlir/lib/Conversion/ConvertToAIRPass.cpp b/mlir/lib/Conversion/ConvertToAIRPass.cpp index 00eb06b9d..90ab0b6e8 100644 --- a/mlir/lib/Conversion/ConvertToAIRPass.cpp +++ b/mlir/lib/Conversion/ConvertToAIRPass.cpp @@ -520,6 +520,11 @@ T cloneScfLoopUsingRemap(OpBuilder builder, IRMapping &remap, T loop_op, } else if (externalGetPut && dyn_cast(child_op)) { // If externalGetPut is not nullptr, then broadcast lowering mode is on replaceAffineIfOpWithChannelOpAndClone(builder, remap, externalGetPut); + } else if (auto dma_op = dyn_cast(child_op)) { + if (child_op.hasAttr("loop-carried-dep")) + builder.clone(child_op, remap); + else + replaceAsyncOpWithWaitAllAndClone(builder, remap, &child_op, false); } else if (getLinalgOpFromExecuteOp(&child_op)) { replaceAsyncOpWithWaitAllAndClone(builder, remap, &child_op, false); } else { @@ -1066,16 +1071,16 @@ class AIRDmaToAIRChannelConversion SmallVector dst_strides = op.getDstStrides(); if (src_offsets.size()) { - if (src_sizes.size() != src_rank) + if (src_sizes.size() != (unsigned)src_rank) return failure(); - if (src_strides.size() != src_rank) + if (src_strides.size() != (unsigned)src_rank) return failure(); } if (dst_offsets.size()) { - if (dst_sizes.size() != dst_rank) + if (dst_sizes.size() != (unsigned)dst_rank) return failure(); - if (dst_strides.size() != dst_rank) + if (dst_strides.size() != (unsigned)dst_rank) return failure(); } @@ -1478,16 +1483,16 @@ class AIRDemoteDmaToAIRHierarchyConversion SmallVector dst_strides = op.getDstStrides(); if (src_offsets.size()) { - if (src_sizes.size() != src_rank) + if (src_sizes.size() != (unsigned)src_rank) return failure(); - if (src_strides.size() != src_rank) + if (src_strides.size() != (unsigned)src_rank) return failure(); } if (dst_offsets.size()) { - if (dst_sizes.size() != dst_rank) + if (dst_sizes.size() != (unsigned)dst_rank) return failure(); - if (dst_strides.size() != dst_rank) + if (dst_strides.size() != (unsigned)dst_rank) return failure(); } @@ -1496,71 +1501,40 @@ class AIRDemoteDmaToAIRHierarchyConversion { OpBuilder::InsertionGuard guard(rewriter); - SetVector backwardSlice; - BackwardSliceOptions bsOptions{ - [&](Operation *o) { return o != hier_op; }}; - getBackwardSlice(op.getOperation(), &backwardSlice, bsOptions); - - for (auto parent = op->getParentOp(); - !isa(parent); - parent = parent->getParentOp()) { - getBackwardSlice(parent, &backwardSlice, bsOptions); - backwardSlice.insert(parent); - } - - for (auto b : backwardSlice) { - if (dyn_cast(b)) { - for (auto &exec_child_op : b->getRegions().front().getOps()) { - getBackwardSlice(&exec_child_op, &backwardSlice, bsOptions); - backwardSlice.insert(&exec_child_op); - } - } - } + SmallVector backwardSlice; + backwardSlice.push_back(op); + if (isa(op->getParentOp())) + backwardSlice.push_back(op->getParentOp()); + for (auto o : backwardSlice) + for (auto oper : o->getOperands()) + if (getConstantIntValue(oper)) + backwardSlice.push_back(oper.getDefiningOp()); for (auto b : backwardSlice) { b->setAttr("hoist", StringAttr::get(ctx, "dep")); } - op->setAttr("hoist", StringAttr::get(op->getContext(), "dep")); op->setAttr("loop-carried-dep", StringAttr::get(op->getContext(), "external")); // Hoist hierarchy op into scf op - Operation *scf_loop = nullptr; - mlir::OpBuilder::InsertPoint - insertionPointAtHierOp; // To keep a record of the insertion point as - // destination for hoisting rewriter.setInsertionPoint(hier_op); - if (herd) { - SmallVector lbs; - SmallVector ubs; - auto size = herd.getSizeOperands(); - for (auto s : size) { - lbs.push_back(0); - ubs.push_back(*mlir::getConstantIntValue(s)); - } - scf::ParallelOp scf_par = - hoistHerdToAsyncParallel(rewriter, loc, ctx, herd, lbs, ubs); - scf_loop = scf_par.getOperation(); - } else if (segment) { - // Since segment doesn't have iteration space, it doesn't hoist a loop - insertionPointAtHierOp = rewriter.saveInsertionPoint(); - } if (herd) { - auto scf_par = dyn_cast(scf_loop); // Get mapping for remapped ssa values entering the hoisted scf.parallel IRMapping remap; - auto herd_size = herd.getSizeOperands(); - remap.map(herd.getSize()[0], herd_size[0]); - remap.map(herd.getSize()[1], herd_size[1]); - remap.map(herd.getIds()[0], scf_par.getInductionVars()[0]); - remap.map(herd.getIds()[1], scf_par.getInductionVars()[1]); + if (auto for_op = dyn_cast(op->getParentOp())) + for (auto init_arg : for_op.getInitArgs()) + remap.map(init_arg, + rewriter + .create( + loc, air::AsyncTokenType::get(op->getContext()), + SmallVector{}) + .getAsyncToken()); int arg_idx = 0; for (auto arg : herd.getKernelArguments()) remap.map(arg, herd.getKernelOperand(arg_idx++)); // Clone ops into hoisted scf.parallel - rewriter.setInsertionPointToStart(scf_par.getBody()); for (Operation &o : herd->getRegions().front().getBlocks().front().getOperations()) { if (isa(o)) @@ -1584,17 +1558,6 @@ class AIRDemoteDmaToAIRHierarchyConversion } else return failure(); - if (scf_loop) { - scf_loop->walk([&](mlir::Operation *o) { - if (o == o->getBlock()->getTerminator()) { - return; - } - if (!o->hasAttr("hoist")) - erased.insert(o); - else - o->removeAttr("hoist"); - }); - } hier_op.walk([&](mlir::Operation *o) { if (o->hasAttr("hoist")) o->removeAttr("hoist"); @@ -2396,13 +2359,13 @@ static LogicalResult canonicalizeAIRDmaOperands(OpBuilder builder, // default order. int max_dim_size = std::max(std::max(offsets.size(), sizes.size()), strides.size()); - if (max_dim_size && offsets.size() < memref.getRank()) { + if (max_dim_size && offsets.size() < (unsigned)memref.getRank()) { for (unsigned i = offsets.size(); i < memref.getRank(); i++) { offsets.insert(offsets.begin(), builder.create( builder.getUnknownLoc(), 0)); } } - if (max_dim_size && sizes.size() < memref.getRank()) { + if (max_dim_size && sizes.size() < (unsigned)memref.getRank()) { for (unsigned i = sizes.size(); i < memref.getRank(); i++) { sizes.insert(sizes.begin(), builder.create( builder.getUnknownLoc(), 1)); @@ -2411,7 +2374,7 @@ static LogicalResult canonicalizeAIRDmaOperands(OpBuilder builder, int memref_size = 1; for (auto size : memref.getShape()) memref_size *= size; - if (max_dim_size && strides.size() < memref.getRank()) { + if (max_dim_size && strides.size() < (unsigned)memref.getRank()) { for (unsigned i = strides.size(); i < memref.getRank(); i++) { strides.insert(strides.begin(), builder.create( @@ -2420,12 +2383,13 @@ static LogicalResult canonicalizeAIRDmaOperands(OpBuilder builder, } // Reduce highest dimensions if more than memref size - while (strides.size() > memref.getRank() && getConstantIntValue(strides[0]) && + while (strides.size() > (unsigned)memref.getRank() && + getConstantIntValue(strides[0]) && *getConstantIntValue(strides[0]) == memref_size) { strides.erase(strides.begin()); } - while (sizes.size() > memref.getRank() && getConstantIntValue(sizes[0]) && - *getConstantIntValue(sizes[0]) == 1) { + while (sizes.size() > (unsigned)memref.getRank() && + getConstantIntValue(sizes[0]) && *getConstantIntValue(sizes[0]) == 1) { sizes.erase(sizes.begin()); } while (offsets.size() > std::min(sizes.size(), strides.size()) && @@ -2755,16 +2719,6 @@ struct DmaToChannelPass : public air::impl::DmaToChannelBase { SmallVector funcOps; module.walk([&](func::FuncOp op) { funcOps.push_back(op); }); - // Hoist broadcast pattern - for (auto f : funcOps) { - f.walk([&](affine::AffineIfOp op) { - if (!op->getParentOfType()) { - // Only hoist top-level affine if op with a nest of if ops - HoistingAffineIf(op); - } - }); - } - // Demote memref alloc pattern std::map> hier_to_allocs; for (auto f : funcOps) { @@ -2784,7 +2738,7 @@ struct DmaToChannelPass : public air::impl::DmaToChannelBase { alloc->getParentOfType() ? alloc->getParentOfType().getOperation() : alloc.getOperation(); - if (memref_type.getMemorySpaceAsInt() < hierMemorySpace) { + if (memref_type.getMemorySpaceAsInt() < (unsigned)hierMemorySpace) { hier_to_allocs[hier_op].push_back(alloc_op); } }); @@ -2794,6 +2748,16 @@ struct DmaToChannelPass : public air::impl::DmaToChannelBase { (void)AIRDemoteMemrefToAIRHierarchy(pair, builder); } + // Hoist broadcast pattern + for (auto f : funcOps) { + f.walk([&](affine::AffineIfOp op) { + if (!op->getParentOfType()) { + // Only hoist top-level affine if op with a nest of if ops + HoistingAffineIf(op); + } + }); + } + // First pattern to demote dma ops to corresponding air hierarchy ConversionTarget target_0(*context); diff --git a/mlir/lib/Transform/AIRDependencyScheduleOpt.cpp b/mlir/lib/Transform/AIRDependencyScheduleOpt.cpp index b8c4b1431..c3a15f5de 100644 --- a/mlir/lib/Transform/AIRDependencyScheduleOpt.cpp +++ b/mlir/lib/Transform/AIRDependencyScheduleOpt.cpp @@ -3338,6 +3338,182 @@ class AIRIsolateAsyncDmaLoopNests } }; +// A pass which performs loop fusion within air.segment op's region. +class AIRSegmentLoopFusion + : public xilinx::air::impl::AIRSegmentLoopFusionBase { + +public: + AIRSegmentLoopFusion() = default; + AIRSegmentLoopFusion(const AIRSegmentLoopFusion &pass) {} + + void getDependentDialects(::mlir::DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnSegment(air::SegmentOp op) { + auto loc = op->getLoc(); + // Get memref.alloc ops. + SmallVector memalloc_execs; + SmallVector memdealloc_execs; + // Map from air.execute op containing alloc to air.execute op containing + // dealloc. + std::map alloc_dealloc_execs; + for (auto execOp : op.getOps()) { + if (auto child_op = execOp.getChildOp()) { + if (isa(child_op)) + alloc_dealloc_execs[execOp] = nullptr; + } + } + for (auto execOp : op.getOps()) { + if (auto child_op = execOp.getChildOp()) { + if (auto dealloc = dyn_cast(child_op)) + if (llvm::any_of(alloc_dealloc_execs, + [&](std::pair pair) { + return dealloc.getMemref() == + pair.first.getResult(1); + })) { + alloc_dealloc_execs[dyn_cast( + dealloc.getMemref().getDefiningOp())] = execOp; + } + } + } + // Get roots to perfectly nested scf.for loops. + SmallVector perfectlyNestedForBands; + auto hasNElements = [](Block *block, unsigned N) { + auto op_ptr = block->begin(); + for (unsigned i = 0; i < N; i++) + op_ptr = std::next(op_ptr); + return op_ptr != block->end() && &*op_ptr == &block->back(); + }; + for (auto forOp : op.getOps()) { + if (hasNElements(forOp.getBody(), 1)) + perfectlyNestedForBands.push_back(forOp); + } + if (perfectlyNestedForBands.empty()) + return; + if (alloc_dealloc_execs.empty()) + return; + + // From the loop bands, get fusable scf.for for loop bands. + SmallVector equalIterationForOps; + equalIterationForOps.push_back(perfectlyNestedForBands[0]); + auto lb = perfectlyNestedForBands[0].getLowerBound(); + auto ub = perfectlyNestedForBands[0].getUpperBound(); + auto step = perfectlyNestedForBands[0].getStep(); + for (unsigned i = 1; i < perfectlyNestedForBands.size(); i++) { + if (perfectlyNestedForBands[i].getLowerBound() == lb && + perfectlyNestedForBands[i].getUpperBound() == ub && + perfectlyNestedForBands[i].getStep() == step) { + equalIterationForOps.push_back(perfectlyNestedForBands[i]); + } + } + if (equalIterationForOps.empty()) + return; + + // Folding memref.alloc / dealloc ops into fused loop. + SmallVector fusableForOps; + OpBuilder builder(equalIterationForOps[0]); + auto new_loop_op_init_arg = + builder + .create( + loc, air::AsyncTokenType::get(builder.getContext()), + SmallVector{}) + .getAsyncToken(); + scf::ForOp new_loop_op = + builder.create(builder.getUnknownLoc(), lb, ub, step, + SmallVector{new_loop_op_init_arg}); + SmallVector erase_keys; + for (auto execOpPair : alloc_dealloc_execs) { + bool canMove = false; + air::ExecuteOp alloc_exec = execOpPair.first; + for (auto token_user : alloc_exec.getAsyncToken().getUsers()) { + if (llvm::any_of(equalIterationForOps, [&](scf::ForOp fusableForOp) { + return fusableForOp == token_user; + })) { + fusableForOps.push_back(dyn_cast(token_user)); + canMove = true; + } + } + if (canMove) { + for (auto user : alloc_exec.getAsyncToken().getUsers()) { + if (auto async_user = dyn_cast(user)) + eraseAsyncDependencyFromAsyncOp(async_user, + alloc_exec.getAsyncToken()); + } + alloc_exec->moveBefore(new_loop_op.getBody(), + new_loop_op.getBody()->getOperations().end()); + } else + erase_keys.push_back(alloc_exec); + } + for (auto e : erase_keys) + alloc_dealloc_execs.erase(e); + + // Loop fusion. + IRMapping remap; + for (auto forOp : fusableForOps) { + remap.map(forOp.getInductionVar(), new_loop_op.getInductionVar()); + for (unsigned i = 0; i < forOp.getRegionIterArgs().size(); i++) + remap.map(forOp.getRegionIterArgs(), new_loop_op.getRegionIterArgs()); + builder.setInsertionPointToEnd(new_loop_op.getBody()); + for (auto &child_op : forOp.getBody()->getOperations()) + if (!child_op.mightHaveTrait()) + builder.clone(child_op, remap); + } + + // Fuse dealloc ops. + for (auto execOpPair : alloc_dealloc_execs) { + air::ExecuteOp dealloc_exec = execOpPair.second; + dealloc_exec->moveBefore(new_loop_op.getBody(), + new_loop_op.getBody()->getOperations().end()); + } + + // Scf.yield op. + builder.setInsertionPointToEnd(new_loop_op.getBody()); + SmallVector yield_dep_list; + for (auto &child_op : new_loop_op.getBody()->getOperations()) { + if (!child_op.getResults().empty()) { + if (isa(child_op.getResult(0).getType()) && + child_op.getResult(0).getUsers().empty()) { + yield_dep_list.push_back(child_op.getResult(0)); + } + } + } + auto wa_op = builder.create( + loc, air::AsyncTokenType::get(builder.getContext()), yield_dep_list); + builder.create(loc, wa_op.getAsyncToken()); + + // Erase original scf.for ops. + for (auto forOp : fusableForOps) { + assert(forOp.getNumResults() == new_loop_op.getNumResults() && + "Fused loop has different number of results as original"); + for (unsigned i = 0; i < forOp.getNumResults(); i++) { + forOp.getResult(i).replaceAllUsesWith(new_loop_op.getResult(i)); + } + forOp->erase(); + } + + new_loop_op.walk([&](air::ChannelPutOp putOp) { + air::ChannelGetOp getOp = nullptr; + for (auto user : putOp.getMemref().getUsers()) + if (auto get_user = dyn_cast(user)) + getOp = get_user; + scf::ForOp put_parent = putOp->getParentOfType(); + put_parent->setOperand(3, getOp.getAsyncToken()); + }); + } + + void runOnOperation() override { + auto func = getOperation(); + SmallVector segs; + func.walk([&](air::SegmentOp op) { segs.push_back(op); }); + for (auto seg : segs) { + runOnSegment(seg); + } + } + +private: +}; + } // namespace namespace xilinx { @@ -3419,6 +3595,10 @@ std::unique_ptr createAIRIsolateAsyncDmaLoopNests() { return std::make_unique(); } +std::unique_ptr createAIRSegmentLoopFusion() { + return std::make_unique(); +} + void populateAIRLoopIndexCanonicalizationPatterns(RewritePatternSet &patterns) { MLIRContext *ctx = patterns.getContext(); patterns.insert(ctx); diff --git a/mlir/lib/Util/Dependency.cpp b/mlir/lib/Util/Dependency.cpp index 34d10db2b..cddc198cb 100644 --- a/mlir/lib/Util/Dependency.cpp +++ b/mlir/lib/Util/Dependency.cpp @@ -1573,11 +1573,14 @@ void dependencyCanonicalizer::removeUnusedExecuteOp(func::FuncOp func) { }); for (auto op : erased_ops) { - for (auto user : op.getAsyncToken().getUsers()) { - if (auto async_user = dyn_cast(user)) { - eraseAsyncDependencyFromAsyncOp(async_user, op.getAsyncToken()); - } - } + OpBuilder builder(op); + auto new_token = + builder + .create( + op->getLoc(), air::AsyncTokenType::get(builder.getContext()), + op.getAsyncDependencies()) + .getAsyncToken(); + op.getAsyncToken().replaceAllUsesWith(new_token); if (!op.getAsyncToken().use_empty()) op->emitOpError("returned async token still has uses"); op->erase(); diff --git a/mlir/test/Conversion/ConvertToAIR/dma_to_channel_nested_for_in_herd.mlir b/mlir/test/Conversion/ConvertToAIR/dma_to_channel_nested_for_in_herd.mlir index 94968e477..9b73c8fe6 100644 --- a/mlir/test/Conversion/ConvertToAIR/dma_to_channel_nested_for_in_herd.mlir +++ b/mlir/test/Conversion/ConvertToAIR/dma_to_channel_nested_for_in_herd.mlir @@ -152,134 +152,51 @@ module { return %results_2 : memref<64x64xi32> } -// CHECK-LABEL: func.func @legalize_memspace_sync -// CHECK: scf.parallel -// CHECK: scf.for{{.*}}%c0 to %c512 step %c64 -// CHECK: scf.for{{.*}}%c0 to %c512 step %c64 -// CHECK: scf.for{{.*}}%c0 to %c64 step %c32 -// CHECK: scf.for{{.*}}%c0 to %c64 step %c32 -// CHECK: scf.for{{.*}}%c0 to %c1024 step %c128 -// CHECK: air.channel.put @channel_8 -// CHECK: } -// CHECK: } -// CHECK: } -// CHECK: } -// CHECK-NEXT: } -// CHECK: scf.parallel -// CHECK: scf.for{{.*}}%c0 to %c512 step %c64 -// CHECK: scf.for{{.*}}%c0 to %c512 step %c64 -// CHECK: air.channel.get @channel_9 -// CHECK: } -// CHECK-NEXT: } +// CHECK-LABEL: func.func @gemm_k_dim_tiling +// CHECK: air.launch +// CHECK: scf.for{{.*}} +// CHECK: air.channel.put{{.*}}@channel_8 // CHECK: air.segment @segment_0 -// CHECK: scf.parallel -// CHECK: scf.for{{.*}}%c0_6 to %c512_2 step %c64_1 -// CHECK: scf.for{{.*}}%c0_6 to %c512_2 step %c64_1 -// CHECK: scf.for{{.*}}%c0_6 to %c64_1 step %c32_3 -// CHECK: scf.for{{.*}}%c0_6 to %c64_1 step %c32_3 -// CHECK: scf.for{{.*}}%c0_6 to %c1024_5 step %c128_4 -// CHECK: air.channel.get @channel_8 -// CHECK: } -// CHECK: } -// CHECK: } -// CHECK: } -// CHECK-NEXT: } -// CHECK: scf.parallel -// CHECK: scf.for{{.*}}%c0_6 to %c512_2 step %c64_1 -// CHECK: scf.for{{.*}}%c0_6 to %c512_2 step %c64_1 -// CHECK: air.channel.put @channel_9 -// CHECK: } -// CHECK-NEXT: } -// CHECK: scf.parallel -// CHECK: scf.for{{.*}}%c0_6 to %c512_2 step %c64_1 -// CHECK: scf.for{{.*}}%c0_6 to %c512_2 step %c64_1 -// CHECK: scf.for{{.*}}%c0_6 to %c64_1 step %c32_3 -// CHECK: scf.for{{.*}}%c0_6 to %c64_1 step %c32_3 -// CHECK: scf.for{{.*}}%c0_6 to %c1024_5 step %c128_4 -// CHECK: scf.for{{.*}}%c0_6 to %c128_4 step %c32_3 -// CHECK: air.channel.put @channel_10 -// CHECK: } -// CHECK: } -// CHECK: } -// CHECK: } -// CHECK: } -// CHECK-NEXT: } -// CHECK: scf.parallel -// CHECK: scf.for{{.*}}%c0_6 to %c512_2 step %c64_1 -// CHECK: scf.for{{.*}}%c0_6 to %c512_2 step %c64_1 -// CHECK: scf.for{{.*}}%c0_6 to %c64_1 step %c32_3 -// CHECK: scf.for{{.*}}%c0_6 to %c64_1 step %c32_3 -// CHECK: air.channel.get @channel_11 -// CHECK: } -// CHECK: } -// CHECK: } -// CHECK-NEXT: } +// CHECK: scf.for{{.*}} +// CHECK: air.channel.get{{.*}}@channel_8 // CHECK: air.herd @herd_0 -// CHECK: scf.for{{.*}}%c0_15 to %c512_14 step %c64_16 -// CHECK: scf.for{{.*}}%c0_15 to %c512_14 step %c64_16 -// CHECK: scf.for{{.*}}%c0_15 to %c64_16 step %c32_13 -// CHECK: scf.for{{.*}}%c0_15 to %c64_16 step %c32_13 -// CHECK: scf.for{{.*}}%c0_15 to %c1024_11 step %c128_12 -// CHECK: scf.for{{.*}}%c0_15 to %c128_12 step %c32_13 -// CHECK: air.channel.get @channel_10 -// CHECK: } -// CHECK: } -// CHECK: air.channel.put @channel_11 -// CHECK: } -// CHECK: } -// CHECK: } -// CHECK: } -// CHECK: air.herd_terminator -// CHECK: air.segment_terminator -// CHECK: air.launch_terminator - func.func @legalize_memspace_sync(%arg0: memref<1024x1024xi32>, %arg1: memref<1024x1024xi32>, %arg2: memref<1024x1024xi32>) { - %c2 = arith.constant 2 : index - %alloc = memref.alloc() : memref<1024x1024xi32> - %c1 = arith.constant 1 : index - %c1_0 = arith.constant 1 : index - air.launch (%arg3, %arg4) in (%arg5=%c1, %arg6=%c1_0) args(%arg7=%arg0, %arg8=%arg1, %arg9=%alloc) : memref<1024x1024xi32>, memref<1024x1024xi32>, memref<1024x1024xi32> { - air.segment @segment_0 args(%arg10=%arg3, %arg11=%arg4, %arg12=%arg5, %arg13=%arg6, %arg14=%arg7, %arg15=%arg8, %arg16=%arg9) : index, index, index, index, memref<1024x1024xi32>, memref<1024x1024xi32>, memref<1024x1024xi32> { - %c2_1 = arith.constant 2 : index - %c2_2 = arith.constant 2 : index - air.herd @herd_0 tile (%arg17, %arg18) in (%arg19=%c2_1, %arg20=%c2_2) args(%arg21=%arg10, %arg22=%arg11, %arg23=%arg12, %arg24=%arg13, %arg25=%arg14, %arg26=%arg15, %arg27=%arg16) : index, index, index, index, memref<1024x1024xi32>, memref<1024x1024xi32>, memref<1024x1024xi32> { - %c1_3 = arith.constant 1 : index - %c1024 = arith.constant 1024 : index - %c128 = arith.constant 128 : index - %c32 = arith.constant 32 : index - %c512 = arith.constant 512 : index - %c0 = arith.constant 0 : index - %c64 = arith.constant 64 : index + func.func @gemm_k_dim_tiling(%arg0: memref<2048x2048xi32>) { + %c32 = arith.constant 32 : index + %0 = air.launch async (%arg3, %arg4) in (%arg5=%c32, %arg6=%c32) args(%arg8=%arg0) : memref<2048x2048xi32> { + %1 = air.segment @segment_0 async args(%arg10=%arg3, %arg11=%arg4, %arg13=%arg8) : index, index, memref<2048x2048xi32> { + %c1 = arith.constant 1 : index + %c2048 = arith.constant 2048 : index + %c64 = arith.constant 64 : index + %c2 = arith.constant 2 : index + %async_token_1, %results_2 = air.execute -> (index) { + %4 = affine.apply #map()[%arg10] + air.execute_terminator %4 : index + } + %async_token_3, %results_4 = air.execute -> (index) { + %4 = affine.apply #map()[%arg11] + air.execute_terminator %4 : index + } + %2 = air.herd @herd_0 async tile (%arg15, %arg16) in (%arg17=%c2, %arg18=%c2) args(%arg20=%arg13, %arg21=%results_2, %arg23=%results_4) : memref<2048x2048xi32>, index, index attributes {id = 1 : i32} { + %c1_8 = arith.constant 1 : index + %c64_9 = arith.constant 64 : index %c0_i32 = arith.constant 0 : i32 - %0 = affine.apply #map()[%arg17] - %1 = affine.apply #map()[%arg18] - scf.for %arg28 = %c0 to %c512 step %c64 { - scf.for %arg29 = %c0 to %c512 step %c64 { - %2 = arith.addi %0, %arg28 : index - %3 = arith.addi %1, %arg29 : index - %alloc_4 = memref.alloc() : memref<64x64xi32, 1> - scf.for %arg30 = %c0 to %c64 step %c32 { - scf.for %arg31 = %c0 to %c64 step %c32 { - %4 = arith.addi %2, %arg30 : index - %5 = arith.addi %3, %arg31 : index - %alloc_5 = memref.alloc() : memref<32x32xi32, 2> - scf.for %arg32 = %c0 to %c1024 step %c128 { - %alloc_6 = memref.alloc() : memref<32x128xi32, 1> - air.dma_memcpy_nd (%alloc_6[] [] [], %arg25[%4, %arg32] [%c32, %c128] [%c1024, %c1_3]) {id = 1 : i32} : (memref<32x128xi32, 1>, memref<1024x1024xi32>) - scf.for %arg33 = %c0 to %c128 step %c32 { - %alloc_8 = memref.alloc() : memref<32x32xi32, 2> - air.dma_memcpy_nd (%alloc_8[] [] [], %alloc_6[%c0, %arg33] [%c32, %c32] [%c128, %c1_3]) {id = 3 : i32} : (memref<32x32xi32, 2>, memref<32x128xi32, 1>) - memref.dealloc %alloc_8 : memref<32x32xi32, 2> - } - memref.dealloc %alloc_6 : memref<32x128xi32, 1> - } - air.dma_memcpy_nd (%alloc_4[%arg30, %arg31] [%c32, %c32] [%c64, %c1_3], %alloc_5[] [] []) {id = 5 : i32} : (memref<64x64xi32, 1>, memref<32x32xi32, 2>) - memref.dealloc %alloc_5 : memref<32x32xi32, 2> - } - } - air.dma_memcpy_nd (%arg27[%2, %3] [%c64, %c64] [%c1024, %c1_3], %alloc_4[] [] []) {id = 6 : i32} : (memref<1024x1024xi32>, memref<64x64xi32, 1>) - memref.dealloc %alloc_4 : memref<64x64xi32, 1> + %c0 = arith.constant 0 : index + %c256 = arith.constant 256 : index + %c32_10 = arith.constant 32 : index + %c2048_11 = arith.constant 2048 : index + %4 = air.wait_all async + %5 = scf.for %arg24 = %c0 to %c2048_11 step %c256 iter_args(%arg25 = %4) -> (!air.async.token) { + %async_token_20, %results_21 = air.execute -> (memref<64x256xi32, 1>) { + %alloc = memref.alloc() : memref<64x256xi32, 1> + air.execute_terminator %alloc : memref<64x256xi32, 1> + } + %7 = air.dma_memcpy_nd async [%arg25, %async_token_20] (%results_21[] [] [], %arg20[%arg21, %arg24] [%c64_9, %c256] [%c2048_11, %c1_8]) {id = 1 : i32} : (memref<64x256xi32, 1>, memref<2048x2048xi32>) + %async_token_24 = air.execute [%7] { + memref.dealloc %results_21 : memref<64x256xi32, 1> } + %11 = air.wait_all async [%7] + scf.yield %11 : !air.async.token } air.herd_terminator } @@ -287,7 +204,6 @@ module { } air.launch_terminator } - memref.copy %alloc, %arg2 : memref<1024x1024xi32> to memref<1024x1024xi32> return } } diff --git a/mlir/test/Transform/AIRDependencyScheduleOpt/segment_loop_fusion.mlir b/mlir/test/Transform/AIRDependencyScheduleOpt/segment_loop_fusion.mlir new file mode 100644 index 000000000..c59f579f6 --- /dev/null +++ b/mlir/test/Transform/AIRDependencyScheduleOpt/segment_loop_fusion.mlir @@ -0,0 +1,109 @@ +//===- segment_loop_fusion.mlir --------------------------------*- MLIR -*-===// +// +// Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. +// SPDX-License-Identifier: MIT +// +//===----------------------------------------------------------------------===// + +// RUN: air-opt -air-loop-fusion %s | FileCheck %s + +// Fuse scf.for loops in air.segment. + +// CHECK-LABEL: func.func @test +// CHECK: air.launch +// CHECK: air.segment +// CHECK: memref.alloc() +// CHECK: scf.for +// CHECK: memref.alloc() +// CHECK: memref.alloc() +// CHECK: %[[EVENT0:.*]] = air.channel.get{{.*}}@channel_4 +// CHECK: scf.for{{.*}}iter_args({{.*}} = %[[EVENT0]]) +// CHECK: %[[EVENT1:.*]] = air.channel.put{{.*}}@channel_1 +// CHECK: scf.for{{.*}}iter_args({{.*}} = %[[EVENT0]]) +// CHECK: %[[EVENT2:.*]] = air.channel.put{{.*}}@channel_0 +// CHECK: %[[EVENT3:.*]] = air.channel.get{{.*}}@channel_5 +// CHECK: scf.for{{.*}}iter_args({{.*}} = %[[EVENT3]]) +// CHECK: %[[EVENT4:.*]] = air.channel.put{{.*}}@channel_3 +// CHECK: scf.for{{.*}}iter_args({{.*}} = %[[EVENT3]]) +// CHECK: %[[EVENT5:.*]] = air.channel.put{{.*}}@channel_2 + +#map = affine_map<()[s0] -> (s0 * 64)> +#map1 = affine_map<()[s0] -> (s0 * 32)> +#set = affine_set<()[s0, s1] : (s0 == 0, s1 >= 0, -s1 + 1 >= 0)> +#set1 = affine_set<()[s0, s1] : (s0 >= 0, -s0 + 1 >= 0, s1 == 0)> +func.func @test() { + %c32 = arith.constant 32 : index + %0 = air.launch async (%arg3, %arg4) in (%arg5=%c32, %arg6=%c32) { + %6 = air.segment @segment_0 async { + %c32_9 = arith.constant 32 : index + %c256_10 = arith.constant 256 : index + %c0_11 = arith.constant 0 : index + %c1_12 = arith.constant 1 : index + %c2048_13 = arith.constant 2048 : index + %c64_14 = arith.constant 64 : index + %c2 = arith.constant 2 : index + %7 = air.wait_all async + %8 = air.wait_all async + %async_token_15, %results_16 = air.execute -> (memref<64x64xi32, 1>) { + %alloc = memref.alloc() : memref<64x64xi32, 1> + air.execute_terminator %alloc : memref<64x64xi32, 1> + } + %async_token_17, %results_18 = air.execute [%async_token_15] -> (memref<64x256xi32, 1>) { + %alloc = memref.alloc() : memref<64x256xi32, 1> + air.execute_terminator %alloc : memref<64x256xi32, 1> + } + %async_token_19, %results_20 = air.execute [%async_token_17] -> (memref<256x64xi32, 1>) { + %alloc = memref.alloc() : memref<256x64xi32, 1> + air.execute_terminator %alloc : memref<256x64xi32, 1> + } + %9 = scf.for %arg10 = %c0_11 to %c2048_13 step %c256_10 iter_args(%arg11 = %async_token_17) -> (!air.async.token) { + %18 = scf.for %arg12 = %c0_11 to %c256_10 step %c32_9 iter_args(%arg13 = %arg11) -> (!air.async.token) { + %19 = air.channel.put async [%arg13] @channel_0[] (%results_18[%c0_11, %arg12] [%c32_9, %c32_9] [%c256_10, %c1_12]) : (memref<64x256xi32, 1>) + scf.yield %19 : !air.async.token + } + scf.yield %18 : !air.async.token + } + %10 = scf.for %arg10 = %c0_11 to %c2048_13 step %c256_10 iter_args(%arg11 = %async_token_17) -> (!air.async.token) { + %18 = scf.for %arg12 = %c0_11 to %c256_10 step %c32_9 iter_args(%arg13 = %arg11) -> (!air.async.token) { + %19 = air.channel.put async [%arg13] @channel_1[] (%results_18[%c32_9, %arg12] [%c32_9, %c32_9] [%c256_10, %c1_12]) : (memref<64x256xi32, 1>) + scf.yield %19 : !air.async.token + } + scf.yield %18 : !air.async.token + } + %11 = scf.for %arg10 = %c0_11 to %c2048_13 step %c256_10 iter_args(%arg11 = %async_token_19) -> (!air.async.token) { + %18 = scf.for %arg12 = %c0_11 to %c256_10 step %c32_9 iter_args(%arg13 = %arg11) -> (!air.async.token) { + %19 = air.channel.put async [%arg13] @channel_2[] (%results_20[%arg12, %c0_11] [%c32_9, %c32_9] [%c64_14, %c1_12]) : (memref<256x64xi32, 1>) + scf.yield %19 : !air.async.token + } + scf.yield %18 : !air.async.token + } + %12 = scf.for %arg10 = %c0_11 to %c2048_13 step %c256_10 iter_args(%arg11 = %async_token_19) -> (!air.async.token) { + %18 = scf.for %arg12 = %c0_11 to %c256_10 step %c32_9 iter_args(%arg13 = %arg11) -> (!air.async.token) { + %19 = air.channel.put async [%arg13] @channel_3[] (%results_20[%arg12, %c32_9] [%c32_9, %c32_9] [%c64_14, %c1_12]) : (memref<256x64xi32, 1>) + scf.yield %19 : !air.async.token + } + scf.yield %18 : !air.async.token + } + %13 = scf.for %arg10 = %c0_11 to %c2048_13 step %c256_10 iter_args(%arg11 = %async_token_17) -> (!air.async.token) { + %18 = air.channel.get async [%arg11, %7] @channel_4[] (%results_18[] [] []) : (memref<64x256xi32, 1>) + scf.yield %18 : !air.async.token + } + %14 = scf.for %arg10 = %c0_11 to %c2048_13 step %c256_10 iter_args(%arg11 = %async_token_19) -> (!air.async.token) { + %18 = air.channel.get async [%arg11, %8] @channel_5[] (%results_20[] [] []) : (memref<256x64xi32, 1>) + scf.yield %18 : !air.async.token + } + %async_token_21 = air.execute [%async_token_19] { + memref.dealloc %results_20 : memref<256x64xi32, 1> + } + %async_token_22 = air.execute [%async_token_17] { + memref.dealloc %results_18 : memref<64x256xi32, 1> + } + %async_token_23 = air.execute [%7, %8] { + memref.dealloc %results_16 : memref<64x64xi32, 1> + } + air.segment_terminator + } + air.launch_terminator + } + return +}