diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp index a009aa03aaf64..e4d88de2cf4ae 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp @@ -26,7 +26,12 @@ using namespace mlir; +constexpr unsigned defaultTargetVectorBitWidth = + std::numeric_limits::max(); + static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) { + if (targetBitWidth == 0) + return false; auto resultTypes = op->getResultTypes(); for (auto resType : resultTypes) { VectorType vecType = dyn_cast(resType); @@ -82,7 +87,7 @@ struct LinearizeConstantLike final LinearizeConstantLike( const TypeConverter &typeConverter, MLIRContext *context, - unsigned targetVectBitWidth = std::numeric_limits::max(), + unsigned targetVectBitWidth = defaultTargetVectorBitWidth, PatternBenefit benefit = 1) : OpTraitConversionPattern(typeConverter, context, benefit), targetVectorBitWidth(targetVectBitWidth) {} @@ -136,7 +141,7 @@ struct LinearizeVectorizable final public: LinearizeVectorizable( const TypeConverter &typeConverter, MLIRContext *context, - unsigned targetVectBitWidth = std::numeric_limits::max(), + unsigned targetVectBitWidth = defaultTargetVectorBitWidth, PatternBenefit benefit = 1) : OpTraitConversionPattern(typeConverter, context, benefit), targetVectorBitWidth(targetVectBitWidth) {} @@ -175,7 +180,7 @@ struct LinearizeVectorExtractStridedSlice final using OpConversionPattern::OpConversionPattern; LinearizeVectorExtractStridedSlice( const TypeConverter &typeConverter, MLIRContext *context, - unsigned targetVectBitWidth = std::numeric_limits::max(), + unsigned targetVectBitWidth = defaultTargetVectorBitWidth, PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit), targetVectorBitWidth(targetVectBitWidth) {} @@ -289,7 +294,7 @@ struct LinearizeVectorShuffle final using OpConversionPattern::OpConversionPattern; LinearizeVectorShuffle( const TypeConverter &typeConverter, MLIRContext *context, - unsigned targetVectBitWidth = std::numeric_limits::max(), + unsigned targetVectBitWidth = defaultTargetVectorBitWidth, PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit), targetVectorBitWidth(targetVectBitWidth) {} @@ -362,13 +367,17 @@ struct LinearizeVectorExtract final using OpConversionPattern::OpConversionPattern; LinearizeVectorExtract( const TypeConverter &typeConverter, MLIRContext *context, - unsigned targetVectBitWidth = std::numeric_limits::max(), + unsigned targetVectBitWidth = defaultTargetVectorBitWidth, PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit), targetVectorBitWidth(targetVectBitWidth) {} LogicalResult matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + // Skip if result is not a vector type + if (!isa(extractOp.getType())) + return rewriter.notifyMatchFailure(extractOp, + "scalar extract is not supported."); Type dstTy = getTypeConverter()->convertType(extractOp.getType()); if (!dstTy) return rewriter.notifyMatchFailure(extractOp, @@ -425,7 +434,7 @@ struct LinearizeVectorInsert final using OpConversionPattern::OpConversionPattern; LinearizeVectorInsert( const TypeConverter &typeConverter, MLIRContext *context, - unsigned targetVectBitWidth = std::numeric_limits::max(), + unsigned targetVectBitWidth = defaultTargetVectorBitWidth, PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit), targetVectorBitWidth(targetVectBitWidth) {} @@ -506,7 +515,7 @@ struct LinearizeVectorBitCast final using OpConversionPattern::OpConversionPattern; LinearizeVectorBitCast( const TypeConverter &typeConverter, MLIRContext *context, - unsigned targetVectBitWidth = std::numeric_limits::max(), + unsigned targetVectBitWidth = defaultTargetVectorBitWidth, PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit), targetVectorBitWidth(targetVectBitWidth) {} @@ -531,12 +540,139 @@ struct LinearizeVectorBitCast final unsigned targetVectorBitWidth; }; +// clang-format off +/// This pattern converts the LoadOp to a series of LoadOp & InsertOp +/// that works on a linearized vector. +/// Following, +/// vector.load %base[%indices] : vector<4x4xf32> +/// is converted to : +/// %result = arith.constant dense<0.0> : vector<4x4xf32> +/// %slice_0 = vector.load %base[%indices] : vector<4xf32> +/// %result_0 = vector.insert %slice_0, %result[0] : vector<4xf32> into vector<4x4xf32> +/// %slice_1 = vector.load %base[%indices + 1] : vector<4xf32> +/// %result_1 = vector.insert %slice_1, %result_0[1] : vector<4xf32> into vector<4x4xf32> +/// ... +/// This unrolls the 2D vector load into multiple 1D vector loads and inserts +/// them into the result vector. The pattern currently supports only 2D vectors +// clang-format on +struct LinearizeVectorLoad final : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LinearizeVectorLoad(const TypeConverter &typeConverter, MLIRContext *context, + unsigned targetVectBitWidth = defaultTargetVectorBitWidth, + PatternBenefit benefit = 1) + : OpConversionPattern(typeConverter, context, benefit), + targetVectorBitWidth(targetVectBitWidth) {} + + LogicalResult + matchAndRewrite(vector::LoadOp loadOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = loadOp->getLoc(); + VectorType vecType = loadOp.getVectorType(); + auto shape = vecType.getShape(); + + if (shape.size() != 2) + return rewriter.notifyMatchFailure(loc, "Can only linearize 2D vectors."); + + auto unrollCount = shape[0]; + auto vecSize = shape[1]; + VectorType newVecType = + VectorType::get({vecSize}, vecType.getElementType()); + + llvm::SmallVector indices = adaptor.getIndices(); + Value xBaseIndex = indices[0]; + + // Construct the 2D vector. + Value resultVec = + rewriter.create(loc, rewriter.getZeroAttr(vecType)); + // Emit unrolled loads for each 1D vector slice. + for (auto i = 0; i < unrollCount; i++) { + Value xIndex = xBaseIndex; + if (i) { + auto increment = rewriter.create(loc, i); + xIndex = rewriter.create(loc, xBaseIndex, increment); + } + indices[0] = xIndex; + auto vec = rewriter.create(loc, newVecType, + adaptor.getBase(), indices); + resultVec = rewriter.create(loc, vec, resultVec, i); + } + + rewriter.replaceOp(loadOp, resultVec); + return success(); + } + +private: + unsigned targetVectorBitWidth; +}; + +/// This pattern converts the StoreOp to a series of StoreOp & ExtractOp +/// that works on a linearized vector. +/// Following, +/// vector.store %source, %base[%indices] : vector<4x4xf32> +/// is converted to : +/// %slice_0 = vector.extract %source[0] : vector<4xf32> +/// vector.store %slice_0, %base[%indices] : vector<4xf32> +/// %slice_1 = vector.extract %source[1] : vector<4xf32> +/// vector.store %slice_1, %base[%indices + 1] : vector<4xf32> +/// ... +/// This unrolls the 2D vector store into multiple 1D vector stores by +/// extracting slices from the source vector and storing them into the +/// destination. The pattern currently supports only 2D vectors +struct LinearizeVectorStore final + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LinearizeVectorStore( + const TypeConverter &typeConverter, MLIRContext *context, + unsigned targetVectBitWidth = defaultTargetVectorBitWidth, + PatternBenefit benefit = 1) + : OpConversionPattern(typeConverter, context, benefit), + targetVectorBitWidth(targetVectBitWidth) {} + + LogicalResult + matchAndRewrite(vector::StoreOp storeOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = storeOp->getLoc(); + VectorType vecType = storeOp.getVectorType(); + auto shape = vecType.getShape(); + + if (shape.size() != 2) + return rewriter.notifyMatchFailure(loc, "Can only linearize 2D vectors."); + + auto unrollCount = shape[0]; + llvm::SmallVector indices = adaptor.getIndices(); + Value xBaseIndex = indices[0]; + + auto vec = rewriter.create(loc, vecType, + adaptor.getValueToStore()); + + for (auto i = 0; i < unrollCount; i++) { + auto vecSlice = rewriter.create(loc, vec, i); + Value xIndex = xBaseIndex; + if (i) { + auto increment = rewriter.create(loc, i); + xIndex = rewriter.create(loc, xBaseIndex, increment); + } + indices[0] = xIndex; + rewriter.create(loc, vecSlice, adaptor.getBase(), + indices); + } + rewriter.eraseOp(storeOp); + return success(); + } + +private: + unsigned targetVectorBitWidth; +}; + } // namespace void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, unsigned targetBitWidth) { + typeConverter.addConversion([](Type type) -> Type { return type; }); typeConverter.addConversion([](VectorType type) -> std::optional { if (!isLinearizableVector(type)) return type; @@ -555,9 +691,10 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality( }; typeConverter.addSourceMaterialization(materializeCast); typeConverter.addTargetMaterialization(materializeCast); + target.addLegalOp(); target.markUnknownOpDynamicallyLegal( [=](Operation *op) -> std::optional { - if ((isa(op) || + if ((isa(op) || op->hasTrait() || op->hasTrait())) { return (isLessThanTargetBitWidth(op, targetBitWidth) @@ -567,9 +704,10 @@ void mlir::vector::populateVectorLinearizeTypeConversionsAndLegality( return std::nullopt; }); - patterns.add(typeConverter, patterns.getContext(), - targetBitWidth); + patterns + .add( + typeConverter, patterns.getContext(), targetBitWidth); } void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns( diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir index 9052c6440e6ac..9e793c5dc8233 100644 --- a/mlir/test/Dialect/Vector/linearize.mlir +++ b/mlir/test/Dialect/Vector/linearize.mlir @@ -399,3 +399,113 @@ func.func @test_vector_bitcast(%arg0: vector<[4]x2xf32>) -> vector<[4]x4xf16> { %1 = vector.bitcast %arg0 : vector<[4]x2xf32> to vector<[4]x4xf16> return %1 : vector<[4]x4xf16> } + +// ----- +// ALL-LABEL: linearize_vector_load +// ALL-SAME: (%[[ARG_0:.*]]: memref<4x4xf16>) +func.func @linearize_2D_vector_load(%arg0: memref<4x4xf16>) -> vector<4x4xf16> { + // DEFAULT: %[[C1:.*]] = arith.constant 1 : index + // DEFAULT: %[[C2:.*]] = arith.constant 2 : index + // DEFAULT: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<16xf16> + // DEFAULT: %[[LOAD0:.*]] = vector.load %[[ARG_0]][%[[C1]], %[[C2]]] : memref<4x4xf16>, vector<4xf16> + // DEFAULT: %[[SHUFFLE0:.*]] = vector.shuffle %[[CST]], %[[LOAD0]] [16, 17, 18, 19, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<16xf16>, vector<4xf16> + // DEFAULT: %[[C1_0:.*]] = arith.constant 1 : index + // DEFAULT: %[[ADD0:.*]] = arith.addi %[[C1]], %[[C1_0]] : index + // DEFAULT: %[[LOAD1:.*]] = vector.load %[[ARG_0]][%[[ADD0]], %[[C2]]] : memref<4x4xf16>, vector<4xf16> + // DEFAULT: %[[SHUFFLE1:.*]] = vector.shuffle %[[SHUFFLE0]], %[[LOAD1]] [0, 1, 2, 3, 16, 17, 18, 19, 8, 9, 10, 11, 12, 13, 14, 15] : vector<16xf16>, vector<4xf16> + // DEFAULT: %[[C2_1:.*]] = arith.constant 2 : index + // DEFAULT: %[[ADD1:.*]] = arith.addi %[[C1]], %[[C2_1]] : index + // DEFAULT: %[[LOAD2:.*]] = vector.load %[[ARG_0]][%[[ADD1]], %[[C2]]] : memref<4x4xf16>, vector<4xf16> + // DEFAULT: %[[SHUFFLE2:.*]] = vector.shuffle %[[SHUFFLE1]], %[[LOAD2]] [0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 12, 13, 14, 15] : vector<16xf16>, vector<4xf16> + // DEFAULT: %[[C3:.*]] = arith.constant 3 : index + // DEFAULT: %[[ADD2:.*]] = arith.addi %[[C1]], %[[C3]] : index + // DEFAULT: %[[LOAD3:.*]] = vector.load %[[ARG_0]][%[[ADD2]], %[[C2]]] : memref<4x4xf16>, vector<4xf16> + // DEFAULT: %[[SHUFFLE3:.*]] = vector.shuffle %[[SHUFFLE2]], %[[LOAD3]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 16, 17, 18, 19] : vector<16xf16>, vector<4xf16> + // DEFAULT: %[[CAST:.*]] = vector.shape_cast %[[SHUFFLE3]] : vector<16xf16> to vector<4x4xf16> + // DEFAULT: return %[[CAST]] : vector<4x4xf16> + + // BW-128: %[[C1:.*]] = arith.constant 1 : index + // BW-128: %[[C2:.*]] = arith.constant 2 : index + // BW-128: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<16xf16> + // BW-128: %[[LOAD0:.*]] = vector.load %[[ARG_0]][%[[C1]], %[[C2]]] : memref<4x4xf16>, vector<4xf16> + // BW-128: %[[SHUFFLE0:.*]] = vector.shuffle %[[CST]], %[[LOAD0]] [16, 17, 18, 19, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] : vector<16xf16>, vector<4xf16> + // BW-128: %[[C1_0:.*]] = arith.constant 1 : index + // BW-128: %[[ADD0:.*]] = arith.addi %[[C1]], %[[C1_0]] : index + // BW-128: %[[LOAD1:.*]] = vector.load %[[ARG_0]][%[[ADD0]], %[[C2]]] : memref<4x4xf16>, vector<4xf16> + // BW-128: %[[SHUFFLE1:.*]] = vector.shuffle %[[SHUFFLE0]], %[[LOAD1]] [0, 1, 2, 3, 16, 17, 18, 19, 8, 9, 10, 11, 12, 13, 14, 15] : vector<16xf16>, vector<4xf16> + // BW-128: %[[C2_1:.*]] = arith.constant 2 : index + // BW-128: %[[ADD1:.*]] = arith.addi %[[C1]], %[[C2_1]] : index + // BW-128: %[[LOAD2:.*]] = vector.load %[[ARG_0]][%[[ADD1]], %[[C2]]] : memref<4x4xf16>, vector<4xf16> + // BW-128: %[[SHUFFLE2:.*]] = vector.shuffle %[[SHUFFLE1]], %[[LOAD2]] [0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 12, 13, 14, 15] : vector<16xf16>, vector<4xf16> + // BW-128: %[[C3:.*]] = arith.constant 3 : index + // BW-128: %[[ADD2:.*]] = arith.addi %[[C1]], %[[C3]] : index + // BW-128: %[[LOAD3:.*]] = vector.load %[[ARG_0]][%[[ADD2]], %[[C2]]] : memref<4x4xf16>, vector<4xf16> + // BW-128: %[[SHUFFLE3:.*]] = vector.shuffle %[[SHUFFLE2]], %[[LOAD3]] [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 16, 17, 18, 19] : vector<16xf16>, vector<4xf16> + // BW-128: %[[CAST:.*]] = vector.shape_cast %[[SHUFFLE3]] : vector<16xf16> to vector<4x4xf16> + // BW-128: return %[[CAST]] : vector<4x4xf16> + + // BW-0: %[[C1:.*]] = arith.constant 1 : index + // BW-0: %[[C2:.*]] = arith.constant 2 : index + // BW-0: %[[LOAD:.*]] = vector.load %[[ARG_0]][%[[C1]], %[[C2]]] : memref<4x4xf16>, vector<4x4xf16> + // BW-0: return %[[LOAD]] : vector<4x4xf16> + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %0 = vector.load %arg0[%c1, %c2] : memref<4x4xf16>, vector<4x4xf16> + return %0 : vector<4x4xf16> +} + +// ----- +// ALL-LABEL: linearize_vector_store +// ALL-SAME: (%[[ARG_0:.*]]: memref<4x4xf16>, %[[ARG_1:.*]]: vector<4x4xf16>) { +func.func @linearize_2D_vector_store(%arg0: memref<4x4xf16>, %arg1: vector<4x4xf16>) { + // DEFAULT: %[[CAST0:.*]] = vector.shape_cast %[[ARG_1]] : vector<4x4xf16> to vector<16xf16> + // DEFAULT: %[[C1:.*]] = arith.constant 1 : index + // DEFAULT: %[[C2:.*]] = arith.constant 2 : index + // DEFAULT: %[[CAST1:.*]] = vector.shape_cast %[[CAST0]] : vector<16xf16> to vector<4x4xf16> + // DEFAULT: %[[CAST2:.*]] = vector.shape_cast %[[CAST1]] : vector<4x4xf16> to vector<16xf16> + // DEFAULT: %[[SHUFFLE0:.*]] = vector.shuffle %[[CAST2]], %[[CAST2]] [0, 1, 2, 3] : vector<16xf16>, vector<16xf16> + // DEFAULT: vector.store %[[SHUFFLE0]], %[[ARG_0]][%[[C1]], %[[C2]]] : memref<4x4xf16>, vector<4xf16> + // DEFAULT: %[[SHUFFLE1:.*]] = vector.shuffle %[[CAST2]], %[[CAST2]] [4, 5, 6, 7] : vector<16xf16>, vector<16xf16> + // DEFAULT: %[[C1_0:.*]] = arith.constant 1 : index + // DEFAULT: %[[ADD0:.*]] = arith.addi %[[C1]], %[[C1_0]] : index + // DEFAULT: vector.store %[[SHUFFLE1]], %[[ARG_0]][%[[ADD0]], %[[C2]]] : memref<4x4xf16>, vector<4xf16> + // DEFAULT: %[[SHUFFLE2:.*]] = vector.shuffle %[[CAST2]], %[[CAST2]] [8, 9, 10, 11] : vector<16xf16>, vector<16xf16> + // DEFAULT: %[[C2_1:.*]] = arith.constant 2 : index + // DEFAULT: %[[ADD1:.*]] = arith.addi %[[C1]], %[[C2_1]] : index + // DEFAULT: vector.store %[[SHUFFLE2]], %[[ARG_0]][%[[ADD1]], %[[C2]]] : memref<4x4xf16>, vector<4xf16> + // DEFAULT: %[[SHUFFLE3:.*]] = vector.shuffle %[[CAST2]], %[[CAST2]] [12, 13, 14, 15] : vector<16xf16>, vector<16xf16> + // DEFAULT: %[[C3:.*]] = arith.constant 3 : index + // DEFAULT: %[[ADD2:.*]] = arith.addi %[[C1]], %[[C3]] : index + // DEFAULT: vector.store %[[SHUFFLE3]], %[[ARG_0]][%[[ADD2]], %[[C2]]] : memref<4x4xf16>, vector<4xf16> + // DEFAULT: return + + // BW-128: %[[CAST0:.*]] = vector.shape_cast %[[ARG_1]] : vector<4x4xf16> to vector<16xf16> + // BW-128: %[[C1:.*]] = arith.constant 1 : index + // BW-128: %[[C2:.*]] = arith.constant 2 : index + // BW-128: %[[CAST1:.*]] = vector.shape_cast %[[CAST0]] : vector<16xf16> to vector<4x4xf16> + // BW-128: %[[CAST2:.*]] = vector.shape_cast %[[CAST1]] : vector<4x4xf16> to vector<16xf16> + // BW-128: %[[SHUFFLE0:.*]] = vector.shuffle %[[CAST2]], %[[CAST2]] [0, 1, 2, 3] : vector<16xf16>, vector<16xf16> + // BW-128: vector.store %[[SHUFFLE0]], %[[ARG_0]][%[[C1]], %[[C2]]] : memref<4x4xf16>, vector<4xf16> + // BW-128: %[[SHUFFLE1:.*]] = vector.shuffle %[[CAST2]], %[[CAST2]] [4, 5, 6, 7] : vector<16xf16>, vector<16xf16> + // BW-128: %[[C1_0:.*]] = arith.constant 1 : index + // BW-128: %[[ADD0:.*]] = arith.addi %[[C1]], %[[C1_0]] : index + // BW-128: vector.store %[[SHUFFLE1]], %[[ARG_0]][%[[ADD0]], %[[C2]]] : memref<4x4xf16>, vector<4xf16> + // BW-128: %[[SHUFFLE2:.*]] = vector.shuffle %[[CAST2]], %[[CAST2]] [8, 9, 10, 11] : vector<16xf16>, vector<16xf16> + // BW-128: %[[C2_1:.*]] = arith.constant 2 : index + // BW-128: %[[ADD1:.*]] = arith.addi %[[C1]], %[[C2_1]] : index + // BW-128: vector.store %[[SHUFFLE2]], %[[ARG_0]][%[[ADD1]], %[[C2]]] : memref<4x4xf16>, vector<4xf16> + // BW-128: %[[SHUFFLE3:.*]] = vector.shuffle %[[CAST2]], %[[CAST2]] [12, 13, 14, 15] : vector<16xf16>, vector<16xf16> + // BW-128: %[[C3:.*]] = arith.constant 3 : index + // BW-128: %[[ADD2:.*]] = arith.addi %[[C1]], %[[C3]] : index + // BW-128: vector.store %[[SHUFFLE3]], %[[ARG_0]][%[[ADD2]], %[[C2]]] : memref<4x4xf16>, vector<4xf16> + // BW-128: return + + // BW-0: %[[C1:.*]] = arith.constant 1 : index + // BW-0: %[[C2:.*]] = arith.constant 2 : index + // BW-0: vector.store %[[ARG_1]], %[[ARG_0]][%[[C1]], %[[C2]]] : memref<4x4xf16>, vector<4x4xf16> + // BW-0: return + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + vector.store %arg1, %arg0[%c1, %c2] : memref<4x4xf16>, vector<4x4xf16> + return +} diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index 03f907e46c2c6..14c7e9d554cd9 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -852,7 +852,8 @@ struct TestVectorLinearize final return "Linearizes ND vectors for N >= 2 into 1D vectors"; } void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); + registry.insert(); } Option targetVectorBitwidth{