Skip to content

Commit

Permalink
Extend -air-specialize-channel-wrap-and-stride for tile/memtile DMA r…
Browse files Browse the repository at this point in the history
…epeat count (Xilinx#456)

* Fixup an issue with folding for loop nests into channel ops with empty stride lists

* Fixup issue with stride value calculation; test

* Add support for tile / memtile dma repeat count feature

* Clang format
  • Loading branch information
erwei-xilinx authored Feb 26, 2024
1 parent 3c1c256 commit 7b39046
Show file tree
Hide file tree
Showing 6 changed files with 133 additions and 2 deletions.
2 changes: 2 additions & 0 deletions mlir/include/air/Conversion/AIRToAIESchedulingUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ bool areIdenticalVectors(std::vector<unsigned> &a, std::vector<unsigned> &b);
int64_t get1DOffset(SmallVector<Value> memcpy_offsets,
SmallVector<Value> memcpy_strides, int byte_count_per_elem);

int getRepeatCount(Operation *memcpy_op);

std::vector<AIE::BDDimLayoutAttr>
getWrapsAndStrides(SmallVector<Value> memcpy_sizes,
SmallVector<Value> memcpy_strides, MLIRContext *ctx);
Expand Down
19 changes: 17 additions & 2 deletions mlir/lib/Conversion/AIRToAIEPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2358,19 +2358,24 @@ class AIRToAIEPass : public air::impl::AIRToAIEBase<AIRToAIEPass> {
generateDmaBd<bufferOpTy>(loc, dir, locks, x, y, arch, bd, memcpyOp,
bufferOp);
}

int repeat_count = 1;
if (p.second.size() == 1)
repeat_count = air::getRepeatCount(p.second[0]);

if (!channel_head) {
channel_head = start_bb;
end_bb = new Block();
mem.getBody().push_back(end_bb);
auto b = OpBuilder::atBlockBegin(channel_head);
b.create<AIE::DMAStartOp>(loc, dir, chan, /*repeat*/ 1, first_bd,
b.create<AIE::DMAStartOp>(loc, dir, chan, repeat_count, first_bd,
end_bb);
b.setInsertionPointToEnd(end_bb);
b.create<AIE::EndOp>(loc);
} else {
auto b = OpBuilder::atBlockBegin(start_bb);
b.create<AIE::DMAStartOp>(
loc, dir, chan, /*repeat*/ 1, first_bd,
loc, dir, chan, repeat_count, first_bd,
channel_head->getTerminator()->getSuccessor(1));
channel_head->getTerminator()->setSuccessor(start_bb, 1);
}
Expand Down Expand Up @@ -2414,6 +2419,16 @@ class AIRToAIEPass : public air::impl::AIRToAIEBase<AIRToAIEPass> {
? ndcpy.getDstStrides()
: ndcpy.getSrcStrides();

// Skip over repeat pattern at highest dimension; repeat pattern handled at
// AIE::DMAStartOp.
if (!strides.empty() && !sizes.empty() && !offsets.empty())
if (auto const_highest_stride = getConstantIntValue(strides[0]))
if (*const_highest_stride == 0) {
strides.erase(strides.begin());
sizes.erase(sizes.begin());
offsets.erase(offsets.begin());
}

int64_t len = getMemcpySizesAsInt(memref, sizes);
int64_t offset =
get1DOffset(offsets, strides, getElementSizeInBytes(memref.getType()));
Expand Down
19 changes: 19 additions & 0 deletions mlir/lib/Conversion/AIRToAIESchedulingUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,25 @@ int64_t air::get1DOffset(SmallVector<Value> memcpy_offsets,
return one_d_offset * byte_count_per_elem;
}

// Get the repeat_count value from an air::ChannelPut/GetOp.
int air::getRepeatCount(Operation *memcpy_op) {
auto chan_op = dyn_cast<air::ChannelInterface>(memcpy_op);
if (!chan_op)
return 1;
if (chan_op.getStrides().empty() || chan_op.getSizes().empty())
return 1;
if (getConstantIntValue(chan_op.getStrides()[0]) &&
getConstantIntValue(chan_op.getSizes()[0])) {
auto const_highest_stride = getConstantIntValue(chan_op.getStrides()[0]);
auto const_highest_size = getConstantIntValue(chan_op.getSizes()[0]);
if (*const_highest_stride == 0) {
// Highest dimension data access pattern is repeat.
return *const_highest_size;
}
}
return 1;
}

std::vector<AIE::BDDimLayoutAttr>
air::getWrapsAndStrides(SmallVector<Value> memcpy_sizes,
SmallVector<Value> memcpy_strides, MLIRContext *ctx) {
Expand Down
15 changes: 15 additions & 0 deletions mlir/lib/Transform/AIRDependencyScheduleOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1698,6 +1698,21 @@ struct AIRSpecializeChannelWrapAndStrideInScfFor
SmallVector<Value> offsets = channel_ops[0].getOffsets();
SmallVector<Value> wraps = channel_ops[0].getSizes();
SmallVector<Value> strides = channel_ops[0].getStrides();
// If empty offsets/sizes/strides, then populate the lists with default
// values.
if (offsets.empty() && wraps.empty() && strides.empty()) {
auto memref_shape = getTensorShape(channel_ops[0].getMemref().getType());
int current_stride =
getTensorVolume(channel_ops[0].getMemref().getType());
for (unsigned i = 0; i < memref_shape.size(); i++) {
offsets.push_back(rewriter.create<arith::ConstantIndexOp>(loc, 0));
wraps.push_back(
rewriter.create<arith::ConstantIndexOp>(loc, memref_shape[i]));
current_stride /= memref_shape[i];
strides.push_back(
rewriter.create<arith::ConstantIndexOp>(loc, current_stride));
}
}
for (auto o : for_loops) {
// Check for perfect loop nest containing only air.channel ops
if (!hasNElements(o.getBody(), 1))
Expand Down
65 changes: 65 additions & 0 deletions mlir/test/Conversion/AIRToAIE/air_shimcpy_to_aie2.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -635,3 +635,68 @@ func.func @func9(%arg0: memref<128xf32>, %arg1: memref<128xf32>) {
return
}

// -----

// Tile / memtile DMA repeat count support.
// CHECK: aie.device(xcve2802)
// CHECK: %[[tileDMA_0_4:.*]] = aie.mem
// CHECK: aie.dma_start(S2MM, 0, ^bb1, ^bb2, repeat_count = 32)
// CHECK: aie.dma_bd({{.*}} : memref<32x256xi32, 2>, 0, 8192)
// CHECK: %[[tileDMA_0_3:.*]] = aie.mem
// CHECK: aie.dma_start(S2MM, 0, ^bb1, ^bb2, repeat_count = 32)
// CHECK: aie.dma_bd({{.*}} : memref<32x256xi32, 2>, 0, 8192)
// CHECK: %[[memTileDMA_2_1:.*]] = aie.memtile_dma
// CHECK: aie.dma_start(MM2S, 0, ^bb1, ^bb3, repeat_count = 32)
// CHECK: aie.dma_bd({{.*}} : memref<32x256xi32, 1>, 0, 8192)
// CHECK: aie.dma_start(MM2S, 1, ^bb4, ^bb2, repeat_count = 32)
// CHECK: aie.dma_bd({{.*}} : memref<32x256xi32, 1>, 0, 8192)

#map = affine_map<()[s0] -> (s0 * 32)>
air.channel @channel_1 [2, 1]
func.func @func10(%arg0: memref<128xf32>, %arg1: memref<128xf32>) {
%c2 = arith.constant 2 : index
%0 = air.launch async (%arg2) in (%arg3=%c2) attributes {id = 1 : i32} {
%1 = air.segment @segment_0 async attributes {id = 2 : i32, x_loc = 0 : i64, x_size = 1 : i64, y_loc = 3 : i64, y_size = 2 : i64} {
%c256 = arith.constant 256 : index
%c32 = arith.constant 32 : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2_0 = arith.constant 2 : index
%async_token, %results = air.execute -> (memref<32x256xi32, 1>) {
%alloc = memref.alloc() : memref<32x256xi32, 1>
air.execute_terminator %alloc : memref<32x256xi32, 1>
}
%2 = scf.parallel (%arg4) = (%c0) to (%c2_0) step (%c1) init (%async_token) -> !air.async.token {
%4 = air.channel.put async [%async_token] @channel_1[%arg4, %c0] (%results[%c0, %c0, %c0] [%c32, %c32, %c256] [%c0, %c256, %c1]) {id = 4 : i32} : (memref<32x256xi32, 1>)
scf.reduce(%4 : !air.async.token) {
^bb0(%arg5: !air.async.token, %arg6: !air.async.token):
%5 = air.wait_all async [%arg5, %arg6]
scf.reduce.return %5 : !air.async.token
}
}
%3 = air.herd @herd_0 async [%async_token] tile (%arg4, %arg5) in (%arg6=%c1, %arg7=%c2_0) attributes {id = 3 : i32, x_loc = 0 : i64, y_loc = 3 : i64} {
%c0_2 = arith.constant 0 : index
%c1_4 = arith.constant 1 : index
%c32_3 = arith.constant 32 : index
%c256_5 = arith.constant 256 : index
%4 = air.wait_all async
%async_token_3, %results_4 = air.execute -> (memref<32x256xi32, 2>) {
%alloc = memref.alloc() : memref<32x256xi32, 2>
air.execute_terminator %alloc : memref<32x256xi32, 2>
}
%5 = air.channel.get async [%4, %async_token_3] @channel_1[%arg5, %c0_2] (%results_4[%c0_2, %c0_2, %c0_2] [%c32_3, %c32_3, %c256_5] [%c0_2, %c256_5, %c1_4]) {id = 6 : i32} : (memref<32x256xi32, 2>)
%async_token_5 = air.execute [%5] {
memref.dealloc %results_4 : memref<32x256xi32, 2>
}
air.herd_terminator
}
%async_token_1 = air.execute [%3] {
memref.dealloc %results : memref<32x256xi32, 1>
}
air.segment_terminator
}
air.launch_terminator
}
return
}

Original file line number Diff line number Diff line change
Expand Up @@ -194,4 +194,19 @@ module {
%2 = air.wait_all async [%0, %1]
return %alloc : memref<128xf32>
}

// CHECK-LABEL: test5
// CHECK: put async @channel_17[] (%arg0[%c0, %c0, %c0] [%c8, %c32, %c32] [%c0, %c32, %c1]) : (memref<32x32xf32>)

func.func @test5(%arg0: memref<32x32xf32>) -> memref<32x32xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c8 = arith.constant 8 : index
%alloc = memref.alloc() : memref<32x32xf32>
scf.for %arg2 = %c0 to %c8 step %c1 {
%0 = affine.apply #map()[%arg2]
%1 = air.channel.put async @channel_17[] (%arg0[] [] []) : (memref<32x32xf32>)
}
return %alloc : memref<32x32xf32>
}
}

0 comments on commit 7b39046

Please sign in to comment.