diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index bc08f8d07fb0d..6d50b0654bc57 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -478,6 +478,69 @@ def Vector_ShuffleOp : let hasCanonicalizer = 1; } +def Vector_InterleaveOp : + Vector_Op<"interleave", [Pure, + AllTypesMatch<["lhs", "rhs"]>, + TypesMatchWith< + "type of 'result' is double the width of the inputs", + "lhs", "result", + [{ + [&]() -> ::mlir::VectorType { + auto vectorType = ::llvm::cast($_self); + ::mlir::VectorType::Builder builder(vectorType); + if (vectorType.getRank() == 0) { + static constexpr int64_t v2xty_shape[] = { 2 }; + return builder.setShape(v2xty_shape); + } + auto lastDim = vectorType.getRank() - 1; + return builder.setDim(lastDim, vectorType.getDimSize(lastDim) * 2); + }() + }]>]> { + let summary = "constructs a vector by interleaving two input vectors"; + let description = [{ + The interleave operation constructs a new vector by interleaving the + elements from the trailing (or final) dimension of two input vectors, + returning a new vector where the trailing dimension is twice the size. + + Note that for the n-D case this differs from the interleaving possible with + `vector.shuffle`, which would only operate on the leading dimension. + + Another key difference is this operation supports scalable vectors, though + currently a general LLVM lowering is limited to the case where only the + trailing dimension is scalable. + + Example: + ```mlir + %0 = vector.interleave %a, %b + : vector<[4]xi32> ; yields vector<[8]xi32> + %1 = vector.interleave %c, %d + : vector<8xi8> ; yields vector<16xi8> + %2 = vector.interleave %e, %f + : vector ; yields vector<2xf16> + %3 = vector.interleave %g, %h + : vector<2x4x[2]xf64> ; yields vector<2x4x[4]xf64> + %4 = vector.interleave %i, %j + : vector<6x3xf32> ; yields vector<6x6xf32> + ``` + }]; + + let arguments = (ins AnyVectorOfAnyRank:$lhs, AnyVectorOfAnyRank:$rhs); + let results = (outs AnyVector:$result); + + let assemblyFormat = [{ + $lhs `,` $rhs attr-dict `:` type($lhs) + }]; + + let extraClassDeclaration = [{ + VectorType getSourceVectorType() { + return ::llvm::cast(getLhs().getType()); + } + VectorType getResultVectorType() { + return ::llvm::cast(getResult().getType()); + } + }]; +} + def Vector_ExtractElementOp : Vector_Op<"extractelement", [Pure, TypesMatchWith<"result type matches element type of vector operand", diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h index 57b39f5f52c6d..1cd3bab46396e 100644 --- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h @@ -264,6 +264,14 @@ void populateVectorMaskLoweringPatternsForSideEffectingOps( void populateVectorMaskedLoadStoreEmulationPatterns(RewritePatternSet &patterns, PatternBenefit benefit = 1); +/// Populate the pattern set with the following patterns: +/// +/// [InterleaveOpLowering] +/// Progressive lowering of InterleaveOp to ExtractOp + InsertOp + lower-D +/// InterleaveOp until dim 1. +void populateVectorInterleaveLoweringPatterns(RewritePatternSet &patterns, + PatternBenefit benefit = 1); + } // namespace vector } // namespace mlir #endif // MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index b66b55ae8d57f..0d9a451d11ca8 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -1734,6 +1734,44 @@ struct VectorSplatNdOpLowering : public ConvertOpToLLVMPattern { } }; +/// Conversion pattern for a `vector.interleave`. +/// This supports fixed-sized vectors and scalable vectors. +struct VectorInterleaveOpLowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(vector::InterleaveOp interleaveOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + VectorType resultType = interleaveOp.getResultVectorType(); + // n-D interleaves should have been lowered already. + if (resultType.getRank() != 1) + return failure(); + // If the result is rank 1, then this directly maps to LLVM. + if (resultType.isScalable()) { + rewriter.replaceOpWithNewOp( + interleaveOp, typeConverter->convertType(resultType), + adaptor.getLhs(), adaptor.getRhs()); + return success(); + } + // Lower fixed-size interleaves to a shufflevector. While the + // vector.interleave2 intrinsic supports fixed and scalable vectors, the + // langref still recommends fixed-vectors use shufflevector, see: + // https://llvm.org/docs/LangRef.html#id876. + int64_t resultVectorSize = resultType.getNumElements(); + SmallVector interleaveShuffleMask; + interleaveShuffleMask.reserve(resultVectorSize); + for (int i = 0, end = resultVectorSize / 2; i < end; ++i) { + interleaveShuffleMask.push_back(i); + interleaveShuffleMask.push_back((resultVectorSize / 2) + i); + } + rewriter.replaceOpWithNewOp( + interleaveOp, adaptor.getLhs(), adaptor.getRhs(), + interleaveShuffleMask); + return success(); + } +}; + } // namespace /// Populate the given list with patterns that convert from Vector to LLVM. @@ -1758,7 +1796,8 @@ void mlir::populateVectorToLLVMConversionPatterns( VectorExpandLoadOpConversion, VectorCompressStoreOpConversion, VectorSplatOpLowering, VectorSplatNdOpLowering, VectorScalableInsertOpLowering, VectorScalableExtractOpLowering, - MaskedReductionOpConversion>(converter); + MaskedReductionOpConversion, VectorInterleaveOpLowering>( + converter); // Transfer ops with rank > 1 are handled by VectorToSCF. populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1); } diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp index ff8e78a668e0f..e3a436c4a9400 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp @@ -68,6 +68,7 @@ void LowerVectorToLLVMPass::runOnOperation() { populateVectorContractLoweringPatterns(patterns, VectorTransformsOptions()); populateVectorMaskOpLoweringPatterns(patterns); populateVectorShapeCastLoweringPatterns(patterns); + populateVectorInterleaveLoweringPatterns(patterns); populateVectorTransposeLoweringPatterns(patterns, VectorTransformsOptions()); // Vector transfer ops with rank > 1 should be lowered with VectorToSCF. diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 452354413e883..084348e68270c 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -2478,11 +2478,52 @@ class ShuffleSplat final : public OpRewritePattern { } }; +/// Pattern to rewrite a fixed-size interleave via vector.shuffle to +/// vector.interleave. +class ShuffleInterleave : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ShuffleOp op, + PatternRewriter &rewriter) const override { + VectorType resultType = op.getResultVectorType(); + if (resultType.isScalable()) + return rewriter.notifyMatchFailure( + op, "ShuffleOp can't represent a scalable interleave"); + + if (resultType.getRank() != 1) + return rewriter.notifyMatchFailure( + op, "ShuffleOp can't represent an n-D interleave"); + + VectorType sourceType = op.getV1VectorType(); + if (sourceType != op.getV2VectorType() || + ArrayRef{sourceType.getNumElements() * 2} != + resultType.getShape()) { + return rewriter.notifyMatchFailure( + op, "ShuffleOp types don't match an interleave"); + } + + ArrayAttr shuffleMask = op.getMask(); + int64_t resultVectorSize = resultType.getNumElements(); + for (int i = 0, e = resultVectorSize / 2; i < e; ++i) { + int64_t maskValueA = cast(shuffleMask[i * 2]).getInt(); + int64_t maskValueB = cast(shuffleMask[(i * 2) + 1]).getInt(); + if (maskValueA != i || maskValueB != (resultVectorSize / 2) + i) + return rewriter.notifyMatchFailure(op, + "ShuffleOp mask not interleaving"); + } + + rewriter.replaceOpWithNewOp(op, op.getV1(), op.getV2()); + return success(); + } +}; + } // namespace void ShuffleOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results.add( + context); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt index daf28882976ef..f221b7462dfd7 100644 --- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt @@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRVectorTransforms LowerVectorBroadcast.cpp LowerVectorContract.cpp LowerVectorGather.cpp + LowerVectorInterleave.cpp LowerVectorMask.cpp LowerVectorMultiReduction.cpp LowerVectorScan.cpp diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp new file mode 100644 index 0000000000000..0ca38eba942a5 --- /dev/null +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp @@ -0,0 +1,64 @@ +//===- LowerVectorInterleave.cpp - Lower 'vector.interleave' operation ----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements target-independent rewrites and utilities to lower the +// 'vector.interleave' operation. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/PatternMatch.h" + +#define DEBUG_TYPE "vector-interleave-lowering" + +using namespace mlir; +using namespace mlir::vector; + +namespace { +/// Progressive lowering of InterleaveOp. +class InterleaveOpLowering : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::InterleaveOp op, + PatternRewriter &rewriter) const override { + VectorType resultType = op.getResultVectorType(); + // 1-D vector.interleave ops can be directly lowered to LLVM (later). + if (resultType.getRank() == 1) + return failure(); + + // Below we unroll the leading (or front) dimension. If that dimension is + // scalable we can't unroll it. + if (resultType.getScalableDims().front()) + return failure(); + + // n-D case: Unroll the leading dimension. + auto loc = op.getLoc(); + Value result = rewriter.create( + loc, resultType, rewriter.getZeroAttr(resultType)); + for (int idx = 0, end = resultType.getDimSize(0); idx < end; ++idx) { + Value extractLhs = rewriter.create(loc, op.getLhs(), idx); + Value extractRhs = rewriter.create(loc, op.getRhs(), idx); + Value interleave = + rewriter.create(loc, extractLhs, extractRhs); + result = rewriter.create(loc, interleave, result, idx); + } + + rewriter.replaceOp(op, result); + return success(); + } +}; + +} // namespace + +void mlir::vector::populateVectorInterleaveLoweringPatterns( + RewritePatternSet &patterns, PatternBenefit benefit) { + patterns.add(patterns.getContext(), benefit); +} diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir index 1c13b16dfd9af..3cbca65472fb6 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -2460,3 +2460,88 @@ func.func @make_fixed_vector_of_scalable_vector(%f : f64) -> vector<3x[2]xf64> %res = vector.broadcast %f : f64 to vector<3x[2]xf64> return %res : vector<3x[2]xf64> } + +// ----- + +// CHECK-LABEL: @vector_interleave_0d +// CHECK-SAME: %[[LHS:.*]]: vector, %[[RHS:.*]]: vector) +func.func @vector_interleave_0d(%a: vector, %b: vector) -> vector<2xi8> { + // CHECK: %[[LHS_RANK1:.*]] = builtin.unrealized_conversion_cast %[[LHS]] : vector to vector<1xi8> + // CHECK: %[[RHS_RANK1:.*]] = builtin.unrealized_conversion_cast %[[RHS]] : vector to vector<1xi8> + // CHECK: %[[ZIP:.*]] = llvm.shufflevector %[[LHS_RANK1]], %[[RHS_RANK1]] [0, 1] : vector<1xi8> + // CHECK: return %[[ZIP]] + %0 = vector.interleave %a, %b : vector + return %0 : vector<2xi8> +} + +// ----- + +// CHECK-LABEL: @vector_interleave_1d +// CHECK-SAME: %[[LHS:.*]]: vector<8xf32>, %[[RHS:.*]]: vector<8xf32>) +func.func @vector_interleave_1d(%a: vector<8xf32>, %b: vector<8xf32>) -> vector<16xf32> +{ + // CHECK: %[[ZIP:.*]] = llvm.shufflevector %[[LHS]], %[[RHS]] [0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15] : vector<8xf32> + // CHECK: return %[[ZIP]] + %0 = vector.interleave %a, %b : vector<8xf32> + return %0 : vector<16xf32> +} + +// ----- + +// CHECK-LABEL: @vector_interleave_1d_scalable +// CHECK-SAME: %[[LHS:.*]]: vector<[4]xi32>, %[[RHS:.*]]: vector<[4]xi32>) +func.func @vector_interleave_1d_scalable(%a: vector<[4]xi32>, %b: vector<[4]xi32>) -> vector<[8]xi32> +{ + // CHECK: %[[ZIP:.*]] = "llvm.intr.experimental.vector.interleave2"(%[[LHS]], %[[RHS]]) : (vector<[4]xi32>, vector<[4]xi32>) -> vector<[8]xi32> + // CHECK: return %[[ZIP]] + %0 = vector.interleave %a, %b : vector<[4]xi32> + return %0 : vector<[8]xi32> +} + +// ----- + +// CHECK-LABEL: @vector_interleave_2d +// CHECK-SAME: %[[LHS:.*]]: vector<2x3xi8>, %[[RHS:.*]]: vector<2x3xi8>) +func.func @vector_interleave_2d(%a: vector<2x3xi8>, %b: vector<2x3xi8>) -> vector<2x6xi8> +{ + // CHECK: %[[LHS_LLVM:.*]] = builtin.unrealized_conversion_cast %[[LHS]] : vector<2x3xi8> to !llvm.array<2 x vector<3xi8>> + // CHECK: %[[RHS_LLVM:.*]] = builtin.unrealized_conversion_cast %[[RHS]] : vector<2x3xi8> to !llvm.array<2 x vector<3xi8>> + // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<2x6xi8> + // CHECK: %[[CST_LLVM:.*]] = builtin.unrealized_conversion_cast %[[CST]] : vector<2x6xi8> to !llvm.array<2 x vector<6xi8>> + // CHECK: %[[LHS_DIM_0:.*]] = llvm.extractvalue %[[LHS_LLVM]][0] : !llvm.array<2 x vector<3xi8>> + // CHECK: %[[RHS_DIM_0:.*]] = llvm.extractvalue %[[RHS_LLVM]][0] : !llvm.array<2 x vector<3xi8>> + // CHECK: %[[ZIM_DIM_0:.*]] = llvm.shufflevector %[[LHS_DIM_0]], %[[RHS_DIM_0]] [0, 3, 1, 4, 2, 5] : vector<3xi8> + // CHECK: %[[RES_0:.*]] = llvm.insertvalue %[[ZIM_DIM_0]], %[[CST_LLVM]][0] : !llvm.array<2 x vector<6xi8>> + // CHECK: %[[LHS_DIM_1:.*]] = llvm.extractvalue %[[LHS_LLVM]][1] : !llvm.array<2 x vector<3xi8>> + // CHECK: %[[RHS_DIM_1:.*]] = llvm.extractvalue %[[RHS_LLVM]][1] : !llvm.array<2 x vector<3xi8>> + // CHECK: %[[ZIM_DIM_1:.*]] = llvm.shufflevector %[[LHS_DIM_1]], %[[RHS_DIM_1]] [0, 3, 1, 4, 2, 5] : vector<3xi8> + // CHECK: %[[RES_1:.*]] = llvm.insertvalue %[[ZIM_DIM_1]], %[[RES_0]][1] : !llvm.array<2 x vector<6xi8>> + // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[RES_1]] : !llvm.array<2 x vector<6xi8>> to vector<2x6xi8> + // CHECK: return %[[RES]] + %0 = vector.interleave %a, %b : vector<2x3xi8> + return %0 : vector<2x6xi8> +} + +// ----- + +// CHECK-LABEL: @vector_interleave_2d_scalable +// CHECK-SAME: %[[LHS:.*]]: vector<2x[8]xi16>, %[[RHS:.*]]: vector<2x[8]xi16>) +func.func @vector_interleave_2d_scalable(%a: vector<2x[8]xi16>, %b: vector<2x[8]xi16>) -> vector<2x[16]xi16> +{ + // CHECK: %[[LHS_LLVM:.*]] = builtin.unrealized_conversion_cast %arg0 : vector<2x[8]xi16> to !llvm.array<2 x vector<[8]xi16>> + // CHECK: %[[RHS_LLVM:.*]] = builtin.unrealized_conversion_cast %arg1 : vector<2x[8]xi16> to !llvm.array<2 x vector<[8]xi16>> + // CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<2x[16]xi16> + // CHECK: %[[CST_LLVM:.*]] = builtin.unrealized_conversion_cast %[[CST]] : vector<2x[16]xi16> to !llvm.array<2 x vector<[16]xi16>> + // CHECK: %[[LHS_DIM_0:.*]] = llvm.extractvalue %[[LHS_LLVM]][0] : !llvm.array<2 x vector<[8]xi16>> + // CHECK: %[[RHS_DIM_0:.*]] = llvm.extractvalue %[[RHS_LLVM]][0] : !llvm.array<2 x vector<[8]xi16>> + // CHECK: %[[ZIM_DIM_0:.*]] = "llvm.intr.experimental.vector.interleave2"(%[[LHS_DIM_0]], %[[RHS_DIM_0]]) : (vector<[8]xi16>, vector<[8]xi16>) -> vector<[16]xi16> + // CHECK: %[[RES_0:.*]] = llvm.insertvalue %[[ZIM_DIM_0]], %[[CST_LLVM]][0] : !llvm.array<2 x vector<[16]xi16>> + // CHECK: %[[LHS_DIM_1:.*]] = llvm.extractvalue %0[1] : !llvm.array<2 x vector<[8]xi16>> + // CHECK: %[[RHS_DIM_1:.*]] = llvm.extractvalue %1[1] : !llvm.array<2 x vector<[8]xi16>> + // CHECK: %[[ZIP_DIM_1:.*]] = "llvm.intr.experimental.vector.interleave2"(%[[LHS_DIM_1]], %[[RHS_DIM_1]]) : (vector<[8]xi16>, vector<[8]xi16>) -> vector<[16]xi16> + // CHECK: %[[RES_1:.*]] = llvm.insertvalue %[[ZIP_DIM_1]], %[[RES_0]][1] : !llvm.array<2 x vector<[16]xi16>> + // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[RES_1]] : !llvm.array<2 x vector<[16]xi16>> to vector<2x[16]xi16> + // CHECK: return %[[RES]] + %0 = vector.interleave %a, %b : vector<2x[8]xi16> + return %0 : vector<2x[16]xi16> +} diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index e6f045e12e519..4c73a6271786e 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -2567,3 +2567,26 @@ func.func @load_store_forwarding_rank_mismatch(%v0: vector<4x1x1xf32>, %arg0: te tensor<4x4x4xf32>, vector<1x100x4x5xf32> return %r : vector<1x100x4x5xf32> } + +// ----- + +// CHECK-LABEL: func.func @rank_0_shuffle_to_interleave( +// CHECK-SAME: %[[LHS:.*]]: vector, %[[RHS:.*]]: vector) +func.func @rank_0_shuffle_to_interleave(%arg0: vector, %arg1: vector) -> vector<2xf64> +{ + // CHECK: %[[ZIP:.*]] = vector.interleave %[[LHS]], %[[RHS]] : vector + // CHECK: return %[[ZIP]] + %0 = vector.shuffle %arg0, %arg1 [0, 1] : vector, vector + return %0 : vector<2xf64> +} + +// ----- + +// CHECK-LABEL: func.func @rank_1_shuffle_to_interleave( +// CHECK-SAME: %[[LHS:.*]]: vector<6xi32>, %[[RHS:.*]]: vector<6xi32>) +func.func @rank_1_shuffle_to_interleave(%arg0: vector<6xi32>, %arg1: vector<6xi32>) -> vector<12xi32> { + // CHECK: %[[ZIP:.*]] = vector.interleave %[[LHS]], %[[RHS]] : vector<6xi32> + // CHECK: return %[[ZIP]] + %0 = vector.shuffle %arg0, %arg1 [0, 6, 1, 7, 2, 8, 3, 9, 4, 10, 5, 11] : vector<6xi32>, vector<6xi32> + return %0 : vector<12xi32> +} diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir index 2f8530e7c171a..79a80be4f8b20 100644 --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -1081,3 +1081,38 @@ func.func @fastmath(%x: vector<42xf32>) -> f32 { %min = vector.reduction , %x fastmath : vector<42xf32> into f32 return %min: f32 } + +// CHECK-LABEL: @interleave_0d +func.func @interleave_0d(%a: vector, %b: vector) -> vector<2xf32> { + // CHECK: vector.interleave %{{.*}}, %{{.*}} : vector + %0 = vector.interleave %a, %b : vector + return %0 : vector<2xf32> +} + +// CHECK-LABEL: @interleave_1d +func.func @interleave_1d(%a: vector<4xf32>, %b: vector<4xf32>) -> vector<8xf32> { + // CHECK: vector.interleave %{{.*}}, %{{.*}} : vector<4xf32> + %0 = vector.interleave %a, %b : vector<4xf32> + return %0 : vector<8xf32> +} + +// CHECK-LABEL: @interleave_1d_scalable +func.func @interleave_1d_scalable(%a: vector<[8]xi16>, %b: vector<[8]xi16>) -> vector<[16]xi16> { + // CHECK: vector.interleave %{{.*}}, %{{.*}} : vector<[8]xi16> + %0 = vector.interleave %a, %b : vector<[8]xi16> + return %0 : vector<[16]xi16> +} + +// CHECK-LABEL: @interleave_2d +func.func @interleave_2d(%a: vector<2x8xf32>, %b: vector<2x8xf32>) -> vector<2x16xf32> { + // CHECK: vector.interleave %{{.*}}, %{{.*}} : vector<2x8xf32> + %0 = vector.interleave %a, %b : vector<2x8xf32> + return %0 : vector<2x16xf32> +} + +// CHECK-LABEL: @interleave_2d_scalable +func.func @interleave_2d_scalable(%a: vector<2x[2]xf64>, %b: vector<2x[2]xf64>) -> vector<2x[4]xf64> { + // CHECK: vector.interleave %{{.*}}, %{{.*}} : vector<2x[2]xf64> + %0 = vector.interleave %a, %b : vector<2x[2]xf64> + return %0 : vector<2x[4]xf64> +} diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/test-scalable-interleave.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/test-scalable-interleave.mlir new file mode 100644 index 0000000000000..479e50123bc2b --- /dev/null +++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/test-scalable-interleave.mlir @@ -0,0 +1,25 @@ +// RUN: mlir-opt %s -test-lower-to-llvm | \ +// RUN: %mcr_aarch64_cmd -e entry -entry-point-result=void \ +// RUN: -shared-libs=%mlir_c_runner_utils,%mlir_arm_runner_utils | \ +// RUN: FileCheck %s + +func.func @entry() { + %f1 = arith.constant 1.0 : f32 + %f2 = arith.constant 2.0 : f32 + %v1 = vector.splat %f1 : vector<[4]xf32> + %v2 = vector.splat %f2 : vector<[4]xf32> + vector.print %v1 : vector<[4]xf32> + vector.print %v2 : vector<[4]xf32> + // + // Test vectors: + // + // CHECK: ( 1, 1, 1, 1 + // CHECK: ( 2, 2, 2, 2 + + %v3 = vector.interleave %v1, %v2 : vector<[4]xf32> + vector.print %v3 : vector<[8]xf32> + // CHECK: ( 1, 2, 1, 2, 1, 2, 1, 2 + + return +} + diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-interleave.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-interleave.mlir new file mode 100644 index 0000000000000..69bf0320a3697 --- /dev/null +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-interleave.mlir @@ -0,0 +1,24 @@ +// RUN: mlir-opt %s -test-lower-to-llvm | \ +// RUN: mlir-cpu-runner -e entry -entry-point-result=void \ +// RUN: -shared-libs=%mlir_c_runner_utils | \ +// RUN: FileCheck %s + +func.func @entry() { + %f1 = arith.constant 1.0 : f32 + %f2 = arith.constant 2.0 : f32 + %v1 = vector.splat %f1 : vector<2x4xf32> + %v2 = vector.splat %f2 : vector<2x4xf32> + vector.print %v1 : vector<2x4xf32> + vector.print %v2 : vector<2x4xf32> + // + // Test vectors: + // + // CHECK: ( ( 1, 1, 1, 1 ), ( 1, 1, 1, 1 ) ) + // CHECK: ( ( 2, 2, 2, 2 ), ( 2, 2, 2, 2 ) ) + + %v3 = vector.interleave %v1, %v2 : vector<2x4xf32> + vector.print %v3 : vector<2x8xf32> + // CHECK: ( ( 1, 2, 1, 2, 1, 2, 1, 2 ), ( 1, 2, 1, 2, 1, 2, 1, 2 ) ) + + return +}