Skip to content

Commit

Permalink
[mlir] add support for index type in vectors.
Browse files Browse the repository at this point in the history
The patch enables the use of index type in vectors. It is a prerequisite to support vectorization for indexed Linalg operations. This refactoring became possible due to the newly introduced data layout infrastructure. The data layout of a module defines the bitwidth of the index type needed to verify bitcasts and similar vector operations.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D99948
  • Loading branch information
Tobias Gysi committed Apr 8, 2021
1 parent 20105b6 commit b614ada
Show file tree
Hide file tree
Showing 20 changed files with 461 additions and 83 deletions.
62 changes: 31 additions & 31 deletions mlir/docs/Rationale/Rationale.md
Original file line number Diff line number Diff line change
Expand Up @@ -202,39 +202,39 @@ and described in
interest
[starts here](https://www.google.com/url?q=https://youtu.be/Ntj8ab-5cvE?t%3D596&sa=D&ust=1529450150971000&usg=AFQjCNFQHEWL7m8q3eO-1DiKw9zqC2v24Q).

### Index type disallowed in vector types

Index types are not allowed as elements of `vector` types. Index
types are intended to be used for platform-specific "size" values and may appear
in subscripts, sizes of aggregate types and affine expressions. They are also
tightly coupled with `affine.apply` and affine.load/store operations; having
`index` type is a necessary precondition of a value to be acceptable by these
operations.

We allow `index` types in tensors and memrefs as a code generation strategy has
to map `index` to an implementation type and hence needs to be able to
materialize corresponding values. However, the target might lack support for
### Index type usage and limitations

Index types are intended to be used for platform-specific "size" values and may
appear in subscripts, sizes of aggregate types and affine expressions. They are
also tightly coupled with `affine.apply` and affine.load/store operations;
having `index` type is a necessary precondition of a value to be acceptable by
these operations.

We allow `index` types in tensors, vectors, and memrefs as a code generation
strategy has to map `index` to an implementation type and hence needs to be able
to materialize corresponding values. However, the target might lack support for
`vector` values with the target specific equivalent of the `index` type.

### Bit width of a non-primitive type and `index` is undefined

The bit width of a compound type is not defined by MLIR, it may be defined by a
specific lowering pass. In MLIR, bit width is a property of certain primitive
_type_, in particular integers and floats. It is equal to the number that
appears in the type definition, e.g. the bit width of `i32` is `32`, so is the
bit width of `f32`. The bit width is not _necessarily_ related to the amount of
memory (in bytes) or the size of register (in bits) that is necessary to store
the value of the given type. These quantities are target and ABI-specific and
should be defined during the lowering process rather than imposed from above.
For example, `vector<3xi57>` is likely to be lowered to a vector of four 64-bit
integers, so that its storage requirement is `4 x 64 / 8 = 32` bytes, rather
than `(3 x 57) ceildiv 8 = 22` bytes as can be naively computed from the
bitwidth. Individual components of MLIR that allocate space for storing values
may use the bit size as the baseline and query the target description when it is
introduced.

The bit width is not defined for dialect-specific types at MLIR level. Dialects
are free to define their own quantities for type sizes.
### Data layout of non-primitive types

Data layout information such as the bit width or the alignment of types may be
target and ABI-specific and thus should be configurable rather than imposed by
the compiler. Especially, the layout of compound or `index` types may vary. MLIR
specifies default bit widths for certain primitive _types_, in particular for
integers and floats. It is equal to the number that appears in the type
definition, e.g. the bit width of `i32` is `32`, so is the bit width of `f32`.
The bit width is not _necessarily_ related to the amount of memory (in bytes) or
the register size (in bits) that is necessary to store the value of the given
type. For example, `vector<3xi57>` is likely to be lowered to a vector of four
64-bit integers, so that its storage requirement is `4 x 64 / 8 = 32` bytes,
rather than `(3 x 57) ceildiv 8 = 22` bytes as can be naively computed from the
bit width. MLIR makes such [data layout information](../DataLayout.md)
configurable using attributes that can be queried during lowering, for example,
when allocating a compound type.

The data layout of dialect-specific types is undefined at MLIR level. Yet
dialects are free to define their own quantities and make them available via the
data layout infrastructure.

### Integer signedness semantics

Expand Down
8 changes: 4 additions & 4 deletions mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -1738,8 +1738,8 @@ def SplatOp : Std_Op<"splat", [NoSideEffect,
let summary = "splat or broadcast operation";
let description = [{
Broadcast the operand to all elements of the result vector or tensor. The
operand has to be of either integer or float type. When the result is a
tensor, it has to be statically shaped.
operand has to be of integer/index/float type. When the result is a tensor,
it has to be statically shaped.

Example:

Expand All @@ -1761,8 +1761,8 @@ def SplatOp : Std_Op<"splat", [NoSideEffect,
```
}];

let arguments = (ins AnyTypeOf<[AnySignlessInteger, AnyFloat],
"integer or float type">:$input);
let arguments = (ins AnyTypeOf<[AnySignlessInteger, Index, AnyFloat],
"integer/index/float type">:$input);
let results = (outs AnyTypeOf<[AnyVector, AnyStaticShapeTensor]>:$aggregate);

let builders = [
Expand Down
10 changes: 5 additions & 5 deletions mlir/include/mlir/Dialect/Vector/VectorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2307,13 +2307,13 @@ def Vector_MatmulOp : Vector_Op<"matrix_multiply", [NoSideEffect,
Arguments<(
// TODO: tighten vector element types that make sense.
ins VectorOfRankAndType<[1],
[AnySignlessInteger, AnySignedInteger, AnyFloat]>:$lhs,
[AnySignlessInteger, AnySignedInteger, Index, AnyFloat]>:$lhs,
VectorOfRankAndType<[1],
[AnySignlessInteger, AnySignedInteger, AnyFloat]>:$rhs,
[AnySignlessInteger, AnySignedInteger, Index, AnyFloat]>:$rhs,
I32Attr:$lhs_rows, I32Attr:$lhs_columns, I32Attr:$rhs_columns)>,
Results<(
outs VectorOfRankAndType<[1],
[AnySignlessInteger, AnySignedInteger, AnyFloat]>:$res)>
[AnySignlessInteger, AnySignedInteger, Index, AnyFloat]>:$res)>
{
let summary = "Vector matrix multiplication op that operates on flattened 1-D"
" MLIR vectors";
Expand Down Expand Up @@ -2370,11 +2370,11 @@ def Vector_FlatTransposeOp : Vector_Op<"flat_transpose", [NoSideEffect,
Arguments<(
// TODO: tighten vector element types that make sense.
ins VectorOfRankAndType<[1],
[AnySignlessInteger, AnySignedInteger, AnyFloat]>:$matrix,
[AnySignlessInteger, AnySignedInteger, Index, AnyFloat]>:$matrix,
I32Attr:$rows, I32Attr:$columns)>,
Results<(
outs VectorOfRankAndType<[1],
[AnySignlessInteger, AnySignedInteger, AnyFloat]>:$res)> {
[AnySignlessInteger, AnySignedInteger, Index, AnyFloat]>:$res)> {
let summary = "Vector matrix transposition on flattened 1-D MLIR vectors";
let description = [{
This is the counterpart of llvm.matrix.transpose in MLIR. It serves
Expand Down
7 changes: 4 additions & 3 deletions mlir/include/mlir/IR/BuiltinTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -874,7 +874,7 @@ def Builtin_Vector : Builtin_Type<"Vector", "ShapedType"> {

```
vector-type ::= `vector` `<` static-dimension-list vector-element-type `>`
vector-element-type ::= float-type | integer-type
vector-element-type ::= float-type | integer-type | index-type

static-dimension-list ::= (decimal-literal `x`)+
```
Expand Down Expand Up @@ -911,9 +911,10 @@ def Builtin_Vector : Builtin_Type<"Vector", "ShapedType"> {
];
let extraClassDeclaration = [{
/// Returns true of the given type can be used as an element of a vector
/// type. In particular, vectors can consist of integer or float primitives.
/// type. In particular, vectors can consist of integer, index, or float
/// primitives.
static bool isValidElementType(Type t) {
return t.isa<IntegerType, FloatType>();
return t.isa<IntegerType, IndexType, FloatType>();
}

/// Get or create a new VectorType with the same shape as `this` and an
Expand Down
6 changes: 3 additions & 3 deletions mlir/include/mlir/IR/OpBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -758,11 +758,11 @@ def BoolLike : TypeConstraint<Or<[I1.predicate, VectorOf<[I1]>.predicate,
"bool-like">;

// Type constraint for signless-integer-like types: signless integers, indices,
// vectors of signless integers, tensors of signless integers.
// vectors of signless integers or indices, tensors of signless integers.
def SignlessIntegerLike : TypeConstraint<Or<[
AnySignlessInteger.predicate, Index.predicate,
VectorOf<[AnySignlessInteger]>.predicate,
TensorOf<[AnySignlessInteger]>.predicate]>,
VectorOf<[AnySignlessInteger, Index]>.predicate,
TensorOf<[AnySignlessInteger, Index]>.predicate]>,
"signless-integer-like">;

// Type constraint for float-like types: floats, vectors or tensors thereof.
Expand Down
3 changes: 3 additions & 0 deletions mlir/include/mlir/Interfaces/DataLayoutInterfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,9 @@ class DataLayout {
explicit DataLayout(DataLayoutOpInterface op);
explicit DataLayout(ModuleOp op);

/// Returns the layout of the closest parent operation carrying layout info.
static DataLayout closest(Operation *op);

/// Returns the size of the given type in the current scope.
unsigned getTypeSize(Type t) const;

Expand Down
6 changes: 3 additions & 3 deletions mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,13 +163,13 @@ replaceTransferOpWithMasked(ConversionPatternRewriter &rewriter,
LLVMTypeConverter &typeConverter, Location loc,
TransferReadOp xferOp, ArrayRef<Value> operands,
Value dataPtr, Value mask) {
VectorType fillType = xferOp.getVectorType();
Value fill = rewriter.create<SplatOp>(loc, fillType, xferOp.padding());

Type vecTy = typeConverter.convertType(xferOp.getVectorType());
if (!vecTy)
return failure();

auto adaptor = TransferReadOpAdaptor(operands, xferOp->getAttrDictionary());
Value fill = rewriter.create<SplatOp>(loc, vecTy, adaptor.padding());

unsigned align;
if (failed(getMemRefAlignment(
typeConverter, xferOp.getShapedType().cast<MemRefType>(), align)))
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Vector/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ add_mlir_dialect_library(MLIRVector
MLIRMemRef
MLIRSCF
MLIRLoopAnalysis
MLIRDataLayoutInterfaces
MLIRSideEffectInterfaces
MLIRVectorInterfaces
)
19 changes: 12 additions & 7 deletions mlir/lib/Dialect/Vector/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2202,12 +2202,15 @@ static LogicalResult verifyTransferOp(Operation *op, ShapedType shapedType,
return op->emitOpError(
"requires source to be a memref or ranked tensor type");
auto elementType = shapedType.getElementType();
DataLayout dataLayout = DataLayout::closest(op);
if (auto vectorElementType = elementType.dyn_cast<VectorType>()) {
// Memref or tensor has vector element type.
unsigned sourceVecSize = vectorElementType.getElementTypeBitWidth() *
vectorElementType.getShape().back();
unsigned sourceVecSize =
dataLayout.getTypeSizeInBits(vectorElementType.getElementType()) *
vectorElementType.getShape().back();
unsigned resultVecSize =
vectorType.getElementTypeBitWidth() * vectorType.getShape().back();
dataLayout.getTypeSizeInBits(vectorType.getElementType()) *
vectorType.getShape().back();
if (resultVecSize % sourceVecSize != 0)
return op->emitOpError(
"requires the bitwidth of the minor 1-D vector to be an integral "
Expand All @@ -2226,8 +2229,9 @@ static LogicalResult verifyTransferOp(Operation *op, ShapedType shapedType,
} else {
// Memref or tensor has scalar element type.
unsigned resultVecSize =
vectorType.getElementTypeBitWidth() * vectorType.getShape().back();
if (resultVecSize % elementType.getIntOrFloatBitWidth() != 0)
dataLayout.getTypeSizeInBits(vectorType.getElementType()) *
vectorType.getShape().back();
if (resultVecSize % dataLayout.getTypeSizeInBits(elementType) != 0)
return op->emitOpError(
"requires the bitwidth of the minor 1-D vector to be an integral "
"multiple of the bitwidth of the source element type");
Expand Down Expand Up @@ -3233,9 +3237,10 @@ static LogicalResult verify(BitCastOp op) {
return op.emitOpError("dimension size mismatch at: ") << i;
}

if (sourceVectorType.getElementTypeBitWidth() *
DataLayout dataLayout = DataLayout::closest(op);
if (dataLayout.getTypeSizeInBits(sourceVectorType.getElementType()) *
sourceVectorType.getShape().back() !=
resultVectorType.getElementTypeBitWidth() *
dataLayout.getTypeSizeInBits(resultVectorType.getElementType()) *
resultVectorType.getShape().back())
return op.emitOpError(
"source/result bitwidth of the minor 1-D vectors must be equal");
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/Vector/VectorTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1388,7 +1388,7 @@ class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
VectorType rhsType = op.getOperandTypeRHS().dyn_cast<VectorType>();
VectorType resType = op.getVectorType();
Type eltType = resType.getElementType();
bool isInt = eltType.isa<IntegerType>();
bool isInt = eltType.isa<IntegerType, IndexType>();
Value acc = (op.acc().empty()) ? nullptr : op.acc()[0];
vector::CombiningKind kind = op.kind();

Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/IR/BuiltinAttributes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -693,7 +693,7 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
"expected attribute value to have element type");
if (eltType.isa<FloatType>())
intVal = values[i].cast<FloatAttr>().getValue().bitcastToAPInt();
else if (eltType.isa<IntegerType>())
else if (eltType.isa<IntegerType, IndexType>())
intVal = values[i].cast<IntegerAttr>().getValue();
else
llvm_unreachable("unexpected element type");
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/IR/BuiltinTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError,
return emitError() << "vector types must have at least one dimension";

if (!isValidElementType(elementType))
return emitError() << "vector elements must be int or float type";
return emitError() << "vector elements must be int/index/float type";

if (any_of(shape, [](int64_t i) { return i <= 0; }))
return emitError() << "vector types must have positive constant sizes";
Expand Down
13 changes: 13 additions & 0 deletions mlir/lib/Interfaces/DataLayoutInterfaces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,19 @@ mlir::DataLayout::DataLayout(ModuleOp op)
#endif
}

mlir::DataLayout mlir::DataLayout::closest(Operation *op) {
// Search the closest parent either being a module operation or implementing
// the data layout interface.
while (op) {
if (auto module = dyn_cast<ModuleOp>(op))
return DataLayout(module);
if (auto iface = dyn_cast<DataLayoutOpInterface>(op))
return DataLayout(iface);
op = op->getParentOp();
}
return DataLayout();
}

void mlir::DataLayout::checkValid() const {
#ifndef NDEBUG
SmallVector<DataLayoutSpecInterface> specs;
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Parser/TypeParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,7 @@ VectorType Parser::parseVectorType() {
if (!elementType || parseToken(Token::greater, "expected '>' in vector type"))
return nullptr;
if (!VectorType::isValidElementType(elementType))
return emitError(typeLoc, "vector elements must be int or float type"),
return emitError(typeLoc, "vector elements must be int/index/float type"),
nullptr;

return VectorType::get(dimensions, elementType);
Expand Down
13 changes: 13 additions & 0 deletions mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -248,3 +248,16 @@ func @fmaf(%arg0: f32, %arg1: vector<4xf32>) {
%1 = fmaf %arg1, %arg1, %arg1 : vector<4xf32>
std.return
}

// -----

// CHECK-LABEL: func @index_vector(
// CHECK-SAME: %[[ARG0:.*]]: vector<4xi64>
func @index_vector(%arg0: vector<4xindex>) {
// CHECK: %[[CST:.*]] = llvm.mlir.constant(dense<[0, 1, 2, 3]> : vector<4xindex>) : vector<4xi64>
%0 = constant dense<[0, 1, 2, 3]> : vector<4xindex>
// CHECK: %[[V:.*]] = llvm.add %[[ARG0]], %[[CST]] : vector<4xi64>
%1 = addi %arg0, %0 : vector<4xindex>
std.return
}

Loading

0 comments on commit b614ada

Please sign in to comment.