Skip to content

Commit

Permalink
AIRRtToNpuPass SHIM DMA BD optimization (Xilinx#550)
Browse files Browse the repository at this point in the history
* When tiling wrap>1023, swap inner and outer wrap so that it is less likely to get stride>1M

* Code quality; improve performance by avoiding use of 'getSymbolUses' method

* faster

---------

Co-authored-by: James Newling <james.newling@gmail.com>
  • Loading branch information
erwei-xilinx and newling authored Apr 25, 2024
1 parent 9445e5c commit 4faaa09
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 85 deletions.
191 changes: 119 additions & 72 deletions mlir/lib/Conversion/AIRRtToNpuPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -578,15 +578,15 @@ void tileIllegalWrapDim(airrt::DmaMemcpyNdOp memcpy_op) {
auto const_stride = *getConstantIntValue(strides[i]);
if (const_wrap >= AIE2_WRAP_UPPER_BOUND) {
// Found dimension with illegal wrap. Tiling.
int inner_wrap = findLargestFactor(const_wrap, AIE2_WRAP_UPPER_BOUND - 1);
int new_wrap = mlir::ceilDiv(const_wrap, inner_wrap);
int outer_wrap = findLargestFactor(const_wrap, AIE2_WRAP_UPPER_BOUND - 1);
int inner_wrap = mlir::ceilDiv(const_wrap, outer_wrap);
wraps[i] = builder.create<arith::ConstantOp>(
loc, builder.getI64Type(),
IntegerAttr::get(builder.getI64Type(), inner_wrap));
wraps.insert(wraps.begin() + i,
builder.create<arith::ConstantOp>(
loc, builder.getI64Type(),
IntegerAttr::get(builder.getI64Type(), new_wrap)));
IntegerAttr::get(builder.getI64Type(), outer_wrap)));
auto new_const_stride =
(const_stride * inner_wrap) %
air::getTensorVolume(
Expand Down Expand Up @@ -1130,56 +1130,71 @@ struct AIRRtToNpuPass : public impl::AIRRtToNpuBase<AIRRtToNpuPass> {
}

std::optional<AIE::ShimDMAAllocationOp>
getAllocOpForSymbol(AIE::DeviceOp dev, StringRef sym_name) {
auto sym = dev.lookupSymbol(sym_name);
if (!sym)
return std::nullopt;

auto uses = SymbolTable::getSymbolUses(sym, dev);
for (auto use : *uses)
if (auto infoOp = dyn_cast<AIE::ShimDMAAllocationOp>(use.getUser()))
return infoOp;

getAllocOpForSymbol(SmallVector<AIE::ShimDMAAllocationOp> shimDmaAllocOps,
StringRef sym_name) {
for (auto shimDmaAllocOp : shimDmaAllocOps)
if (shimDmaAllocOp.getSymName() == sym_name)
return shimDmaAllocOp;
return std::nullopt;
}

std::optional<AIE::ObjectFifoCreateOp>
getObjectFifoCreateOpForSymbol(AIE::DeviceOp dev, StringRef sym_name) {
auto sym = dev.lookupSymbol(sym_name);
if (!sym)
return std::nullopt;

for (auto objFifoCreateOp : dev.getOps<AIE::ObjectFifoCreateOp>()) {
if (objFifoCreateOp.getSymName().str() == sym_name.str())
return objFifoCreateOp;
}

std::optional<AIE::ObjectFifoCreateOp> getObjectFifoCreateOpForSymbol(
SmallVector<AIE::ObjectFifoCreateOp> objectFifoCreateOps,
StringRef sym_name) {
for (auto objectFifoCreateOp : objectFifoCreateOps)
if (objectFifoCreateOp.getSymName().str() == sym_name.str())
return objectFifoCreateOp;
return std::nullopt;
}

void insertNpuSyncOpForResults(ModuleOp module) {
module.walk([&](mlir::func::FuncOp f) {
SmallVector<mlir::func::FuncOp> funcOps;
module.walk([&](mlir::func::FuncOp f) { funcOps.push_back(f); });
for (auto f : funcOps) {
SmallVector<AIEX::NpuDmaMemcpyNdOp> dmas;
f.walk([&](AIEX::NpuDmaMemcpyNdOp dma) { dmas.push_back(dma); });
auto d = f->getParentOfType<AIE::DeviceOp>();

SmallVector<AIE::ShimDMAAllocationOp> shimDmaAllocOps;
if (d)
d.walk([&](AIE::ShimDMAAllocationOp shimDmaAllocOp) {
shimDmaAllocOps.push_back(shimDmaAllocOp);
});
// Performance optimization: instead of repeating calls to
// getAllocOpForSymbol with the same symbol name, cache the result of the
// first call and use the cache for subsequent calls. This dramatically
// improves compile time for some designs.
llvm::DenseMap<StringRef, std::optional<AIE::ShimDMAAllocationOp>>
allocationCache;
auto getAllocOpForSymbolWithCaching = [&](StringRef sym_name) {
auto iter = allocationCache.find(sym_name);
if (iter != allocationCache.end()) {
return iter->second;
}
auto infaOp = getAllocOpForSymbol(shimDmaAllocOps, sym_name);
allocationCache.insert({sym_name, infaOp});
return infaOp;
};

if (!d)
return;
continue;
OpBuilder builder(f);
for (auto dma : dmas) {
if (auto infoOp = getAllocOpForSymbol(d, dma.getMetadata())) {
if (infoOp->getChannelDir() == AIE::DMAChannelDir::S2MM) {
// Found dma op copying results to host
OpBuilder builder(dma);
auto col = builder.getI32IntegerAttr(infoOp->getCol());
auto row = builder.getI32IntegerAttr(0);
auto dir = builder.getI32IntegerAttr(0);
auto chan = builder.getI32IntegerAttr(infoOp->getChannelIndex());
auto col_num = builder.getI32IntegerAttr(1);
auto row_num = builder.getI32IntegerAttr(1);
builder.setInsertionPointAfter(dma);
builder.create<AIEX::NpuSyncOp>(dma->getLoc(), col, row, dir, chan,
col_num, row_num);
}
}
auto infoOp = getAllocOpForSymbolWithCaching(dma.getMetadata());
if (!infoOp)
continue;
if (infoOp->getChannelDir() != AIE::DMAChannelDir::S2MM)
continue;
// Found dma op copying results to host
auto col = builder.getI32IntegerAttr(infoOp->getCol());
auto row = builder.getI32IntegerAttr(0);
auto dir = builder.getI32IntegerAttr(0);
auto chan = builder.getI32IntegerAttr(infoOp->getChannelIndex());
auto col_num = builder.getI32IntegerAttr(1);
auto row_num = builder.getI32IntegerAttr(1);
builder.setInsertionPointAfter(dma);
builder.create<AIEX::NpuSyncOp>(dma->getLoc(), col, row, dir, chan,
col_num, row_num);
}

// Attempt to make npu.sync ops contiguous if they are not operating on
Expand All @@ -1189,54 +1204,86 @@ struct AIRRtToNpuPass : public impl::AIRRtToNpuBase<AIRRtToNpuPass> {
if (auto sync = dyn_cast<AIEX::NpuSyncOp>(op))
previsouSyncs.push_back(sync);
else if (auto dma = dyn_cast<AIEX::NpuDmaMemcpyNdOp>(op)) {
auto infoOp = getAllocOpForSymbol(d, dma.getMetadata());
if (infoOp && infoOp->getChannelDir() == AIE::DMAChannelDir::S2MM &&
!previsouSyncs.empty()) {
auto infoOp = getAllocOpForSymbolWithCaching(dma.getMetadata());
if (!infoOp)
return;
if (previsouSyncs.empty())
return;
if (infoOp->getChannelDir() == AIE::DMAChannelDir::S2MM) {
for (auto prevSync : previsouSyncs)
prevSync->moveAfter(op);
} else if (infoOp &&
infoOp->getChannelDir() == AIE::DMAChannelDir::MM2S &&
!previsouSyncs.empty()) {
} else if (infoOp->getChannelDir() == AIE::DMAChannelDir::MM2S) {
previsouSyncs.clear();
}
}
});
});
}
}

// Renumber aiex.npu.dma_memcpy_nd ops per column of AIEs.
void renumberNpuDmaOps(Block *blk) {
std::map<int, int> chanToIdMap;
AIE::DeviceOp d = nullptr;
blk->walk([&](AIE::DeviceOp op) { d = op; });
SmallVector<AIE::ShimDMAAllocationOp> shimDmaAllocOps;
if (d)
d.walk([&](AIE::ShimDMAAllocationOp shimDmaAllocOp) {
shimDmaAllocOps.push_back(shimDmaAllocOp);
});
// Performance optimization: instead of repeating calls to
// getAllocOpForSymbol with the same symbol name, cache the result of the
// first call and use the cache for subsequent calls. This dramatically
// improves compile time for some designs.
llvm::DenseMap<StringRef, std::optional<AIE::ShimDMAAllocationOp>>
allocationCache;
auto getAllocOpForSymbolWithCaching = [&](StringRef sym_name) {
auto iter = allocationCache.find(sym_name);
if (iter != allocationCache.end()) {
return iter->second;
}
auto infaOp = getAllocOpForSymbol(shimDmaAllocOps, sym_name);
allocationCache.insert({sym_name, infaOp});
return infaOp;
};
SmallVector<AIE::ObjectFifoCreateOp> objectFifoCreateOps;
if (d)
d.walk([&](AIE::ObjectFifoCreateOp objectFifoCreateOp) {
objectFifoCreateOps.push_back(objectFifoCreateOp);
});
OpBuilder builder(blk->getParentOp());
blk->walk([&](Operation *op) {
if (auto dma = dyn_cast<AIEX::NpuDmaMemcpyNdOp>(op)) {
OpBuilder builder(dma);
int col = -1;
if (d) {
if (auto infoOp = getAllocOpForSymbol(d, dma.getMetadata())) {
col = infoOp->getCol();
} else if (auto objFifoCreateOp =
getObjectFifoCreateOpForSymbol(d, dma.getMetadata())) {
auto prodTileOp =
objFifoCreateOp->getProducerTile().getDefiningOp<AIE::TileOp>();
if (prodTileOp.isShimTile())
col = prodTileOp.colIndex();
for (auto consumerTileOp : objFifoCreateOp->getConsumerTiles()) {
auto consTileOp = consumerTileOp.getDefiningOp<AIE::TileOp>();
if (consTileOp.isShimTile()) {
col = consTileOp.colIndex();
}
auto dma = dyn_cast<AIEX::NpuDmaMemcpyNdOp>(op);
auto sync = dyn_cast<AIEX::NpuSyncOp>(op);
if (sync) {
chanToIdMap.clear();
return;
}
if (!dma)
return;
builder.setInsertionPoint(dma);
int col = -1;
if (d) {
if (auto infoOp = getAllocOpForSymbolWithCaching(dma.getMetadata())) {
col = infoOp->getCol();
} else if (auto objFifoCreateOp = getObjectFifoCreateOpForSymbol(
objectFifoCreateOps, dma.getMetadata())) {
auto prodTileOp =
objFifoCreateOp->getProducerTile().getDefiningOp<AIE::TileOp>();
if (prodTileOp.isShimTile())
col = prodTileOp.colIndex();
for (auto consumerTileOp : objFifoCreateOp->getConsumerTiles()) {
auto consTileOp = consumerTileOp.getDefiningOp<AIE::TileOp>();
if (consTileOp.isShimTile()) {
col = consTileOp.colIndex();
}
}
}
if (!chanToIdMap.count(col))
chanToIdMap[col] = 0;
dma->setAttr("id", mlir::IntegerAttr::get(
mlir::IntegerType::get(dma->getContext(), 64),
chanToIdMap[col]++));
} else if (isa<AIEX::NpuSyncOp>(op))
chanToIdMap.clear();
}
if (!chanToIdMap.count(col))
chanToIdMap[col] = 0;
dma->setAttr("id", mlir::IntegerAttr::get(
mlir::IntegerType::get(dma->getContext(), 64),
chanToIdMap[col]++));
});
}

Expand Down
18 changes: 9 additions & 9 deletions mlir/test/Conversion/AIRRtToNpu/airrt_to_npu.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -455,10 +455,10 @@ module {
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[ARG0]][0, 0, 64, 0][4, 8, 64, 256][0, 256, 2048]) {id = 1 : i64, metadata = @airMemcpyId20} : memref<2048x2048xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[ARG0]][0, 0, 128, 0][4, 8, 64, 256][0, 256, 2048]) {id = 2 : i64, metadata = @airMemcpyId20} : memref<2048x2048xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[ARG0]][0, 0, 192, 0][4, 8, 64, 256][0, 256, 2048]) {id = 3 : i64, metadata = @airMemcpyId20} : memref<2048x2048xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[ARG1]][0, 0, 0, 0][4, 4, 512, 64][64, 1048576, 2048]) {id = 4 : i64, metadata = @airMemcpyId21} : memref<2048x2048xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[ARG1]][0, 0, 0, 0][4, 4, 512, 64][64, 1048576, 2048]) {id = 5 : i64, metadata = @airMemcpyId21} : memref<2048x2048xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[ARG1]][0, 0, 0, 0][4, 4, 512, 64][64, 1048576, 2048]) {id = 6 : i64, metadata = @airMemcpyId21} : memref<2048x2048xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[ARG1]][0, 0, 0, 0][4, 4, 512, 64][64, 1048576, 2048]) {id = 7 : i64, metadata = @airMemcpyId21} : memref<2048x2048xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[ARG1]][0, 0, 0, 0][4, 512, 4, 64][64, 8192, 2048]) {id = 4 : i64, metadata = @airMemcpyId21} : memref<2048x2048xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[ARG1]][0, 0, 0, 0][4, 512, 4, 64][64, 8192, 2048]) {id = 5 : i64, metadata = @airMemcpyId21} : memref<2048x2048xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[ARG1]][0, 0, 0, 0][4, 512, 4, 64][64, 8192, 2048]) {id = 6 : i64, metadata = @airMemcpyId21} : memref<2048x2048xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[ARG1]][0, 0, 0, 0][4, 512, 4, 64][64, 8192, 2048]) {id = 7 : i64, metadata = @airMemcpyId21} : memref<2048x2048xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[ARG2]][0, 0, 0, 0][4, 4, 64, 64][131072, 64, 2048]) {id = 8 : i64, metadata = @airMemcpyId26} : memref<2048x2048xi32>

#map = affine_map<()[s0] -> (s0 * 64)>
Expand Down Expand Up @@ -521,9 +521,9 @@ module {

// CHECK-LABEL: aie.device(npu)
// CHECK: func.func @func10(%[[ARG0:.*]]: memref<2654208xi32>)
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[ARG0]][0, 0, 0, 0][3, 3, 768, 32][128, 884736, 1152]) {id = 0 : i64, metadata = @airMemcpyId21} : memref<2654208xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[ARG0]][0, 0, 0, 0][3, 3, 768, 32][128, 884736, 1152]) {id = 1 : i64, metadata = @airMemcpyId21} : memref<2654208xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[ARG0]][0, 0, 0, 0][3, 3, 768, 32][128, 884736, 1152]) {id = 2 : i64, metadata = @airMemcpyId21} : memref<2654208xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[ARG0]][0, 0, 0, 0][3, 768, 3, 32][128, 3456, 1152]) {id = 0 : i64, metadata = @airMemcpyId21} : memref<2654208xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[ARG0]][0, 0, 0, 0][3, 768, 3, 32][128, 3456, 1152]) {id = 1 : i64, metadata = @airMemcpyId21} : memref<2654208xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[ARG0]][0, 0, 0, 0][3, 768, 3, 32][128, 3456, 1152]) {id = 2 : i64, metadata = @airMemcpyId21} : memref<2654208xi32>

#map = affine_map<()[s0] -> (s0 * 64)>
module {
Expand Down Expand Up @@ -701,8 +701,8 @@ module {
// CHECK-SAME: %[[VAL_0:.*]]: memref<262144xi32>, %[[VAL_1:.*]]: memref<262144xi32>, %[[VAL_2:.*]]: memref<131072xi32>) {
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_0]][0, 0, 0, 0][2, 4, 256, 128][0, 128, 512]) {id = 0 : i64, metadata = @airMemcpyId7} : memref<262144xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_0]][0, 0, 0, 131072][2, 4, 256, 128][0, 128, 512]) {id = 1 : i64, metadata = @airMemcpyId7} : memref<262144xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_1]][0, 0, 0, 0][2, 2, 512, 128][128, 131072, 256]) {id = 2 : i64, metadata = @airMemcpyId12} : memref<262144xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_1]][0, 0, 0, 0][2, 2, 512, 128][128, 131072, 256]) {id = 3 : i64, metadata = @airMemcpyId12} : memref<262144xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_1]][0, 0, 0, 0][2, 512, 2, 128][128, 512, 256]) {id = 2 : i64, metadata = @airMemcpyId12} : memref<262144xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_1]][0, 0, 0, 0][2, 512, 2, 128][128, 512, 256]) {id = 3 : i64, metadata = @airMemcpyId12} : memref<262144xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_2]][0, 0, 0, 0][2, 2, 64, 128][65536, 128, 256]) {id = 4 : i64, metadata = @airMemcpyId45} : memref<131072xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_2]][0, 0, 0, 16384][2, 2, 64, 128][65536, 128, 256]) {id = 5 : i64, metadata = @airMemcpyId46} : memref<131072xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_2]][0, 0, 0, 32768][2, 2, 64, 128][65536, 128, 256]) {id = 0 : i64, metadata = @airMemcpyId47} : memref<131072xi32>
Expand Down
8 changes: 4 additions & 4 deletions mlir/test/Conversion/AIRRtToNpu/buffer_memref_to_args.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,10 @@ module {
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_0]][0, 0, 0, 131072][4, 8, 128, 128][0, 128, 1024]) {id = 1 : i64, metadata = @airMemcpyId10} : memref<2097152xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_0]][0, 0, 0, 262144][4, 8, 128, 128][0, 128, 1024]) {id = 2 : i64, metadata = @airMemcpyId10} : memref<2097152xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_0]][0, 0, 0, 393216][4, 8, 128, 128][0, 128, 1024]) {id = 3 : i64, metadata = @airMemcpyId10} : memref<2097152xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_1]][0, 0, 0, 0][4, 4, 512, 64][64, 524288, 1024]) {id = 4 : i64, metadata = @airMemcpyId13} : memref<2097152xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_1]][0, 0, 0, 0][4, 4, 512, 64][64, 524288, 1024]) {id = 5 : i64, metadata = @airMemcpyId13} : memref<2097152xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_1]][0, 0, 0, 0][4, 4, 512, 64][64, 524288, 1024]) {id = 6 : i64, metadata = @airMemcpyId13} : memref<2097152xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_1]][0, 0, 0, 0][4, 4, 512, 64][64, 524288, 1024]) {id = 7 : i64, metadata = @airMemcpyId13} : memref<2097152xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_1]][0, 0, 0, 0][4, 512, 4, 64][64, 4096, 1024]) {id = 4 : i64, metadata = @airMemcpyId13} : memref<2097152xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_1]][0, 0, 0, 0][4, 512, 4, 64][64, 4096, 1024]) {id = 5 : i64, metadata = @airMemcpyId13} : memref<2097152xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_1]][0, 0, 0, 0][4, 512, 4, 64][64, 4096, 1024]) {id = 6 : i64, metadata = @airMemcpyId13} : memref<2097152xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_1]][0, 0, 0, 0][4, 512, 4, 64][64, 4096, 1024]) {id = 7 : i64, metadata = @airMemcpyId13} : memref<2097152xi32>
// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_2]][0, 0, 0, 0][4, 4, 128, 64][131072, 64, 1024]) {id = 8 : i64, metadata = @airMemcpyId26} : memref<2097152xi32>

module {
Expand Down

0 comments on commit 4faaa09

Please sign in to comment.