diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index 56d866ac5b40c..c30996351c672 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -720,10 +720,9 @@ def Vector_ExtractOp : return getStaticPosition().size(); } + /// Return "true" if the op has at least one dynamic position. bool hasDynamicPosition() { - auto dynPos = getDynamicPosition(); - return std::any_of(dynPos.begin(), dynPos.end(), - [](Value operand) { return operand != nullptr; }); + return !getDynamicPosition().empty(); } }]; @@ -769,6 +768,41 @@ def Vector_FMAOp : }]; } +def Vector_FromElementsOp : Vector_Op<"from_elements", [ + Pure, + TypesMatchWith<"operand types match result element type", + "result", "elements", "SmallVector(" + "::llvm::cast($_self).getNumElements(), " + "::llvm::cast($_self).getElementType())">]> { + let summary = "operation that defines a vector from scalar elements"; + let description = [{ + This operation defines a vector from one or multiple scalar elements. The + number of elements must match the number of elements in the result type. + All elements must have the same type, which must match the element type of + the result vector type. + + `elements` are a flattened version of the result vector in row-major order. + + Example: + + ```mlir + // %f1 + %0 = vector.from_elements %f1 : vector + // [%f1, %f2] + %1 = vector.from_elements %f1, %f2 : vector<2xf32> + // [[%f1, %f2, %f3], [%f4, %f5, %f6]] + %2 = vector.from_elements %f1, %f2, %f3, %f4, %f5, %f6 : vector<2x3xf32> + // [[[%f1, %f2]], [[%f3, %f4]], [[%f5, %f6]]] + %3 = vector.from_elements %f1, %f2, %f3, %f4, %f5, %f6 : vector<3x1x2xf32> + ``` + }]; + + let arguments = (ins Variadic:$elements); + let results = (outs AnyVectorOfAnyRank:$result); + let assemblyFormat = "$elements attr-dict `:` type($result)"; + let hasCanonicalizer = 1; +} + def Vector_InsertElementOp : Vector_Op<"insertelement", [Pure, TypesMatchWith<"source operand type matches element type of result", diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 60f7e95ade689..0eac55255b133 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -1836,6 +1836,30 @@ struct VectorDeinterleaveOpLowering } }; +/// Conversion pattern for a `vector.from_elements`. +struct VectorFromElementsLowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(vector::FromElementsOp fromElementsOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = fromElementsOp.getLoc(); + VectorType vectorType = fromElementsOp.getType(); + // TODO: Multi-dimensional vectors lower to !llvm.array<... x vector<>>. + // Such ops should be handled in the same way as vector.insert. + if (vectorType.getRank() > 1) + return rewriter.notifyMatchFailure(fromElementsOp, + "rank > 1 vectors are not supported"); + Type llvmType = typeConverter->convertType(vectorType); + Value result = rewriter.create(loc, llvmType); + for (auto [idx, val] : llvm::enumerate(adaptor.getElements())) + result = rewriter.create(loc, val, result, idx); + rewriter.replaceOp(fromElementsOp, result); + return success(); + } +}; + } // namespace /// Populate the given list with patterns that convert from Vector to LLVM. @@ -1861,7 +1885,8 @@ void mlir::populateVectorToLLVMConversionPatterns( VectorSplatOpLowering, VectorSplatNdOpLowering, VectorScalableInsertOpLowering, VectorScalableExtractOpLowering, MaskedReductionOpConversion, VectorInterleaveOpLowering, - VectorDeinterleaveOpLowering>(converter); + VectorDeinterleaveOpLowering, VectorFromElementsLowering>( + converter); // Transfer ops with rank > 1 are handled by VectorToSCF. populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1); } diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 2bf4f16f96e6a..89805d90ea1b0 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -1877,6 +1877,45 @@ static Value foldExtractStridedOpFromInsertChain(ExtractOp extractOp) { return Value(); } +/// Try to fold the extraction of a scalar from a vector defined by +/// vector.from_elements. E.g.: +/// +/// %0 = vector.from_elements %a, %b : vector<2xf32> +/// %1 = vector.extract %0[0] : f32 from vector<2xf32> +/// ==> fold to %a +static Value foldScalarExtractFromFromElements(ExtractOp extractOp) { + // Dynamic extractions cannot be folded. + if (extractOp.hasDynamicPosition()) + return {}; + + // Look for extract(from_elements). + auto fromElementsOp = extractOp.getVector().getDefiningOp(); + if (!fromElementsOp) + return {}; + + // Scalable vectors are not supported. + auto vecType = llvm::cast(fromElementsOp.getType()); + if (vecType.isScalable()) + return {}; + + // Only extractions of scalars are supported. + int64_t rank = vecType.getRank(); + ArrayRef indices = extractOp.getStaticPosition(); + if (extractOp.getType() != vecType.getElementType()) + return {}; + assert(static_cast(indices.size()) == rank && + "unexpected number of indices"); + + // Compute flattened/linearized index and fold to operand. + int flatIndex = 0; + int stride = 1; + for (int i = rank - 1; i >= 0; --i) { + flatIndex += indices[i] * stride; + stride *= vecType.getDimSize(i); + } + return fromElementsOp.getElements()[flatIndex]; +} + OpFoldResult ExtractOp::fold(FoldAdaptor) { // Fold "vector.extract %v[] : vector<2x2xf32> from vector<2x2xf32>" to %v. // Note: Do not fold "vector.extract %v[] : f32 from vector" (type @@ -1895,6 +1934,8 @@ OpFoldResult ExtractOp::fold(FoldAdaptor) { return val; if (auto val = foldExtractStridedOpFromInsertChain(*this)) return val; + if (auto val = foldScalarExtractFromFromElements(*this)) + return val; return OpFoldResult(); } @@ -2099,6 +2140,52 @@ LogicalResult foldExtractFromShapeCastToShapeCast(ExtractOp extractOp, return success(); } +/// Try to canonicalize the extraction of a subvector from a vector defined by +/// vector.from_elements. E.g.: +/// +/// %0 = vector.from_elements %a, %b, %a, %a : vector<2x2xf32> +/// %1 = vector.extract %0[0] : vector<2xf32> from vector<2x2xf32> +/// ==> canonicalize to vector.from_elements %a, %b : vector<2xf32> +LogicalResult foldExtractFromFromElements(ExtractOp extractOp, + PatternRewriter &rewriter) { + // Dynamic positions are not supported. + if (extractOp.hasDynamicPosition()) + return failure(); + + // Scalar extracts are handled by the folder. + auto resultType = dyn_cast(extractOp.getType()); + if (!resultType) + return failure(); + + // Look for extracts from a from_elements op. + auto fromElementsOp = extractOp.getVector().getDefiningOp(); + if (!fromElementsOp) + return failure(); + VectorType inputType = fromElementsOp.getType(); + + // Scalable vectors are not supported. + if (resultType.isScalable() || inputType.isScalable()) + return failure(); + + // Compute the position of first extracted element and flatten/linearize the + // position. + SmallVector firstElementPos = + llvm::to_vector(extractOp.getStaticPosition()); + firstElementPos.append(/*NumInputs=*/resultType.getRank(), /*Elt=*/0); + int flatIndex = 0; + int stride = 1; + for (int64_t i = inputType.getRank() - 1; i >= 0; --i) { + flatIndex += firstElementPos[i] * stride; + stride *= inputType.getDimSize(i); + } + + // Replace the op with a smaller from_elements op. + rewriter.replaceOpWithNewOp( + extractOp, resultType, + fromElementsOp.getElements().slice(flatIndex, + resultType.getNumElements())); + return success(); +} } // namespace void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results, @@ -2106,6 +2193,7 @@ void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results, results.add(context); results.add(foldExtractFromShapeCastToShapeCast); + results.add(foldExtractFromFromElements); } static void populateFromInt64AttrArray(ArrayAttr arrayAttr, @@ -2122,6 +2210,29 @@ std::optional> FMAOp::getShapeForUnroll() { return llvm::to_vector<4>(getVectorType().getShape()); } +//===----------------------------------------------------------------------===// +// FromElementsOp +//===----------------------------------------------------------------------===// + +/// Rewrite a vector.from_elements into a vector.splat if all elements are the +/// same SSA value. E.g.: +/// +/// %0 = vector.from_elements %a, %a, %a : vector<3xf32> +/// ==> rewrite to vector.splat %a : vector<3xf32> +static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp, + PatternRewriter &rewriter) { + if (!llvm::all_equal(fromElementsOp.getElements())) + return failure(); + rewriter.replaceOpWithNewOp(fromElementsOp, fromElementsOp.getType(), + fromElementsOp.getElements().front()); + return success(); +} + +void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(rewriteFromElementsAsSplat); +} + //===----------------------------------------------------------------------===// // BroadcastOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir index bf4281ebcdec9..09b79708a9ab2 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -2590,3 +2590,34 @@ func.func @vector_bitcast_2d(%arg0: vector<2x4xi32>) -> vector<2x2xi64> { %0 = vector.bitcast %arg0 : vector<2x4xi32> to vector<2x2xi64> return %0 : vector<2x2xi64> } + +// ----- + +// CHECK-LABEL: func.func @vector_from_elements_1d( +// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32) +// CHECK: %[[undef:.*]] = llvm.mlir.undef : vector<3xf32> +// CHECK: %[[c0:.*]] = llvm.mlir.constant(0 : i64) : i64 +// CHECK: %[[insert0:.*]] = llvm.insertelement %[[a]], %[[undef]][%[[c0]] : i64] : vector<3xf32> +// CHECK: %[[c1:.*]] = llvm.mlir.constant(1 : i64) : i64 +// CHECK: %[[insert1:.*]] = llvm.insertelement %[[b]], %[[insert0]][%[[c1]] : i64] : vector<3xf32> +// CHECK: %[[c2:.*]] = llvm.mlir.constant(2 : i64) : i64 +// CHECK: %[[insert2:.*]] = llvm.insertelement %[[a]], %[[insert1]][%[[c2]] : i64] : vector<3xf32> +// CHECK: return %[[insert2]] +func.func @vector_from_elements_1d(%a: f32, %b: f32) -> vector<3xf32> { + %0 = vector.from_elements %a, %b, %a : vector<3xf32> + return %0 : vector<3xf32> +} + +// ----- + +// CHECK-LABEL: func.func @vector_from_elements_0d( +// CHECK-SAME: %[[a:.*]]: f32) +// CHECK: %[[undef:.*]] = llvm.mlir.undef : vector<1xf32> +// CHECK: %[[c0:.*]] = llvm.mlir.constant(0 : i64) : i64 +// CHECK: %[[insert0:.*]] = llvm.insertelement %[[a]], %[[undef]][%[[c0]] : i64] : vector<1xf32> +// CHECK: %[[cast:.*]] = builtin.unrealized_conversion_cast %[[insert0]] : vector<1xf32> to vector +// CHECK: return %[[cast]] +func.func @vector_from_elements_0d(%a: f32) -> vector { + %0 = vector.from_elements %a : vector + return %0 : vector +} diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index caccd1f1c9c24..8181f1a8c5d13 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -2642,3 +2642,72 @@ func.func @extract_from_0d_splat_broadcast_regression(%a: f32, %b: vector, // CHECK: return %[[a]], %[[a]], %[[extract1]], %[[a]], %[[a]], %[[extract2]], %[[extract3]] return %1, %3, %5, %7, %9, %10, %11 : f32, f32, f32, f32, f32, vector<6x7xf32>, vector<3xf32> } + +// ----- + +// CHECK-LABEL: func @extract_scalar_from_from_elements( +// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32) +func.func @extract_scalar_from_from_elements(%a: f32, %b: f32) -> (f32, f32, f32, f32, f32, f32, f32) { + // Extract from 0D. + %0 = vector.from_elements %a : vector + %1 = vector.extract %0[] : f32 from vector + + // Extract from 1D. + %2 = vector.from_elements %a : vector<1xf32> + %3 = vector.extract %2[0] : f32 from vector<1xf32> + %4 = vector.from_elements %a, %b, %a, %a, %b : vector<5xf32> + %5 = vector.extract %4[4] : f32 from vector<5xf32> + + // Extract from 2D. + %6 = vector.from_elements %a, %a, %a, %b, %b, %b : vector<2x3xf32> + %7 = vector.extract %6[0, 0] : f32 from vector<2x3xf32> + %8 = vector.extract %6[0, 1] : f32 from vector<2x3xf32> + %9 = vector.extract %6[1, 1] : f32 from vector<2x3xf32> + %10 = vector.extract %6[1, 2] : f32 from vector<2x3xf32> + + // CHECK: return %[[a]], %[[a]], %[[b]], %[[a]], %[[a]], %[[b]], %[[b]] + return %1, %3, %5, %7, %8, %9, %10 : f32, f32, f32, f32, f32, f32, f32 +} + +// ----- + +// CHECK-LABEL: func @extract_1d_from_from_elements( +// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32) +func.func @extract_1d_from_from_elements(%a: f32, %b: f32) -> (vector<3xf32>, vector<3xf32>) { + %0 = vector.from_elements %a, %a, %a, %b, %b, %b : vector<2x3xf32> + // CHECK: %[[splat1:.*]] = vector.splat %[[a]] : vector<3xf32> + %1 = vector.extract %0[0] : vector<3xf32> from vector<2x3xf32> + // CHECK: %[[splat2:.*]] = vector.splat %[[b]] : vector<3xf32> + %2 = vector.extract %0[1] : vector<3xf32> from vector<2x3xf32> + // CHECK: return %[[splat1]], %[[splat2]] + return %1, %2 : vector<3xf32>, vector<3xf32> +} + +// ----- + +// CHECK-LABEL: func @extract_2d_from_from_elements( +// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32) +func.func @extract_2d_from_from_elements(%a: f32, %b: f32) -> (vector<2x2xf32>, vector<2x2xf32>) { + %0 = vector.from_elements %a, %a, %a, %b, %b, %b, %b, %a, %b, %a, %a, %b : vector<3x2x2xf32> + // CHECK: %[[splat1:.*]] = vector.from_elements %[[a]], %[[a]], %[[a]], %[[b]] : vector<2x2xf32> + %1 = vector.extract %0[0] : vector<2x2xf32> from vector<3x2x2xf32> + // CHECK: %[[splat2:.*]] = vector.from_elements %[[b]], %[[b]], %[[b]], %[[a]] : vector<2x2xf32> + %2 = vector.extract %0[1] : vector<2x2xf32> from vector<3x2x2xf32> + // CHECK: return %[[splat1]], %[[splat2]] + return %1, %2 : vector<2x2xf32>, vector<2x2xf32> +} + +// ----- + +// CHECK-LABEL: func @from_elements_to_splat( +// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32) +func.func @from_elements_to_splat(%a: f32, %b: f32) -> (vector<2x3xf32>, vector<2x3xf32>, vector) { + // CHECK: %[[splat:.*]] = vector.splat %[[a]] : vector<2x3xf32> + %0 = vector.from_elements %a, %a, %a, %a, %a, %a : vector<2x3xf32> + // CHECK: %[[from_el:.*]] = vector.from_elements {{.*}} : vector<2x3xf32> + %1 = vector.from_elements %a, %a, %a, %a, %b, %a : vector<2x3xf32> + // CHECK: %[[splat2:.*]] = vector.splat %[[a]] : vector + %2 = vector.from_elements %a : vector + // CHECK: return %[[splat]], %[[from_el]], %[[splat2]] + return %0, %1, %2 : vector<2x3xf32>, vector<2x3xf32>, vector +} diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir index 1516f51fe1458..d0eaed8f98cc5 100644 --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -1854,3 +1854,20 @@ func.func @deinterleave_scalable_rank_fail(%vec : vector<2x[4]xf32>) { %0, %1 = "vector.deinterleave" (%vec) : (vector<2x[4]xf32>) -> (vector<2x[2]xf32>, vector<[2]xf32>) return } + +// ----- + +func.func @invalid_from_elements(%a: f32) { + // expected-error @+1 {{'vector.from_elements' 1 operands present, but expected 2}} + vector.from_elements %a : vector<2xf32> + return +} + +// ----- + +// expected-note @+1 {{prior use here}} +func.func @invalid_from_elements(%a: f32, %b: i32) { + // expected-error @+1 {{use of value '%b' expects different type than prior uses: 'f32' vs 'i32'}} + vector.from_elements %a, %b : vector<2xf32> + return +} diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir index c868c881d079a..4da09584db88b 100644 --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -1158,3 +1158,17 @@ func.func @deinterleave_nd_scalable(%arg:vector<2x3x4x[6]xf32>) -> (vector<2x3x4 %0, %1 = vector.deinterleave %arg : vector<2x3x4x[6]xf32> -> vector<2x3x4x[3]xf32> return %0, %1 : vector<2x3x4x[3]xf32>, vector<2x3x4x[3]xf32> } + +// CHECK-LABEL: func @from_elements( +// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: f32) +func.func @from_elements(%a: f32, %b: f32) -> (vector, vector<1xf32>, vector<1x2xf32>, vector<2x2xf32>) { + // CHECK: vector.from_elements %[[a]] : vector + %0 = vector.from_elements %a : vector + // CHECK: vector.from_elements %[[a]] : vector<1xf32> + %1 = vector.from_elements %a : vector<1xf32> + // CHECK: vector.from_elements %[[a]], %[[b]] : vector<1x2xf32> + %2 = vector.from_elements %a, %b : vector<1x2xf32> + // CHECK: vector.from_elements %[[b]], %[[b]], %[[a]], %[[a]] : vector<2x2xf32> + %3 = vector.from_elements %b, %b, %a, %a : vector<2x2xf32> + return %0, %1, %2, %3 : vector, vector<1xf32>, vector<1x2xf32>, vector<2x2xf32> +} \ No newline at end of file