diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td index ac773ba3d4d0..a1d89e01f449 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIROps.td +++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td @@ -73,13 +73,18 @@ def CastOp : CIR_Op<"cast", [Pure]> { let description = [{ Apply C/C++ usual conversions rules between values. Currently supported kinds: - - `int_to_bool` - - `ptr_to_bool` - `array_to_ptrdecay` - - `integral` - `bitcast` + - `integral` + - `int_to_bool` + - `int_to_float` - `floating` - `float_to_int` + - `float_to_bool` + - `ptr_to_int` + - `ptr_to_bool` + - `bool_to_int` + - `bool_to_float` This is effectively a subset of the rules from `llvm-project/clang/include/clang/AST/OperationKinds.def`; but note that some @@ -1648,6 +1653,54 @@ def GetMemberOp : CIR_Op<"get_member"> { let hasVerifier = 1; } +//===----------------------------------------------------------------------===// +// VecExtractOp +//===----------------------------------------------------------------------===// + +def VecExtractOp : CIR_Op<"vec.extract", [Pure, + TypesMatchWith<"type of 'result' matches element type of 'vec'", + "vec", "result", + "$_self.cast().getEltType()">]> { + + let summary = "Extract one element from a vector object"; + let description = [{ + The `cir.vec.extract` operation extracts the element at the given index + from a vector object. + }]; + + let arguments = (ins CIR_VectorType:$vec, CIR_IntType:$index); + let results = (outs AnyType:$result); + + let assemblyFormat = [{ + $vec `[` $index `:` type($index) `]` type($vec) `->` type($result) attr-dict + }]; + + let hasVerifier = 0; +} + +//===----------------------------------------------------------------------===// +// VecCreate +//===----------------------------------------------------------------------===// + +def VecCreateOp : CIR_Op<"vec.create", [Pure]> { + + let summary = "Create a vector value"; + let description = [{ + The `cir.vec.create` operation creates a vector value with the given element + values. The number of element arguments must match the number of elements + in the vector type. + }]; + + let arguments = (ins Variadic:$elements); + let results = (outs CIR_VectorType:$result); + + let assemblyFormat = [{ + `(` ($elements^ `:` type($elements))? `)` `:` type($result) attr-dict + }]; + + let hasVerifier = 1; +} + //===----------------------------------------------------------------------===// // BaseClassAddr //===----------------------------------------------------------------------===// diff --git a/clang/include/clang/CIR/Dialect/IR/CIRTypes.td b/clang/include/clang/CIR/Dialect/IR/CIRTypes.td index e9c60e763ba8..0d568c2d504c 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIRTypes.td +++ b/clang/include/clang/CIR/Dialect/IR/CIRTypes.td @@ -149,6 +149,26 @@ def CIR_ArrayType : CIR_Type<"Array", "array", }]; } +//===----------------------------------------------------------------------===// +// VectorType (fixed size) +//===----------------------------------------------------------------------===// + +def CIR_VectorType : CIR_Type<"Vector", "vector", + [DeclareTypeInterfaceMethods]> { + + let summary = "CIR vector type"; + let description = [{ + `cir.vector' represents fixed-size vector types. The parameters are the + element type and the number of elements. + }]; + + let parameters = (ins "mlir::Type":$eltType, "uint64_t":$size); + + let assemblyFormat = [{ + `<` $eltType `x` $size `>` + }]; +} + //===----------------------------------------------------------------------===// // FuncType //===----------------------------------------------------------------------===// diff --git a/clang/lib/CIR/CodeGen/CIRGenDecl.cpp b/clang/lib/CIR/CodeGen/CIRGenDecl.cpp index 739790b3d150..1fc2e923b2b1 100644 --- a/clang/lib/CIR/CodeGen/CIRGenDecl.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenDecl.cpp @@ -180,7 +180,7 @@ static void emitStoresForConstant(CIRGenModule &CGM, const VarDecl &D, if (!ConstantSize) return; assert(!UnimplementedFeature::addAutoInitAnnotation()); - assert(!UnimplementedFeature::cirVectorType()); + assert(!UnimplementedFeature::vectorConstants()); assert(!UnimplementedFeature::shouldUseBZeroPlusStoresToInitialize()); assert(!UnimplementedFeature::shouldUseMemSetToInitialize()); assert(!UnimplementedFeature::shouldSplitConstantStore()); @@ -1004,4 +1004,4 @@ void CIRGenFunction::pushEHDestroy(QualType::DestructionKind dtorKind, assert(needsEHCleanup(dtorKind)); pushDestroy(EHCleanup, addr, type, getDestroyer(dtorKind), true); -} \ No newline at end of file +} diff --git a/clang/lib/CIR/CodeGen/CIRGenExpr.cpp b/clang/lib/CIR/CodeGen/CIRGenExpr.cpp index ef76c91f0b69..5c130a6889ef 100644 --- a/clang/lib/CIR/CodeGen/CIRGenExpr.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenExpr.cpp @@ -545,11 +545,9 @@ void CIRGenFunction::buildStoreOfScalar(mlir::Value Value, Address Addr, bool Volatile, QualType Ty, LValueBaseInfo BaseInfo, bool isInit, bool isNontemporal) { - if (!CGM.getCodeGenOpts().PreserveVec3Type) { - if (Ty->isVectorType()) { - llvm_unreachable("NYI"); - } - } + if (!CGM.getCodeGenOpts().PreserveVec3Type && Ty->isVectorType() && + Ty->castAs()->getNumElements() == 3) + llvm_unreachable("NYI: Special treatment of 3-element vectors"); Value = buildToMemory(Value, Ty); @@ -2358,11 +2356,9 @@ mlir::Value CIRGenFunction::buildLoadOfScalar(Address Addr, bool Volatile, QualType Ty, mlir::Location Loc, LValueBaseInfo BaseInfo, bool isNontemporal) { - if (!CGM.getCodeGenOpts().PreserveVec3Type) { - if (Ty->isVectorType()) { - llvm_unreachable("NYI"); - } - } + if (!CGM.getCodeGenOpts().PreserveVec3Type && Ty->isVectorType() && + Ty->castAs()->getNumElements() == 3) + llvm_unreachable("NYI: Special treatment of 3-element vectors"); // Atomic operations have to be done on integral types LValue AtomicLValue = LValue::makeAddr(Addr, Ty, getContext(), BaseInfo); diff --git a/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp b/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp index 3e18e033a641..6103570bb34d 100644 --- a/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp @@ -246,13 +246,19 @@ class ScalarExprEmitter : public StmtVisitor { } mlir::Value VisitArraySubscriptExpr(ArraySubscriptExpr *E) { // Do we need anything like TestAndClearIgnoreResultAssign()? - assert(!E->getBase()->getType()->isVectorType() && - "vector types not implemented"); - // Emit subscript expressions in rvalue context's. For most cases, this - // just loads the lvalue formed by the subscript expr. However, we have to - // be careful, because the base of a vector subscript is occasionally an - // rvalue, so we can't get it as an lvalue. + if (E->getBase()->getType()->isVectorType()) { + assert(!UnimplementedFeature::scalableVectors() && + "NYI: index into scalable vector"); + // Subscript of vector type. This is handled differently, with a custom + // operation. + mlir::Value VecValue = Visit(E->getBase()); + mlir::Value IndexValue = Visit(E->getIdx()); + return CGF.builder.create( + CGF.getLoc(E->getSourceRange()), VecValue, IndexValue); + } + + // Just load the lvalue formed by the subscript expression. return buildLoadOfLValue(E); } @@ -919,6 +925,7 @@ class ScalarExprEmitter : public StmtVisitor { "Internal error: conversion between matrix type and scalar type"); // TODO(CIR): Support VectorTypes + assert(!UnimplementedFeature::cirVectorType() && "NYI: vector cast"); // Finally, we have the arithmetic types: real int/float. mlir::Value Res = nullptr; @@ -1579,8 +1586,18 @@ mlir::Value ScalarExprEmitter::VisitInitListExpr(InitListExpr *E) { if (E->hadArrayRangeDesignator()) llvm_unreachable("NYI"); - if (UnimplementedFeature::cirVectorType()) - llvm_unreachable("NYI"); + if (E->getType()->isVectorType()) { + assert(!UnimplementedFeature::scalableVectors() && + "NYI: scalable vector init"); + assert(!UnimplementedFeature::vectorConstants() && "NYI: vector constants"); + SmallVector Elements; + for (Expr *init : E->inits()) { + Elements.push_back(Visit(init)); + } + return CGF.getBuilder().create( + CGF.getLoc(E->getSourceRange()), CGF.getCIRType(E->getType()), + Elements); + } if (NumInitElements == 0) { // C++11 value-initialization for the scalar. diff --git a/clang/lib/CIR/CodeGen/CIRGenTypes.cpp b/clang/lib/CIR/CodeGen/CIRGenTypes.cpp index d71f4fe5d59f..07535d459d34 100644 --- a/clang/lib/CIR/CodeGen/CIRGenTypes.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenTypes.cpp @@ -647,7 +647,10 @@ mlir::Type CIRGenTypes::ConvertType(QualType T) { } case Type::ExtVector: case Type::Vector: { - assert(0 && "not implemented"); + const VectorType *V = cast(Ty); + auto ElementType = convertTypeForMem(V->getElementType()); + ResultType = ::mlir::cir::VectorType::get(Builder.getContext(), ElementType, + V->getNumElements()); break; } case Type::ConstantMatrix: { diff --git a/clang/lib/CIR/CodeGen/UnimplementedFeatureGuarding.h b/clang/lib/CIR/CodeGen/UnimplementedFeatureGuarding.h index 12f2b2037d61..ee3d643dd136 100644 --- a/clang/lib/CIR/CodeGen/UnimplementedFeatureGuarding.h +++ b/clang/lib/CIR/CodeGen/UnimplementedFeatureGuarding.h @@ -22,9 +22,13 @@ struct UnimplementedFeature { static bool buildTypeCheck() { return false; } static bool tbaa() { return false; } static bool cleanups() { return false; } - // This is for whether or not we've implemented a cir::VectorType - // corresponding to `llvm::VectorType` + + // cir::VectorType is in progress, so cirVectorType() will go away soon. + // Start adding feature flags for more advanced vector types and operations + // that will take longer to implement. static bool cirVectorType() { return false; } + static bool scalableVectors() { return false; } + static bool vectorConstants() { return false; } // Address space related static bool addressSpace() { return false; } diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp index 87bc36f33314..af6b0b85f3f5 100644 --- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp +++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp @@ -422,6 +422,31 @@ LogicalResult CastOp::verify() { llvm_unreachable("Unknown CastOp kind?"); } +//===----------------------------------------------------------------------===// +// VecCreateOp +//===----------------------------------------------------------------------===// + +LogicalResult VecCreateOp::verify() { + // Verify that the number of arguments matches the number of elements in the + // vector, and that the type of all the arguments matches the type of the + // elements in the vector. + auto VecTy = getResult().getType(); + if (getElements().size() != VecTy.getSize()) { + return emitOpError() << "operand count of " << getElements().size() + << " doesn't match vector type " << VecTy + << " element count of " << VecTy.getSize(); + } + auto ElementType = VecTy.getEltType(); + for (auto Element : getElements()) { + if (Element.getType() != ElementType) { + return emitOpError() << "operand type " << Element.getType() + << " doesn't match vector element type " + << ElementType; + } + } + return success(); +} + //===----------------------------------------------------------------------===// // ReturnOp //===----------------------------------------------------------------------===// diff --git a/clang/lib/CIR/Dialect/IR/CIRTypes.cpp b/clang/lib/CIR/Dialect/IR/CIRTypes.cpp index da3a7bbb5576..2eea669c77a5 100644 --- a/clang/lib/CIR/Dialect/IR/CIRTypes.cpp +++ b/clang/lib/CIR/Dialect/IR/CIRTypes.cpp @@ -411,6 +411,25 @@ ArrayType::getPreferredAlignment(const ::mlir::DataLayout &dataLayout, return dataLayout.getTypePreferredAlignment(getEltType()); } +llvm::TypeSize cir::VectorType::getTypeSizeInBits( + const ::mlir::DataLayout &dataLayout, + ::mlir::DataLayoutEntryListRef params) const { + return llvm::TypeSize::getFixed(getSize() * + dataLayout.getTypeSizeInBits(getEltType())); +} + +uint64_t +cir::VectorType::getABIAlignment(const ::mlir::DataLayout &dataLayout, + ::mlir::DataLayoutEntryListRef params) const { + return getSize() * dataLayout.getTypeABIAlignment(getEltType()); +} + +uint64_t cir::VectorType::getPreferredAlignment( + const ::mlir::DataLayout &dataLayout, + ::mlir::DataLayoutEntryListRef params) const { + return getSize() * dataLayout.getTypePreferredAlignment(getEltType()); +} + llvm::TypeSize StructType::getTypeSizeInBits(const ::mlir::DataLayout &dataLayout, ::mlir::DataLayoutEntryListRef params) const { @@ -605,9 +624,9 @@ FuncType FuncType::clone(TypeRange inputs, TypeRange results) const { return get(llvm::to_vector(inputs), results[0], isVarArg()); } -mlir::ParseResult -parseFuncTypeArgs(mlir::AsmParser &p, llvm::SmallVector ¶ms, - bool &isVarArg) { +mlir::ParseResult parseFuncTypeArgs(mlir::AsmParser &p, + llvm::SmallVector ¶ms, + bool &isVarArg) { isVarArg = false; // `(` `)` if (succeeded(p.parseOptionalRParen())) @@ -637,9 +656,8 @@ parseFuncTypeArgs(mlir::AsmParser &p, llvm::SmallVector ¶ms, return p.parseRParen(); } -void printFuncTypeArgs(mlir::AsmPrinter &p, - mlir::ArrayRef params, - bool isVarArg) { +void printFuncTypeArgs(mlir::AsmPrinter &p, mlir::ArrayRef params, + bool isVarArg) { llvm::interleaveComma(params, p, [&p](mlir::Type type) { p.printType(type); }); if (isVarArg) { diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp index 182fb65d78e7..428f8f2211db 100644 --- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp +++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp @@ -232,8 +232,7 @@ mlir::Value lowerCirAttrAsValue(mlir::Operation *parentOp, } else if (auto fun = dyn_cast(sourceSymbol)) { sourceType = converter->convertType(fun.getFunctionType()); symName = fun.getSymName(); - } - else { + } else { llvm_unreachable("Unexpected GlobalOp type"); } @@ -1111,6 +1110,48 @@ class CIRConstantLowering } }; +class CIRVectorCreateLowering + : public mlir::OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(mlir::cir::VecCreateOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + // Start with an 'undef' value for the vector. Then 'insertelement' for + // each of the vector elements. + auto vecTy = op.getType().dyn_cast(); + assert(vecTy && "result type of cir.vec op is not VectorType"); + auto llvmTy = typeConverter->convertType(vecTy); + auto loc = op.getLoc(); + mlir::Value result = rewriter.create(loc, llvmTy); + assert(vecTy.getSize() == op.getElements().size() && + "cir.vec operands count doesn't match vector type elements count"); + for (uint64_t i = 0; i < vecTy.getSize(); ++i) { + mlir::Value indexValue = rewriter.create( + loc, rewriter.getI64Type(), i); + result = rewriter.create( + loc, result, adaptor.getElements()[i], indexValue); + } + rewriter.replaceOp(op, result); + return mlir::success(); + } +}; + +class CIRVectorExtractLowering + : public mlir::OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(mlir::cir::VecExtractOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp( + op, adaptor.getVec(), adaptor.getIndex()); + return mlir::success(); + } +}; + class CIRVAStartLowering : public mlir::OpConversionPattern { public: @@ -1615,13 +1656,17 @@ class CIRBinOpLowering : public mlir::OpConversionPattern { assert((op.getLhs().getType() == op.getRhs().getType()) && "inconsistent operands' types not supported yet"); mlir::Type type = op.getRhs().getType(); - assert((type.isa()) && + assert((type.isa()) && "operand type not supported yet"); auto llvmTy = getTypeConverter()->convertType(op.getType()); auto rhs = adaptor.getRhs(); auto lhs = adaptor.getLhs(); + if (type.isa()) + type = type.dyn_cast().getEltType(); + switch (op.getKind()) { case mlir::cir::BinOpKind::Add: if (type.isa()) @@ -2001,7 +2046,8 @@ void populateCIRToLLVMConversionPatterns(mlir::RewritePatternSet &patterns, CIRVAArgLowering, CIRBrOpLowering, CIRTernaryOpLowering, CIRGetMemberOpLowering, CIRSwitchOpLowering, CIRPtrDiffOpLowering, CIRCopyOpLowering, CIRMemCpyOpLowering, - CIRFAbsOpLowering, CIRVTableAddrPointOpLowering>( + CIRFAbsOpLowering, CIRVTableAddrPointOpLowering, + CIRVectorCreateLowering, CIRVectorExtractLowering>( converter, patterns.getContext()); } @@ -2016,6 +2062,10 @@ void prepareTypeConverter(mlir::LLVMTypeConverter &converter, auto ty = converter.convertType(type.getEltType()); return mlir::LLVM::LLVMArrayType::get(ty, type.getSize()); }); + converter.addConversion([&](mlir::cir::VectorType type) -> mlir::Type { + auto ty = converter.convertType(type.getEltType()); + return mlir::LLVM::getFixedVectorType(ty, type.getSize()); + }); converter.addConversion([&](mlir::cir::BoolType type) -> mlir::Type { return mlir::IntegerType::get(type.getContext(), 8, mlir::IntegerType::Signless); diff --git a/clang/test/CIR/CodeGen/vectype.cpp b/clang/test/CIR/CodeGen/vectype.cpp new file mode 100644 index 000000000000..aa4e481d1dbc --- /dev/null +++ b/clang/test/CIR/CodeGen/vectype.cpp @@ -0,0 +1,40 @@ +// RUN: %clang_cc1 -std=c++17 -triple x86_64-unknown-linux-gnu -fclangir-enable -emit-cir %s -o - | FileCheck %s + +typedef int int4 __attribute__((vector_size(16))); +int test_vector_basic(int x, int y, int z) { + int4 a = { 1, 2, 3, 4 }; + int4 b = { x, y, z, x + y + z }; + int4 c = a + b; + return c[1]; +} + +// CHECK: %4 = cir.alloca !cir.vector, cir.ptr >, ["a", init] {alignment = 16 : i64} +// CHECK: %5 = cir.alloca !cir.vector, cir.ptr >, ["b", init] {alignment = 16 : i64} +// CHECK: %6 = cir.alloca !cir.vector, cir.ptr >, ["c", init] {alignment = 16 : i64} + +// CHECK: %7 = cir.const(#cir.int<1> : !s32i) : !s32i +// CHECK: %8 = cir.const(#cir.int<2> : !s32i) : !s32i +// CHECK: %9 = cir.const(#cir.int<3> : !s32i) : !s32i +// CHECK: %10 = cir.const(#cir.int<4> : !s32i) : !s32i +// CHECK: %11 = cir.vec.create(%7, %8, %9, %10 : !s32i, !s32i, !s32i, !s32i) : +// CHECK: cir.store %11, %4 : !cir.vector, cir.ptr > +// CHECK: %12 = cir.load %0 : cir.ptr , !s32i +// CHECK: %13 = cir.load %1 : cir.ptr , !s32i +// CHECK: %14 = cir.load %2 : cir.ptr , !s32i +// CHECK: %15 = cir.load %0 : cir.ptr , !s32i +// CHECK: %16 = cir.load %1 : cir.ptr , !s32i +// CHECK: %17 = cir.binop(add, %15, %16) : !s32i +// CHECK: %18 = cir.load %2 : cir.ptr , !s32i +// CHECK: %19 = cir.binop(add, %17, %18) : !s32i +// CHECK: %20 = cir.vec.create(%12, %13, %14, %19 : !s32i, !s32i, !s32i, !s32i) : +// CHECK: cir.store %20, %5 : !cir.vector, cir.ptr > +// CHECK: %21 = cir.load %4 : cir.ptr >, !cir.vector +// CHECK: %22 = cir.load %5 : cir.ptr >, !cir.vector +// CHECK: %23 = cir.binop(add, %21, %22) : !cir.vector +// CHECK: cir.store %23, %6 : !cir.vector, cir.ptr > +// CHECK: %24 = cir.load %6 : cir.ptr >, !cir.vector +// CHECK: %25 = cir.const(#cir.int<1> : !s32i) : !s32i +// CHECK: %26 = cir.vec.extract %24[%25 : !s32i] -> !s32i +// CHECK: cir.store %26, %3 : !s32i, cir.ptr +// CHECK: %27 = cir.load %3 : cir.ptr , !s32i +// CHECK: cir.return %27 : !s32i diff --git a/clang/test/CIR/IR/invalid.cir b/clang/test/CIR/IR/invalid.cir index 5571dd030f25..d122be4d0a34 100644 --- a/clang/test/CIR/IR/invalid.cir +++ b/clang/test/CIR/IR/invalid.cir @@ -419,6 +419,35 @@ module { // ----- +!s32i = !cir.int +cir.func @vec_op_size() { + %0 = cir.const(#cir.int<1> : !s32i) : !s32i + %1 = cir.vec.create(%0 : !s32i) : // expected-error {{'cir.vec.create' op operand count of 1 doesn't match vector type '!cir.vector x 2>' element count of 2}} +} + +// ----- + +!s32i = !cir.int +!u32i = !cir.int +cir.func @vec_op_type() { + %0 = cir.const(#cir.int<1> : !s32i) : !s32i + %1 = cir.const(#cir.int<2> : !u32i) : !u32i + %2 = cir.vec.create(%0, %1 : !s32i, !u32i) : // expected-error {{'cir.vec.create' op operand type '!cir.int' doesn't match vector element type '!cir.int'}} +} + +// ----- + +!s32i = !cir.int +!u32i = !cir.int +cir.func @vec_extract_type() { + %0 = cir.const(#cir.int<1> : !s32i) : !s32i + %1 = cir.const(#cir.int<2> : !s32i) : !s32i + %2 = cir.vec.create(%0, %1 : !s32i, !s32i) : + %3 = cir.vec.extract %2[%0 : !s32i] -> !u32i // expected-error {{'cir.vec.extract' op failed to verify that type of 'result' matches element type of 'vec'}} +} + +// ----- + cir.func coroutine @bad_task() { // expected-error {{coroutine body must use at least one cir.await op}} cir.return }