Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LowerToAIE] Make DMA creation sizes and strides static as early as possible #999

Merged
merged 1 commit into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -41,55 +41,26 @@ namespace mlir::iree_compiler::AMDAIE {
// AIEDeviceBuilder utilities
//===----------------------------------------------------------------------===//

FailureOr<BDDimLayoutAndLength>
AIEDeviceBuilder::convertSizeStrideToBDDimLayoutArrayAttr(
SmallVector<OpFoldResult> sizes, SmallVector<OpFoldResult> strides,
uint8_t memSpace, function_ref<InFlightDiagnostic()> emitError) {
BDDimLayoutAndLength AIEDeviceBuilder::convertSizeStrideToBDDimLayoutArrayAttr(
ArrayRef<int64_t> sizes, ArrayRef<int64_t> strides) {
assert(sizes.size() == strides.size() &&
"expected stride and size vectors of same size");
if (failed(foldRepetitionCount(rewriter.getContext(), sizes, strides))) {
return emitError() << "could not fold repetition count";
}
// Fold remaining dimensions, assuming zero offsets as offsets should be taken
// care of separately.
SmallVector<OpFoldResult> offsets(
strides.size(), getAsIndexOpFoldResult(rewriter.getContext(), 0));
SmallVector<OpFoldResult> newOffsets;
SmallVector<OpFoldResult> newSizes;
SmallVector<OpFoldResult> newStrides;
foldDims(offsets, sizes, strides, newOffsets, newSizes, newStrides, memSpace);

SmallVector<AIE::BDDimLayoutAttr, 4> bdDimLayoutAttr;
// If the access pattern (strides/sizes) have a single dimension, make it
// implicit with an empty `BDDimLayoutAttr` as this is what the AIE dialect
// expects.
if (newStrides.size() == 1) {
std::optional<int64_t> stride = getConstantIntValue(newStrides[0]);
if (stride && stride.value() == 1) {
std::optional<int64_t> maybeSize = getConstantIntValue(newSizes[0]);
if (!maybeSize) return emitError() << "expected a static size";
return std::make_pair(
AIE::BDDimLayoutArrayAttr::get(rewriter.getContext(),
ArrayRef(bdDimLayoutAttr)),
maybeSize.value());
}
if (strides.size() == 1 && strides[0] == 1) {
return std::make_pair(AIE::BDDimLayoutArrayAttr::get(
rewriter.getContext(), ArrayRef(bdDimLayoutAttr)),
sizes[0]);
}
bdDimLayoutAttr.reserve(newSizes.size());
bdDimLayoutAttr.reserve(sizes.size());
// Compute the length of the DMA transfer.
std::optional<SmallVector<int64_t>> maybeStaticSizes =
getConstantIntValues(newSizes);
std::optional<SmallVector<int64_t>> maybeStaticStrides =
getConstantIntValues(newStrides);
if (!maybeStaticSizes || !maybeStaticStrides) {
return emitError() << "expected static sizes and strides";
}
int64_t transferLength =
maybeStaticSizes->empty()
sizes.empty()
? 0
: std::accumulate(maybeStaticSizes->begin(), maybeStaticSizes->end(),
1, std::multiplies<>());
for (auto [size, stride] :
llvm::zip(maybeStaticSizes.value(), maybeStaticStrides.value())) {
: std::accumulate(sizes.begin(), sizes.end(), 1, std::multiplies<>());
for (auto [size, stride] : llvm::zip(sizes, strides)) {
bdDimLayoutAttr.push_back(
AIE::BDDimLayoutAttr::get(rewriter.getContext(), size, stride));
}
Expand All @@ -116,20 +87,15 @@ AIEDeviceBuilder::convertSizeStrideToBDDimLayoutArrayAttr(
/// aie.dma_bd(%buffer_0_1_50 : memref<2048xi32, 1 : i32>) {len = 2048 : i32}
/// aie.use_lock(%lock_0_1_52, Release, 2)
/// aie.next_bd ^bb1
LogicalResult AIEDeviceBuilder::createDMA(
LogicalResult AIEDeviceBuilder::createDMABlocks(
Operation *memOp, AIE::DMAChannelDir channelDir, int channelIndex,
SmallVector<OpFoldResult> sizes, SmallVector<OpFoldResult> strides,
uint8_t memSpace, size_t acqNum, size_t relNum, int64_t offset,
const SmallVector<AIE::BufferOp> &bufferOps,
ArrayRef<int64_t> sizes, ArrayRef<int64_t> strides, size_t acqNum,
size_t relNum, int64_t offset, const SmallVector<AIE::BufferOp> &bufferOps,
const std::pair<AIE::LockOp, AIE::LockOp> &locks,
std::optional<uint8_t> pktId) {
OpBuilder::InsertionGuard g(rewriter);

FailureOr<BDDimLayoutAndLength> maybeDimsAndLength =
convertSizeStrideToBDDimLayoutArrayAttr(
sizes, strides, memSpace, [&]() { return memOp->emitOpError(); });
if (failed(maybeDimsAndLength)) return failure();
auto [dims, len] = maybeDimsAndLength.value();
auto [dims, len] = convertSizeStrideToBDDimLayoutArrayAttr(sizes, strides);

Block &endBlock = memOp->getRegion(0).getBlocks().back();
assert(!endBlock.getOps<AIE::EndOp>().empty() &&
Expand Down Expand Up @@ -249,26 +215,38 @@ void AIEDeviceBuilder::eraseOp(Operation *op) {
rewriter.eraseOp(op);
}

void AIEDeviceBuilder::foldDims(const SmallVector<OpFoldResult> &offsets,
const SmallVector<OpFoldResult> &sizes,
const SmallVector<OpFoldResult> &strides,
SmallVector<OpFoldResult> &newOffsets,
SmallVector<OpFoldResult> &newSizes,
SmallVector<OpFoldResult> &newStrides,
uint8_t memSpace) {
SmallVector<OpFoldResult> tmpOffsets;
SmallVector<OpFoldResult> tmpSizes;
SmallVector<OpFoldResult> tmpStrides;
(void)foldUnitDims(rewriter.getContext(), offsets, sizes, strides, tmpOffsets,
tmpSizes, tmpStrides);
LogicalResult AIEDeviceBuilder::foldDimsAndReturnAsStatic(
SmallVector<OpFoldResult> sizes, SmallVector<OpFoldResult> strides,
SmallVector<int64_t> &newSizes, SmallVector<int64_t> &newStrides,
uint8_t memSpace, function_ref<InFlightDiagnostic()> emitError) {
if (failed(foldRepetitionCount(rewriter.getContext(), sizes, strides))) {
return emitError() << "could not fold repetition counts";
}
SmallVector<OpFoldResult> offsets(
strides.size(), getAsIndexOpFoldResult(rewriter.getContext(), 0));
SmallVector<OpFoldResult> unitOffsets, unitSizes, unitStrides, newOffsets;
(void)foldUnitDims(rewriter.getContext(), offsets, sizes, strides,
unitOffsets, unitSizes, unitStrides);
DmaDimConfig dmaDimConfig(deviceModel, memSpace);
SmallVector<int64_t> maxSizes = dmaDimConfig.getMaxSizes(tmpOffsets.size());
SmallVector<int64_t> maxSizes = dmaDimConfig.getMaxSizes(unitOffsets.size());
SmallVector<OpFoldResult> linearOffsets, linearSizes, linearStrides;
(void)foldLinearDims(
rewriter.getContext(), tmpOffsets, tmpSizes, tmpStrides, newOffsets,
newSizes, newStrides, [&](size_t idxFromEnd, int64_t size) {
rewriter.getContext(), unitOffsets, unitSizes, unitStrides, linearOffsets,
linearSizes, linearStrides, [&](size_t idxFromEnd, int64_t size) {
return idxFromEnd < maxSizes.size() &&
size <= maxSizes[maxSizes.size() - idxFromEnd - 1];
});
std::optional<SmallVector<int64_t>> maybeStaticSizes =
getConstantIntValues(linearSizes);
std::optional<SmallVector<int64_t>> maybeStaticStrides =
getConstantIntValues(linearStrides);
if (!maybeStaticSizes || !maybeStaticStrides) {
return emitError()
<< "found dynamic sizes or strides which is not supported";
}
newSizes = std::move(maybeStaticSizes.value());
newStrides = std::move(maybeStaticStrides.value());
return success();
}

void AIEDeviceBuilder::remapOperands(Operation *op) {
Expand Down Expand Up @@ -582,11 +560,18 @@ LogicalResult AIEDeviceBuilder::connectionToAIE(
std::make_pair(consumerLocks[0], producerLocks[0]);
rewriter.moveOpBefore(memOp, deviceBlock,
deviceBlock->without_terminator().end());
if (failed(createDMA(memOp, AIE::DMAChannelDir::MM2S, channel.getValue(),
maybeNpuDmaUserOp->getSourceMixedSizes(),
maybeNpuDmaUserOp->getSourceMixedStrides(),
maybeSourceMemSpace.value(), acqNum, acqNum,
maybeOffset.value(), buffers, lockPair, packetId))) {
SmallVector<int64_t> canonicalizedSizes, canonicalizedStrides;
if (failed(foldDimsAndReturnAsStatic(
maybeNpuDmaUserOp->getSourceMixedSizes(),
maybeNpuDmaUserOp->getSourceMixedStrides(), canonicalizedSizes,
canonicalizedStrides, maybeSourceMemSpace.value(),
[&]() { return maybeNpuDmaUserOp->emitOpError(); }))) {
return failure();
};
if (failed(createDMABlocks(
memOp, AIE::DMAChannelDir::MM2S, channel.getValue(),
canonicalizedSizes, canonicalizedStrides, acqNum, acqNum,
maybeOffset.value(), buffers, lockPair, packetId))) {
return failure();
}
}
Expand Down Expand Up @@ -671,13 +656,20 @@ LogicalResult AIEDeviceBuilder::connectionToAIE(
}
std::pair<AIE::LockOp, AIE::LockOp> lockPair =
std::make_pair(producerLocks[0], consumerLocks[0]);
SmallVector<int64_t> canonicalizedSizes, canonicalizedStrides;
if (failed(foldDimsAndReturnAsStatic(
maybeNpuDmaUserOp->getTargetMixedSizes(),
maybeNpuDmaUserOp->getTargetMixedStrides(), canonicalizedSizes,
canonicalizedStrides, maybeTargetMemSpace.value(),
[&]() { return maybeNpuDmaUserOp->emitOpError(); }))) {
return failure();
};
rewriter.moveOpBefore(memOp, deviceBlock,
deviceBlock->without_terminator().end());
if (failed(createDMA(memOp, AIE::DMAChannelDir::S2MM, channel.getValue(),
maybeNpuDmaUserOp->getTargetMixedSizes(),
maybeNpuDmaUserOp->getTargetMixedStrides(),
maybeTargetMemSpace.value(), acqNum, acqNum,
maybeOffset.value(), buffers, lockPair, packetId))) {
if (failed(createDMABlocks(
memOp, AIE::DMAChannelDir::S2MM, channel.getValue(),
canonicalizedSizes, canonicalizedStrides, acqNum, acqNum,
maybeOffset.value(), buffers, lockPair, packetId))) {
return failure();
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,18 +68,17 @@ class AIEDeviceBuilder {

/// Utility to convert vectors of `size` and `stride` into an
/// `AIE::BDDimLayoutArrayAttr`.
FailureOr<BDDimLayoutAndLength> convertSizeStrideToBDDimLayoutArrayAttr(
SmallVector<OpFoldResult> sizes, SmallVector<OpFoldResult> strides,
uint8_t memSpace, function_ref<InFlightDiagnostic()> emitError);
BDDimLayoutAndLength convertSizeStrideToBDDimLayoutArrayAttr(
ArrayRef<int64_t> sizes, ArrayRef<int64_t> strides);

/// Utility to create DMA blocks and add them to `memOp`.
LogicalResult createDMA(Operation *memOp, AIE::DMAChannelDir channelDir,
int channelIndex, SmallVector<OpFoldResult> sizes,
SmallVector<OpFoldResult> strides, uint8_t memSpace,
size_t acqNum, size_t relNum, int64_t offset,
const SmallVector<AIE::BufferOp> &bufferOps,
const std::pair<AIE::LockOp, AIE::LockOp> &locks,
std::optional<uint8_t> pktId);
LogicalResult createDMABlocks(
Operation *memOp, AIE::DMAChannelDir channelDir, int channelIndex,
ArrayRef<int64_t> sizes, ArrayRef<int64_t> strides, size_t acqNum,
size_t relNum, int64_t offset,
const SmallVector<AIE::BufferOp> &bufferOps,
const std::pair<AIE::LockOp, AIE::LockOp> &locks,
std::optional<uint8_t> pktId);

/// Utility to create flow ops from connection ops.
SmallVector<Operation *> createFlowOps(
Expand All @@ -99,14 +98,12 @@ class AIEDeviceBuilder {
/// might be used after `op` is erased.
void eraseOp(Operation *op);

/// Utility to fold linear dims, unit dims and single dims in the provided
/// `offsets`, `sizes` and `strides` access patterns.
void foldDims(const SmallVector<OpFoldResult> &offsets,
const SmallVector<OpFoldResult> &sizes,
const SmallVector<OpFoldResult> &strides,
SmallVector<OpFoldResult> &newOffsets,
SmallVector<OpFoldResult> &newSizes,
SmallVector<OpFoldResult> &newStrides, uint8_t memSpace);
/// Utility to fold the provided repetition count, unit dims, linear dims and
/// to convert the sizes and strides into static versions and return them.
LogicalResult foldDimsAndReturnAsStatic(
SmallVector<OpFoldResult> sizes, SmallVector<OpFoldResult> strides,
SmallVector<int64_t> &newSizes, SmallVector<int64_t> &newStrides,
uint8_t memSpace, function_ref<InFlightDiagnostic()> emitError);

/// Utility to remap the provided operation's operands.
void remapOperands(Operation *op);
Expand Down
Loading