Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ConvertToDma] Add options to tranpose dma dimensions on target side #812

Merged
merged 2 commits into from
Oct 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,27 +17,28 @@

#define DEBUG_TYPE "iree-amdaie-convert-to-dma"



namespace mlir::iree_compiler::AMDAIE {

namespace {

/// Applies packing to a given input.
LogicalResult packDmaInputs(IREE::LinalgExt::PackOp packOp,
SmallVector<OpFoldResult> &offsets,
SmallVector<OpFoldResult> &sizes,
SmallVector<OpFoldResult> &strides) {
MLIRContext *ctx = packOp.getContext();
/// Applies dma transposition on the side that has lower number of dimensions,
/// which means the source side for pack ops and the destination side for unpack
/// ops.
template <typename PackOrUnpackOp>
LogicalResult dmaTransposeOnLowerNumDims(PackOrUnpackOp packOrUnpackOp,
SmallVector<OpFoldResult> &offsets,
SmallVector<OpFoldResult> &sizes,
SmallVector<OpFoldResult> &strides) {
MLIRContext *ctx = packOrUnpackOp.getContext();

llvm::ArrayRef<int64_t> permutation = packOp.getOuterDimsPerm();
llvm::ArrayRef<int64_t> innerTiles = packOp.getStaticInnerTiles();
llvm::ArrayRef<int64_t> permutation = packOrUnpackOp.getOuterDimsPerm();
llvm::ArrayRef<int64_t> innerTiles = packOrUnpackOp.getStaticInnerTiles();

SmallVector<OpFoldResult> innerSizes;
SmallVector<OpFoldResult> innerStrides;
SmallVector<OpFoldResult> innerOffsets;

auto innerDimsPos = packOp.getInnerDimsPos();
ArrayRef<int64_t> innerDimsPos = packOrUnpackOp.getInnerDimsPos();

for (int i = 0; i < innerTiles.size(); i++) {
// Calculate new sizes.
Expand All @@ -52,7 +53,7 @@ LogicalResult packDmaInputs(IREE::LinalgExt::PackOp packOp,
"in dimension {0}, the tile size {1} does not divide the tensor size "
"{2}. Imperfect/partial tiling is currently not supported.",
i, innerTiles[i], size.value());
return packOp->emitOpError(message);
return packOrUnpackOp->emitOpError(message);
}

sizes[innerDimsPos[i]] =
Expand All @@ -71,33 +72,38 @@ LogicalResult packDmaInputs(IREE::LinalgExt::PackOp packOp,
innerOffsets.push_back(offsets[innerDimsPos[i]]);
offsets[innerDimsPos[i]] = getAsIndexOpFoldResult(ctx, 0);
}

// Apply permutations to the outer dims if provided.
if (!permutation.empty()) {
applyPermutationToVector(strides, permutation);
applyPermutationToVector(sizes, permutation);
applyPermutationToVector(offsets, permutation);
}

// Merge the dims.
sizes.insert(sizes.end(), innerSizes.begin(), innerSizes.end());
strides.insert(strides.end(), innerStrides.begin(), innerStrides.end());
offsets.insert(offsets.end(), innerOffsets.begin(), innerOffsets.end());
return success();
}

/// Applies unpacking to a given input.
LogicalResult unPackDmaInputs(IREE::LinalgExt::UnPackOp unPackOp,
SmallVector<OpFoldResult> &offsets,
SmallVector<OpFoldResult> &sizes,
SmallVector<OpFoldResult> &strides) {
MLIRContext *ctx = unPackOp.getContext();
/// Applies dma transposition on the side which has higher number of dimensions,
/// which means the destination side for pack ops and the source side for unpack
/// ops.
template <typename PackOrUnpackOp>
LogicalResult dmaTransposeOnHigherNumDims(PackOrUnpackOp packOrUnpackOp,
SmallVector<OpFoldResult> &offsets,
SmallVector<OpFoldResult> &sizes,
SmallVector<OpFoldResult> &strides) {
MLIRContext *ctx = packOrUnpackOp.getContext();

llvm::ArrayRef<int64_t> permutation = unPackOp.getOuterDimsPerm();
llvm::ArrayRef<int64_t> innerTiles = unPackOp.getStaticInnerTiles();
llvm::ArrayRef<int64_t> permutation = packOrUnpackOp.getOuterDimsPerm();
llvm::ArrayRef<int64_t> innerTiles = packOrUnpackOp.getStaticInnerTiles();

SmallVector<OpFoldResult> innerSizes;
SmallVector<OpFoldResult> innerStrides;
SmallVector<OpFoldResult> innerOffsets;
auto innerDimsPos = unPackOp.getInnerDimsPos();
ArrayRef<int64_t> innerDimsPos = packOrUnpackOp.getInnerDimsPos();

int numOuterDims = sizes.size() - innerTiles.size();
SmallVector<OpFoldResult> outerOffsets = SmallVector<OpFoldResult>(
Expand All @@ -116,29 +122,30 @@ LogicalResult unPackDmaInputs(IREE::LinalgExt::UnPackOp unPackOp,
applyPermutationToVector(outerSizes, inversePermutation);
applyPermutationToVector(outerOffsets, inversePermutation);
}
// Do the unpacking on the Outer dims.

// Initialize the indexing of each outer dim.
llvm::SmallDenseMap<int64_t, int64_t> outerDimsIndexMap;
// Intialize the indexing of each outer dim.
for (int i = 0; i < numOuterDims; i++) {
outerDimsIndexMap[i] = i;
}

// Update outer dim sizes/strides/offsts.
for (int i = 0; i < innerTiles.size(); i++) {
// Insert inner dims adjacent to there corresponding outer dims.
outerSizes.insert(
outerSizes.begin() + outerDimsIndexMap[innerDimsPos[i]] + 1,
getAsIndexOpFoldResult(ctx, innerTiles[i]));
outerStrides.insert(
outerStrides.begin() + outerDimsIndexMap[innerDimsPos[i]] + 1,
strides[numOuterDims + i]);
outerOffsets.insert(
outerOffsets.begin() + outerDimsIndexMap[innerDimsPos[i]] + 1,
offsets[numOuterDims + i]);
// Insert inner dims adjacent to their corresponding outer dims.
int insertionIndex = outerDimsIndexMap[innerDimsPos[i]] + 1;
outerSizes.insert(outerSizes.begin() + insertionIndex,
getAsIndexOpFoldResult(ctx, innerTiles[i]));
outerStrides.insert(outerStrides.begin() + insertionIndex,
strides[numOuterDims + i]);
outerOffsets.insert(outerOffsets.begin() + insertionIndex,
offsets[numOuterDims + i]);
// Update the map as all the dimensions inner to the innerDimsPos[i] are now
// shifted by 1.
for (int j = innerDimsPos[i] + 1; j < numOuterDims; j++) {
outerDimsIndexMap[j]++;
}
}

// Make the outer dims as the final returned dims
offsets = outerOffsets;
strides = outerStrides;
Expand All @@ -147,7 +154,7 @@ LogicalResult unPackDmaInputs(IREE::LinalgExt::UnPackOp unPackOp,
}

/// Examines an input/output of a pack/unpack op and provides the
/// corresponding offsets, sizes and strides required by the dma op
/// corresponding offsets, sizes and strides required by the dma op.
LogicalResult setDmaInputs(Operation *&operandOp,
SmallVector<OpFoldResult> &offsets,
SmallVector<OpFoldResult> &sizes,
Expand Down Expand Up @@ -232,25 +239,6 @@ LogicalResult setDmaInputs(Operation *&operandOp,
"and SubViewOp as inputs.");
}

/// Get the inputs from the pack/unpack op 'op'. Return failure if 'op' is not
/// a pack/unpack op, or if 'op' is determined unlowerable to a DMA operation.
LogicalResult processInputs(Operation *op, SmallVector<OpFoldResult> &offsets,
SmallVector<OpFoldResult> &sizes,
SmallVector<OpFoldResult> &strides) {
if (auto packOp = dyn_cast<IREE::LinalgExt::PackOp>(op)) {
if (failed(packDmaInputs(packOp, offsets, sizes, strides))) {
return failure();
}
} else if (auto unPackOp = dyn_cast<IREE::LinalgExt::UnPackOp>(op)) {
if (failed(unPackDmaInputs(unPackOp, offsets, sizes, strides))) {
return failure();
}
} else {
return failure();
}
return success();
}

/// Rewrite the pack/unpack op 'op' as a DMA operation. The function arguments
/// 'input', 'output', and 'innerTiles' are the input, output, and inner tile
/// of 'op'. If 'op' is not a pack/unpack op, or if it determined to not
Expand All @@ -260,8 +248,11 @@ LogicalResult processInputs(Operation *op, SmallVector<OpFoldResult> &offsets,
/// obtained from 'op' inside this function if it were templatized, but
/// I've factorized out that logic to reduce the total amount of templatized
/// code.
LogicalResult rewriteAsDma(IRRewriter &rewriter, Operation *op, Value input,
Value output, llvm::ArrayRef<int64_t> innerTiles) {
template <typename PackOrUnpackOp>
LogicalResult rewriteAsDma(IRRewriter &rewriter, PackOrUnpackOp op, Value input,
Value output, llvm::ArrayRef<int64_t> innerTiles,
bool packTransposeOnSource,
bool unpackTransposeOnSource) {
if (llvm::any_of(innerTiles,
[](int64_t size) { return ShapedType::isDynamic(size); })) {
op->emitError("has a non-static shape: not yet supported by this pass.");
Expand All @@ -283,10 +274,6 @@ LogicalResult rewriteAsDma(IRRewriter &rewriter, Operation *op, Value input,
return failure();
}

if (!succeeded(processInputs(op, srcOffsets, srcShape, srcBaseStrides))) {
return failure();
}

// Prepare destination DMA inputs.
SmallVector<OpFoldResult> dstOffsets;
SmallVector<OpFoldResult> dstBaseStrides;
Expand All @@ -295,6 +282,29 @@ LogicalResult rewriteAsDma(IRRewriter &rewriter, Operation *op, Value input,
return failure();
}

// Update dma source or destination addressing based on the side for dma
// transposition and pack/unpack operations.
if (packTransposeOnSource && isa<IREE::LinalgExt::PackOp>(op)) {
if (!succeeded(dmaTransposeOnLowerNumDims(op, srcOffsets, srcShape,
srcBaseStrides)))
return failure();
} else if (!packTransposeOnSource && isa<IREE::LinalgExt::PackOp>(op)) {
if (!succeeded(dmaTransposeOnHigherNumDims(op, dstOffsets, dstShape,
dstBaseStrides)))
return failure();
} else if (unpackTransposeOnSource && isa<IREE::LinalgExt::UnPackOp>(op)) {
if (!succeeded(dmaTransposeOnHigherNumDims(op, srcOffsets, srcShape,
srcBaseStrides)))
return failure();
} else if (!unpackTransposeOnSource && isa<IREE::LinalgExt::UnPackOp>(op)) {
if (!succeeded(dmaTransposeOnLowerNumDims(op, dstOffsets, dstShape,
dstBaseStrides)))
return failure();
} else {
op->emitError("unhandled option for dma addressing update.");
return failure();
}

// Create logical objectFifos from source and destination memrefs.
Value srcVal = sourceOp->getResult(0);
Value dstVal = dstOp->getResult(0);
Expand All @@ -317,11 +327,14 @@ LogicalResult rewriteAsDma(IRRewriter &rewriter, Operation *op, Value input,
}

template <typename PackOrUnpackOp>
LogicalResult rewriteAsDma(PackOrUnpackOp op, IRRewriter &rewriter) {
LogicalResult rewriteAsDma(PackOrUnpackOp op, IRRewriter &rewriter,
bool packTransposeOnSource,
bool unpackTransposeOnSource) {
Value input = op.getInput();
Value output = op.getOutput();
llvm::ArrayRef<int64_t> innerTiles = op.getStaticInnerTiles();
return rewriteAsDma(rewriter, op, input, output, innerTiles);
return rewriteAsDma(rewriter, op, input, output, innerTiles,
packTransposeOnSource, unpackTransposeOnSource);
}

/// Convert a linalg.copy operation on 2 memrefs to an equivalent pack/unpack
Expand Down Expand Up @@ -375,6 +388,9 @@ class AMDAIEConvertToDmaPass

AMDAIEConvertToDmaPass() = default;
AMDAIEConvertToDmaPass(const AMDAIEConvertToDmaPass &pass){};
AMDAIEConvertToDmaPass(const AMDAIEConvertToDmaOptions &options)
: AMDAIEConvertToDmaBase(options) {}

void runOnOperation() override;
};

Expand All @@ -387,32 +403,35 @@ void AMDAIEConvertToDmaPass::runOnOperation() {
// step. This is easy to implement, but not the most direct lowering, so
// we might want to revisit this.
WalkResult convertCopiesWalkResult =
getOperation()->walk([&rewriter](linalg::CopyOp copyOp) {
getOperation()->walk([&](linalg::CopyOp copyOp) {
Copy link
Contributor

@newling newling Oct 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this change? My preference is
[&rewriter] > [&] > [&, this]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think there's need to pass &rewriter. And I'm trying to follow the coding styles in llvm-project which just pass [&] in such walking functions, e.g. https://github.com/llvm/llvm-project/blob/e1e788f423b5c780c40912ab102b0a3c4b92b9de/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp#L63

if (failed(copyToPack(rewriter, copyOp)))
return WalkResult::interrupt();
return WalkResult::advance();
});
if (convertCopiesWalkResult.wasInterrupted()) return signalPassFailure();

auto walkResult = getOperation()->walk(
[&rewriter](IREE::LinalgExt::PackOp op) -> WalkResult {
if (failed(rewriteAsDma(op, rewriter))) {
auto walkResult =
getOperation()->walk([&, this](IREE::LinalgExt::PackOp op) -> WalkResult {
if (failed(rewriteAsDma(op, rewriter, packTransposeOnSource,
unpackTransposeOnSource))) {
return WalkResult::interrupt();
}
return WalkResult::advance();
});
if (walkResult.wasInterrupted()) signalPassFailure();
walkResult = getOperation()->walk(
[&rewriter](IREE::LinalgExt::UnPackOp op) -> WalkResult {
if (failed(rewriteAsDma(op, rewriter))) {
[&, this](IREE::LinalgExt::UnPackOp op) -> WalkResult {
if (failed(rewriteAsDma(op, rewriter, packTransposeOnSource,
unpackTransposeOnSource))) {
return WalkResult::interrupt();
}
return WalkResult::advance();
});
if (walkResult.wasInterrupted()) signalPassFailure();
}

std::unique_ptr<Pass> createAMDAIEConvertToDmaPass() {
return std::make_unique<AMDAIEConvertToDmaPass>();
std::unique_ptr<Pass> createAMDAIEConvertToDmaPass(
AMDAIEConvertToDmaOptions options) {
return std::make_unique<AMDAIEConvertToDmaPass>(options);
}
} // namespace mlir::iree_compiler::AMDAIE
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,8 @@ std::unique_ptr<Pass> createAMDAIEPackAndTransposePass(

/// Create pass to lower copy/pack/unpack ops to AMDAIE DMA ops operating on
/// logical objectFifos.
std::unique_ptr<Pass> createAMDAIEConvertToDmaPass();
std::unique_ptr<Pass> createAMDAIEConvertToDmaPass(
AMDAIEConvertToDmaOptions options = {});

/// Create a pass to pad MatmulOp.
std::unique_ptr<Pass> createAMDAIEPadPass(AMDAIEPadOptions options = {});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,12 @@ def AMDAIEConvertToDma :

}];
let constructor = "mlir::iree_compiler::AMDAIE::createAMDAIEConvertToDmaPass()";
let options = [
Option<"packTransposeOnSource", "pack-transpose-on-source", "bool", /*default=*/"true",
"Option to set transposed dma dimensions on source or target side for pack ops">,
Option<"unpackTransposeOnSource", "unpack-transpose-on-source", "bool", /*default=*/"true",
"Option to set transposed dma dimensions on source or target side for unpack ops">
];
}

def AMDAIEPad :
Expand Down
Loading
Loading