-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR] Legalize certain vector.transfer_read
ops of scalable vectors
#143146
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: users/momchil-velikov/memref-contig-slice
Are you sure you want to change the base?
[MLIR] Legalize certain vector.transfer_read
ops of scalable vectors
#143146
Conversation
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir-sve Author: Momchil Velikov (momchil-velikov) ChangesTHis patch add a transform of Full diff: https://github.com/llvm/llvm-project/pull/143146.diff 3 Files Affected:
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp
index d2ac850a5f70b..f16d33c004fec 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp
@@ -298,6 +298,113 @@ struct LegalizeSVEMaskLoadConversion : public OpRewritePattern<memref::LoadOp> {
}
};
+/// Transforms a `transfer_read` operation so it reads vector of a type that
+/// can be mapped to an LLVM type. This is done by collapsing trailing
+/// dimensions so we obtain a vector type with a single scalable dimension in
+/// the rightmost position.
+///
+/// Example:
+/// ```
+/// %v = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8
+/// {in_bounds = [false, true, true, true]}
+/// : memref<?x?x2x8xi8>, vector<2x[4]x2x8xi8>
+/// ```
+/// is rewriten to
+/// ```
+/// %collapse_shape = memref.collapse_shape %M [[0], [1, 2, 3]]
+/// : memref<?x?x2x8xi8> into memref<?x?xi8>
+/// %0 = vector.transfer_read %collapse_shape[%i, %j], %c0_i8
+/// {in_bounds = [false, true]}
+/// : memref<?x?xi8>, vector<2x[64]xi8>
+/// %1 = vector.shape_cast %0 : vector<2x[64]xi8> to vector<2x[4]x2x8xi8>
+/// ```
+struct LegalizeTransferRead : public OpRewritePattern<vector::TransferReadOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
+ PatternRewriter &rewriter) const override {
+
+ if (!readOp.getPermutationMap().isMinorIdentity())
+ return rewriter.notifyMatchFailure(readOp, "non-identity permutation");
+
+ // We handle transfers of vectors with rank >= 2 and a single scalable
+ // dimension.
+ VectorType origVT = readOp.getVectorType();
+ ArrayRef<bool> origScalableDims = origVT.getScalableDims();
+ const int64_t origVRank = origVT.getRank();
+ if (origVRank < 2 || llvm::count(origScalableDims, true) != 1)
+ return rewriter.notifyMatchFailure(readOp, "wrong dimensions");
+
+ // Number of trailing dimensions to collapse, including the scalable
+ // dimension. Nothing to do if the single scalable dimension is already the
+ // last one.
+ const int64_t numCollapseDims = std::distance(
+ llvm::find(origScalableDims, true), origScalableDims.end());
+ if (numCollapseDims < 2)
+ return rewriter.notifyMatchFailure(readOp,
+ "scalable dimension is trailing");
+
+ // We want a simple memref (not a tensor) with contiguous elements for at
+ // least all the trailing dimensions up to and including the scalable one.
+ auto memTy = dyn_cast<MemRefType>(readOp.getBase().getType());
+ if (!(memTy && memTy.areTrailingDimsContiguous(numCollapseDims)))
+ return rewriter.notifyMatchFailure(
+ readOp, "non-contiguous memref dimensions to collapse");
+
+ // The collapsed dimensions (excluding the scalable one) of the vector and
+ // the memref must match and the corresponding indices must be in-bounds (it
+ // follows these indices would be zero). This guarantees that the operation
+ // transfers a contiguous block.
+ if (!llvm::equal(memTy.getShape().take_back(numCollapseDims - 1),
+ origVT.getShape().take_back(numCollapseDims - 1)))
+ return rewriter.notifyMatchFailure(
+ readOp, "memref and vector dimensions do not match");
+
+ SmallVector<bool> origInBounds = readOp.getInBoundsValues();
+ if (!llvm::all_of(
+ ArrayRef<bool>(origInBounds).take_back(numCollapseDims - 1),
+ [](bool v) { return v; }))
+ return rewriter.notifyMatchFailure(readOp,
+ "out-if-bounds index to collapse");
+
+ // Collapse the trailing dimensions of the memref.
+ SmallVector<ReassociationIndices> reassoc;
+ for (int64_t i = 0; i < memTy.getRank() - numCollapseDims + 1; ++i)
+ reassoc.push_back({i});
+ for (int64_t i = memTy.getRank() - numCollapseDims + 1; i < memTy.getRank();
+ ++i)
+ reassoc.back().push_back(i);
+ if (!memref::CollapseShapeOp::isGuaranteedCollapsible(memTy, reassoc))
+ return failure();
+ Value collapsedMem = rewriter.create<memref::CollapseShapeOp>(
+ readOp.getLoc(), readOp.getBase(), reassoc);
+
+ // Get a vector type with collapsed trailing dimensions.
+ SmallVector<int64_t> shape(origVT.getShape());
+ for (int64_t i = origVRank - numCollapseDims + 1; i < origVRank; ++i)
+ shape[origVRank - numCollapseDims] *= shape[i];
+ shape.pop_back_n(numCollapseDims - 1);
+ auto collapsedVT =
+ VectorType::get(shape, origVT.getElementType(),
+ origScalableDims.drop_back(numCollapseDims - 1));
+
+ // Drop the extra (zero) indices.
+ auto indices = readOp.getIndices().drop_back(numCollapseDims - 1);
+
+ // Create the new `transfer_read`.
+ auto newReadOp = rewriter.create<vector::TransferReadOp>(
+ readOp.getLoc(), collapsedVT, collapsedMem, indices,
+ ArrayRef<bool>(origInBounds).drop_back(numCollapseDims - 1));
+
+ // Cast back to the orignal vector type.
+ auto toOrigShape = rewriter.create<vector::ShapeCastOp>(readOp.getLoc(),
+ origVT, newReadOp);
+
+ rewriter.replaceOp(readOp, toOrigShape);
+ return success();
+ }
+};
+
} // namespace
void mlir::arm_sve::populateLegalizeVectorStoragePatterns(
@@ -306,7 +413,8 @@ void mlir::arm_sve::populateLegalizeVectorStoragePatterns(
LegalizeSVEMaskAllocation<memref::AllocaOp>,
LegalizeSVEMaskAllocation<memref::AllocOp>,
LegalizeSVEMaskTypeCastConversion,
- LegalizeSVEMaskStoreConversion, LegalizeSVEMaskLoadConversion>(
+ LegalizeSVEMaskStoreConversion, LegalizeSVEMaskLoadConversion,
+ LegalizeTransferRead>(
patterns.getContext());
}
diff --git a/mlir/test/Dialect/ArmSVE/legalize-transfer-read.mlir b/mlir/test/Dialect/ArmSVE/legalize-transfer-read.mlir
new file mode 100644
index 0000000000000..d12a2c11bbdba
--- /dev/null
+++ b/mlir/test/Dialect/ArmSVE/legalize-transfer-read.mlir
@@ -0,0 +1,226 @@
+// RUN: mlir-opt --arm-sve-legalize-vector-storage --split-input-file %s | FileCheck %s
+
+// -----
+
+// CHECK-LABEL: @test_base_case
+// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]]:
+// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[M]]
+// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
+// CHECK-SAME: : memref<?x?x?x8xi8> into memref<?x?x?xi8>
+// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read %[[COLLAPSE]][%[[I]], %[[J]], %c0], %c0_i8 {in_bounds = [true]}
+// CHECK-SAME: : memref<?x?x?xi8>, vector<[32]xi8>
+// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] : vector<[32]xi8> to vector<[4]x8xi8>
+// CHECK-NEXT: return %[[T1]] : vector<[4]x8xi8>
+
+func.func @test_base_case(%i : index, %j : index, %M : memref<?x?x?x8xi8>) -> vector<[4]x8xi8> {
+ %c0 = arith.constant 0 : index
+ %c0_i8 = arith.constant 0 : i8
+
+ %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref<?x?x?x8xi8>, vector<[4]x8xi8>
+
+ return %A : vector<[4]x8xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @test_using_strided_layout
+// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]]
+// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[M]]
+// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
+// CHECK-SAME: : memref<?x?x?x8xi8, strided<[?, ?, 8, 1]>> into
+// CHECK-SAME: memref<?x?x?xi8, strided<[?, ?, 1]>>
+// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read %[[COLLAPSE]][%[[I]], %[[J]], %c0], %c0_i8 {in_bounds = [true]}
+// CHECK-SAME: : memref<?x?x?xi8, strided<[?, ?, 1]>>, vector<[32]xi8>
+// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] : vector<[32]xi8> to vector<[4]x8xi8>
+// CHECK-NEXT: return %[[T1]] : vector<[4]x8xi8>
+
+#s0 = strided<[?, ?, 8, 1]>
+
+func.func @test_using_strided_layout(%i : index, %j : index, %M : memref<?x?x?x8xi8, #s0>) -> vector<[4]x8xi8> {
+ %c0 = arith.constant 0 : index
+ %c0_i8 = arith.constant 0 : i8
+
+ %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref<?x?x?x8xi8, #s0>, vector<[4]x8xi8>
+
+ return %A : vector<[4]x8xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @test_3d_vector
+// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]]
+// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[M]]
+// CHECK-SAME{LITERAL}: [[0], [1, 2, 3]]
+// CHECK-SAME: : memref<?x?x2x8xi8, strided<[?, 16, 8, 1]>> into
+// CHECK-SAME: memref<?x?xi8, strided<[?, 1]>>
+// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read %[[COLLAPSED]][%[[I]], %[[J]]], %c0_i8 {in_bounds = [true]}
+// CHECK-SAME: : memref<?x?xi8, strided<[?, 1]>>, vector<[64]xi8>
+// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] : vector<[64]xi8> to vector<[4]x2x8xi8>
+// CHECK-NEXT: return %[[T1]] : vector<[4]x2x8xi8>
+
+#s1 = strided<[?, 16, 8, 1]>
+
+func.func @test_3d_vector(%i : index, %j : index, %M : memref<?x?x2x8xi8, #s1>) -> vector<[4]x2x8xi8> {
+ %c0 = arith.constant 0 : index
+ %c0_i8 = arith.constant 0 : i8
+
+ %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true, true]} : memref<?x?x2x8xi8, #s1>, vector<[4]x2x8xi8>
+
+ return %A : vector<[4]x2x8xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @test_4d_vector
+// CHECK-SAME: %[[I:arg0]]: index, %[[J:arg1]]: index, %[[M:arg2]]
+// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[M]]
+// CHECK-SAME{LITERAL}: [[0], [1, 2, 3]]
+// CHECK-SAME: : memref<?x?x2x8xi8, strided<[?, 16, 8, 1]>> into
+// CHECK-SAME: memref<?x?xi8, strided<[?, 1]>>
+// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read %[[COLLAPSED]][%[[I]], %[[J]]], %c0_i8 {in_bounds = [false, true]}
+// CHECK-SAME: : memref<?x?xi8, strided<[?, 1]>>, vector<2x[64]xi8>
+// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] : vector<2x[64]xi8> to vector<2x[4]x2x8xi8>
+// CHECK-NEXT: return %[[T1]] : vector<2x[4]x2x8xi8>
+
+#s2 = strided<[?, 16, 8, 1]>
+
+func.func @test_4d_vector(%i : index, %j : index, %M : memref<?x?x2x8xi8, #s2>) -> vector<2x[4]x2x8xi8> {
+ %c0 = arith.constant 0 : index
+ %c0_i8 = arith.constant 0 : i8
+
+ %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [false, true, true, true]} : memref<?x?x2x8xi8, #s2>, vector<2x[4]x2x8xi8>
+
+ return %A : vector<2x[4]x2x8xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @negative_test_vector_legal_non_scalable
+// CHECK-NOT: memref.collapse
+
+func.func @negative_test_vector_legal_non_scalable(%i : index, %j : index, %M : memref<?x?x?x8xi8>) -> vector<8x8xi8> {
+ %c0 = arith.constant 0 : index
+ %c0_i8 = arith.constant 0 : i8
+
+ %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref<?x?x?x8xi8>, vector<8x8xi8>
+
+ return %A : vector<8x8xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @negative_test_vector_legal_scalable_0
+// CHECK-NOT: memref.collapse
+
+func.func @negative_test_vector_legal_scalable_0(%i : index, %j : index, %M : memref<?x?x?x8xi8>) -> vector<[8]xi8> {
+ %c0 = arith.constant 0 : index
+ %c0_i8 = arith.constant 0 : i8
+
+ %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true]} : memref<?x?x?x8xi8>, vector<[8]xi8>
+
+ return %A : vector<[8]xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @negative_test_vector_legal_scalable_1
+// CHECK-NOT: memref.collapse
+
+func.func @negative_test_vector_legal_scalable_1(%i : index, %j : index, %M : memref<?x?x?x8xi8>) -> vector<8x[8]xi8> {
+ %c0 = arith.constant 0 : index
+ %c0_i8 = arith.constant 0 : i8
+
+ %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref<?x?x?x8xi8>, vector<8x[8]xi8>
+
+ return %A : vector<8x[8]xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @negative_test_vector_type_not_supported
+// CHECK-NOT: memref.collapse
+
+func.func @negative_test_vector_type_not_supported(%i : index, %j : index, %M : memref<?x?x?x8xi8>) -> vector<[8]x[8]x8xi8> {
+ %c0 = arith.constant 0 : index
+ %c0_i8 = arith.constant 0 : i8
+
+ %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true, true]} : memref<?x?x?x8xi8>, vector<[8]x[8]x8xi8>
+
+ return %A : vector<[8]x[8]x8xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @negative_test_non_mem
+// CHECK-NOT: memref.collapse
+
+func.func @negative_test_non_mem(%i : index, %j : index, %M : tensor<?x?x?x8xi8>) -> vector<[4]x8xi8> {
+ %c0 = arith.constant 0 : index
+ %c0_i8 = arith.constant 0 : i8
+
+ %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : tensor<?x?x?x8xi8>, vector<[4]x8xi8>
+
+ return %A : vector<[4]x8xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @negative_test_discontig_mem_0
+// CHECK-NOT: memref.collapse
+
+#s3 = strided<[?, ?, 16, 1]>
+
+func.func @negative_test_discontig_mem_0(%i : index, %j : index, %M : memref<?x?x?x8xi8, #s3>) -> vector<[4]x8xi8> {
+ %c0 = arith.constant 0 : index
+ %c0_i8 = arith.constant 0 : i8
+
+ %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref<?x?x?x8xi8, #s3>, vector<[4]x8xi8>
+
+ return %A : vector<[4]x8xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @negative_test_discontig_mem_1
+// CHECK-NOT: memref.collapse
+
+#layout = affine_map<(i, j, k, p) -> (j, i, k, p)>
+
+func.func @negative_test_discontig_mem_1(%i : index, %j : index, %M : memref<?x?x?x8xi8, #layout>) -> vector<[4]x8xi8> {
+ %c0 = arith.constant 0 : index
+ %c0_i8 = arith.constant 0 : i8
+
+ %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref<?x?x?x8xi8, #layout>, vector<[4]x8xi8>
+
+ return %A : vector<[4]x8xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @negative_test_discontig_read_strided_vec
+// CHECK-NOT: memref.collapse
+
+func.func @negative_test_discontig_read_strided_vec(%i : index, %j : index, %M : memref<?x?x?x8xi8>) -> vector<[4]x4xi8> {
+ %c0 = arith.constant 0 : index
+ %c0_i8 = arith.constant 0 : i8
+
+ %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref<?x?x?x8xi8>, vector<[4]x4xi8>
+
+ return %A : vector<[4]x4xi8>
+}
+
+// -----
+
+// CHECK-LABEL: @negative_test_bcast_transp
+// CHECK-NOT: memref.collapse
+
+#perm = affine_map<(i, j, k, p) -> (k, 0)>
+
+func.func @negative_test_bcast_transp(%i : index, %j : index, %M : memref<?x?x?x8xi8>) -> vector<[4]x8xi8> {
+ %c0 = arith.constant 0 : index
+ %c0_i8 = arith.constant 0 : i8
+
+ %A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {permutation_map = #perm, in_bounds = [true, true] } : memref<?x?x?x8xi8>, vector<[4]x8xi8>
+
+ return %A : vector<[4]x8xi8>
+}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/transfer-read-scalable-not-rightmost.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/transfer-read-scalable-not-rightmost.mlir
new file mode 100644
index 0000000000000..7f68d8f7ab848
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/transfer-read-scalable-not-rightmost.mlir
@@ -0,0 +1,72 @@
+// REQUIRES: arm-emulator
+
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE: --arm-sve-legalize-vector-storage --convert-vector-to-scf --convert-scf-to-cf --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' \
+// DEFINE: --expand-strided-metadata --lower-affine --convert-to-llvm --finalize-memref-to-llvm --reconcile-unrealized-casts \
+// DEFINE: -o %t
+
+// DEFINE: %{entry_point} = main
+
+// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void --march=aarch64 --mattr="+sve,+i8mm" \
+// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%native_mlir_arm_runner_utils
+
+// RUN: rm -f %t && %{compile} && %{run} | FileCheck %s
+
+func.func private @setArmVLBits(%bits : i32)
+
+func.func private @printVec(%v : vector<[32]xi8>) {
+ %v0 = vector.scalable.extract %v[0] : vector<[16]xi8> from vector<[32]xi8>
+ %v1 = vector.scalable.extract %v[16] : vector<[16]xi8> from vector<[32]xi8>
+ vector.print %v0 : vector<[16]xi8>
+ vector.print %v1 : vector<[16]xi8>
+ return
+}
+
+func.func @transfer_read_scalable_not_rightmost(%vs : i32, %M : memref<?x?x?x8xi8>) {
+ func.call @setArmVLBits(%vs) : (i32) -> ()
+
+ %c0 = arith.constant 0 : index
+ %c0_i8 = arith.constant 0 : i8
+ %A = vector.transfer_read %M[%c0, %c0, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref<?x?x?x8xi8>, vector<[4]x8xi8>
+
+ %B = vector.shape_cast %A : vector<[4]x8xi8> to vector<[32]xi8>
+ func.call @printVec(%B) : (vector<[32]xi8>) -> ()
+
+ return
+}
+
+func.func @main() {
+
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c0_i32 = arith.constant 0 : i32
+ %c0_i8 = arith.constant 0 : i8
+
+ %A0_cst = arith.constant dense<[[11, 12, 13, 14, 15, 16, 17, 18],
+ [21, 22, 23, 24, 25, 26, 27, 28],
+ [31, 32, 33, 34, 35, 36, 37, 38],
+ [41, 42, 43, 44, 45, 46, 47, 48]]> : vector<4x8xi8>
+
+ %A1_cst = arith.constant dense<[[51, 52, 53, 54, 55, 56, 57, 58],
+ [61, 62, 63, 64, 65, 66, 67, 68],
+ [71, 72, 73, 74, 75, 76, 77, 78],
+ [81, 82, 83, 84, 85, 86, 87, 88]]> : vector<4x8xi8>
+
+ %M = memref.alloca() : memref<1x2x4x8xi8>
+ vector.transfer_write %A0_cst, %M[%c0, %c0, %c0, %c0] : vector<4x8xi8>, memref<1x2x4x8xi8>
+ vector.transfer_write %A1_cst, %M[%c0, %c1, %c0, %c0] : vector<4x8xi8>, memref<1x2x4x8xi8>
+
+ %MM = memref.cast %M : memref<1x2x4x8xi8> to memref<?x?x?x8xi8>
+
+// CHECK:( 11, 12, 13, 14, 15, 16, 17, 18, 21, 22, 23, 24, 25, 26, 27, 28 )
+// CHECK:( 31, 32, 33, 34, 35, 36, 37, 38, 41, 42, 43, 44, 45, 46, 47, 48 )
+ %c128 = arith.constant 128 : i32
+ func.call @transfer_read_scalable_not_rightmost(%c128, %MM) : (i32, memref<?x?x?x8xi8>) -> ()
+
+// CHECK: ( 11, 12, 13, 14, 15, 16, 17, 18, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 41, 42, 43, 44, 45, 46, 47, 48 )
+// CHECK: ( 51, 52, 53, 54, 55, 56, 57, 58, 61, 62, 63, 64, 65, 66, 67, 68, 71, 72, 73, 74, 75, 76, 77, 78, 81, 82, 83, 84, 85, 86, 87, 88 )
+ %c256 = arith.constant 256 : i32
+ func.call @transfer_read_scalable_not_rightmost(%c256, %MM) : (i32, memref<?x?x?x8xi8>) -> ()
+
+ return
+}
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
f60e73d
to
1210d59
Compare
4d13aa2
to
413d9dc
Compare
1210d59
to
3b17c94
Compare
413d9dc
to
5496f97
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work, Momchil - thank you!
I've left a number of comments, but nothing major. My main high-level suggestion is to follow the guidance in MLIR's Testing Guide a bit more closely. It’s a relatively new (and long!) document, so I’ve included specific in-line suggestions to make it easier to see where things could align better.
For additional context, this RFC provides some of the rationale behind that approach.
Also - what about memrefs with dynamic dimensions?
VectorType origVT = readOp.getVectorType(); | ||
ArrayRef<bool> origScalableDims = origVT.getScalableDims(); | ||
const int64_t origVRank = origVT.getRank(); | ||
if (origVRank < 2 || llvm::count(origScalableDims, true) != 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nit] getNumScalableDims would be more canonical then llvm::count
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
if (!readOp.getPermutationMap().isMinorIdentity()) | ||
return rewriter.notifyMatchFailure(readOp, "non-identity permutation"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would supporting non-identity be a problem? It would be good to add a comment, either:
TODO: We haven't required this, so leaving for later.
or- "Too complex because of , disabling".
Any hint for future developers would be helpful.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
// We handle transfers of vectors with rank >= 2 and a single scalable | ||
// dimension. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nit] It would be helpful to add why:
- Don't need to worry about 1D, that's supported by default.
- More than 1 scalable dims are tricky (how to collapse e.g.
vscale * vscale
?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comment added.
// The collapsed dimensions (excluding the scalable one) of the vector and | ||
// the memref must match and the corresponding indices must be in-bounds (it | ||
// follows these indices would be zero). This guarantees that the operation | ||
// transfers a contiguous block. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// The collapsed dimensions (excluding the scalable one) of the vector and
// the memref must match
What about dynamic dim sizes in the memref? If that's not supported, is there a test?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This part wasn't tested at all. Test cases added.
ArrayRef<bool>(origInBounds).take_back(numCollapseDims - 1), | ||
[](bool v) { return v; })) | ||
return rewriter.notifyMatchFailure(readOp, | ||
"out-if-bounds index to collapse"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note, it's not really index that's out-of-bounds, but the corresponding memory access. So, index could be in-bounds, but we might be reading "more" than there's available to read (starting at that index). For example:
vector.transfer_read %mem[5] : memref<7xi8>, vector<7xi8>
"out-if-bounds index to collapse"); | |
"out-of-bounds index to collapse"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
|
||
#s3 = strided<[?, ?, 16, 1]> | ||
|
||
func.func @negative_test_discontig_mem_0(%i : index, %j : index, %M : memref<?x?x?x8xi8, #s3>) -> vector<[4]x8xi8> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nit] Avoid "magic" suffixes likes _0
.
func.func @negative_test_discontig_mem_0(%i : index, %j : index, %M : memref<?x?x?x8xi8, #s3>) -> vector<[4]x8xi8> { | |
func.func @negative_test_discont_mem_due_to_strides(%i : index, %j : index, %M : memref<?x?x?x8xi8, #s3>) -> vector<[4]x8xi8> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
|
||
#layout = affine_map<(i, j, k, p) -> (j, i, k, p)> | ||
|
||
func.func @negative_test_discontig_mem_1(%i : index, %j : index, %M : memref<?x?x?x8xi8, #layout>) -> vector<[4]x8xi8> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nit] Same as above.
func.func @negative_test_discontig_mem_1(%i : index, %j : index, %M : memref<?x?x?x8xi8, #layout>) -> vector<[4]x8xi8> { | |
func.func @negative_test_discontig_mem_due_to_maps(%i : index, %j : index, %M : memref<?x?x?x8xi8, #layout>) -> vector<[4]x8xi8> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Test removed, no need to test here all the possible ways a memref could be discontinuous.
func.func @negative_test_discontig_read_strided_vec(%i : index, %j : index, %M : memref<?x?x?x8xi8>) -> vector<[4]x4xi8> { | ||
%c0 = arith.constant 0 : index | ||
%c0_i8 = arith.constant 0 : i8 | ||
|
||
%A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref<?x?x?x8xi8>, vector<[4]x4xi8> | ||
|
||
return %A : vector<[4]x4xi8> | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What makes this a negative test? It says "strided vec", but I'm not sure what you mean?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's garbage, deleted.
func.func @negative_test_vector_mask( | ||
%i : index, %j : index, | ||
%M : memref<?x?x?x8xi8>, %mask : vector<[4]x8xi1>) -> vector<[4]x8xi8> { | ||
|
||
%c0 = arith.constant 0 : index | ||
%c0_i8 = arith.constant 0 : i8 | ||
|
||
%A = vector.mask %mask { | ||
vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 {in_bounds = [true, true] } : memref<?x?x?x8xi8>, vector<[4]x8xi8> | ||
} : vector<[4]x8xi1> -> vector<[4]x8xi8> | ||
|
||
return %A : vector<[4]x8xi8> | ||
} | ||
|
||
// ----- | ||
|
||
// CHECK-LABEL: @negative_test_mask_operand | ||
// CHECK-NOT: memref.collapse | ||
|
||
func.func @negative_test_mask_operand( | ||
%i : index, %j : index, | ||
%M : memref<?x?x?x8xi8>, %mask : vector<[4]x8xi1>) -> vector<[4]x8xi8> { | ||
|
||
%c0 = arith.constant 0 : index | ||
%c0_i8 = arith.constant 0 : i8 | ||
|
||
%A = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8, %mask {in_bounds = [true, true] } : memref<?x?x?x8xi8>, vector<[4]x8xi8> | ||
|
||
return %A : vector<[4]x8xi8> | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the past, I would differentiate these are:
- "masked" (
vector.mask {vector. transfer_read}
), vs - "with_mask" (
vector.transfer_read %mask
)
Would you mind following similar convention?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is mixing fixed-width and scalable vectors. Lets avoid that until we understand better how to mix VLA + VLS programming.
THis patch add a transform of `transfer_read` operation to change the vector type to one that can be mapped to an LLVM type. This is done by collapsing trailing dimensions so we obtain a vector type with a single scalable dimension in the rightmost position.
5496f97
to
e422213
Compare
THis patch add a transform of
transfer_read
operation to change the vector type to one that can be mapped to an LLVM type. This is done by collapsing trailing dimensions so we obtain a vector type with a single scalable dimension in the rightmost position.