From 46301eef5b8db835bb16c28f0d0956674ea958cb Mon Sep 17 00:00:00 2001 From: Jorn Tuyls Date: Fri, 20 Dec 2024 08:29:03 -0800 Subject: [PATCH] [LowerToAIE] Make DMA creation sizes and strides static as early as possible --- .../Transforms/AMDAIELowerToAIE.cpp | 138 +++++++++--------- .../Transforms/AMDAIELowerToAIE.h | 33 ++--- 2 files changed, 80 insertions(+), 91 deletions(-) diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIELowerToAIE.cpp b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIELowerToAIE.cpp index cff956854..b00fcc9ee 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIELowerToAIE.cpp +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIELowerToAIE.cpp @@ -41,55 +41,26 @@ namespace mlir::iree_compiler::AMDAIE { // AIEDeviceBuilder utilities //===----------------------------------------------------------------------===// -FailureOr -AIEDeviceBuilder::convertSizeStrideToBDDimLayoutArrayAttr( - SmallVector sizes, SmallVector strides, - uint8_t memSpace, function_ref emitError) { +BDDimLayoutAndLength AIEDeviceBuilder::convertSizeStrideToBDDimLayoutArrayAttr( + ArrayRef sizes, ArrayRef strides) { assert(sizes.size() == strides.size() && "expected stride and size vectors of same size"); - if (failed(foldRepetitionCount(rewriter.getContext(), sizes, strides))) { - return emitError() << "could not fold repetition count"; - } - // Fold remaining dimensions, assuming zero offsets as offsets should be taken - // care of separately. - SmallVector offsets( - strides.size(), getAsIndexOpFoldResult(rewriter.getContext(), 0)); - SmallVector newOffsets; - SmallVector newSizes; - SmallVector newStrides; - foldDims(offsets, sizes, strides, newOffsets, newSizes, newStrides, memSpace); - SmallVector bdDimLayoutAttr; // If the access pattern (strides/sizes) have a single dimension, make it // implicit with an empty `BDDimLayoutAttr` as this is what the AIE dialect // expects. - if (newStrides.size() == 1) { - std::optional stride = getConstantIntValue(newStrides[0]); - if (stride && stride.value() == 1) { - std::optional maybeSize = getConstantIntValue(newSizes[0]); - if (!maybeSize) return emitError() << "expected a static size"; - return std::make_pair( - AIE::BDDimLayoutArrayAttr::get(rewriter.getContext(), - ArrayRef(bdDimLayoutAttr)), - maybeSize.value()); - } + if (strides.size() == 1 && strides[0] == 1) { + return std::make_pair(AIE::BDDimLayoutArrayAttr::get( + rewriter.getContext(), ArrayRef(bdDimLayoutAttr)), + sizes[0]); } - bdDimLayoutAttr.reserve(newSizes.size()); + bdDimLayoutAttr.reserve(sizes.size()); // Compute the length of the DMA transfer. - std::optional> maybeStaticSizes = - getConstantIntValues(newSizes); - std::optional> maybeStaticStrides = - getConstantIntValues(newStrides); - if (!maybeStaticSizes || !maybeStaticStrides) { - return emitError() << "expected static sizes and strides"; - } int64_t transferLength = - maybeStaticSizes->empty() + sizes.empty() ? 0 - : std::accumulate(maybeStaticSizes->begin(), maybeStaticSizes->end(), - 1, std::multiplies<>()); - for (auto [size, stride] : - llvm::zip(maybeStaticSizes.value(), maybeStaticStrides.value())) { + : std::accumulate(sizes.begin(), sizes.end(), 1, std::multiplies<>()); + for (auto [size, stride] : llvm::zip(sizes, strides)) { bdDimLayoutAttr.push_back( AIE::BDDimLayoutAttr::get(rewriter.getContext(), size, stride)); } @@ -116,20 +87,15 @@ AIEDeviceBuilder::convertSizeStrideToBDDimLayoutArrayAttr( /// aie.dma_bd(%buffer_0_1_50 : memref<2048xi32, 1 : i32>) {len = 2048 : i32} /// aie.use_lock(%lock_0_1_52, Release, 2) /// aie.next_bd ^bb1 -LogicalResult AIEDeviceBuilder::createDMA( +LogicalResult AIEDeviceBuilder::createDMABlocks( Operation *memOp, AIE::DMAChannelDir channelDir, int channelIndex, - SmallVector sizes, SmallVector strides, - uint8_t memSpace, size_t acqNum, size_t relNum, int64_t offset, - const SmallVector &bufferOps, + ArrayRef sizes, ArrayRef strides, size_t acqNum, + size_t relNum, int64_t offset, const SmallVector &bufferOps, const std::pair &locks, std::optional pktId) { OpBuilder::InsertionGuard g(rewriter); - FailureOr maybeDimsAndLength = - convertSizeStrideToBDDimLayoutArrayAttr( - sizes, strides, memSpace, [&]() { return memOp->emitOpError(); }); - if (failed(maybeDimsAndLength)) return failure(); - auto [dims, len] = maybeDimsAndLength.value(); + auto [dims, len] = convertSizeStrideToBDDimLayoutArrayAttr(sizes, strides); Block &endBlock = memOp->getRegion(0).getBlocks().back(); assert(!endBlock.getOps().empty() && @@ -249,26 +215,38 @@ void AIEDeviceBuilder::eraseOp(Operation *op) { rewriter.eraseOp(op); } -void AIEDeviceBuilder::foldDims(const SmallVector &offsets, - const SmallVector &sizes, - const SmallVector &strides, - SmallVector &newOffsets, - SmallVector &newSizes, - SmallVector &newStrides, - uint8_t memSpace) { - SmallVector tmpOffsets; - SmallVector tmpSizes; - SmallVector tmpStrides; - (void)foldUnitDims(rewriter.getContext(), offsets, sizes, strides, tmpOffsets, - tmpSizes, tmpStrides); +LogicalResult AIEDeviceBuilder::foldDimsAndReturnAsStatic( + SmallVector sizes, SmallVector strides, + SmallVector &newSizes, SmallVector &newStrides, + uint8_t memSpace, function_ref emitError) { + if (failed(foldRepetitionCount(rewriter.getContext(), sizes, strides))) { + return emitError() << "could not fold repetition counts"; + } + SmallVector offsets( + strides.size(), getAsIndexOpFoldResult(rewriter.getContext(), 0)); + SmallVector unitOffsets, unitSizes, unitStrides, newOffsets; + (void)foldUnitDims(rewriter.getContext(), offsets, sizes, strides, + unitOffsets, unitSizes, unitStrides); DmaDimConfig dmaDimConfig(deviceModel, memSpace); - SmallVector maxSizes = dmaDimConfig.getMaxSizes(tmpOffsets.size()); + SmallVector maxSizes = dmaDimConfig.getMaxSizes(unitOffsets.size()); + SmallVector linearOffsets, linearSizes, linearStrides; (void)foldLinearDims( - rewriter.getContext(), tmpOffsets, tmpSizes, tmpStrides, newOffsets, - newSizes, newStrides, [&](size_t idxFromEnd, int64_t size) { + rewriter.getContext(), unitOffsets, unitSizes, unitStrides, linearOffsets, + linearSizes, linearStrides, [&](size_t idxFromEnd, int64_t size) { return idxFromEnd < maxSizes.size() && size <= maxSizes[maxSizes.size() - idxFromEnd - 1]; }); + std::optional> maybeStaticSizes = + getConstantIntValues(linearSizes); + std::optional> maybeStaticStrides = + getConstantIntValues(linearStrides); + if (!maybeStaticSizes || !maybeStaticStrides) { + return emitError() + << "found dynamic sizes or strides which is not supported"; + } + newSizes = std::move(maybeStaticSizes.value()); + newStrides = std::move(maybeStaticStrides.value()); + return success(); } void AIEDeviceBuilder::remapOperands(Operation *op) { @@ -582,11 +560,18 @@ LogicalResult AIEDeviceBuilder::connectionToAIE( std::make_pair(consumerLocks[0], producerLocks[0]); rewriter.moveOpBefore(memOp, deviceBlock, deviceBlock->without_terminator().end()); - if (failed(createDMA(memOp, AIE::DMAChannelDir::MM2S, channel.getValue(), - maybeNpuDmaUserOp->getSourceMixedSizes(), - maybeNpuDmaUserOp->getSourceMixedStrides(), - maybeSourceMemSpace.value(), acqNum, acqNum, - maybeOffset.value(), buffers, lockPair, packetId))) { + SmallVector canonicalizedSizes, canonicalizedStrides; + if (failed(foldDimsAndReturnAsStatic( + maybeNpuDmaUserOp->getSourceMixedSizes(), + maybeNpuDmaUserOp->getSourceMixedStrides(), canonicalizedSizes, + canonicalizedStrides, maybeSourceMemSpace.value(), + [&]() { return maybeNpuDmaUserOp->emitOpError(); }))) { + return failure(); + }; + if (failed(createDMABlocks( + memOp, AIE::DMAChannelDir::MM2S, channel.getValue(), + canonicalizedSizes, canonicalizedStrides, acqNum, acqNum, + maybeOffset.value(), buffers, lockPair, packetId))) { return failure(); } } @@ -671,13 +656,20 @@ LogicalResult AIEDeviceBuilder::connectionToAIE( } std::pair lockPair = std::make_pair(producerLocks[0], consumerLocks[0]); + SmallVector canonicalizedSizes, canonicalizedStrides; + if (failed(foldDimsAndReturnAsStatic( + maybeNpuDmaUserOp->getTargetMixedSizes(), + maybeNpuDmaUserOp->getTargetMixedStrides(), canonicalizedSizes, + canonicalizedStrides, maybeTargetMemSpace.value(), + [&]() { return maybeNpuDmaUserOp->emitOpError(); }))) { + return failure(); + }; rewriter.moveOpBefore(memOp, deviceBlock, deviceBlock->without_terminator().end()); - if (failed(createDMA(memOp, AIE::DMAChannelDir::S2MM, channel.getValue(), - maybeNpuDmaUserOp->getTargetMixedSizes(), - maybeNpuDmaUserOp->getTargetMixedStrides(), - maybeTargetMemSpace.value(), acqNum, acqNum, - maybeOffset.value(), buffers, lockPair, packetId))) { + if (failed(createDMABlocks( + memOp, AIE::DMAChannelDir::S2MM, channel.getValue(), + canonicalizedSizes, canonicalizedStrides, acqNum, acqNum, + maybeOffset.value(), buffers, lockPair, packetId))) { return failure(); } } diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIELowerToAIE.h b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIELowerToAIE.h index e005a8d8f..5a256b280 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIELowerToAIE.h +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIELowerToAIE.h @@ -68,18 +68,17 @@ class AIEDeviceBuilder { /// Utility to convert vectors of `size` and `stride` into an /// `AIE::BDDimLayoutArrayAttr`. - FailureOr convertSizeStrideToBDDimLayoutArrayAttr( - SmallVector sizes, SmallVector strides, - uint8_t memSpace, function_ref emitError); + BDDimLayoutAndLength convertSizeStrideToBDDimLayoutArrayAttr( + ArrayRef sizes, ArrayRef strides); /// Utility to create DMA blocks and add them to `memOp`. - LogicalResult createDMA(Operation *memOp, AIE::DMAChannelDir channelDir, - int channelIndex, SmallVector sizes, - SmallVector strides, uint8_t memSpace, - size_t acqNum, size_t relNum, int64_t offset, - const SmallVector &bufferOps, - const std::pair &locks, - std::optional pktId); + LogicalResult createDMABlocks( + Operation *memOp, AIE::DMAChannelDir channelDir, int channelIndex, + ArrayRef sizes, ArrayRef strides, size_t acqNum, + size_t relNum, int64_t offset, + const SmallVector &bufferOps, + const std::pair &locks, + std::optional pktId); /// Utility to create flow ops from connection ops. SmallVector createFlowOps( @@ -99,14 +98,12 @@ class AIEDeviceBuilder { /// might be used after `op` is erased. void eraseOp(Operation *op); - /// Utility to fold linear dims, unit dims and single dims in the provided - /// `offsets`, `sizes` and `strides` access patterns. - void foldDims(const SmallVector &offsets, - const SmallVector &sizes, - const SmallVector &strides, - SmallVector &newOffsets, - SmallVector &newSizes, - SmallVector &newStrides, uint8_t memSpace); + /// Utility to fold the provided repetition count, unit dims, linear dims and + /// to convert the sizes and strides into static versions and return them. + LogicalResult foldDimsAndReturnAsStatic( + SmallVector sizes, SmallVector strides, + SmallVector &newSizes, SmallVector &newStrides, + uint8_t memSpace, function_ref emitError); /// Utility to remap the provided operation's operands. void remapOperands(Operation *op);