Skip to content

[MLIR] Legalize certain vector.transfer_read ops of scalable vectors #143146

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: users/momchil-velikov/memref-contig-slice
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 145 additions & 6 deletions mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -298,16 +298,155 @@ struct LegalizeSVEMaskLoadConversion : public OpRewritePattern<memref::LoadOp> {
}
};

/// Transforms a `transfer_read` operation so it reads vector of a type that
/// can be mapped to an LLVM type ("LLVM-legal" type). This is done by
/// collapsing trailing dimensions so we obtain a vector type with a single
/// scalable dimension in the rightmost position.
///
/// Example:
/// ```
/// %v = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8
/// {in_bounds = [false, true, true, true]}
/// : memref<?x?x2x8xi8>, vector<2x[4]x2x8xi8>
/// ```
/// is rewritten to
/// ```
/// %collapse_shape = memref.collapse_shape %M [[0], [1, 2, 3]]
/// : memref<?x?x2x8xi8> into memref<?x?xi8>
/// %0 = vector.transfer_read %collapse_shape[%i, %j], %c0_i8
/// {in_bounds = [false, true]}
/// : memref<?x?xi8>, vector<2x[64]xi8>
/// %1 = vector.shape_cast %0 : vector<2x[64]xi8> to vector<2x[4]x2x8xi8>
/// ```
struct LegalizeTransferRead : public OpRewritePattern<vector::TransferReadOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
PatternRewriter &rewriter) const override {

// Do not try to transform masked reads. For example, if we have a transfer
// to a `vector<[4]x4xi8>` we could have a mask like
// 1 1 1 0
// 1 1 1 0
// 1 1 1 0
// 0 0 0 0
// Flattening this mask would look like
// 1 1 1 0 1 1 1 0 1 1 1 0 0 0 0 0
// and we have not yet figured out an efficient way to build such a mask,
// neither from the mask operand, nor from the original `vector.create_mask`
// operation (if visible at all).
if (readOp.isMasked() || readOp.getMask())
return rewriter.notifyMatchFailure(readOp,
"masked transfers not-supported");

// General permutation maps are not supported. The issue is with transpose,
// broadcast, and other forms of non-identify mapping in the minor
// dimensions which is impossible to represent after collapsing (at least
// because the resulting "collapsed" maps would have smaller number of
// dimension indices).
// TODO: We have not had yet the need for it, but some forms of permutation
// maps with identity in the minor dimensions voukld be supported, for
// example `(i, j, k, p) -> (j, i, k, p)` where we need to collapse only `k`
// and `p`.
if (!readOp.getPermutationMap().isMinorIdentity())
return rewriter.notifyMatchFailure(readOp, "non-identity permutation");
Comment on lines +351 to +352
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would supporting non-identity be a problem? It would be good to add a comment, either:

  • TODO: We haven't required this, so leaving for later. or
  • "Too complex because of , disabling".

Any hint for future developers would be helpful.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


// We handle transfers of vectors with rank >= 2 and a single scalable
// dimension. This transformation aims to transform an LLVM-illegal type
// into an LLVM-legal type and one dimensional vectors are already
// LLVM-legal, even if scalable. A value of a vector type with more than one
// scalable dimension is impossible to represent using a vector type with no
// scalable dimensions or a single one. For example a `vector<[4]x[4]xi8>`
// would have `4 * 4 * vscale * vscale` elements and this quantity is
// impossible to represent as `N` or `N * vscale` (where `N` is a constant).
VectorType origVT = readOp.getVectorType();
ArrayRef<bool> origScalableDims = origVT.getScalableDims();
const int64_t origVRank = origVT.getRank();
if (origVRank < 2 || origVT.getNumScalableDims() != 1)
return rewriter.notifyMatchFailure(readOp, "wrong dimensions");

// Number of trailing dimensions to collapse, including the scalable
// dimension. Nothing to do if the single scalable dimension is already the
// last one.
const int64_t numCollapseDims = std::distance(
llvm::find(origScalableDims, true), origScalableDims.end());
if (numCollapseDims < 2)
return rewriter.notifyMatchFailure(readOp,
"scalable dimension is trailing");

// We want a simple memref (not a tensor) with contiguous elements for at
// least all the trailing dimensions up to and including the scalable one.
auto memTy = dyn_cast<MemRefType>(readOp.getBase().getType());
if (!(memTy && memTy.areTrailingDimsContiguous(numCollapseDims)))
return rewriter.notifyMatchFailure(
readOp, "non-contiguous memref dimensions to collapse");

// The dimensions to collapse (excluding the scalable one) of the vector and
// the memref must match. A dynamic memref dimension is considered
// non-matching. The transfers from the dimensions to collapse must be
// in-bounds (it follows the corresponding indices would be zero). This
// guarantees that the operation transfers a contiguous block.
if (!llvm::equal(memTy.getShape().take_back(numCollapseDims - 1),
origVT.getShape().take_back(numCollapseDims - 1)))
return rewriter.notifyMatchFailure(
readOp, "memref and vector dimensions do not match");

SmallVector<bool> origInBounds = readOp.getInBoundsValues();
if (!llvm::all_of(
ArrayRef<bool>(origInBounds).take_back(numCollapseDims - 1),
[](bool v) { return v; }))
return rewriter.notifyMatchFailure(
readOp, "out-of-bounds transfer from a dimension to collapse");

// Collapse the trailing dimensions of the memref.
SmallVector<ReassociationIndices> reassoc;
for (int64_t i = 0; i < memTy.getRank() - numCollapseDims + 1; ++i)
reassoc.push_back({i});
for (int64_t i = memTy.getRank() - numCollapseDims + 1; i < memTy.getRank();
++i)
reassoc.back().push_back(i);
if (!memref::CollapseShapeOp::isGuaranteedCollapsible(memTy, reassoc))
return failure();
Value collapsedMem = rewriter.create<memref::CollapseShapeOp>(
readOp.getLoc(), readOp.getBase(), reassoc);

// Get a vector type with collapsed trailing dimensions.
SmallVector<int64_t> shape(origVT.getShape());
for (int64_t i = origVRank - numCollapseDims + 1; i < origVRank; ++i)
shape[origVRank - numCollapseDims] *= shape[i];
shape.pop_back_n(numCollapseDims - 1);
auto collapsedVT =
VectorType::get(shape, origVT.getElementType(),
origScalableDims.drop_back(numCollapseDims - 1));

// Drop the extra (zero) indices.
auto indices = readOp.getIndices().drop_back(numCollapseDims - 1);

// Create the new `transfer_read`.
auto newReadOp = rewriter.create<vector::TransferReadOp>(
readOp.getLoc(), collapsedVT, collapsedMem, indices,
ArrayRef<bool>(origInBounds).drop_back(numCollapseDims - 1));

// Cast back to the orignal vector type.
auto toOrigShape = rewriter.create<vector::ShapeCastOp>(readOp.getLoc(),
origVT, newReadOp);

rewriter.replaceOp(readOp, toOrigShape);
return success();
}
};

} // namespace

void mlir::arm_sve::populateLegalizeVectorStoragePatterns(
RewritePatternSet &patterns) {
patterns.add<RelaxScalableVectorAllocaAlignment,
LegalizeSVEMaskAllocation<memref::AllocaOp>,
LegalizeSVEMaskAllocation<memref::AllocOp>,
LegalizeSVEMaskTypeCastConversion,
LegalizeSVEMaskStoreConversion, LegalizeSVEMaskLoadConversion>(
patterns.getContext());
patterns
.add<RelaxScalableVectorAllocaAlignment,
LegalizeSVEMaskAllocation<memref::AllocaOp>,
LegalizeSVEMaskAllocation<memref::AllocOp>,
LegalizeSVEMaskTypeCastConversion, LegalizeSVEMaskStoreConversion,
LegalizeSVEMaskLoadConversion, LegalizeTransferRead>(
patterns.getContext());
}

namespace {
Expand Down
Loading
Loading