diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp index f9428a4ce28640..156e33257d8716 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp @@ -492,60 +492,6 @@ struct TransferReadToVectorLoadLowering std::optional maxTransferRank; }; -/// Replace a 0-d vector.load with a memref.load + vector.broadcast. -// TODO: we shouldn't cross the vector/scalar domains just for this -// but atm we lack the infra to avoid it. Possible solutions include: -// - go directly to LLVM + bitcast -// - introduce a bitcast op and likely a new pointer dialect -// - let memref.load/store additionally support the 0-d vector case -// There are still deeper data layout issues lingering even in this -// trivial case (for architectures for which this matters). -struct VectorLoadToMemrefLoadLowering - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(vector::LoadOp loadOp, - PatternRewriter &rewriter) const override { - auto vecType = loadOp.getVectorType(); - if (vecType.getNumElements() != 1) - return rewriter.notifyMatchFailure(loadOp, "not a single element vector"); - - auto memrefLoad = rewriter.create( - loadOp.getLoc(), loadOp.getBase(), loadOp.getIndices()); - rewriter.replaceOpWithNewOp(loadOp, vecType, - memrefLoad); - return success(); - } -}; - -/// Replace a 0-d vector.store with a vector.extractelement + memref.store. -struct VectorStoreToMemrefStoreLowering - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(vector::StoreOp storeOp, - PatternRewriter &rewriter) const override { - auto vecType = storeOp.getVectorType(); - if (vecType.getNumElements() != 1) - return rewriter.notifyMatchFailure(storeOp, "not single element vector"); - - Value extracted; - if (vecType.getRank() == 0) { - // TODO: Unifiy once ExtractOp supports 0-d vectors. - extracted = rewriter.create( - storeOp.getLoc(), storeOp.getValueToStore()); - } else { - SmallVector indices(vecType.getRank(), 0); - extracted = rewriter.create( - storeOp.getLoc(), storeOp.getValueToStore(), indices); - } - - rewriter.replaceOpWithNewOp( - storeOp, extracted, storeOp.getBase(), storeOp.getIndices()); - return success(); - } -}; - /// Progressive lowering of transfer_write. This pattern supports lowering of /// `vector.transfer_write` to `vector.store` if all of the following hold: /// - Stride of most minor memref dimension must be 1. @@ -645,7 +591,4 @@ void mlir::vector::populateVectorTransferLoweringPatterns( patterns.add(patterns.getContext(), maxTransferRank, benefit); - patterns - .add( - patterns.getContext(), benefit); } diff --git a/mlir/test/Conversion/GPUCommon/transfer_write.mlir b/mlir/test/Conversion/GPUCommon/transfer_write.mlir index 2242786fe67595..4d2ae8c39240c5 100644 --- a/mlir/test/Conversion/GPUCommon/transfer_write.mlir +++ b/mlir/test/Conversion/GPUCommon/transfer_write.mlir @@ -1,13 +1,15 @@ // RUN: mlir-opt %s --gpu-to-llvm | FileCheck %s - func.func @warp_extract(%arg0: index, %arg1: memref<1024x1024xf32>, %arg2: index, %arg3: vector<1xf32>) { +// CHECK-LABEL: @warp_extract +// CHECK-SAME: %[[VEC:[a-zA-Z0-9_]+]]: vector<1xf32> +// CHECK:%[[BASE:[0-9]+]] = llvm.extractvalue +// CHECK:%[[PTR:[0-9]+]] = llvm.getelementptr %[[BASE]] +// CHECK:llvm.store %[[VEC]], %[[PTR]] {alignment = 4 : i64} : vector<1xf32>, !llvm.ptr + +func.func @warp_extract(%arg0: index, %arg1: memref<1024x1024xf32>, %arg2: vector<1xf32>) { %c0 = arith.constant 0 : index gpu.warp_execute_on_lane_0(%arg0)[32] { - // CHECK:%[[val:[0-9]+]] = llvm.extractelement - // CHECK:%[[base:[0-9]+]] = llvm.extractvalue - // CHECK:%[[ptr:[0-9]+]] = llvm.getelementptr %[[base]] - // CHECK:llvm.store %[[val]], %[[ptr]] - vector.transfer_write %arg3, %arg1[%c0, %c0] {in_bounds = [true]} : vector<1xf32>, memref<1024x1024xf32> + vector.transfer_write %arg2, %arg1[%c0, %c0] {in_bounds = [true]} : vector<1xf32>, memref<1024x1024xf32> } return } diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir index f95e943250bd44..218431d76e96df 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -3282,13 +3282,17 @@ func.func @load_0d(%memref : memref<200x100xf32>, %i : index, %j : index) -> vec } // CHECK-LABEL: func @load_0d -// CHECK: %[[LOAD:.*]] = memref.load %{{.*}}[%{{.*}}, %{{.*}}] -// CHECK: %[[VEC:.*]] = llvm.mlir.undef : vector<1xf32> -// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32 -// CHECK: %[[INSERTED:.*]] = llvm.insertelement %[[LOAD]], %[[VEC]][%[[C0]] : i32] : vector<1xf32> -// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[INSERTED]] : vector<1xf32> to vector -// CHECK: return %[[CAST]] : vector - +// CHECK: %[[J:.*]] = builtin.unrealized_conversion_cast %{{.*}} : index to i64 +// CHECK: %[[I:.*]] = builtin.unrealized_conversion_cast %{{.*}} : index to i64 +// CHECK: %[[CAST_MEMREF:.*]] = builtin.unrealized_conversion_cast %{{.*}} : memref<200x100xf32> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[REF:.*]] = llvm.extractvalue %[[CAST_MEMREF]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[C100:.*]] = llvm.mlir.constant(100 : index) : i64 +// CHECK: %[[MUL:.*]] = llvm.mul %[[I]], %[[C100]] : i64 +// CHECK: %[[ADD:.*]] = llvm.add %[[MUL]], %[[J]] : i64 +// CHECK: %[[ADDR:.*]] = llvm.getelementptr %[[REF]][%[[ADD]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32 +// CHECK: %[[LOAD:.*]] = llvm.load %[[ADDR]] {alignment = 4 : i64} : !llvm.ptr -> vector<1xf32> +// CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[LOAD]] : vector<1xf32> to vector +// CHECK: return %[[RES]] : vector // ----- //===----------------------------------------------------------------------===// @@ -3382,11 +3386,18 @@ func.func @store_0d(%memref : memref<200x100xf32>, %i : index, %j : index) { } // CHECK-LABEL: func @store_0d -// CHECK: %[[VAL:.*]] = arith.constant dense<1.100000e+01> : vector -// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[VAL]] : vector to vector<1xf32> -// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64 -// CHECK: %[[EXTRACTED:.*]] = llvm.extractelement %[[CAST]][%[[C0]] : i64] : vector<1xf32> -// CHECK: memref.store %[[EXTRACTED]], %{{.*}}[%{{.*}}, %{{.*}}] +// CHECK: %[[J:.*]] = builtin.unrealized_conversion_cast %{{.*}} : index to i64 +// CHECK: %[[I:.*]] = builtin.unrealized_conversion_cast %{{.*}} : index to i64 +// CHECK: %[[CAST_MEMREF:.*]] = builtin.unrealized_conversion_cast %{{.*}} : memref<200x100xf32> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[CST:.*]] = arith.constant dense<1.100000e+01> : vector +// CHECK: %[[VAL:.*]] = builtin.unrealized_conversion_cast %[[CST]] : vector to vector<1xf32> +// CHECK: %[[REF:.*]] = llvm.extractvalue %[[CAST_MEMREF]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[C100:.*]] = llvm.mlir.constant(100 : index) : i64 +// CHECK: %[[MUL:.*]] = llvm.mul %[[I]], %[[C100]] : i64 +// CHECK: %[[ADD:.*]] = llvm.add %[[MUL]], %[[J]] : i64 +// CHECK: %[[ADDR:.*]] = llvm.getelementptr %[[REF]][%[[ADD]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32 +// CHECK: llvm.store %[[VAL]], %[[ADDR]] {alignment = 4 : i64} : vector<1xf32>, !llvm.ptr +// CHECK: return // ----- diff --git a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir index f90111b4c88618..fd50acf03e79b1 100644 --- a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir @@ -6,16 +6,13 @@ func.func @vector_transfer_ops_0d_memref(%mem: memref, %vec: vector<1x1x1xf32>) { %f0 = arith.constant 0.0 : f32 -// CHECK-NEXT: %[[S:.*]] = memref.load %[[MEM]][] : memref -// CHECK-NEXT: %[[V:.*]] = vector.broadcast %[[S]] : f32 to vector +// CHECK-NEXT: %[[S:.*]] = vector.load %[[MEM]][] : memref, vector %0 = vector.transfer_read %mem[], %f0 : memref, vector -// CHECK-NEXT: %[[SS:.*]] = vector.extractelement %[[V]][] : vector -// CHECK-NEXT: memref.store %[[SS]], %[[MEM]][] : memref +// CHECK-NEXT: vector.store %[[S]], %[[MEM]][] : memref, vector vector.transfer_write %0, %mem[] : vector, memref -// CHECK-NEXT: %[[VV:.*]] = vector.extract %arg1[0, 0, 0] : f32 from vector<1x1x1xf32> -// CHECK-NEXT: memref.store %[[VV]], %[[MEM]][] : memref +// CHECK-NEXT: vector.store %[[VEC]], %[[MEM]][] : memref, vector<1x1x1xf32> vector.store %vec, %mem[] : memref, vector<1x1x1xf32> return @@ -191,8 +188,8 @@ func.func @transfer_perm_map(%mem : memref<8x8xf32>, %idx : index) -> vector<4xf // CHECK-LABEL: func @transfer_broadcasting( // CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32>, // CHECK-SAME: %[[IDX:.*]]: index) -> vector<4xf32> { -// CHECK-NEXT: %[[LOAD:.*]] = memref.load %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xf32> -// CHECK-NEXT: %[[RES:.*]] = vector.broadcast %[[LOAD]] : f32 to vector<4xf32> +// CHECK-NEXT: %[[LOAD:.*]] = vector.load %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xf32>, vector<1xf32> +// CHECK-NEXT: %[[RES:.*]] = vector.broadcast %[[LOAD]] : vector<1xf32> to vector<4xf32> // CHECK-NEXT: return %[[RES]] : vector<4xf32> // CHECK-NEXT: } @@ -208,8 +205,7 @@ func.func @transfer_broadcasting(%mem : memref<8x8xf32>, %idx : index) -> vector // CHECK-LABEL: func @transfer_scalar( // CHECK-SAME: %[[MEM:.*]]: memref, // CHECK-SAME: %[[IDX:.*]]: index) -> vector<1xf32> { -// CHECK-NEXT: %[[LOAD:.*]] = memref.load %[[MEM]][%[[IDX]], %[[IDX]]] : memref -// CHECK-NEXT: %[[RES:.*]] = vector.broadcast %[[LOAD]] : f32 to vector<1xf32> +// CHECK-NEXT: %[[RES:.*]] = vector.load %[[MEM]][%[[IDX]], %[[IDX]]] : memref, vector<1xf32> // CHECK-NEXT: return %[[RES]] : vector<1xf32> // CHECK-NEXT: } func.func @transfer_scalar(%mem : memref, %idx : index) -> vector<1xf32> { @@ -222,8 +218,8 @@ func.func @transfer_scalar(%mem : memref, %idx : index) -> vector<1xf32 // CHECK-LABEL: func @transfer_broadcasting_2D( // CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32>, // CHECK-SAME: %[[IDX:.*]]: index) -> vector<4x4xf32> { -// CHECK-NEXT: %[[LOAD:.*]] = memref.load %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xf32> -// CHECK-NEXT: %[[RES:.*]] = vector.broadcast %[[LOAD]] : f32 to vector<4x4xf32> +// CHECK-NEXT: %[[LOAD:.*]] = vector.load %[[MEM]][%[[IDX]], %[[IDX]]] : memref<8x8xf32>, vector<1x1xf32> +// CHECK-NEXT: %[[RES:.*]] = vector.broadcast %[[LOAD]] : vector<1x1xf32> to vector<4x4xf32> // CHECK-NEXT: return %[[RES]] : vector<4x4xf32> // CHECK-NEXT: } @@ -322,8 +318,8 @@ func.func @transfer_read_permutations(%mem_0 : memref, %mem_1 : memref< // CHECK: vector.transpose %{{.*}}, [2, 1, 3, 0] : vector<16x14x7x8xf32> to vector<7x14x8x16xf32> %6 = vector.transfer_read %mem_0[%c0, %c0], %cst {in_bounds = [true], permutation_map = #map6} : memref, vector<8xf32> -// CHECK: memref.load %{{.*}}[%[[C0]], %[[C0]]] : memref -// CHECK: vector.broadcast %{{.*}} : f32 to vector<8xf32> +// CHECK: vector.load %{{.*}}[%[[C0]], %[[C0]]] : memref, vector<1xf32> +// CHECK: vector.broadcast %{{.*}} : vector<1xf32> to vector<8xf32> return %0, %1, %2, %3, %4, %5, %6 : vector<7x14x8x16xf32>, vector<7x14x8x16xf32>, vector<7x14x8x16xf32>, vector<7x14x8x16xf32>, vector<7x14x8x16xf32>,