diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp index 678a88627ca82..99d18fec18120 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp @@ -623,6 +623,66 @@ struct LinearizeVectorCreateMask final } }; +/// This pattern linearizes vector.load from vector<1x1x...xN> to vector +/// It currently supports linearization where all but the last dimension are 1 +/// The following, +/// vector.load %arg0[%c0, %c0] : memref<1x4xf32>, vector<1x4xf32> +/// is converted to: +/// vector.load %arg0[%c0, %c0] : memref<1x4xf32>, vector<4xf32> +/// vector.shape_cast %load_result : vector<4xf32> to vector<1x4xf32> +struct LinearizeVectorLoad final : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LinearizeVectorLoad(const TypeConverter &typeConverter, MLIRContext *context, + PatternBenefit benefit = 1) + : OpConversionPattern(typeConverter, context, benefit) {} + + LogicalResult + matchAndRewrite(vector::LoadOp loadOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + VectorType vecTy = loadOp.getType(); + if (!vecTy || !llvm::all_of(vecTy.getShape().drop_back(1), + [](auto d) { return d == 1; })) + return rewriter.notifyMatchFailure(loadOp, + "only vector<1x1x...xN> supported"); + auto linearTy = typeConverter->convertType(loadOp.getType()); + auto newLoad = rewriter.create( + loadOp.getLoc(), linearTy, adaptor.getBase(), adaptor.getIndices()); + rewriter.replaceOp(loadOp, newLoad.getResult()); + return success(); + } +}; + +/// This pattern linearizes vector.store from vector<1x1x...xN> to vector +/// It currently supports linearization where all but the last dimension are 1 +/// The following, +/// vector.store %arg0, %arg1[%c0, %c0]s +/// : vector<1x4xf32>, memref<1x4xf32> +/// is converted to: +/// vector.shape_cast %arg0 : vector<1x4xf32> to vector<4xf32> +/// vector.store %arg0, %arg1[%c0, %c0] +/// : vector<4xf32>, memref<1x4xf32> +struct LinearizeVectorStore final + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LinearizeVectorStore(const TypeConverter &typeConverter, MLIRContext *context, + PatternBenefit benefit = 1) + : OpConversionPattern(typeConverter, context, benefit) {} + + LogicalResult + matchAndRewrite(vector::StoreOp storeOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + VectorType vecTy = storeOp.getValueToStore().getType(); + if (!vecTy || !llvm::all_of(vecTy.getShape().drop_back(1), + [](auto d) { return d == 1; })) + return rewriter.notifyMatchFailure(storeOp, + "only vector<1x1x...xN> supported"); + rewriter.replaceOpWithNewOp( + storeOp, adaptor.getValueToStore(), adaptor.getBase(), + adaptor.getIndices()); + return success(); + } +}; + } // namespace /// This method defines the set of operations that are linearizable, and hence @@ -714,8 +774,8 @@ void mlir::vector::populateVectorLinearizeBasePatterns( RewritePatternSet &patterns) { patterns .add( - typeConverter, patterns.getContext()); + LinearizeVectorSplat, LinearizeVectorCreateMask, LinearizeVectorLoad, + LinearizeVectorStore>(typeConverter, patterns.getContext()); } void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns( diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir index 9cbf319ffddb2..9a017ceedcebe 100644 --- a/mlir/test/Dialect/Vector/linearize.mlir +++ b/mlir/test/Dialect/Vector/linearize.mlir @@ -464,3 +464,26 @@ func.func @linearize_scalable_create_mask(%arg0 : index, %arg1 : index) -> vecto %0 = vector.create_mask %arg0, %arg1 : vector<1x[16]xi1> return %0 : vector<1x[16]xi1> } + +// CHECK-LABEL: linearize_vector_load +// CHECK-SAME: (%[[ARG0:.*]]: memref<2x8xf32>) -> vector<1x4xf32> +func.func @linearize_vector_load(%arg0: memref<2x8xf32>) -> vector<1x4xf32> { + // CHECK: %[[CST0:.*]] = arith.constant 0 : index + // CHECK: %[[LOAD:.*]] = vector.load %[[ARG0]][%[[CST0]], %[[CST0]]] : memref<2x8xf32>, vector<4xf32> + // CHECK: %[[CAST:.*]] = vector.shape_cast %[[LOAD]] : vector<4xf32> to vector<1x4xf32> + // CHECK: return %[[CAST]] : vector<1x4xf32> + %c0 = arith.constant 0 : index + %0 = vector.load %arg0[%c0, %c0] : memref<2x8xf32>, vector<1x4xf32> + return %0 : vector<1x4xf32> +} + +// CHECK-LABEL: linearize_vector_store +// CHECK-SAME: (%[[ARG0:.*]]: memref<2x8xf32>, %[[ARG1:.*]]: vector<1x4xf32>) +func.func @linearize_vector_store(%arg0: memref<2x8xf32>, %arg1: vector<1x4xf32>) { + // CHECK: %[[CAST:.*]] = vector.shape_cast %arg1 : vector<1x4xf32> to vector<4xf32> + // CHECK: %[[CST0:.*]] = arith.constant 0 : index + // CHECK: vector.store %[[CAST]], %[[ARG0]][%[[CST0]], %[[CST0]]] : memref<2x8xf32>, vector<4xf32> + %c0 = arith.constant 0 : index + vector.store %arg1, %arg0[%c0, %c0] : memref<2x8xf32>, vector<1x4xf32> + return +}