diff --git a/mlir/docs/Rationale/Rationale.md b/mlir/docs/Rationale/Rationale.md index f09f03b946a8..c159d3eda59f 100644 --- a/mlir/docs/Rationale/Rationale.md +++ b/mlir/docs/Rationale/Rationale.md @@ -202,39 +202,39 @@ and described in interest [starts here](https://www.google.com/url?q=https://youtu.be/Ntj8ab-5cvE?t%3D596&sa=D&ust=1529450150971000&usg=AFQjCNFQHEWL7m8q3eO-1DiKw9zqC2v24Q). -### Index type disallowed in vector types - -Index types are not allowed as elements of `vector` types. Index -types are intended to be used for platform-specific "size" values and may appear -in subscripts, sizes of aggregate types and affine expressions. They are also -tightly coupled with `affine.apply` and affine.load/store operations; having -`index` type is a necessary precondition of a value to be acceptable by these -operations. - -We allow `index` types in tensors and memrefs as a code generation strategy has -to map `index` to an implementation type and hence needs to be able to -materialize corresponding values. However, the target might lack support for +### Index type usage and limitations + +Index types are intended to be used for platform-specific "size" values and may +appear in subscripts, sizes of aggregate types and affine expressions. They are +also tightly coupled with `affine.apply` and affine.load/store operations; +having `index` type is a necessary precondition of a value to be acceptable by +these operations. + +We allow `index` types in tensors, vectors, and memrefs as a code generation +strategy has to map `index` to an implementation type and hence needs to be able +to materialize corresponding values. However, the target might lack support for `vector` values with the target specific equivalent of the `index` type. -### Bit width of a non-primitive type and `index` is undefined - -The bit width of a compound type is not defined by MLIR, it may be defined by a -specific lowering pass. In MLIR, bit width is a property of certain primitive -_type_, in particular integers and floats. It is equal to the number that -appears in the type definition, e.g. the bit width of `i32` is `32`, so is the -bit width of `f32`. The bit width is not _necessarily_ related to the amount of -memory (in bytes) or the size of register (in bits) that is necessary to store -the value of the given type. These quantities are target and ABI-specific and -should be defined during the lowering process rather than imposed from above. -For example, `vector<3xi57>` is likely to be lowered to a vector of four 64-bit -integers, so that its storage requirement is `4 x 64 / 8 = 32` bytes, rather -than `(3 x 57) ceildiv 8 = 22` bytes as can be naively computed from the -bitwidth. Individual components of MLIR that allocate space for storing values -may use the bit size as the baseline and query the target description when it is -introduced. - -The bit width is not defined for dialect-specific types at MLIR level. Dialects -are free to define their own quantities for type sizes. +### Data layout of non-primitive types + +Data layout information such as the bit width or the alignment of types may be +target and ABI-specific and thus should be configurable rather than imposed by +the compiler. Especially, the layout of compound or `index` types may vary. MLIR +specifies default bit widths for certain primitive _types_, in particular for +integers and floats. It is equal to the number that appears in the type +definition, e.g. the bit width of `i32` is `32`, so is the bit width of `f32`. +The bit width is not _necessarily_ related to the amount of memory (in bytes) or +the register size (in bits) that is necessary to store the value of the given +type. For example, `vector<3xi57>` is likely to be lowered to a vector of four +64-bit integers, so that its storage requirement is `4 x 64 / 8 = 32` bytes, +rather than `(3 x 57) ceildiv 8 = 22` bytes as can be naively computed from the +bit width. MLIR makes such [data layout information](../DataLayout.md) +configurable using attributes that can be queried during lowering, for example, +when allocating a compound type. + +The data layout of dialect-specific types is undefined at MLIR level. Yet +dialects are free to define their own quantities and make them available via the +data layout infrastructure. ### Integer signedness semantics diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td index fcfe8f1850e9..6d058f426141 100644 --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -1738,8 +1738,8 @@ def SplatOp : Std_Op<"splat", [NoSideEffect, let summary = "splat or broadcast operation"; let description = [{ Broadcast the operand to all elements of the result vector or tensor. The - operand has to be of either integer or float type. When the result is a - tensor, it has to be statically shaped. + operand has to be of integer/index/float type. When the result is a tensor, + it has to be statically shaped. Example: @@ -1761,8 +1761,8 @@ def SplatOp : Std_Op<"splat", [NoSideEffect, ``` }]; - let arguments = (ins AnyTypeOf<[AnySignlessInteger, AnyFloat], - "integer or float type">:$input); + let arguments = (ins AnyTypeOf<[AnySignlessInteger, Index, AnyFloat], + "integer/index/float type">:$input); let results = (outs AnyTypeOf<[AnyVector, AnyStaticShapeTensor]>:$aggregate); let builders = [ diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td index 14afe9504806..0a1228599b60 100644 --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -2307,13 +2307,13 @@ def Vector_MatmulOp : Vector_Op<"matrix_multiply", [NoSideEffect, Arguments<( // TODO: tighten vector element types that make sense. ins VectorOfRankAndType<[1], - [AnySignlessInteger, AnySignedInteger, AnyFloat]>:$lhs, + [AnySignlessInteger, AnySignedInteger, Index, AnyFloat]>:$lhs, VectorOfRankAndType<[1], - [AnySignlessInteger, AnySignedInteger, AnyFloat]>:$rhs, + [AnySignlessInteger, AnySignedInteger, Index, AnyFloat]>:$rhs, I32Attr:$lhs_rows, I32Attr:$lhs_columns, I32Attr:$rhs_columns)>, Results<( outs VectorOfRankAndType<[1], - [AnySignlessInteger, AnySignedInteger, AnyFloat]>:$res)> + [AnySignlessInteger, AnySignedInteger, Index, AnyFloat]>:$res)> { let summary = "Vector matrix multiplication op that operates on flattened 1-D" " MLIR vectors"; @@ -2370,11 +2370,11 @@ def Vector_FlatTransposeOp : Vector_Op<"flat_transpose", [NoSideEffect, Arguments<( // TODO: tighten vector element types that make sense. ins VectorOfRankAndType<[1], - [AnySignlessInteger, AnySignedInteger, AnyFloat]>:$matrix, + [AnySignlessInteger, AnySignedInteger, Index, AnyFloat]>:$matrix, I32Attr:$rows, I32Attr:$columns)>, Results<( outs VectorOfRankAndType<[1], - [AnySignlessInteger, AnySignedInteger, AnyFloat]>:$res)> { + [AnySignlessInteger, AnySignedInteger, Index, AnyFloat]>:$res)> { let summary = "Vector matrix transposition on flattened 1-D MLIR vectors"; let description = [{ This is the counterpart of llvm.matrix.transpose in MLIR. It serves diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td index 22d194db3b68..f271c56f4162 100644 --- a/mlir/include/mlir/IR/BuiltinTypes.td +++ b/mlir/include/mlir/IR/BuiltinTypes.td @@ -874,7 +874,7 @@ def Builtin_Vector : Builtin_Type<"Vector", "ShapedType"> { ``` vector-type ::= `vector` `<` static-dimension-list vector-element-type `>` - vector-element-type ::= float-type | integer-type + vector-element-type ::= float-type | integer-type | index-type static-dimension-list ::= (decimal-literal `x`)+ ``` @@ -911,9 +911,10 @@ def Builtin_Vector : Builtin_Type<"Vector", "ShapedType"> { ]; let extraClassDeclaration = [{ /// Returns true of the given type can be used as an element of a vector - /// type. In particular, vectors can consist of integer or float primitives. + /// type. In particular, vectors can consist of integer, index, or float + /// primitives. static bool isValidElementType(Type t) { - return t.isa(); + return t.isa(); } /// Get or create a new VectorType with the same shape as `this` and an diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td index 3ea9bb41518e..a2469dc5bee3 100644 --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -758,11 +758,11 @@ def BoolLike : TypeConstraint.predicate, "bool-like">; // Type constraint for signless-integer-like types: signless integers, indices, -// vectors of signless integers, tensors of signless integers. +// vectors of signless integers or indices, tensors of signless integers. def SignlessIntegerLike : TypeConstraint.predicate, - TensorOf<[AnySignlessInteger]>.predicate]>, + VectorOf<[AnySignlessInteger, Index]>.predicate, + TensorOf<[AnySignlessInteger, Index]>.predicate]>, "signless-integer-like">; // Type constraint for float-like types: floats, vectors or tensors thereof. diff --git a/mlir/include/mlir/Interfaces/DataLayoutInterfaces.h b/mlir/include/mlir/Interfaces/DataLayoutInterfaces.h index 87cac054c55e..0633eb341f11 100644 --- a/mlir/include/mlir/Interfaces/DataLayoutInterfaces.h +++ b/mlir/include/mlir/Interfaces/DataLayoutInterfaces.h @@ -144,6 +144,9 @@ class DataLayout { explicit DataLayout(DataLayoutOpInterface op); explicit DataLayout(ModuleOp op); + /// Returns the layout of the closest parent operation carrying layout info. + static DataLayout closest(Operation *op); + /// Returns the size of the given type in the current scope. unsigned getTypeSize(Type t) const; diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 0c752c33ff16..9ecee857e2e5 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -163,13 +163,13 @@ replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter, LLVMTypeConverter &typeConverter, Location loc, TransferReadOp xferOp, ArrayRef operands, Value dataPtr, Value mask) { - VectorType fillType = xferOp.getVectorType(); - Value fill = rewriter.create(loc, fillType, xferOp.padding()); - Type vecTy = typeConverter.convertType(xferOp.getVectorType()); if (!vecTy) return failure(); + auto adaptor = TransferReadOpAdaptor(operands, xferOp->getAttrDictionary()); + Value fill = rewriter.create(loc, vecTy, adaptor.padding()); + unsigned align; if (failed(getMemRefAlignment( typeConverter, xferOp.getShapedType().cast(), align))) diff --git a/mlir/lib/Dialect/Vector/CMakeLists.txt b/mlir/lib/Dialect/Vector/CMakeLists.txt index 1c895f950c28..b74c9a5a823a 100644 --- a/mlir/lib/Dialect/Vector/CMakeLists.txt +++ b/mlir/lib/Dialect/Vector/CMakeLists.txt @@ -23,6 +23,7 @@ add_mlir_dialect_library(MLIRVector MLIRMemRef MLIRSCF MLIRLoopAnalysis + MLIRDataLayoutInterfaces MLIRSideEffectInterfaces MLIRVectorInterfaces ) diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp index cff5fcb5649e..0ad89109b3ad 100644 --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -2202,12 +2202,15 @@ static LogicalResult verifyTransferOp(Operation *op, ShapedType shapedType, return op->emitOpError( "requires source to be a memref or ranked tensor type"); auto elementType = shapedType.getElementType(); + DataLayout dataLayout = DataLayout::closest(op); if (auto vectorElementType = elementType.dyn_cast()) { // Memref or tensor has vector element type. - unsigned sourceVecSize = vectorElementType.getElementTypeBitWidth() * - vectorElementType.getShape().back(); + unsigned sourceVecSize = + dataLayout.getTypeSizeInBits(vectorElementType.getElementType()) * + vectorElementType.getShape().back(); unsigned resultVecSize = - vectorType.getElementTypeBitWidth() * vectorType.getShape().back(); + dataLayout.getTypeSizeInBits(vectorType.getElementType()) * + vectorType.getShape().back(); if (resultVecSize % sourceVecSize != 0) return op->emitOpError( "requires the bitwidth of the minor 1-D vector to be an integral " @@ -2226,8 +2229,9 @@ static LogicalResult verifyTransferOp(Operation *op, ShapedType shapedType, } else { // Memref or tensor has scalar element type. unsigned resultVecSize = - vectorType.getElementTypeBitWidth() * vectorType.getShape().back(); - if (resultVecSize % elementType.getIntOrFloatBitWidth() != 0) + dataLayout.getTypeSizeInBits(vectorType.getElementType()) * + vectorType.getShape().back(); + if (resultVecSize % dataLayout.getTypeSizeInBits(elementType) != 0) return op->emitOpError( "requires the bitwidth of the minor 1-D vector to be an integral " "multiple of the bitwidth of the source element type"); @@ -3233,9 +3237,10 @@ static LogicalResult verify(BitCastOp op) { return op.emitOpError("dimension size mismatch at: ") << i; } - if (sourceVectorType.getElementTypeBitWidth() * + DataLayout dataLayout = DataLayout::closest(op); + if (dataLayout.getTypeSizeInBits(sourceVectorType.getElementType()) * sourceVectorType.getShape().back() != - resultVectorType.getElementTypeBitWidth() * + dataLayout.getTypeSizeInBits(resultVectorType.getElementType()) * resultVectorType.getShape().back()) return op.emitOpError( "source/result bitwidth of the minor 1-D vectors must be equal"); diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp index ba8ca26b336e..3bb333cf786d 100644 --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -1388,7 +1388,7 @@ class OuterProductOpLowering : public OpRewritePattern { VectorType rhsType = op.getOperandTypeRHS().dyn_cast(); VectorType resType = op.getVectorType(); Type eltType = resType.getElementType(); - bool isInt = eltType.isa(); + bool isInt = eltType.isa(); Value acc = (op.acc().empty()) ? nullptr : op.acc()[0]; vector::CombiningKind kind = op.kind(); diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp index 8ef7c2674184..ce6d4b3a603b 100644 --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -693,7 +693,7 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type, "expected attribute value to have element type"); if (eltType.isa()) intVal = values[i].cast().getValue().bitcastToAPInt(); - else if (eltType.isa()) + else if (eltType.isa()) intVal = values[i].cast().getValue(); else llvm_unreachable("unexpected element type"); diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp index f792dfeacbf1..da1453367c7a 100644 --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -392,7 +392,7 @@ LogicalResult VectorType::verify(function_ref emitError, return emitError() << "vector types must have at least one dimension"; if (!isValidElementType(elementType)) - return emitError() << "vector elements must be int or float type"; + return emitError() << "vector elements must be int/index/float type"; if (any_of(shape, [](int64_t i) { return i <= 0; })) return emitError() << "vector types must have positive constant sizes"; diff --git a/mlir/lib/Interfaces/DataLayoutInterfaces.cpp b/mlir/lib/Interfaces/DataLayoutInterfaces.cpp index 9f5c75a425fb..3369d61a834b 100644 --- a/mlir/lib/Interfaces/DataLayoutInterfaces.cpp +++ b/mlir/lib/Interfaces/DataLayoutInterfaces.cpp @@ -264,6 +264,19 @@ mlir::DataLayout::DataLayout(ModuleOp op) #endif } +mlir::DataLayout mlir::DataLayout::closest(Operation *op) { + // Search the closest parent either being a module operation or implementing + // the data layout interface. + while (op) { + if (auto module = dyn_cast(op)) + return DataLayout(module); + if (auto iface = dyn_cast(op)) + return DataLayout(iface); + op = op->getParentOp(); + } + return DataLayout(); +} + void mlir::DataLayout::checkValid() const { #ifndef NDEBUG SmallVector specs; diff --git a/mlir/lib/Parser/TypeParser.cpp b/mlir/lib/Parser/TypeParser.cpp index 378b82f3bb1f..d81cb53060b1 100644 --- a/mlir/lib/Parser/TypeParser.cpp +++ b/mlir/lib/Parser/TypeParser.cpp @@ -472,7 +472,7 @@ VectorType Parser::parseVectorType() { if (!elementType || parseToken(Token::greater, "expected '>' in vector type")) return nullptr; if (!VectorType::isValidElementType(elementType)) - return emitError(typeLoc, "vector elements must be int or float type"), + return emitError(typeLoc, "vector elements must be int/index/float type"), nullptr; return VectorType::get(dimensions, elementType); diff --git a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir index 5eca81dcad00..1d12eb937378 100644 --- a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir +++ b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir @@ -248,3 +248,16 @@ func @fmaf(%arg0: f32, %arg1: vector<4xf32>) { %1 = fmaf %arg1, %arg1, %arg1 : vector<4xf32> std.return } + +// ----- + +// CHECK-LABEL: func @index_vector( +// CHECK-SAME: %[[ARG0:.*]]: vector<4xi64> +func @index_vector(%arg0: vector<4xindex>) { + // CHECK: %[[CST:.*]] = llvm.mlir.constant(dense<[0, 1, 2, 3]> : vector<4xindex>) : vector<4xi64> + %0 = constant dense<[0, 1, 2, 3]> : vector<4xindex> + // CHECK: %[[V:.*]] = llvm.add %[[ARG0]], %[[CST]] : vector<4xi64> + %1 = addi %arg0, %0 : vector<4xindex> + std.return +} + diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir index 9faf7caa3439..c3ca8ef095e5 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -23,18 +23,40 @@ func @bitcast_i8_to_f32_vector(%input: vector<64xi8>) -> vector<16xf32> { // ----- +func @bitcast_index_to_i8_vector(%input: vector<16xindex>) -> vector<128xi8> { + %0 = vector.bitcast %input : vector<16xindex> to vector<128xi8> + return %0 : vector<128xi8> +} + +// CHECK-LABEL: @bitcast_index_to_i8_vector +// CHECK-SAME: %[[input:.*]]: vector<16xindex> +// CHECK: %[[T0:.*]] = llvm.mlir.cast %[[input]] : vector<16xindex> to vector<16xi64> +// CHECK: llvm.bitcast %[[T0]] : vector<16xi64> to vector<128xi8> + +// ----- -func @broadcast_vec1d_from_scalar(%arg0: f32) -> vector<2xf32> { +func @broadcast_vec1d_from_f32(%arg0: f32) -> vector<2xf32> { %0 = vector.broadcast %arg0 : f32 to vector<2xf32> return %0 : vector<2xf32> } -// CHECK-LABEL: @broadcast_vec1d_from_scalar +// CHECK-LABEL: @broadcast_vec1d_from_f32 // CHECK-SAME: %[[A:.*]]: f32) // CHECK: %[[T0:.*]] = splat %[[A]] : vector<2xf32> // CHECK: return %[[T0]] : vector<2xf32> // ----- +func @broadcast_vec1d_from_index(%arg0: index) -> vector<2xindex> { + %0 = vector.broadcast %arg0 : index to vector<2xindex> + return %0 : vector<2xindex> +} +// CHECK-LABEL: @broadcast_vec1d_from_index +// CHECK-SAME: %[[A:.*]]: index) +// CHECK: %[[T0:.*]] = splat %[[A]] : vector<2xindex> +// CHECK: return %[[T0]] : vector<2xindex> + +// ----- + func @broadcast_vec2d_from_scalar(%arg0: f32) -> vector<2x3xf32> { %0 = vector.broadcast %arg0 : f32 to vector<2x3xf32> return %0 : vector<2x3xf32> @@ -83,6 +105,22 @@ func @broadcast_vec2d_from_vec1d(%arg0: vector<2xf32>) -> vector<3x2xf32> { // ----- +func @broadcast_vec2d_from_index_vec1d(%arg0: vector<2xindex>) -> vector<3x2xindex> { + %0 = vector.broadcast %arg0 : vector<2xindex> to vector<3x2xindex> + return %0 : vector<3x2xindex> +} +// CHECK-LABEL: @broadcast_vec2d_from_index_vec1d( +// CHECK-SAME: %[[A:.*]]: vector<2xindex>) +// CHECK: %[[T0:.*]] = constant dense<0> : vector<3x2xindex> +// CHECK: %[[T1:.*]] = llvm.mlir.cast %[[A]] : vector<2xindex> to vector<2xi64> +// CHECK: %[[T2:.*]] = llvm.mlir.cast %[[T0]] : vector<3x2xindex> to !llvm.array<3 x vector<2xi64>> +// CHECK: %[[T3:.*]] = llvm.insertvalue %[[T1]], %[[T2]][0] : !llvm.array<3 x vector<2xi64>> + +// CHECK: %[[T4:.*]] = llvm.mlir.cast %{{.*}} : !llvm.array<3 x vector<2xi64>> to vector<3x2xindex> +// CHECK: return %[[T4]] : vector<3x2xindex> + +// ----- + func @broadcast_vec3d_from_vec1d(%arg0: vector<2xf32>) -> vector<4x3x2xf32> { %0 = vector.broadcast %arg0 : vector<2xf32> to vector<4x3x2xf32> return %0 : vector<4x3x2xf32> @@ -264,6 +302,26 @@ func @outerproduct(%arg0: vector<2xf32>, %arg1: vector<3xf32>) -> vector<2x3xf32 // ----- +func @outerproduct_index(%arg0: vector<2xindex>, %arg1: vector<3xindex>) -> vector<2x3xindex> { + %2 = vector.outerproduct %arg0, %arg1 : vector<2xindex>, vector<3xindex> + return %2 : vector<2x3xindex> +} +// CHECK-LABEL: @outerproduct_index( +// CHECK-SAME: %[[A:.*]]: vector<2xindex>, +// CHECK-SAME: %[[B:.*]]: vector<3xindex>) +// CHECK: %[[T0:.*]] = constant dense<0> : vector<2x3xindex> +// CHECK: %[[T1:.*]] = llvm.mlir.cast %[[A]] : vector<2xindex> to vector<2xi64> +// CHECK: %[[T2:.*]] = llvm.mlir.constant(0 : i64) : i64 +// CHECK: %[[T3:.*]] = llvm.extractelement %[[T1]]{{\[}}%[[T2]] : i64] : vector<2xi64> +// CHECK: %[[T4:.*]] = llvm.mlir.cast %[[T3]] : i64 to index +// CHECK: %[[T5:.*]] = splat %[[T4]] : vector<3xindex> +// CHECK: %[[T6:.*]] = muli %[[T5]], %[[B]] : vector<3xindex> +// CHECK: %[[T7:.*]] = llvm.mlir.cast %[[T6]] : vector<3xindex> to vector<3xi64> +// CHECK: %[[T8:.*]] = llvm.mlir.cast %[[T0]] : vector<2x3xindex> to !llvm.array<2 x vector<3xi64>> +// CHECK: %{{.*}} = llvm.insertvalue %[[T7]], %[[T8]][0] : !llvm.array<2 x vector<3xi64>> + +// ----- + func @outerproduct_add(%arg0: vector<2xf32>, %arg1: vector<3xf32>, %arg2: vector<2x3xf32>) -> vector<2x3xf32> { %2 = vector.outerproduct %arg0, %arg1, %arg2 : vector<2xf32>, vector<3xf32> return %2 : vector<2x3xf32> @@ -305,6 +363,21 @@ func @shuffle_1D_direct(%arg0: vector<2xf32>, %arg1: vector<2xf32>) -> vector<2x // ----- +func @shuffle_1D_index_direct(%arg0: vector<2xindex>, %arg1: vector<2xindex>) -> vector<2xindex> { + %1 = vector.shuffle %arg0, %arg1 [0, 1] : vector<2xindex>, vector<2xindex> + return %1 : vector<2xindex> +} +// CHECK-LABEL: @shuffle_1D_index_direct( +// CHECK-SAME: %[[A:.*]]: vector<2xindex>, +// CHECK-SAME: %[[B:.*]]: vector<2xindex>) +// CHECK: %[[T0:.*]] = llvm.mlir.cast %[[A]] : vector<2xindex> to vector<2xi64> +// CHECK: %[[T1:.*]] = llvm.mlir.cast %[[B]] : vector<2xindex> to vector<2xi64> +// CHECK: %[[T2:.*]] = llvm.shufflevector %[[T0]], %[[T1]] [0, 1] : vector<2xi64>, vector<2xi64> +// CHECK: %[[T3:.*]] = llvm.mlir.cast %[[T2]] : vector<2xi64> to vector<2xindex> +// CHECK: return %[[T3]] : vector<2xindex> + +// ----- + func @shuffle_1D(%arg0: vector<2xf32>, %arg1: vector<3xf32>) -> vector<5xf32> { %1 = vector.shuffle %arg0, %arg1 [4, 3, 2, 1, 0] : vector<2xf32>, vector<3xf32> return %1 : vector<5xf32> @@ -382,6 +455,20 @@ func @extract_element_from_vec_1d(%arg0: vector<16xf32>) -> f32 { // ----- +func @extract_index_element_from_vec_1d(%arg0: vector<16xindex>) -> index { + %0 = vector.extract %arg0[15]: vector<16xindex> + return %0 : index +} +// CHECK-LABEL: @extract_index_element_from_vec_1d( +// CHECK-SAME: %[[A:.*]]: vector<16xindex>) +// CHECK: %[[T0:.*]] = llvm.mlir.cast %[[A]] : vector<16xindex> to vector<16xi64> +// CHECK: %[[T1:.*]] = llvm.mlir.constant(15 : i64) : i64 +// CHECK: %[[T2:.*]] = llvm.extractelement %[[T0]][%[[T1]] : i64] : vector<16xi64> +// CHECK: %[[T3:.*]] = llvm.mlir.cast %[[T2]] : i64 to index +// CHECK: return %[[T3]] : index + +// ----- + func @extract_vec_2d_from_vec_3d(%arg0: vector<4x3x16xf32>) -> vector<3x16xf32> { %0 = vector.extract %arg0[0]: vector<4x3x16xf32> return %0 : vector<3x16xf32> @@ -439,6 +526,22 @@ func @insert_element_into_vec_1d(%arg0: f32, %arg1: vector<4xf32>) -> vector<4xf // ----- +func @insert_index_element_into_vec_1d(%arg0: index, %arg1: vector<4xindex>) -> vector<4xindex> { + %0 = vector.insert %arg0, %arg1[3] : index into vector<4xindex> + return %0 : vector<4xindex> +} +// CHECK-LABEL: @insert_index_element_into_vec_1d( +// CHECK-SAME: %[[A:.*]]: index, +// CHECK-SAME: %[[B:.*]]: vector<4xindex>) +// CHECK: %[[T0:.*]] = llvm.mlir.cast %[[A]] : index to i64 +// CHECK: %[[T1:.*]] = llvm.mlir.cast %[[B]] : vector<4xindex> to vector<4xi64> +// CHECK: %[[T3:.*]] = llvm.mlir.constant(3 : i64) : i64 +// CHECK: %[[T4:.*]] = llvm.insertelement %[[T0]], %[[T1]][%[[T3]] : i64] : vector<4xi64> +// CHECK: %[[T5:.*]] = llvm.mlir.cast %[[T4]] : vector<4xi64> to vector<4xindex> +// CHECK: return %[[T5]] : vector<4xindex> + +// ----- + func @insert_vec_2d_into_vec_3d(%arg0: vector<8x16xf32>, %arg1: vector<4x8x16xf32>) -> vector<4x8x16xf32> { %0 = vector.insert %arg0, %arg1[3] : vector<8x16xf32> into vector<4x8x16xf32> return %0 : vector<4x8x16xf32> @@ -489,6 +592,18 @@ func @vector_type_cast(%arg0: memref<8x8x8xf32>) -> memref> { // ----- +func @vector_index_type_cast(%arg0: memref<8x8x8xindex>) -> memref> { + %0 = vector.type_cast %arg0: memref<8x8x8xindex> to memref> + return %0 : memref> +} +// CHECK-LABEL: @vector_index_type_cast( +// CHECK-SAME: %[[A:.*]]: memref<8x8x8xindex>) +// CHECK: %{{.*}} = llvm.mlir.cast %[[A]] : memref<8x8x8xindex> to !llvm.struct<(ptr, ptr, i64, array<3 x i64>, array<3 x i64>)> + +// CHECK: %{{.*}} = llvm.mlir.cast %{{.*}} : !llvm.struct<(ptr>>>, ptr>>>, i64)> to memref> + +// ----- + func @vector_type_cast_non_zero_addrspace(%arg0: memref<8x8x8xf32, 3>) -> memref, 3> { %0 = vector.type_cast %arg0: memref<8x8x8xf32, 3> to memref, 3> return %0 : memref, 3> @@ -723,6 +838,20 @@ func @extract_strided_slice1(%arg0: vector<4xf32>) -> vector<2xf32> { // ----- +func @extract_strided_index_slice1(%arg0: vector<4xindex>) -> vector<2xindex> { + %0 = vector.extract_strided_slice %arg0 {offsets = [2], sizes = [2], strides = [1]} : vector<4xindex> to vector<2xindex> + return %0 : vector<2xindex> +} +// CHECK-LABEL: @extract_strided_index_slice1( +// CHECK-SAME: %[[A:.*]]: vector<4xindex>) +// CHECK: %[[T0:.*]] = llvm.mlir.cast %[[A]] : vector<4xindex> to vector<4xi64> +// CHECK: %[[T1:.*]] = llvm.mlir.cast %[[A]] : vector<4xindex> to vector<4xi64> +// CHECK: %[[T2:.*]] = llvm.shufflevector %[[T0]], %[[T1]] [2, 3] : vector<4xi64>, vector<4xi64> +// CHECK: %[[T3:.*]] = llvm.mlir.cast %[[T2]] : vector<2xi64> to vector<2xindex> +// CHECK: return %[[T3]] : vector<2xindex> + +// ----- + func @extract_strided_slice2(%arg0: vector<4x8xf32>) -> vector<2x8xf32> { %0 = vector.extract_strided_slice %arg0 {offsets = [2], sizes = [2], strides = [1]} : vector<4x8xf32> to vector<2x8xf32> return %0 : vector<2x8xf32> @@ -772,6 +901,16 @@ func @insert_strided_slice1(%b: vector<4x4xf32>, %c: vector<4x4x4xf32>) -> vecto // ----- +func @insert_strided_index_slice1(%b: vector<4x4xindex>, %c: vector<4x4x4xindex>) -> vector<4x4x4xindex> { + %0 = vector.insert_strided_slice %b, %c {offsets = [2, 0, 0], strides = [1, 1]} : vector<4x4xindex> into vector<4x4x4xindex> + return %0 : vector<4x4x4xindex> +} +// CHECK-LABEL: @insert_strided_index_slice1( +// CHECK: llvm.extractvalue {{.*}}[2] : !llvm.array<4 x array<4 x vector<4xi64>>> +// CHECK: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm.array<4 x array<4 x vector<4xi64>>> + +// ----- + func @insert_strided_slice2(%a: vector<2x2xf32>, %b: vector<4x4xf32>) -> vector<4x4xf32> { %0 = vector.insert_strided_slice %a, %b {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> return %0 : vector<4x4xf32> @@ -1019,6 +1158,18 @@ func @reduce_i64(%arg0: vector<16xi64>) -> i64 { // CHECK: %[[V:.*]] = "llvm.intr.vector.reduce.add"(%[[A]]) // CHECK: return %[[V]] : i64 +// ----- + +func @reduce_index(%arg0: vector<16xindex>) -> index { + %0 = vector.reduction "add", %arg0 : vector<16xindex> into index + return %0 : index +} +// CHECK-LABEL: @reduce_index( +// CHECK-SAME: %[[A:.*]]: vector<16xindex>) +// CHECK: %[[T0:.*]] = llvm.mlir.cast %[[A]] : vector<16xindex> to vector<16xi64> +// CHECK: %[[T1:.*]] = "llvm.intr.vector.reduce.add"(%[[T0]]) +// CHECK: %[[T2:.*]] = llvm.mlir.cast %[[T1]] : i64 to index +// CHECK: return %[[T2]] : index // 4x16 16x3 4x3 // ----- @@ -1036,6 +1187,19 @@ func @matrix_ops(%A: vector<64xf64>, %B: vector<48xf64>) -> vector<12xf64> { // ----- +func @matrix_ops_index(%A: vector<64xindex>, %B: vector<48xindex>) -> vector<12xindex> { + %C = vector.matrix_multiply %A, %B + { lhs_rows = 4: i32, lhs_columns = 16: i32 , rhs_columns = 3: i32 } : + (vector<64xindex>, vector<48xindex>) -> vector<12xindex> + return %C: vector<12xindex> +} +// CHECK-LABEL: @matrix_ops_index +// CHECK: llvm.intr.matrix.multiply %{{.*}}, %{{.*}} { +// CHECK-SAME: lhs_columns = 16 : i32, lhs_rows = 4 : i32, rhs_columns = 3 : i32 +// CHECK-SAME: } : (vector<64xi64>, vector<48xi64>) -> vector<12xi64> + +// ----- + func @transfer_read_1d(%A : memref, %base: index) -> vector<17xf32> { %f7 = constant 7.0: f32 %f = vector.transfer_read %A[%base], %f7 @@ -1108,6 +1272,29 @@ func @transfer_read_1d(%A : memref, %base: index) -> vector<17xf32> { // ----- +func @transfer_read_index_1d(%A : memref, %base: index) -> vector<17xindex> { + %f7 = constant 7: index + %f = vector.transfer_read %A[%base], %f7 + {permutation_map = affine_map<(d0) -> (d0)>} : + memref, vector<17xindex> + vector.transfer_write %f, %A[%base] + {permutation_map = affine_map<(d0) -> (d0)>} : + vector<17xindex>, memref + return %f: vector<17xindex> +} +// CHECK-LABEL: func @transfer_read_index_1d +// CHECK-SAME: %[[BASE:[a-zA-Z0-9]*]]: index) -> vector<17xindex> +// CHECK: %[[C7:.*]] = constant 7 +// CHECK: %{{.*}} = llvm.mlir.cast %[[C7]] : index to i64 + +// CHECK: %[[loaded:.*]] = llvm.intr.masked.load %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : +// CHECK-SAME: (!llvm.ptr>, vector<17xi1>, vector<17xi64>) -> vector<17xi64> + +// CHECK: llvm.intr.masked.store %[[loaded]], %{{.*}}, %{{.*}} {alignment = 8 : i32} : +// CHECK-SAME: vector<17xi64>, vector<17xi1> into !llvm.ptr> + +// ----- + func @transfer_read_2d_to_1d(%A : memref, %base0: index, %base1: index) -> vector<17xf32> { %f7 = constant 7.0: f32 %f = vector.transfer_read %A[%base0, %base1], %f7 @@ -1258,6 +1445,22 @@ func @flat_transpose(%arg0: vector<16xf32>) -> vector<16xf32> { // ----- +func @flat_transpose_index(%arg0: vector<16xindex>) -> vector<16xindex> { + %0 = vector.flat_transpose %arg0 { rows = 4: i32, columns = 4: i32 } + : vector<16xindex> -> vector<16xindex> + return %0 : vector<16xindex> +} +// CHECK-LABEL: func @flat_transpose_index +// CHECK-SAME: %[[A:.*]]: vector<16xindex> +// CHECK: %[[T0:.*]] = llvm.mlir.cast %[[A]] : vector<16xindex> to vector<16xi64> +// CHECK: %[[T1:.*]] = llvm.intr.matrix.transpose %[[T0]] +// CHECK-SAME: {columns = 4 : i32, rows = 4 : i32} : +// CHECK-SAME: vector<16xi64> into vector<16xi64> +// CHECK: %[[T2:.*]] = llvm.mlir.cast %[[T1]] : vector<16xi64> to vector<16xindex> +// CHECK: return %[[T2]] : vector<16xindex> + +// ----- + func @vector_load_op(%memref : memref<200x100xf32>, %i : index, %j : index) -> vector<8xf32> { %0 = vector.load %memref[%i, %j] : memref<200x100xf32>, vector<8xf32> return %0 : vector<8xf32> @@ -1271,6 +1474,19 @@ func @vector_load_op(%memref : memref<200x100xf32>, %i : index, %j : index) -> v // CHECK: %[[bcast:.*]] = llvm.bitcast %[[gep]] : !llvm.ptr to !llvm.ptr> // CHECK: llvm.load %[[bcast]] {alignment = 4 : i64} : !llvm.ptr> +// ----- + +func @vector_load_op_index(%memref : memref<200x100xindex>, %i : index, %j : index) -> vector<8xindex> { + %0 = vector.load %memref[%i, %j] : memref<200x100xindex>, vector<8xindex> + return %0 : vector<8xindex> +} +// CHECK-LABEL: func @vector_load_op_index +// CHECK: %[[T0:.*]] = llvm.load %{{.*}} {alignment = 8 : i64} : !llvm.ptr> +// CHECK: %[[T1:.*]] = llvm.mlir.cast %[[T0]] : vector<8xi64> to vector<8xindex> +// CHECK: return %[[T1]] : vector<8xindex> + +// ----- + func @vector_store_op(%memref : memref<200x100xf32>, %i : index, %j : index) { %val = constant dense<11.0> : vector<4xf32> vector.store %val, %memref[%i, %j] : memref<200x100xf32>, vector<4xf32> @@ -1285,6 +1501,18 @@ func @vector_store_op(%memref : memref<200x100xf32>, %i : index, %j : index) { // CHECK: %[[bcast:.*]] = llvm.bitcast %[[gep]] : !llvm.ptr to !llvm.ptr> // CHECK: llvm.store %{{.*}}, %[[bcast]] {alignment = 4 : i64} : !llvm.ptr> +// ----- + +func @vector_store_op_index(%memref : memref<200x100xindex>, %i : index, %j : index) { + %val = constant dense<11> : vector<4xindex> + vector.store %val, %memref[%i, %j] : memref<200x100xindex>, vector<4xindex> + return +} +// CHECK-LABEL: func @vector_store_op_index +// CHECK: llvm.store %{{.*}}, %{{.*}} {alignment = 8 : i64} : !llvm.ptr> + +// ----- + func @masked_load_op(%arg0: memref, %arg1: vector<16xi1>, %arg2: vector<16xf32>) -> vector<16xf32> { %c0 = constant 0: index %0 = vector.maskedload %arg0[%c0], %arg1, %arg2 : memref, vector<16xi1>, vector<16xf32> into vector<16xf32> @@ -1301,6 +1529,16 @@ func @masked_load_op(%arg0: memref, %arg1: vector<16xi1>, %arg2: vector<1 // ----- +func @masked_load_op_index(%arg0: memref, %arg1: vector<16xi1>, %arg2: vector<16xindex>) -> vector<16xindex> { + %c0 = constant 0: index + %0 = vector.maskedload %arg0[%c0], %arg1, %arg2 : memref, vector<16xi1>, vector<16xindex> into vector<16xindex> + return %0 : vector<16xindex> +} +// CHECK-LABEL: func @masked_load_op_index +// CHECK: %{{.*}} = llvm.intr.masked.load %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : (!llvm.ptr>, vector<16xi1>, vector<16xi64>) -> vector<16xi64> + +// ----- + func @masked_store_op(%arg0: memref, %arg1: vector<16xi1>, %arg2: vector<16xf32>) { %c0 = constant 0: index vector.maskedstore %arg0[%c0], %arg1, %arg2 : memref, vector<16xi1>, vector<16xf32> @@ -1316,6 +1554,16 @@ func @masked_store_op(%arg0: memref, %arg1: vector<16xi1>, %arg2: vector< // ----- +func @masked_store_op_index(%arg0: memref, %arg1: vector<16xi1>, %arg2: vector<16xindex>) { + %c0 = constant 0: index + vector.maskedstore %arg0[%c0], %arg1, %arg2 : memref, vector<16xi1>, vector<16xindex> + return +} +// CHECK-LABEL: func @masked_store_op_index +// CHECK: llvm.intr.masked.store %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : vector<16xi64>, vector<16xi1> into !llvm.ptr> + +// ----- + func @gather_op(%arg0: memref, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf32>) -> vector<3xf32> { %0 = constant 0: index %1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 : memref, vector<3xi32>, vector<3xi1>, vector<3xf32> into vector<3xf32> @@ -1329,6 +1577,16 @@ func @gather_op(%arg0: memref, %arg1: vector<3xi32>, %arg2: vector<3xi1>, // ----- +func @gather_op_index(%arg0: memref, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xindex>) -> vector<3xindex> { + %0 = constant 0: index + %1 = vector.gather %arg0[%0][%arg1], %arg2, %arg3 : memref, vector<3xi32>, vector<3xi1>, vector<3xindex> into vector<3xindex> + return %1 : vector<3xindex> +} +// CHECK-LABEL: func @gather_op_index +// CHECK: %{{.*}} = llvm.intr.masked.gather %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : (!llvm.vec<3 x ptr>, vector<3xi1>, vector<3xi64>) -> vector<3xi64> + +// ----- + func @gather_2d_op(%arg0: memref<4x4xf32>, %arg1: vector<4xi32>, %arg2: vector<4xi1>, %arg3: vector<4xf32>) -> vector<4xf32> { %0 = constant 3 : index %1 = vector.gather %arg0[%0, %0][%arg1], %arg2, %arg3 : memref<4x4xf32>, vector<4xi32>, vector<4xi1>, vector<4xf32> into vector<4xf32> @@ -1355,6 +1613,16 @@ func @scatter_op(%arg0: memref, %arg1: vector<3xi32>, %arg2: vector<3xi1> // ----- +func @scatter_op_index(%arg0: memref, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xindex>) { + %0 = constant 0: index + vector.scatter %arg0[%0][%arg1], %arg2, %arg3 : memref, vector<3xi32>, vector<3xi1>, vector<3xindex> + return +} +// CHECK-LABEL: func @scatter_op_index +// CHECK: llvm.intr.masked.scatter %{{.*}}, %{{.*}}, %{{.*}} {alignment = 8 : i32} : vector<3xi64>, vector<3xi1> into !llvm.vec<3 x ptr> + +// ----- + func @scatter_2d_op(%arg0: memref<4x4xf32>, %arg1: vector<4xi32>, %arg2: vector<4xi1>, %arg3: vector<4xf32>) { %0 = constant 3 : index vector.scatter %arg0[%0, %0][%arg1], %arg2, %arg3 : memref<4x4xf32>, vector<4xi32>, vector<4xi1>, vector<4xf32> @@ -1383,6 +1651,16 @@ func @expand_load_op(%arg0: memref, %arg1: vector<11xi1>, %arg2: vector<1 // ----- +func @expand_load_op_index(%arg0: memref, %arg1: vector<11xi1>, %arg2: vector<11xindex>) -> vector<11xindex> { + %c0 = constant 0: index + %0 = vector.expandload %arg0[%c0], %arg1, %arg2 : memref, vector<11xi1>, vector<11xindex> into vector<11xindex> + return %0 : vector<11xindex> +} +// CHECK-LABEL: func @expand_load_op_index +// CHECK: %{{.*}} = "llvm.intr.masked.expandload"(%{{.*}}, %{{.*}}, %{{.*}}) : (!llvm.ptr, vector<11xi1>, vector<11xi64>) -> vector<11xi64> + +// ----- + func @compress_store_op(%arg0: memref, %arg1: vector<11xi1>, %arg2: vector<11xf32>) { %c0 = constant 0: index vector.compressstore %arg0[%c0], %arg1, %arg2 : memref, vector<11xi1>, vector<11xf32> @@ -1394,3 +1672,13 @@ func @compress_store_op(%arg0: memref, %arg1: vector<11xi1>, %arg2: vecto // CHECK: %[[C:.*]] = llvm.mlir.cast %[[CO]] : index to i64 // CHECK: %[[P:.*]] = llvm.getelementptr %{{.*}}[%[[C]]] : (!llvm.ptr, i64) -> !llvm.ptr // CHECK: "llvm.intr.masked.compressstore"(%{{.*}}, %[[P]], %{{.*}}) : (vector<11xf32>, !llvm.ptr, vector<11xi1>) -> () + +// ----- + +func @compress_store_op_index(%arg0: memref, %arg1: vector<11xi1>, %arg2: vector<11xindex>) { + %c0 = constant 0: index + vector.compressstore %arg0[%c0], %arg1, %arg2 : memref, vector<11xi1>, vector<11xindex> + return +} +// CHECK-LABEL: func @compress_store_op_index +// CHECK: "llvm.intr.masked.compressstore"(%{{.*}}, %{{.*}}, %{{.*}}) : (vector<11xi64>, !llvm.ptr, vector<11xi1>) -> () diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir index 43bef97f799e..fd5c0c8ac67e 100644 --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -3,14 +3,18 @@ // CHECK-LABEL: func @vector_transfer_ops( func @vector_transfer_ops(%arg0: memref, %arg1 : memref>, - %arg2 : memref>) { + %arg2 : memref>, + %arg3 : memref>) { // CHECK: %[[C3:.*]] = constant 3 : index %c3 = constant 3 : index %cst = constant 3.0 : f32 %f0 = constant 0.0 : f32 %c0 = constant 0 : i32 + %i0 = constant 0 : index + %vf0 = splat %f0 : vector<4x3xf32> %v0 = splat %c0 : vector<4x3xi32> + %vi0 = splat %i0 : vector<4x3xindex> %m = constant dense<[0, 0, 1, 0, 1]> : vector<5xi1> // @@ -28,8 +32,10 @@ func @vector_transfer_ops(%arg0: memref, %5 = vector.transfer_read %arg1[%c3, %c3], %vf0 {in_bounds = [false, true]} : memref>, vector<1x1x4x3xf32> // CHECK: vector.transfer_read %{{.*}}[%[[C3]], %[[C3]]], %{{.*}} : memref>, vector<5x24xi8> %6 = vector.transfer_read %arg2[%c3, %c3], %v0 : memref>, vector<5x24xi8> + // CHECK: vector.transfer_read %{{.*}}[%[[C3]], %[[C3]]], %{{.*}} : memref>, vector<5x48xi8> + %7 = vector.transfer_read %arg3[%c3, %c3], %vi0 : memref>, vector<5x48xi8> // CHECK: vector.transfer_read %{{.*}}[%[[C3]], %[[C3]]], %{{.*}}, %{{.*}} : memref, vector<5xf32> - %7 = vector.transfer_read %arg0[%c3, %c3], %f0, %m : memref, vector<5xf32> + %8 = vector.transfer_read %arg0[%c3, %c3], %f0, %m : memref, vector<5xf32> // CHECK: vector.transfer_write vector.transfer_write %0, %arg0[%c3, %c3] {permutation_map = affine_map<(d0, d1)->(d0)>} : vector<128xf32>, memref @@ -41,8 +47,11 @@ func @vector_transfer_ops(%arg0: memref, vector.transfer_write %5, %arg1[%c3, %c3] {in_bounds = [false, false]} : vector<1x1x4x3xf32>, memref> // CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%[[C3]], %[[C3]]] : vector<5x24xi8>, memref> vector.transfer_write %6, %arg2[%c3, %c3] : vector<5x24xi8>, memref> + // CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%[[C3]], %[[C3]]] : vector<5x48xi8>, memref> + vector.transfer_write %7, %arg3[%c3, %c3] : vector<5x48xi8>, memref> // CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%[[C3]], %[[C3]]], %{{.*}} : vector<5xf32>, memref - vector.transfer_write %7, %arg0[%c3, %c3], %m : vector<5xf32>, memref + vector.transfer_write %8, %arg0[%c3, %c3], %m : vector<5xf32>, memref + return } @@ -50,16 +59,21 @@ func @vector_transfer_ops(%arg0: memref, // CHECK-LABEL: func @vector_transfer_ops_tensor( func @vector_transfer_ops_tensor(%arg0: tensor, %arg1 : tensor>, - %arg2 : tensor>) -> + %arg2 : tensor>, + %arg3 : tensor>) -> (tensor, tensor, tensor>, - tensor>, tensor>){ + tensor>, tensor>, + tensor>){ // CHECK: %[[C3:.*]] = constant 3 : index %c3 = constant 3 : index %cst = constant 3.0 : f32 %f0 = constant 0.0 : f32 %c0 = constant 0 : i32 + %i0 = constant 0 : index + %vf0 = splat %f0 : vector<4x3xf32> %v0 = splat %c0 : vector<4x3xi32> + %vi0 = splat %i0 : vector<4x3xindex> // // CHECK: vector.transfer_read @@ -76,22 +90,27 @@ func @vector_transfer_ops_tensor(%arg0: tensor, %5 = vector.transfer_read %arg1[%c3, %c3], %vf0 {in_bounds = [false, true]} : tensor>, vector<1x1x4x3xf32> // CHECK: vector.transfer_read %{{.*}}[%[[C3]], %[[C3]]], %{{.*}} : tensor>, vector<5x24xi8> %6 = vector.transfer_read %arg2[%c3, %c3], %v0 : tensor>, vector<5x24xi8> + // CHECK: vector.transfer_read %{{.*}}[%[[C3]], %[[C3]]], %{{.*}} : tensor>, vector<5x48xi8> + %7 = vector.transfer_read %arg3[%c3, %c3], %vi0 : tensor>, vector<5x48xi8> // CHECK: vector.transfer_write - %7 = vector.transfer_write %0, %arg0[%c3, %c3] {permutation_map = affine_map<(d0, d1)->(d0)>} : vector<128xf32>, tensor + %8 = vector.transfer_write %0, %arg0[%c3, %c3] {permutation_map = affine_map<(d0, d1)->(d0)>} : vector<128xf32>, tensor // CHECK: vector.transfer_write - %8 = vector.transfer_write %1, %arg0[%c3, %c3] {permutation_map = affine_map<(d0, d1)->(d1, d0)>} : vector<3x7xf32>, tensor + %9 = vector.transfer_write %1, %arg0[%c3, %c3] {permutation_map = affine_map<(d0, d1)->(d1, d0)>} : vector<3x7xf32>, tensor // CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%[[C3]], %[[C3]]] : vector<1x1x4x3xf32>, tensor> - %9 = vector.transfer_write %4, %arg1[%c3, %c3] {permutation_map = affine_map<(d0, d1)->(d0, d1)>} : vector<1x1x4x3xf32>, tensor> + %10 = vector.transfer_write %4, %arg1[%c3, %c3] {permutation_map = affine_map<(d0, d1)->(d0, d1)>} : vector<1x1x4x3xf32>, tensor> // CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%[[C3]], %[[C3]]] : vector<1x1x4x3xf32>, tensor> - %10 = vector.transfer_write %5, %arg1[%c3, %c3] {in_bounds = [false, false]} : vector<1x1x4x3xf32>, tensor> + %11 = vector.transfer_write %5, %arg1[%c3, %c3] {in_bounds = [false, false]} : vector<1x1x4x3xf32>, tensor> // CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%[[C3]], %[[C3]]] : vector<5x24xi8>, tensor> - %11 = vector.transfer_write %6, %arg2[%c3, %c3] : vector<5x24xi8>, tensor> + %12 = vector.transfer_write %6, %arg2[%c3, %c3] : vector<5x24xi8>, tensor> + // CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%[[C3]], %[[C3]]] : vector<5x48xi8>, tensor> + %13 = vector.transfer_write %7, %arg3[%c3, %c3] : vector<5x48xi8>, tensor> - return %7, %8, %9, %10, %11 : + return %8, %9, %10, %11, %12, %13 : tensor, tensor, tensor>, - tensor>, tensor> + tensor>, tensor>, + tensor> } // CHECK-LABEL: @vector_broadcast @@ -381,8 +400,9 @@ func @shape_cast(%arg0 : vector<5x1x3x2xf32>, // CHECK-LABEL: @bitcast func @bitcast(%arg0 : vector<5x1x3x2xf32>, %arg1 : vector<8x1xi32>, - %arg2 : vector<16x1x8xi8>) - -> (vector<5x1x3x4xf16>, vector<5x1x3x8xi8>, vector<8x4xi8>, vector<8x1xf32>, vector<16x1x2xi32>, vector<16x1x4xi16>) { + %arg2 : vector<16x1x8xi8>, + %arg3 : vector<8x2x1xindex>) + -> (vector<5x1x3x4xf16>, vector<5x1x3x8xi8>, vector<8x4xi8>, vector<8x1xf32>, vector<16x1x2xi32>, vector<16x1x4xi16>, vector<16x1x1xindex>, vector<8x2x2xf32>) { // CHECK: vector.bitcast %{{.*}} : vector<5x1x3x2xf32> to vector<5x1x3x4xf16> %0 = vector.bitcast %arg0 : vector<5x1x3x2xf32> to vector<5x1x3x4xf16> @@ -402,7 +422,13 @@ func @bitcast(%arg0 : vector<5x1x3x2xf32>, // CHECK-NEXT: vector.bitcast %{{.*}} : vector<16x1x8xi8> to vector<16x1x4xi16> %5 = vector.bitcast %arg2 : vector<16x1x8xi8> to vector<16x1x4xi16> - return %0, %1, %2, %3, %4, %5 : vector<5x1x3x4xf16>, vector<5x1x3x8xi8>, vector<8x4xi8>, vector<8x1xf32>, vector<16x1x2xi32>, vector<16x1x4xi16> + // CHECK-NEXT: vector.bitcast %{{.*}} : vector<16x1x8xi8> to vector<16x1x1xindex> + %6 = vector.bitcast %arg2 : vector<16x1x8xi8> to vector<16x1x1xindex> + + // CHECK-NEXT: vector.bitcast %{{.*}} : vector<8x2x1xindex> to vector<8x2x2xf32> + %7 = vector.bitcast %arg3 : vector<8x2x1xindex> to vector<8x2x2xf32> + + return %0, %1, %2, %3, %4, %5, %6, %7 : vector<5x1x3x4xf16>, vector<5x1x3x8xi8>, vector<8x4xi8>, vector<8x1xf32>, vector<16x1x2xi32>, vector<16x1x4xi16>, vector<16x1x1xindex>, vector<8x2x2xf32> } // CHECK-LABEL: @vector_fma diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir index 3747039944b5..a3240704f431 100644 --- a/mlir/test/IR/invalid-ops.mlir +++ b/mlir/test/IR/invalid-ops.mlir @@ -842,7 +842,7 @@ func @invalid_splat(%v : f32) { func @invalid_splat(%v : vector<8xf32>) { %w = splat %v : tensor<8xvector<8xf32>> - // expected-error@-1 {{must be integer or float type}} + // expected-error@-1 {{must be integer/index/float type}} return } diff --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir index 39e72d2cb8da..220f46d5b344 100644 --- a/mlir/test/IR/invalid.mlir +++ b/mlir/test/IR/invalid.mlir @@ -23,10 +23,6 @@ func @illegalmemrefelementtype(memref>) -> () // expected-error {{i func @illegalunrankedmemrefelementtype(memref<*xtensor>) -> () // expected-error {{invalid memref element type}} -// ----- - -func @indexvector(vector<4 x index>) -> () // expected-error {{vector elements must be int or float type}} - // ----- // Test no map in memref type. func @memrefs(memref<2x4xi8, >) // expected-error {{expected list element}} @@ -387,7 +383,7 @@ func @succ_arg_type_mismatch() { // Test no nested vector. func @vectors(vector<1 x vector<1xi32>>, vector<2x4xf32>) -// expected-error@-1 {{vector elements must be int or float type}} +// expected-error@-1 {{vector elements must be int/index/float type}} // ----- diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-index-vectors.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-index-vectors.mlir new file mode 100644 index 000000000000..2ae8a92c61f9 --- /dev/null +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-index-vectors.mlir @@ -0,0 +1,32 @@ +// RUN: mlir-opt %s -convert-vector-to-llvm -convert-std-to-llvm | \ +// RUN: mlir-cpu-runner -e entry -entry-point-result=void \ +// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s + +func @entry() { + %c0 = constant dense<[0, 1, 2, 3]>: vector<4xindex> + %c1 = constant dense<[0, 1]>: vector<2xindex> + %c2 = constant 2 : index + + %v1 = vector.broadcast %c0 : vector<4xindex> to vector<2x4xindex> + %v2 = vector.broadcast %c1 : vector<2xindex> to vector<4x2xindex> + %v3 = vector.transpose %v2, [1, 0] : vector<4x2xindex> to vector<2x4xindex> + %v4 = vector.broadcast %c2 : index to vector<2x4xindex> + + %v5 = addi %v1, %v3 : vector<2x4xindex> + + vector.print %v1 : vector<2x4xindex> + vector.print %v3 : vector<2x4xindex> + vector.print %v4 : vector<2x4xindex> + vector.print %v5 : vector<2x4xindex> + + // + // created index vectors: + // + // CHECK: ( ( 0, 1, 2, 3 ), ( 0, 1, 2, 3 ) ) + // CHECK: ( ( 0, 0, 0, 0 ), ( 1, 1, 1, 1 ) ) + // CHECK: ( ( 2, 2, 2, 2 ), ( 2, 2, 2, 2 ) ) + // CHECK: ( ( 0, 1, 2, 3 ), ( 1, 2, 3, 4 ) ) + + return +}