From 3c1c2564360dad715b6809eb7e4bf3d22eb34dd9 Mon Sep 17 00:00:00 2001 From: erwei-xilinx Date: Fri, 23 Feb 2024 17:21:11 -0800 Subject: [PATCH] Memref shrinkage by analyzing the air.channel data access pattern (#451) * After -air-loop-fusion, check for redundant memref allocation by analyzing the new data access pattern * Update test after wrap-and-stride canonicalizer becomes more conservative * Memref shrinkage unit test --- mlir/lib/Conversion/ConvertToAIRPass.cpp | 20 +- .../Transform/AIRDependencyScheduleOpt.cpp | 302 +++++++++++++++++- mlir/lib/Util/Util.cpp | 17 +- .../condense_memref_ops_to_air_memcpy.mlir | 5 +- .../segment_loop_fusion.mlir | 78 ++++- 5 files changed, 390 insertions(+), 32 deletions(-) diff --git a/mlir/lib/Conversion/ConvertToAIRPass.cpp b/mlir/lib/Conversion/ConvertToAIRPass.cpp index d6174870a..9673c63ee 100644 --- a/mlir/lib/Conversion/ConvertToAIRPass.cpp +++ b/mlir/lib/Conversion/ConvertToAIRPass.cpp @@ -2433,14 +2433,15 @@ static LogicalResult canonicalizeAIRDmaOperands(OpBuilder builder, // default order. int max_dim_size = std::max(std::max(offsets.size(), sizes.size()), strides.size()); - if (max_dim_size && offsets.size() < (unsigned)memref.getRank()) { - for (unsigned i = offsets.size(); i < memref.getRank(); i++) { + int target_dim_size = std::max(max_dim_size, (int)memref.getRank()); + if (max_dim_size && offsets.size() < target_dim_size) { + for (unsigned i = offsets.size(); i < target_dim_size; i++) { offsets.insert(offsets.begin(), builder.create( builder.getUnknownLoc(), 0)); } } - if (max_dim_size && sizes.size() < (unsigned)memref.getRank()) { - for (unsigned i = sizes.size(); i < memref.getRank(); i++) { + if (max_dim_size && sizes.size() < target_dim_size) { + for (unsigned i = sizes.size(); i < target_dim_size; i++) { sizes.insert(sizes.begin(), builder.create( builder.getUnknownLoc(), 1)); } @@ -2448,8 +2449,8 @@ static LogicalResult canonicalizeAIRDmaOperands(OpBuilder builder, int memref_size = 1; for (auto size : memref.getShape()) memref_size *= size; - if (max_dim_size && strides.size() < (unsigned)memref.getRank()) { - for (unsigned i = strides.size(); i < memref.getRank(); i++) { + if (max_dim_size && strides.size() < target_dim_size) { + for (unsigned i = strides.size(); i < target_dim_size; i++) { strides.insert(strides.begin(), builder.create( builder.getUnknownLoc(), memref_size)); @@ -2457,13 +2458,12 @@ static LogicalResult canonicalizeAIRDmaOperands(OpBuilder builder, } // Reduce highest dimensions if more than memref size - while (strides.size() > (unsigned)memref.getRank() && - getConstantIntValue(strides[0]) && + while (strides.size() > target_dim_size && getConstantIntValue(strides[0]) && *getConstantIntValue(strides[0]) == memref_size) { strides.erase(strides.begin()); } - while (sizes.size() > (unsigned)memref.getRank() && - getConstantIntValue(sizes[0]) && *getConstantIntValue(sizes[0]) == 1) { + while (sizes.size() > target_dim_size && getConstantIntValue(sizes[0]) && + *getConstantIntValue(sizes[0]) == 1) { sizes.erase(sizes.begin()); } while (offsets.size() > std::min(sizes.size(), strides.size()) && diff --git a/mlir/lib/Transform/AIRDependencyScheduleOpt.cpp b/mlir/lib/Transform/AIRDependencyScheduleOpt.cpp index 4d2d2ba70..c8540bc14 100644 --- a/mlir/lib/Transform/AIRDependencyScheduleOpt.cpp +++ b/mlir/lib/Transform/AIRDependencyScheduleOpt.cpp @@ -3348,6 +3348,184 @@ class AIRIsolateAsyncDmaLoopNests } }; +// Get the memref size along a given dimension, that the access pattern actually +// covers. +SmallVector +getEffectiveMemrefSizeFromAccessPattern(SmallVector memref_shape, + SmallVector sizes, + SmallVector strides) { + SmallVector access_bounds(memref_shape.size(), -1); + for (int i = sizes.size() - 1; i >= 0; i--) { + int current_memref_volumn = 1; + for (int j = memref_shape.size() - 1; j >= 0; j--) { + current_memref_volumn *= memref_shape[j]; + if (mlir::floorDiv(*getConstantIntValue(strides[i]), + current_memref_volumn)) + continue; + int64_t bound = mlir::floorDiv(*getConstantIntValue(strides[i]), + current_memref_volumn / memref_shape[j]) * + *getConstantIntValue(sizes[i]); + access_bounds[j] = std::max(access_bounds[j], bound); + } + } + return access_bounds; +} + +// A pattern which attempts to shrink the memref sizes, based on the access +// patterns of all its uses. +struct ShrinkMemrefSizesByAccessPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(memref::AllocOp alloc, + PatternRewriter &rewriter) const override { + + // Get memref. + Value memref = alloc.getMemref(); + if (auto exec = dyn_cast(alloc->getParentOp())) + memref = exec->getResult(1); + + if (alloc->hasAttr("test")) + return failure(); + + // Get dealloc. + memref::DeallocOp dealloc; + // Get channel op users. + SmallVector gets; + SmallVector puts; + SmallVector chanOps; + for (auto user : memref.getUsers()) { + if (auto da = dyn_cast(user)) + dealloc = da; + else if (auto chanOp = dyn_cast(user)) + chanOps.push_back(chanOp); + else + return failure(); // NYI. + } + + // Analyze data access pattern. + auto memref_shape = getTensorShape(memref.getType()); + SmallVector overall_access_bounds(memref_shape.size(), -1); + for (auto chanOp : chanOps) { + SmallVector access_bounds(memref_shape.size(), -1); + if (chanOp.getOffsets().empty()) + for (unsigned i = 0; i < memref_shape.size(); i++) + access_bounds[i] = memref_shape[i]; + // for (auto oper : chanOp->getOperands()){ + bool forIterationAccess = false; + for (unsigned i = 0; i < chanOp.getOffsets().size(); i++) { + if (auto forOp = scf::getForInductionVarOwner(chanOp.getOffsets()[i])) { + if (forOp == chanOp->getParentOp() && + getStaticScfForTripCountAsInt(forOp)) { + access_bounds = getEffectiveMemrefSizeFromAccessPattern( + memref_shape, chanOp.getSizes(), chanOp.getStrides()); + forIterationAccess = true; + } + } + } + if (!forIterationAccess && + memref_shape.size() == chanOp.getSizes().size()) { + for (unsigned i = 0; i < memref_shape.size(); i++) { + access_bounds[i] = *getConstantIntValue(chanOp.getSizes()[i]); + } + } + // Update overall access bounds. + for (unsigned i = 0; i < memref_shape.size(); i++) + overall_access_bounds[i] = + std::max(overall_access_bounds[i], access_bounds[i]); + } + + bool shrinkMemref = false; + for (unsigned i = 0; i < memref_shape.size(); i++) { + if (overall_access_bounds[i] < 0) + return failure(); + if (overall_access_bounds[i] < memref_shape[i]) { + shrinkMemref = true; + } + } + if (shrinkMemref) { + // Start shrinking memref. + for (auto chanOp : chanOps) { + rewriter.setInsertionPoint(chanOp); + auto new_strides = getUpdatedStridesAfterShrinkage( + memref_shape, overall_access_bounds, chanOp.getStrides()); + int strideListIdxOffset = + dyn_cast(chanOp.getOperation()) + .getAsyncDependencies() + .size() + + 1 + chanOp.getOffsets().size() + chanOp.getSizes().size(); + for (unsigned i = strideListIdxOffset; + i < strideListIdxOffset + chanOp.getStrides().size(); i++) { + chanOp->getOpOperand(i).assign( + rewriter.create( + chanOp->getLoc(), new_strides[i - strideListIdxOffset])); + } + } + + // Replace memref alloc op; + Type elemType = memref.getType().cast().getElementType(); + Attribute memorySpace = + memref.getType().cast().getMemorySpace(); + auto newMemrefType = MemRefType::get(overall_access_bounds, elemType, + nullptr, memorySpace); + if (auto execOp = dyn_cast(alloc->getParentOp())) { + rewriter.setInsertionPoint(execOp); + auto newExecOp = rewriter.create( + execOp->getLoc(), air::AsyncTokenType::get(rewriter.getContext()), + newMemrefType, execOp.getAsyncDependencies()); + Block *async_exec_bb = rewriter.createBlock(&newExecOp.getBody()); + rewriter.setInsertionPointToStart(async_exec_bb); + auto newAlloc = + rewriter.create(alloc->getLoc(), newMemrefType); + rewriter.create(rewriter.getUnknownLoc(), + newAlloc.getResult()); + for (unsigned i = 0; i < execOp->getNumResults(); i++) + execOp->getResult(i).replaceAllUsesWith(newExecOp->getResult(i)); + rewriter.eraseOp(execOp); + + } else { + rewriter.setInsertionPoint(alloc); + auto newAlloc = + rewriter.create(alloc->getLoc(), newMemrefType); + newAlloc->setAttr("test", rewriter.getBoolAttr(true)); + alloc.getResult().replaceAllUsesWith(newAlloc.getResult()); + rewriter.eraseOp(alloc); + } + return success(); + } + + return failure(); + } + +private: + // Update strides after memref shrinkage. Assuming there is only dimension + // being shrunk. + SmallVector + getUpdatedStridesAfterShrinkage(SmallVector old_memref_shape, + SmallVector new_memref_shape, + SmallVector strides) const { + SmallVector new_strides(strides.size(), -1); + int shrinkage_volumn = 1; + int shrinkage_factor = 1; + for (int j = old_memref_shape.size() - 1; j >= 0; j--) { + shrinkage_volumn *= old_memref_shape[j]; + if (old_memref_shape[j] != new_memref_shape[j]) { + shrinkage_factor = + mlir::ceilDiv(old_memref_shape[j], new_memref_shape[j]); + break; + } + } + for (int i = strides.size() - 1; i >= 0; i--) { + if (mlir::floorDiv(*getConstantIntValue(strides[i]), shrinkage_volumn)) + new_strides[i] = + mlir::ceilDiv(*getConstantIntValue(strides[i]), shrinkage_factor); + else + new_strides[i] = *getConstantIntValue(strides[i]); + } + return new_strides; + } +}; + // A pass which performs loop fusion within air.segment op's region. class AIRSegmentLoopFusion : public xilinx::air::impl::AIRSegmentLoopFusionBase { @@ -3396,7 +3574,8 @@ class AIRSegmentLoopFusion return op_ptr != block->end() && &*op_ptr == &block->back(); }; for (auto forOp : op.getOps()) { - if (hasNElements(forOp.getBody(), 1)) + if (hasNElements(forOp.getBody(), 1) && + air::getStaticScfForTripCountAsInt(forOp)) perfectlyNestedForBands.push_back(forOp); } if (perfectlyNestedForBands.empty()) @@ -3411,10 +3590,22 @@ class AIRSegmentLoopFusion auto ub = perfectlyNestedForBands[0].getUpperBound(); auto step = perfectlyNestedForBands[0].getStep(); for (unsigned i = 1; i < perfectlyNestedForBands.size(); i++) { + int band_step_as_int = + *mlir::getConstantIntValue(perfectlyNestedForBands[i].getStep()); + int step_as_int = *mlir::getConstantIntValue(step); if (perfectlyNestedForBands[i].getLowerBound() == lb && perfectlyNestedForBands[i].getUpperBound() == ub && perfectlyNestedForBands[i].getStep() == step) { equalIterationForOps.push_back(perfectlyNestedForBands[i]); + } else if (perfectlyNestedForBands[i].getLowerBound() == lb && + perfectlyNestedForBands[i].getUpperBound() == ub && + mlir::mod(std::max(band_step_as_int, step_as_int), + std::min(band_step_as_int, step_as_int)) == 0) { + // If scf.for loops are not identical, but tilable to having identical + // roots. + if (simpleScfForLoopTiling(perfectlyNestedForBands[i], step_as_int, + band_step_as_int)) + equalIterationForOps.push_back(perfectlyNestedForBands[i]); } } if (equalIterationForOps.empty()) @@ -3502,26 +3693,133 @@ class AIRSegmentLoopFusion forOp->erase(); } + std::vector put_parents; + // Map from channel.put's scf.for op parents to dependent channel.get. + std::map put_get_mapping; new_loop_op.walk([&](air::ChannelPutOp putOp) { air::ChannelGetOp getOp = nullptr; for (auto user : putOp.getMemref().getUsers()) if (auto get_user = dyn_cast(user)) getOp = get_user; scf::ForOp put_parent = putOp->getParentOfType(); - put_parent->setOperand(3, getOp.getAsyncToken()); + put_get_mapping[put_parent] = getOp; + put_parents.push_back(put_parent); }); + for (auto put_parent : put_parents) { + auto getOp = put_get_mapping[put_parent]; + put_parent->moveAfter(getOp); + put_parent->setOperand(put_parent.getNumControlOperands(), + getOp.getAsyncToken()); + } + } + + void runPreProcPatterns(func::FuncOp funcOp) { + MLIRContext *ctx = funcOp.getContext(); + RewritePatternSet patterns(&getContext()); + patterns.insert(ctx); + (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); + } + + void runPostProcPatterns(func::FuncOp funcOp) { + MLIRContext *ctx = funcOp.getContext(); + RewritePatternSet patterns(&getContext()); + patterns.insert(ctx); + (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); } void runOnOperation() override { auto func = getOperation(); + runPreProcPatterns(func); SmallVector segs; func.walk([&](air::SegmentOp op) { segs.push_back(op); }); for (auto seg : segs) { runOnSegment(seg); } + runPostProcPatterns(func); } private: + // Scf.for loop tiling. This simple tiling implementation generates a new + // inner scf.for loop which starts from the original loop's lower bound. It + // may change the meaning of the original scf.for loop, therefore it requires + // a separate check to make sure that it is legal to tile this way. + scf::ForOp simpleScfForLoopTiling(scf::ForOp forOp, int original_step, + int tiled_step) const { + // Check if it is legal to tile the for loop this way, by checking the + // access pattern of all memrefs within the loop. + SmallVector channel_ops; + forOp.walk([&](air::ChannelInterface op) { channel_ops.push_back(op); }); + if (channel_ops.size() != 1) + return scf::ForOp(); // Expected to only have one channel op in loop body. + auto offsets = channel_ops[0].getOffsets(); + auto sizes = channel_ops[0].getSizes(); + auto strides = channel_ops[0].getStrides(); + int induction_var_dim = -1; + // Find memref type dimension which the for loop iterates on. + auto memref_shape = getTensorShape(channel_ops[0].getMemref().getType()); + for (unsigned i = 0; i < offsets.size(); i++) { + if (scf::getForInductionVarOwner(offsets[i]) == forOp) + induction_var_dim = i; + } + if (induction_var_dim == -1 || + induction_var_dim < offsets.size() - memref_shape.size()) + return scf::ForOp(); // NYI. + if (offsets.size() > memref_shape.size()) + induction_var_dim -= offsets.size() - memref_shape.size(); + int access_volumn = 1; + for (auto v : sizes) + access_volumn *= *getConstantIntValue(v); + if (offsets.empty() || + access_volumn == getTensorVolume(channel_ops[0].getMemref().getType())) + return scf::ForOp(); // May access the whole memref. + + int effective_access_size = getEffectiveMemrefSizeFromAccessPattern( + memref_shape, sizes, strides)[induction_var_dim]; + effective_access_size *= mlir::ceilDiv(original_step, tiled_step); + if (effective_access_size > original_step) + return scf::ForOp(); // Loop iteration access out of bound. + + // Tiling. + auto loc = forOp->getLoc(); + OpBuilder builder(forOp); + forOp.getStepMutable().assign( + builder.create(loc, original_step)); + builder.setInsertionPointToStart(forOp.getBody()); + auto new_for_op = builder.create( + loc, builder.create(loc, 0), + builder.create(loc, original_step), + builder.create(loc, tiled_step), + forOp.getRegionIterArgs()); + builder.setInsertionPointToStart(new_for_op.getBody()); + IRMapping remap; + for (unsigned j = 0; j < forOp.getNumRegionIterArgs(); j++) + remap.map(forOp.getRegionIterArgs()[j], + new_for_op.getRegionIterArgs()[j]); + remap.map(forOp.getInductionVar(), new_for_op.getInductionVar()); + SmallVector erased; + Value yielded_token = nullptr; + for (auto &o : forOp.getOps()) { + if (&o != new_for_op && &o != forOp.getBody()->getTerminator()) { + auto new_o = builder.clone(o, remap); + if (isAsyncOp(new_o)) { + yielded_token = new_o->getResult(0); + erased.push_back(&o); + } + } + } + if (!new_for_op.getBody()->mightHaveTerminator()) { + if (yielded_token) + builder.create(loc, yielded_token); + else + builder.create(loc); + } + for (auto o : erased) { + o->getResult(0).replaceAllUsesWith(new_for_op->getResult(0)); + o->erase(); + } + + return new_for_op; + } }; } // namespace diff --git a/mlir/lib/Util/Util.cpp b/mlir/lib/Util/Util.cpp index 4d65ce551..feb2bc41f 100644 --- a/mlir/lib/Util/Util.cpp +++ b/mlir/lib/Util/Util.cpp @@ -910,12 +910,17 @@ void air::foldForLoopNestAsExtendedSizesAndStrides( } // Index offset taking into account mismatch between memref rank and // offset list size difference. - ind_var_factor *= - getTensorShape(memref.getType()).size() < offsets.size() - ? getTensorVolume(memref.getType()) - : getTensorShape(memref.getType()) - [i + memref.getType().cast().getRank() - - offsets.size()]; + int memref_rank = getTensorShape(memref.getType()).size(); + if (memref_rank < offsets.size()) { + if (i < offsets.size() - memref_rank) + ind_var_factor *= getTensorVolume(memref.getType()); + else + ind_var_factor *= getTensorShape( + memref.getType())[i + memref_rank - offsets.size()]; + } else { + ind_var_factor *= + getTensorShape(memref.getType())[i + memref_rank - offsets.size()]; + } } int trip_count = -1; if (auto afo = dyn_cast(o)) diff --git a/mlir/test/Conversion/ConvertToAIR/condense_memref_ops_to_air_memcpy.mlir b/mlir/test/Conversion/ConvertToAIR/condense_memref_ops_to_air_memcpy.mlir index da50694d6..f0697960d 100644 --- a/mlir/test/Conversion/ConvertToAIR/condense_memref_ops_to_air_memcpy.mlir +++ b/mlir/test/Conversion/ConvertToAIR/condense_memref_ops_to_air_memcpy.mlir @@ -17,6 +17,7 @@ // CHECK: air.dma_memcpy_nd (%{{.*}}[] [] [], %{{.*}}[%[[CST0]], %{{.*}}] [%[[CST16]], %[[CST16]]] [%[[CST32]], %[[CST1]]]) : (memref<1x1x16x16xi32, 1>, memref<16x32xi32>) // CHECK: air.herd @herd_0 // CHECK: %[[CST32_0:.*]] = arith.constant 32 : index +// CHECK: %[[CST256_0:.*]] = arith.constant 256 : index // CHECK: %[[CST4_0:.*]] = arith.constant 4 : index // CHECK: %[[CST2_0:.*]] = arith.constant 2 : index // CHECK: %[[CST1_0:.*]] = arith.constant 1 : index @@ -25,8 +26,8 @@ // CHECK: %[[CST8_0:.*]] = arith.constant 8 : index // CHECK: %[[CST128_0:.*]] = arith.constant 128 : index // CHECK: %[[CST0_0:.*]] = arith.constant 0 : index -// CHECK: air.dma_memcpy_nd (%{{.*}}[] [] [], %{{.*}}[%{{.*}}, %[[CST0_0]], %[[CST0_0]], %[[CST0_0]]] [%[[CST2_0]], %[[CST2_0]], %[[CST4_0]], %[[CST8_0]]] [%[[CST8_0]], %[[CST64_0]], %[[CST16_0]], %[[CST1_0]]]) : (memref<1x1x2x2x4x8xi32, 2>, memref<1x1x8x16xi32, 1>) -// CHECK: air.dma_memcpy_nd (%{{.*}}[] [] [], %{{.*}}[%[[CST0_0]], %{{.*}}, %[[CST0_0]], %[[CST0_0]]] [%[[CST2_0]], %[[CST2_0]], %[[CST8_0]], %[[CST8_0]]] [%[[CST8_0]], %[[CST128_0]], %[[CST16_0]], %[[CST1_0]]]) : (memref<1x1x2x2x8x8xi32, 2>, memref<1x1x16x16xi32, 1>) +// CHECK: air.dma_memcpy_nd (%{{.*}}[] [] [], %{{.*}}[%[[CST0_0]], %[[CST0_0]], %{{.*}}, %[[CST0_0]], %[[CST0_0]], %[[CST0_0]]] [%[[CST1_0]], %[[CST1_0]], %[[CST2_0]], %[[CST2_0]], %[[CST4_0]], %[[CST8_0]]] [%[[CST128_0]], %[[CST128_0]], %[[CST8_0]], %[[CST64_0]], %[[CST16_0]], %[[CST1_0]]]) : (memref<1x1x2x2x4x8xi32, 2>, memref<1x1x8x16xi32, 1>) +// CHECK: air.dma_memcpy_nd (%{{.*}}[] [] [], %{{.*}}[%[[CST0_0]], %[[CST0_0]], %[[CST0_0]], %{{.*}}, %[[CST0_0]], %[[CST0_0]]] [%[[CST1_0]], %[[CST1_0]], %[[CST2_0]], %[[CST2_0]], %[[CST8_0]], %[[CST8_0]]] [%[[CST256_0]], %[[CST256_0]], %[[CST8_0]], %[[CST128_0]], %[[CST16_0]], %[[CST1_0]]]) : (memref<1x1x2x2x8x8xi32, 2>, memref<1x1x16x16xi32, 1>) // CHECK: air.dma_memcpy_nd (%{{.*}}[%{{.*}}, %{{.*}}, %[[CST0_0]], %[[CST0_0]]] [%[[CST1_0]], %[[CST1_0]], %[[CST8_0]], %[[CST16_0]]] [%[[CST128_0]], %[[CST128_0]], %[[CST16_0]], %[[CST1_0]]], %{{.*}}[%[[CST0_0]], %[[CST0_0]], %[[CST0_0]], %[[CST0_0]], %[[CST0_0]], %[[CST0_0]]] [%[[CST1_0]], %[[CST1_0]], %[[CST2_0]], %[[CST4_0]], %[[CST2_0]], %[[CST8_0]]] [%[[CST128_0]], %[[CST128_0]], %[[CST32_0]], %[[CST8_0]], %[[CST64_0]], %[[CST1_0]]]) : (memref<1x1x8x16xi32, 1>, memref<1x1x2x2x4x8xi32, 2>) // CHECK: air.herd_terminator // CHECK: air.dma_memcpy_nd (%{{.*}}[%{{.*}}, %{{.*}}] [%[[CST8]], %[[CST16]]] [%[[CST32]], %[[CST1]]], %{{.*}}[%[[CST0]], %[[CST0]], %[[CST0]], %[[CST0]]] [%[[CST1]], %[[CST1]], %[[CST8]], %[[CST16]]] [%[[CST128]], %[[CST128]], %[[CST16]], %[[CST1]]]) : (memref<8x32xi32>, memref<1x1x8x16xi32, 1>) diff --git a/mlir/test/Transform/AIRDependencyScheduleOpt/segment_loop_fusion.mlir b/mlir/test/Transform/AIRDependencyScheduleOpt/segment_loop_fusion.mlir index c59f579f6..7d9754efc 100644 --- a/mlir/test/Transform/AIRDependencyScheduleOpt/segment_loop_fusion.mlir +++ b/mlir/test/Transform/AIRDependencyScheduleOpt/segment_loop_fusion.mlir @@ -9,29 +9,25 @@ // Fuse scf.for loops in air.segment. -// CHECK-LABEL: func.func @test +// CHECK-LABEL: func.func @func0 // CHECK: air.launch // CHECK: air.segment // CHECK: memref.alloc() // CHECK: scf.for // CHECK: memref.alloc() // CHECK: memref.alloc() -// CHECK: %[[EVENT0:.*]] = air.channel.get{{.*}}@channel_4 +// CHECK: %[[EVENT0:.*]] = air.channel.get{{.*}}@channel_ // CHECK: scf.for{{.*}}iter_args({{.*}} = %[[EVENT0]]) -// CHECK: %[[EVENT1:.*]] = air.channel.put{{.*}}@channel_1 +// CHECK: %[[EVENT1:.*]] = air.channel.put{{.*}}@channel_ // CHECK: scf.for{{.*}}iter_args({{.*}} = %[[EVENT0]]) -// CHECK: %[[EVENT2:.*]] = air.channel.put{{.*}}@channel_0 -// CHECK: %[[EVENT3:.*]] = air.channel.get{{.*}}@channel_5 +// CHECK: %[[EVENT2:.*]] = air.channel.put{{.*}}@channel_ +// CHECK: %[[EVENT3:.*]] = air.channel.get{{.*}}@channel_ // CHECK: scf.for{{.*}}iter_args({{.*}} = %[[EVENT3]]) -// CHECK: %[[EVENT4:.*]] = air.channel.put{{.*}}@channel_3 +// CHECK: %[[EVENT4:.*]] = air.channel.put{{.*}}@channel_ // CHECK: scf.for{{.*}}iter_args({{.*}} = %[[EVENT3]]) -// CHECK: %[[EVENT5:.*]] = air.channel.put{{.*}}@channel_2 +// CHECK: %[[EVENT5:.*]] = air.channel.put{{.*}}@channel_ -#map = affine_map<()[s0] -> (s0 * 64)> -#map1 = affine_map<()[s0] -> (s0 * 32)> -#set = affine_set<()[s0, s1] : (s0 == 0, s1 >= 0, -s1 + 1 >= 0)> -#set1 = affine_set<()[s0, s1] : (s0 >= 0, -s0 + 1 >= 0, s1 == 0)> -func.func @test() { +func.func @func0() { %c32 = arith.constant 32 : index %0 = air.launch async (%arg3, %arg4) in (%arg5=%c32, %arg6=%c32) { %6 = air.segment @segment_0 async { @@ -107,3 +103,61 @@ func.func @test() { } return } + +// Memref shrinkage via data access pattern analysis. + +// CHECK-LABEL: func.func @func1 +// CHECK: air.launch +// CHECK: air.segment +// CHECK: scf.for +// CHECK: memref.alloc() : memref<64x256xi32, 1> +// CHECK: %[[EVENT0:.*]] = air.channel.get {{.*}} : (memref<64x256xi32, 1>) +// CHECK: scf.for{{.*}}iter_args({{.*}} = %[[EVENT0]]) +// CHECK: %[[EVENT1:.*]] = air.channel.put {{.*}} : (memref<64x256xi32, 1>) +// CHECK: scf.for{{.*}}iter_args({{.*}} = %[[EVENT0]]) +// CHECK: %[[EVENT2:.*]] = air.channel.put {{.*}} : (memref<64x256xi32, 1>) + +func.func @func1() { + %c32 = arith.constant 32 : index + %0 = air.launch async (%arg3, %arg4) in (%arg5=%c32, %arg6=%c32) { + %6 = air.segment @segment_0 async { + %c32_9 = arith.constant 32 : index + %c256_10 = arith.constant 256 : index + %c0_11 = arith.constant 0 : index + %c1_12 = arith.constant 1 : index + %c2048_13 = arith.constant 2048 : index + %c64_14 = arith.constant 64 : index + %c2 = arith.constant 2 : index + %7 = air.wait_all async + %8 = air.wait_all async + %async_token_17, %results_18 = air.execute -> (memref<64x2048xi32, 1>) { + %alloc = memref.alloc() : memref<64x2048xi32, 1> + air.execute_terminator %alloc : memref<64x2048xi32, 1> + } + %9 = scf.for %arg10 = %c0_11 to %c2048_13 step %c256_10 iter_args(%arg11 = %async_token_17) -> (!air.async.token) { + %18 = scf.for %arg12 = %c0_11 to %c256_10 step %c32_9 iter_args(%arg13 = %arg11) -> (!air.async.token) { + %19 = air.channel.put async [%arg13] @channel_0[] (%results_18[%c0_11, %arg12] [%c32_9, %c32_9] [%c2048_13, %c1_12]) : (memref<64x2048xi32, 1>) + scf.yield %19 : !air.async.token + } + scf.yield %18 : !air.async.token + } + %10 = scf.for %arg10 = %c0_11 to %c2048_13 step %c256_10 iter_args(%arg11 = %async_token_17) -> (!air.async.token) { + %18 = scf.for %arg12 = %c0_11 to %c256_10 step %c32_9 iter_args(%arg13 = %arg11) -> (!air.async.token) { + %19 = air.channel.put async [%arg13] @channel_1[] (%results_18[%c32_9, %arg12] [%c32_9, %c32_9] [%c2048_13, %c1_12]) : (memref<64x2048xi32, 1>) + scf.yield %19 : !air.async.token + } + scf.yield %18 : !air.async.token + } + %13 = scf.for %arg10 = %c0_11 to %c2048_13 step %c256_10 iter_args(%arg11 = %async_token_17) -> (!air.async.token) { + %18 = air.channel.get async [%arg11, %7] @channel_4[] (%results_18[%c0_11, %c0_11] [%c64_14, %c256_10] [%c2048_13, %c1_12]) : (memref<64x2048xi32, 1>) + scf.yield %18 : !air.async.token + } + %async_token_22 = air.execute [%async_token_17] { + memref.dealloc %results_18 : memref<64x2048xi32, 1> + } + air.segment_terminator + } + air.launch_terminator + } + return +}