Skip to content

Commit

Permalink
[CIR] Vector types - part 1 (#347)
Browse files Browse the repository at this point in the history
This is the first part of implementing vector types and vector
operations in ClangIR, issue #284. This is enough to compile this test
program. I haven't tried to do anything beyond that yet.
```
typedef int int4 __attribute__((vector_size(16)));
int main(int argc, char** argv) {
  int4 a = { 1, argc, argc + 1, 4 };
  int4 b = { 5, argc + 2, argc + 3, 8 };
  int4 c = a + b;
  return c[1];
}
```

This change includes:

* Fixed-sized vector types which are parameterized on the element type
and the number of elements. For example, `!cir.vector<s32i x 4>`. (No
scalable vector types yet; those will come later.)

* New operation `cir.vec` which creates an object of a vector type with
the given operands.

* New operation `cir.vec_elem` which extracts an element from a vector.
(The array subscript operation doesn't work here because the result is
an rvalue, not an lvalue.)

* Basic binary arithmetic operations on vector types, though only
addition has been tested.

There are no unary operators, comparison operators, casts, or shuffle
operations yet. Those will all come later.
  • Loading branch information
dkolsen-pgi authored and lanza committed Jan 29, 2024
1 parent b6e2936 commit b908ad7
Show file tree
Hide file tree
Showing 12 changed files with 291 additions and 36 deletions.
59 changes: 56 additions & 3 deletions clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,18 @@ def CastOp : CIR_Op<"cast", [Pure]> {
let description = [{
Apply C/C++ usual conversions rules between values. Currently supported kinds:

- `int_to_bool`
- `ptr_to_bool`
- `array_to_ptrdecay`
- `integral`
- `bitcast`
- `integral`
- `int_to_bool`
- `int_to_float`
- `floating`
- `float_to_int`
- `float_to_bool`
- `ptr_to_int`
- `ptr_to_bool`
- `bool_to_int`
- `bool_to_float`

This is effectively a subset of the rules from
`llvm-project/clang/include/clang/AST/OperationKinds.def`; but note that some
Expand Down Expand Up @@ -1648,6 +1653,54 @@ def GetMemberOp : CIR_Op<"get_member"> {
let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
// VecExtractOp
//===----------------------------------------------------------------------===//

def VecExtractOp : CIR_Op<"vec.extract", [Pure,
TypesMatchWith<"type of 'result' matches element type of 'vec'",
"vec", "result",
"$_self.cast<VectorType>().getEltType()">]> {

let summary = "Extract one element from a vector object";
let description = [{
The `cir.vec.extract` operation extracts the element at the given index
from a vector object.
}];

let arguments = (ins CIR_VectorType:$vec, CIR_IntType:$index);
let results = (outs AnyType:$result);

let assemblyFormat = [{
$vec `[` $index `:` type($index) `]` type($vec) `->` type($result) attr-dict
}];

let hasVerifier = 0;
}

//===----------------------------------------------------------------------===//
// VecCreate
//===----------------------------------------------------------------------===//

def VecCreateOp : CIR_Op<"vec.create", [Pure]> {

let summary = "Create a vector value";
let description = [{
The `cir.vec.create` operation creates a vector value with the given element
values. The number of element arguments must match the number of elements
in the vector type.
}];

let arguments = (ins Variadic<AnyType>:$elements);
let results = (outs CIR_VectorType:$result);

let assemblyFormat = [{
`(` ($elements^ `:` type($elements))? `)` `:` type($result) attr-dict
}];

let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
// BaseClassAddr
//===----------------------------------------------------------------------===//
Expand Down
20 changes: 20 additions & 0 deletions clang/include/clang/CIR/Dialect/IR/CIRTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,26 @@ def CIR_ArrayType : CIR_Type<"Array", "array",
}];
}

//===----------------------------------------------------------------------===//
// VectorType (fixed size)
//===----------------------------------------------------------------------===//

def CIR_VectorType : CIR_Type<"Vector", "vector",
[DeclareTypeInterfaceMethods<DataLayoutTypeInterface>]> {

let summary = "CIR vector type";
let description = [{
`cir.vector' represents fixed-size vector types. The parameters are the
element type and the number of elements.
}];

let parameters = (ins "mlir::Type":$eltType, "uint64_t":$size);

let assemblyFormat = [{
`<` $eltType `x` $size `>`
}];
}

//===----------------------------------------------------------------------===//
// FuncType
//===----------------------------------------------------------------------===//
Expand Down
4 changes: 2 additions & 2 deletions clang/lib/CIR/CodeGen/CIRGenDecl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ static void emitStoresForConstant(CIRGenModule &CGM, const VarDecl &D,
if (!ConstantSize)
return;
assert(!UnimplementedFeature::addAutoInitAnnotation());
assert(!UnimplementedFeature::cirVectorType());
assert(!UnimplementedFeature::vectorConstants());
assert(!UnimplementedFeature::shouldUseBZeroPlusStoresToInitialize());
assert(!UnimplementedFeature::shouldUseMemSetToInitialize());
assert(!UnimplementedFeature::shouldSplitConstantStore());
Expand Down Expand Up @@ -1004,4 +1004,4 @@ void CIRGenFunction::pushEHDestroy(QualType::DestructionKind dtorKind,
assert(needsEHCleanup(dtorKind));

pushDestroy(EHCleanup, addr, type, getDestroyer(dtorKind), true);
}
}
16 changes: 6 additions & 10 deletions clang/lib/CIR/CodeGen/CIRGenExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -545,11 +545,9 @@ void CIRGenFunction::buildStoreOfScalar(mlir::Value Value, Address Addr,
bool Volatile, QualType Ty,
LValueBaseInfo BaseInfo, bool isInit,
bool isNontemporal) {
if (!CGM.getCodeGenOpts().PreserveVec3Type) {
if (Ty->isVectorType()) {
llvm_unreachable("NYI");
}
}
if (!CGM.getCodeGenOpts().PreserveVec3Type && Ty->isVectorType() &&
Ty->castAs<clang::VectorType>()->getNumElements() == 3)
llvm_unreachable("NYI: Special treatment of 3-element vectors");

Value = buildToMemory(Value, Ty);

Expand Down Expand Up @@ -2358,11 +2356,9 @@ mlir::Value CIRGenFunction::buildLoadOfScalar(Address Addr, bool Volatile,
QualType Ty, mlir::Location Loc,
LValueBaseInfo BaseInfo,
bool isNontemporal) {
if (!CGM.getCodeGenOpts().PreserveVec3Type) {
if (Ty->isVectorType()) {
llvm_unreachable("NYI");
}
}
if (!CGM.getCodeGenOpts().PreserveVec3Type && Ty->isVectorType() &&
Ty->castAs<clang::VectorType>()->getNumElements() == 3)
llvm_unreachable("NYI: Special treatment of 3-element vectors");

// Atomic operations have to be done on integral types
LValue AtomicLValue = LValue::makeAddr(Addr, Ty, getContext(), BaseInfo);
Expand Down
33 changes: 25 additions & 8 deletions clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -246,13 +246,19 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
}
mlir::Value VisitArraySubscriptExpr(ArraySubscriptExpr *E) {
// Do we need anything like TestAndClearIgnoreResultAssign()?
assert(!E->getBase()->getType()->isVectorType() &&
"vector types not implemented");

// Emit subscript expressions in rvalue context's. For most cases, this
// just loads the lvalue formed by the subscript expr. However, we have to
// be careful, because the base of a vector subscript is occasionally an
// rvalue, so we can't get it as an lvalue.
if (E->getBase()->getType()->isVectorType()) {
assert(!UnimplementedFeature::scalableVectors() &&
"NYI: index into scalable vector");
// Subscript of vector type. This is handled differently, with a custom
// operation.
mlir::Value VecValue = Visit(E->getBase());
mlir::Value IndexValue = Visit(E->getIdx());
return CGF.builder.create<mlir::cir::VecExtractOp>(
CGF.getLoc(E->getSourceRange()), VecValue, IndexValue);
}

// Just load the lvalue formed by the subscript expression.
return buildLoadOfLValue(E);
}

Expand Down Expand Up @@ -919,6 +925,7 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
"Internal error: conversion between matrix type and scalar type");

// TODO(CIR): Support VectorTypes
assert(!UnimplementedFeature::cirVectorType() && "NYI: vector cast");

// Finally, we have the arithmetic types: real int/float.
mlir::Value Res = nullptr;
Expand Down Expand Up @@ -1579,8 +1586,18 @@ mlir::Value ScalarExprEmitter::VisitInitListExpr(InitListExpr *E) {
if (E->hadArrayRangeDesignator())
llvm_unreachable("NYI");

if (UnimplementedFeature::cirVectorType())
llvm_unreachable("NYI");
if (E->getType()->isVectorType()) {
assert(!UnimplementedFeature::scalableVectors() &&
"NYI: scalable vector init");
assert(!UnimplementedFeature::vectorConstants() && "NYI: vector constants");
SmallVector<mlir::Value, 16> Elements;
for (Expr *init : E->inits()) {
Elements.push_back(Visit(init));
}
return CGF.getBuilder().create<mlir::cir::VecCreateOp>(
CGF.getLoc(E->getSourceRange()), CGF.getCIRType(E->getType()),
Elements);
}

if (NumInitElements == 0) {
// C++11 value-initialization for the scalar.
Expand Down
5 changes: 4 additions & 1 deletion clang/lib/CIR/CodeGen/CIRGenTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -647,7 +647,10 @@ mlir::Type CIRGenTypes::ConvertType(QualType T) {
}
case Type::ExtVector:
case Type::Vector: {
assert(0 && "not implemented");
const VectorType *V = cast<VectorType>(Ty);
auto ElementType = convertTypeForMem(V->getElementType());
ResultType = ::mlir::cir::VectorType::get(Builder.getContext(), ElementType,
V->getNumElements());
break;
}
case Type::ConstantMatrix: {
Expand Down
8 changes: 6 additions & 2 deletions clang/lib/CIR/CodeGen/UnimplementedFeatureGuarding.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,13 @@ struct UnimplementedFeature {
static bool buildTypeCheck() { return false; }
static bool tbaa() { return false; }
static bool cleanups() { return false; }
// This is for whether or not we've implemented a cir::VectorType
// corresponding to `llvm::VectorType`

// cir::VectorType is in progress, so cirVectorType() will go away soon.
// Start adding feature flags for more advanced vector types and operations
// that will take longer to implement.
static bool cirVectorType() { return false; }
static bool scalableVectors() { return false; }
static bool vectorConstants() { return false; }

// Address space related
static bool addressSpace() { return false; }
Expand Down
25 changes: 25 additions & 0 deletions clang/lib/CIR/Dialect/IR/CIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,31 @@ LogicalResult CastOp::verify() {
llvm_unreachable("Unknown CastOp kind?");
}

//===----------------------------------------------------------------------===//
// VecCreateOp
//===----------------------------------------------------------------------===//

LogicalResult VecCreateOp::verify() {
// Verify that the number of arguments matches the number of elements in the
// vector, and that the type of all the arguments matches the type of the
// elements in the vector.
auto VecTy = getResult().getType();
if (getElements().size() != VecTy.getSize()) {
return emitOpError() << "operand count of " << getElements().size()
<< " doesn't match vector type " << VecTy
<< " element count of " << VecTy.getSize();
}
auto ElementType = VecTy.getEltType();
for (auto Element : getElements()) {
if (Element.getType() != ElementType) {
return emitOpError() << "operand type " << Element.getType()
<< " doesn't match vector element type "
<< ElementType;
}
}
return success();
}

//===----------------------------------------------------------------------===//
// ReturnOp
//===----------------------------------------------------------------------===//
Expand Down
30 changes: 24 additions & 6 deletions clang/lib/CIR/Dialect/IR/CIRTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,25 @@ ArrayType::getPreferredAlignment(const ::mlir::DataLayout &dataLayout,
return dataLayout.getTypePreferredAlignment(getEltType());
}

llvm::TypeSize cir::VectorType::getTypeSizeInBits(
const ::mlir::DataLayout &dataLayout,
::mlir::DataLayoutEntryListRef params) const {
return llvm::TypeSize::getFixed(getSize() *
dataLayout.getTypeSizeInBits(getEltType()));
}

uint64_t
cir::VectorType::getABIAlignment(const ::mlir::DataLayout &dataLayout,
::mlir::DataLayoutEntryListRef params) const {
return getSize() * dataLayout.getTypeABIAlignment(getEltType());
}

uint64_t cir::VectorType::getPreferredAlignment(
const ::mlir::DataLayout &dataLayout,
::mlir::DataLayoutEntryListRef params) const {
return getSize() * dataLayout.getTypePreferredAlignment(getEltType());
}

llvm::TypeSize
StructType::getTypeSizeInBits(const ::mlir::DataLayout &dataLayout,
::mlir::DataLayoutEntryListRef params) const {
Expand Down Expand Up @@ -605,9 +624,9 @@ FuncType FuncType::clone(TypeRange inputs, TypeRange results) const {
return get(llvm::to_vector(inputs), results[0], isVarArg());
}

mlir::ParseResult
parseFuncTypeArgs(mlir::AsmParser &p, llvm::SmallVector<mlir::Type> &params,
bool &isVarArg) {
mlir::ParseResult parseFuncTypeArgs(mlir::AsmParser &p,
llvm::SmallVector<mlir::Type> &params,
bool &isVarArg) {
isVarArg = false;
// `(` `)`
if (succeeded(p.parseOptionalRParen()))
Expand Down Expand Up @@ -637,9 +656,8 @@ parseFuncTypeArgs(mlir::AsmParser &p, llvm::SmallVector<mlir::Type> &params,
return p.parseRParen();
}

void printFuncTypeArgs(mlir::AsmPrinter &p,
mlir::ArrayRef<mlir::Type> params,
bool isVarArg) {
void printFuncTypeArgs(mlir::AsmPrinter &p, mlir::ArrayRef<mlir::Type> params,
bool isVarArg) {
llvm::interleaveComma(params, p,
[&p](mlir::Type type) { p.printType(type); });
if (isVarArg) {
Expand Down
Loading

0 comments on commit b908ad7

Please sign in to comment.