Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][Vector] Remove Vector{Load|Store}ToMemrefLoadLowering #121454

Merged
merged 1 commit into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 0 additions & 57 deletions mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -492,60 +492,6 @@ struct TransferReadToVectorLoadLowering
std::optional<unsigned> maxTransferRank;
};

/// Replace a 0-d vector.load with a memref.load + vector.broadcast.
// TODO: we shouldn't cross the vector/scalar domains just for this
// but atm we lack the infra to avoid it. Possible solutions include:
// - go directly to LLVM + bitcast
// - introduce a bitcast op and likely a new pointer dialect
// - let memref.load/store additionally support the 0-d vector case
// There are still deeper data layout issues lingering even in this
// trivial case (for architectures for which this matters).
struct VectorLoadToMemrefLoadLowering
: public OpRewritePattern<vector::LoadOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(vector::LoadOp loadOp,
PatternRewriter &rewriter) const override {
auto vecType = loadOp.getVectorType();
if (vecType.getNumElements() != 1)
return rewriter.notifyMatchFailure(loadOp, "not a single element vector");

auto memrefLoad = rewriter.create<memref::LoadOp>(
loadOp.getLoc(), loadOp.getBase(), loadOp.getIndices());
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(loadOp, vecType,
memrefLoad);
return success();
}
};

/// Replace a 0-d vector.store with a vector.extractelement + memref.store.
struct VectorStoreToMemrefStoreLowering
: public OpRewritePattern<vector::StoreOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(vector::StoreOp storeOp,
PatternRewriter &rewriter) const override {
auto vecType = storeOp.getVectorType();
if (vecType.getNumElements() != 1)
return rewriter.notifyMatchFailure(storeOp, "not single element vector");

Value extracted;
if (vecType.getRank() == 0) {
// TODO: Unifiy once ExtractOp supports 0-d vectors.
extracted = rewriter.create<vector::ExtractElementOp>(
storeOp.getLoc(), storeOp.getValueToStore());
} else {
SmallVector<int64_t> indices(vecType.getRank(), 0);
extracted = rewriter.create<vector::ExtractOp>(
storeOp.getLoc(), storeOp.getValueToStore(), indices);
}

rewriter.replaceOpWithNewOp<memref::StoreOp>(
storeOp, extracted, storeOp.getBase(), storeOp.getIndices());
return success();
}
};

/// Progressive lowering of transfer_write. This pattern supports lowering of
/// `vector.transfer_write` to `vector.store` if all of the following hold:
/// - Stride of most minor memref dimension must be 1.
Expand Down Expand Up @@ -645,7 +591,4 @@ void mlir::vector::populateVectorTransferLoweringPatterns(
patterns.add<TransferReadToVectorLoadLowering,
TransferWriteToVectorStoreLowering>(patterns.getContext(),
maxTransferRank, benefit);
patterns
.add<VectorLoadToMemrefLoadLowering, VectorStoreToMemrefStoreLowering>(
patterns.getContext(), benefit);
}
14 changes: 8 additions & 6 deletions mlir/test/Conversion/GPUCommon/transfer_write.mlir
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
// RUN: mlir-opt %s --gpu-to-llvm | FileCheck %s

func.func @warp_extract(%arg0: index, %arg1: memref<1024x1024xf32>, %arg2: index, %arg3: vector<1xf32>) {
// CHECK-LABEL: @warp_extract
// CHECK-SAME: %[[VEC:[a-zA-Z0-9_]+]]: vector<1xf32>
// CHECK:%[[BASE:[0-9]+]] = llvm.extractvalue
// CHECK:%[[PTR:[0-9]+]] = llvm.getelementptr %[[BASE]]
// CHECK:llvm.store %[[VEC]], %[[PTR]] {alignment = 4 : i64} : vector<1xf32>, !llvm.ptr

func.func @warp_extract(%arg0: index, %arg1: memref<1024x1024xf32>, %arg2: vector<1xf32>) {
%c0 = arith.constant 0 : index
gpu.warp_execute_on_lane_0(%arg0)[32] {
// CHECK:%[[val:[0-9]+]] = llvm.extractelement
// CHECK:%[[base:[0-9]+]] = llvm.extractvalue
// CHECK:%[[ptr:[0-9]+]] = llvm.getelementptr %[[base]]
// CHECK:llvm.store %[[val]], %[[ptr]]
vector.transfer_write %arg3, %arg1[%c0, %c0] {in_bounds = [true]} : vector<1xf32>, memref<1024x1024xf32>
vector.transfer_write %arg2, %arg1[%c0, %c0] {in_bounds = [true]} : vector<1xf32>, memref<1024x1024xf32>
}
return
}
35 changes: 23 additions & 12 deletions mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3282,13 +3282,17 @@ func.func @load_0d(%memref : memref<200x100xf32>, %i : index, %j : index) -> vec
}

// CHECK-LABEL: func @load_0d
// CHECK: %[[LOAD:.*]] = memref.load %{{.*}}[%{{.*}}, %{{.*}}]
// CHECK: %[[VEC:.*]] = llvm.mlir.undef : vector<1xf32>
// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: %[[INSERTED:.*]] = llvm.insertelement %[[LOAD]], %[[VEC]][%[[C0]] : i32] : vector<1xf32>
// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[INSERTED]] : vector<1xf32> to vector<f32>
// CHECK: return %[[CAST]] : vector<f32>

// CHECK: %[[J:.*]] = builtin.unrealized_conversion_cast %{{.*}} : index to i64
// CHECK: %[[I:.*]] = builtin.unrealized_conversion_cast %{{.*}} : index to i64
// CHECK: %[[CAST_MEMREF:.*]] = builtin.unrealized_conversion_cast %{{.*}} : memref<200x100xf32> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[REF:.*]] = llvm.extractvalue %[[CAST_MEMREF]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[C100:.*]] = llvm.mlir.constant(100 : index) : i64
// CHECK: %[[MUL:.*]] = llvm.mul %[[I]], %[[C100]] : i64
// CHECK: %[[ADD:.*]] = llvm.add %[[MUL]], %[[J]] : i64
// CHECK: %[[ADDR:.*]] = llvm.getelementptr %[[REF]][%[[ADD]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK: %[[LOAD:.*]] = llvm.load %[[ADDR]] {alignment = 4 : i64} : !llvm.ptr -> vector<1xf32>
// CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[LOAD]] : vector<1xf32> to vector<f32>
// CHECK: return %[[RES]] : vector<f32>
// -----

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -3382,11 +3386,18 @@ func.func @store_0d(%memref : memref<200x100xf32>, %i : index, %j : index) {
}

// CHECK-LABEL: func @store_0d
// CHECK: %[[VAL:.*]] = arith.constant dense<1.100000e+01> : vector<f32>
// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[VAL]] : vector<f32> to vector<1xf32>
// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64
// CHECK: %[[EXTRACTED:.*]] = llvm.extractelement %[[CAST]][%[[C0]] : i64] : vector<1xf32>
// CHECK: memref.store %[[EXTRACTED]], %{{.*}}[%{{.*}}, %{{.*}}]
// CHECK: %[[J:.*]] = builtin.unrealized_conversion_cast %{{.*}} : index to i64
// CHECK: %[[I:.*]] = builtin.unrealized_conversion_cast %{{.*}} : index to i64
// CHECK: %[[CAST_MEMREF:.*]] = builtin.unrealized_conversion_cast %{{.*}} : memref<200x100xf32> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[CST:.*]] = arith.constant dense<1.100000e+01> : vector<f32>
// CHECK: %[[VAL:.*]] = builtin.unrealized_conversion_cast %[[CST]] : vector<f32> to vector<1xf32>
// CHECK: %[[REF:.*]] = llvm.extractvalue %[[CAST_MEMREF]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[C100:.*]] = llvm.mlir.constant(100 : index) : i64
// CHECK: %[[MUL:.*]] = llvm.mul %[[I]], %[[C100]] : i64
// CHECK: %[[ADD:.*]] = llvm.add %[[MUL]], %[[J]] : i64
// CHECK: %[[ADDR:.*]] = llvm.getelementptr %[[REF]][%[[ADD]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
// CHECK: llvm.store %[[VAL]], %[[ADDR]] {alignment = 4 : i64} : vector<1xf32>, !llvm.ptr
// CHECK: return

// -----

Expand Down
24 changes: 10 additions & 14 deletions mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,13 @@
func.func @vector_transfer_ops_0d_memref(%mem: memref<f32>, %vec: vector<1x1x1xf32>) {
%f0 = arith.constant 0.0 : f32

// CHECK-NEXT: %[[S:.*]] = memref.load %[[MEM]][] : memref<f32>
// CHECK-NEXT: %[[V:.*]] = vector.broadcast %[[S]] : f32 to vector<f32>
// CHECK-NEXT: %[[S:.*]] = vector.load %[[MEM]][] : memref<f32>, vector<f32>
%0 = vector.transfer_read %mem[], %f0 : memref<f32>, vector<f32>

// CHECK-NEXT: %[[SS:.*]] = vector.extractelement %[[V]][] : vector<f32>
// CHECK-NEXT: memref.store %[[SS]], %[[MEM]][] : memref<f32>
// CHECK-NEXT: vector.store %[[S]], %[[MEM]][] : memref<f32>, vector<f32>
vector.transfer_write %0, %mem[] : vector<f32>, memref<f32>

// CHECK-NEXT: %[[VV:.*]] = vector.extract %arg1[0, 0, 0] : f32 from vector<1x1x1xf32>
// CHECK-NEXT: memref.store %[[VV]], %[[MEM]][] : memref<f32>
// CHECK-NEXT: vector.store %[[VEC]], %[[MEM]][] : memref<f32>, vector<1x1x1xf32>
vector.store %vec, %mem[] : memref<f32>, vector<1x1x1xf32>

return
Expand Down Expand Up @@ -191,8 +188,8 @@ func.func @transfer_perm_map(%mem : memref<8x8xf32>, %idx : index) -> vector<4xf
// CHECK-LABEL: func @transfer_broadcasting(
// CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32>,
// CHECK-SAME: %[[IDX:.*]]: index) -> vector<4xf32> {
// CHECK-NEXT: %[[LOAD:.*]] = memref.load %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xf32>
// CHECK-NEXT: %[[RES:.*]] = vector.broadcast %[[LOAD]] : f32 to vector<4xf32>
// CHECK-NEXT: %[[LOAD:.*]] = vector.load %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xf32>, vector<1xf32>
// CHECK-NEXT: %[[RES:.*]] = vector.broadcast %[[LOAD]] : vector<1xf32> to vector<4xf32>
// CHECK-NEXT: return %[[RES]] : vector<4xf32>
// CHECK-NEXT: }

Expand All @@ -208,8 +205,7 @@ func.func @transfer_broadcasting(%mem : memref<8x8xf32>, %idx : index) -> vector
// CHECK-LABEL: func @transfer_scalar(
// CHECK-SAME: %[[MEM:.*]]: memref<?x?xf32>,
// CHECK-SAME: %[[IDX:.*]]: index) -> vector<1xf32> {
// CHECK-NEXT: %[[LOAD:.*]] = memref.load %[[MEM]][%[[IDX]], %[[IDX]]] : memref<?x?xf32>
// CHECK-NEXT: %[[RES:.*]] = vector.broadcast %[[LOAD]] : f32 to vector<1xf32>
// CHECK-NEXT: %[[RES:.*]] = vector.load %[[MEM]][%[[IDX]], %[[IDX]]] : memref<?x?xf32>, vector<1xf32>
// CHECK-NEXT: return %[[RES]] : vector<1xf32>
// CHECK-NEXT: }
func.func @transfer_scalar(%mem : memref<?x?xf32>, %idx : index) -> vector<1xf32> {
Expand All @@ -222,8 +218,8 @@ func.func @transfer_scalar(%mem : memref<?x?xf32>, %idx : index) -> vector<1xf32
// CHECK-LABEL: func @transfer_broadcasting_2D(
// CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32>,
// CHECK-SAME: %[[IDX:.*]]: index) -> vector<4x4xf32> {
// CHECK-NEXT: %[[LOAD:.*]] = memref.load %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xf32>
// CHECK-NEXT: %[[RES:.*]] = vector.broadcast %[[LOAD]] : f32 to vector<4x4xf32>
// CHECK-NEXT: %[[LOAD:.*]] = vector.load %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xf32>, vector<1x1xf32>
// CHECK-NEXT: %[[RES:.*]] = vector.broadcast %[[LOAD]] : vector<1x1xf32> to vector<4x4xf32>
// CHECK-NEXT: return %[[RES]] : vector<4x4xf32>
// CHECK-NEXT: }

Expand Down Expand Up @@ -322,8 +318,8 @@ func.func @transfer_read_permutations(%mem_0 : memref<?x?xf32>, %mem_1 : memref<
// CHECK: vector.transpose %{{.*}}, [2, 1, 3, 0] : vector<16x14x7x8xf32> to vector<7x14x8x16xf32>

%6 = vector.transfer_read %mem_0[%c0, %c0], %cst {in_bounds = [true], permutation_map = #map6} : memref<?x?xf32>, vector<8xf32>
// CHECK: memref.load %{{.*}}[%[[C0]], %[[C0]]] : memref<?x?xf32>
// CHECK: vector.broadcast %{{.*}} : f32 to vector<8xf32>
// CHECK: vector.load %{{.*}}[%[[C0]], %[[C0]]] : memref<?x?xf32>, vector<1xf32>
// CHECK: vector.broadcast %{{.*}} : vector<1xf32> to vector<8xf32>

return %0, %1, %2, %3, %4, %5, %6 : vector<7x14x8x16xf32>, vector<7x14x8x16xf32>,
vector<7x14x8x16xf32>, vector<7x14x8x16xf32>, vector<7x14x8x16xf32>,
Expand Down
Loading