diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index bc08f8d07fb0d..6d50b0654bc57 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -478,6 +478,69 @@ def Vector_ShuffleOp : let hasCanonicalizer = 1; } +def Vector_InterleaveOp : + Vector_Op<"interleave", [Pure, + AllTypesMatch<["lhs", "rhs"]>, + TypesMatchWith< + "type of 'result' is double the width of the inputs", + "lhs", "result", + [{ + [&]() -> ::mlir::VectorType { + auto vectorType = ::llvm::cast($_self); + ::mlir::VectorType::Builder builder(vectorType); + if (vectorType.getRank() == 0) { + static constexpr int64_t v2xty_shape[] = { 2 }; + return builder.setShape(v2xty_shape); + } + auto lastDim = vectorType.getRank() - 1; + return builder.setDim(lastDim, vectorType.getDimSize(lastDim) * 2); + }() + }]>]> { + let summary = "constructs a vector by interleaving two input vectors"; + let description = [{ + The interleave operation constructs a new vector by interleaving the + elements from the trailing (or final) dimension of two input vectors, + returning a new vector where the trailing dimension is twice the size. + + Note that for the n-D case this differs from the interleaving possible with + `vector.shuffle`, which would only operate on the leading dimension. + + Another key difference is this operation supports scalable vectors, though + currently a general LLVM lowering is limited to the case where only the + trailing dimension is scalable. + + Example: + ```mlir + %0 = vector.interleave %a, %b + : vector<[4]xi32> ; yields vector<[8]xi32> + %1 = vector.interleave %c, %d + : vector<8xi8> ; yields vector<16xi8> + %2 = vector.interleave %e, %f + : vector ; yields vector<2xf16> + %3 = vector.interleave %g, %h + : vector<2x4x[2]xf64> ; yields vector<2x4x[4]xf64> + %4 = vector.interleave %i, %j + : vector<6x3xf32> ; yields vector<6x6xf32> + ``` + }]; + + let arguments = (ins AnyVectorOfAnyRank:$lhs, AnyVectorOfAnyRank:$rhs); + let results = (outs AnyVector:$result); + + let assemblyFormat = [{ + $lhs `,` $rhs attr-dict `:` type($lhs) + }]; + + let extraClassDeclaration = [{ + VectorType getSourceVectorType() { + return ::llvm::cast(getLhs().getType()); + } + VectorType getResultVectorType() { + return ::llvm::cast(getResult().getType()); + } + }]; +} + def Vector_ExtractElementOp : Vector_Op<"extractelement", [Pure, TypesMatchWith<"result type matches element type of vector operand", diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir index 2f8530e7c171a..79a80be4f8b20 100644 --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -1081,3 +1081,38 @@ func.func @fastmath(%x: vector<42xf32>) -> f32 { %min = vector.reduction , %x fastmath : vector<42xf32> into f32 return %min: f32 } + +// CHECK-LABEL: @interleave_0d +func.func @interleave_0d(%a: vector, %b: vector) -> vector<2xf32> { + // CHECK: vector.interleave %{{.*}}, %{{.*}} : vector + %0 = vector.interleave %a, %b : vector + return %0 : vector<2xf32> +} + +// CHECK-LABEL: @interleave_1d +func.func @interleave_1d(%a: vector<4xf32>, %b: vector<4xf32>) -> vector<8xf32> { + // CHECK: vector.interleave %{{.*}}, %{{.*}} : vector<4xf32> + %0 = vector.interleave %a, %b : vector<4xf32> + return %0 : vector<8xf32> +} + +// CHECK-LABEL: @interleave_1d_scalable +func.func @interleave_1d_scalable(%a: vector<[8]xi16>, %b: vector<[8]xi16>) -> vector<[16]xi16> { + // CHECK: vector.interleave %{{.*}}, %{{.*}} : vector<[8]xi16> + %0 = vector.interleave %a, %b : vector<[8]xi16> + return %0 : vector<[16]xi16> +} + +// CHECK-LABEL: @interleave_2d +func.func @interleave_2d(%a: vector<2x8xf32>, %b: vector<2x8xf32>) -> vector<2x16xf32> { + // CHECK: vector.interleave %{{.*}}, %{{.*}} : vector<2x8xf32> + %0 = vector.interleave %a, %b : vector<2x8xf32> + return %0 : vector<2x16xf32> +} + +// CHECK-LABEL: @interleave_2d_scalable +func.func @interleave_2d_scalable(%a: vector<2x[2]xf64>, %b: vector<2x[2]xf64>) -> vector<2x[4]xf64> { + // CHECK: vector.interleave %{{.*}}, %{{.*}} : vector<2x[2]xf64> + %0 = vector.interleave %a, %b : vector<2x[2]xf64> + return %0 : vector<2x[4]xf64> +}