diff --git a/clang/include/clang/CIR/Dialect/IR/CIRTypes.td b/clang/include/clang/CIR/Dialect/IR/CIRTypes.td index abbee1419613..88087f8915ad 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIRTypes.td +++ b/clang/include/clang/CIR/Dialect/IR/CIRTypes.td @@ -128,6 +128,7 @@ def CIR_StructType : CIR_Type<"Struct", "struct", // for the struct size and alignment. mutable std::optional size{}, align{}; mutable std::optional padded{}; + mutable mlir::Type largestMember{}; void computeSizeAndAlignment(const ::mlir::DataLayout &dataLayout) const; public: void dropAst(); @@ -141,12 +142,15 @@ def CIR_StructType : CIR_Type<"Struct", "struct", case RecordKind::Class: return "class." + name; case RecordKind::Union: - return "union "+ name; + return "union." + name; case RecordKind::Struct: return "struct." + name; } } + /// Return the member with the largest bit-length. + mlir::Type getLargestMember(const ::mlir::DataLayout &dataLayout) const; + /// Return whether this is a class declaration. bool isClass() const { return getKind() == RecordKind::Class; } diff --git a/clang/lib/CIR/Dialect/IR/CIRTypes.cpp b/clang/lib/CIR/Dialect/IR/CIRTypes.cpp index 47691a7fc463..241f570b8719 100644 --- a/clang/lib/CIR/Dialect/IR/CIRTypes.cpp +++ b/clang/lib/CIR/Dialect/IR/CIRTypes.cpp @@ -18,6 +18,7 @@ #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/DialectImplementation.h" +#include "mlir/Interfaces/DataLayoutInterfaces.h" #include "mlir/Support/LogicalResult.h" #include "llvm/ADT/STLExtras.h" @@ -94,6 +95,15 @@ void BoolType::print(mlir::AsmPrinter &printer) const {} // StructType Definitions //===----------------------------------------------------------------------===// +/// Return the largest member of in the type. +/// +/// Recurses into union members never returning a union as the largest member. +Type StructType::getLargestMember(const ::mlir::DataLayout &dataLayout) const { + if (!largestMember) + computeSizeAndAlignment(dataLayout); + return largestMember; +} + Type StructType::parse(mlir::AsmParser &parser) { const auto loc = parser.getCurrentLocation(); llvm::SmallVector members; @@ -278,7 +288,7 @@ void StructType::computeSizeAndAlignment( const ::mlir::DataLayout &dataLayout) const { assert(!isOpaque() && "Cannot get layout of opaque structs"); // Do not recompute. - if (size || align || padded) + if (size || align || padded || largestMember) return; // This is a similar algorithm to LLVM's StructLayout. @@ -287,11 +297,25 @@ void StructType::computeSizeAndAlignment( [[maybe_unused]] bool isPadded = false; unsigned numElements = getNumElements(); auto members = getMembers(); + unsigned largestMemberSize = 0; // Loop over each of the elements, placing them in memory. for (unsigned i = 0, e = numElements; i != e; ++i) { auto ty = members[i]; + // Found a nested union: recurse into it to fetch its largest member. + auto structMember = ty.dyn_cast(); + if (structMember && structMember.isUnion()) { + auto candidate = structMember.getLargestMember(dataLayout); + if (dataLayout.getTypeSize(candidate) > largestMemberSize) { + largestMember = candidate; + largestMemberSize = dataLayout.getTypeSize(largestMember); + } + } else if (dataLayout.getTypeSize(ty) > largestMemberSize) { + largestMember = ty; + largestMemberSize = dataLayout.getTypeSize(largestMember); + } + // This matches LLVM since it uses the ABI instead of preferred alignment. const llvm::Align tyAlign = llvm::Align(getPacked() ? 1 : dataLayout.getTypeABIAlignment(ty)); @@ -312,6 +336,14 @@ void StructType::computeSizeAndAlignment( structSize += dataLayout.getTypeSize(ty); } + // For unions, the size and aligment is that of the largest element. + if (isUnion()) { + size = largestMemberSize; + align = structAlignment.value(); + padded = false; + return; + } + // Add padding to the end of the struct so that it could be put in an array // and all array elements would be aligned correctly. if (!llvm::isAligned(structAlignment, structSize)) { diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp index 011984399c83..ffa3f099d703 100644 --- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp +++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp @@ -54,10 +54,12 @@ #include "clang/CIR/Dialect/IR/CIRTypes.h" #include "clang/CIR/Passes.h" #include "llvm/ADT/APInt.h" +#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Sequence.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" +#include "llvm/IR/DataLayout.h" #include "llvm/IR/DebugInfoMetadata.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/Support/Casting.h" @@ -1306,8 +1308,8 @@ class CIRGlobalOpLowering sourceSymbol.getSymName()); llvm::SmallVector offset{0}; auto gepOp = rewriter.create( - loc, llvmType, sourceSymbol.getType(), - addressOfOp.getResult(), offset); + loc, llvmType, sourceSymbol.getType(), addressOfOp.getResult(), + offset); rewriter.create(loc, gepOp.getResult()); return mlir::success(); } else if (isa(init.value())) { @@ -1721,14 +1723,30 @@ class CIRGetMemberOpLowering matchAndRewrite(mlir::cir::GetMemberOp op, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { auto llResTy = getTypeConverter()->convertType(op.getType()); - // Since the base address is a pointer to structs, the first offset is - // always zero. The second offset tell us which member it will access. - llvm::SmallVector offset{0, op.getIndex()}; - const auto elementTy = getTypeConverter()->convertType( - op.getAddr().getType().getPointee()); - rewriter.replaceOpWithNewOp( - op, llResTy, elementTy, adaptor.getAddr(), offset); - return mlir::success(); + const auto structTy = + op.getAddrTy().getPointee().cast(); + assert(structTy && "expected struct type"); + + switch (structTy.getKind()) { + case mlir::cir::StructType::Struct: { + // Since the base address is a pointer to an aggregate, the first offset + // is always zero. The second offset tell us which member it will access. + llvm::SmallVector offset{0, op.getIndex()}; + const auto elementTy = getTypeConverter()->convertType(structTy); + rewriter.replaceOpWithNewOp(op, llResTy, elementTy, + adaptor.getAddr(), offset); + return mlir::success(); + } + case mlir::cir::StructType::Union: + // Union members share the address space, so we just need a bitcast to + // conform to type-checking. + rewriter.replaceOpWithNewOp(op, llResTy, + adaptor.getAddr()); + return mlir::success(); + default: + return op.emitError() + << "struct kind '" << structTy.getKind() << "' is NYI"; + } } }; @@ -1789,7 +1807,8 @@ void populateCIRToLLVMConversionPatterns(mlir::RewritePatternSet &patterns, } namespace { -void prepareTypeConverter(mlir::LLVMTypeConverter &converter) { +void prepareTypeConverter(mlir::LLVMTypeConverter &converter, + mlir::DataLayout &dataLayout) { converter.addConversion([&](mlir::cir::PointerType type) -> mlir::Type { return mlir::LLVM::LLVMPointerType::get(&converter.getContext()); }); @@ -1814,9 +1833,24 @@ void prepareTypeConverter(mlir::LLVMTypeConverter &converter) { return mlir::LLVM::LLVMFunctionType::get(result, arguments, varArg); }); converter.addConversion([&](mlir::cir::StructType type) -> mlir::Type { + // FIXME(cir): create separate unions, struct, and classes types. + // Convert struct members. llvm::SmallVector llvmMembers; - for (auto ty : type.getMembers()) - llvmMembers.push_back(converter.convertType(ty)); + switch (type.getKind()) { + case mlir::cir::StructType::Class: + // TODO(cir): This should be properly validated. + case mlir::cir::StructType::Struct: + for (auto ty : type.getMembers()) + llvmMembers.push_back(converter.convertType(ty)); + break; + // Unions are lowered as only the largest member. + case mlir::cir::StructType::Union: { + auto largestMember = type.getLargestMember(dataLayout); + if (largestMember) + llvmMembers.push_back(converter.convertType(largestMember)); + break; + } + } // Struct has a name: lower as an identified struct. mlir::LLVM::LLVMStructType llvmStruct; @@ -1847,7 +1881,7 @@ static void buildCtorList(mlir::ModuleOp module) { assert(attr.isa() && "must be a GlobalCtorAttr"); if (auto ctorAttr = attr.cast()) { - // default priority is 65536 + // default priority is 65536 int priority = 65536; if (ctorAttr.getPriority()) priority = *ctorAttr.getPriority(); @@ -1885,15 +1919,15 @@ static void buildCtorList(mlir::ModuleOp module) { newGlobalOp.getRegion().push_back(new mlir::Block()); builder.setInsertionPointToEnd(newGlobalOp.getInitializerBlock()); - mlir::Value result = builder.create( - loc, CtorStructArrayTy); + mlir::Value result = + builder.create(loc, CtorStructArrayTy); for (uint64_t I = 0; I < globalCtors.size(); I++) { auto fn = globalCtors[I]; mlir::Value structInit = builder.create(loc, CtorStructTy); - mlir::Value initPriority = - builder.create(loc, CtorStructFields[0], fn.second); + mlir::Value initPriority = builder.create( + loc, CtorStructFields[0], fn.second); mlir::Value initFuncAddr = builder.create( loc, CtorStructFields[1], fn.first); mlir::Value initAssociate = @@ -1914,9 +1948,9 @@ static void buildCtorList(mlir::ModuleOp module) { void ConvertCIRToLLVMPass::runOnOperation() { auto module = getOperation(); - + mlir::DataLayout dataLayout(module); mlir::LLVMTypeConverter converter(&getContext()); - prepareTypeConverter(converter); + prepareTypeConverter(converter, dataLayout); mlir::RewritePatternSet patterns(&getContext()); diff --git a/clang/test/CIR/Lowering/unions.cir b/clang/test/CIR/Lowering/unions.cir new file mode 100644 index 000000000000..c5ee736c4a7d --- /dev/null +++ b/clang/test/CIR/Lowering/unions.cir @@ -0,0 +1,42 @@ +// RUN: cir-opt %s -cir-to-llvm -o %t.mlir +// RUN: FileCheck --input-file=%t.mlir %s + +!s16i = !cir.int +!s32i = !cir.int +#true = #cir.bool : !cir.bool +!ty_22U122 = !cir.struct +!ty_22U222 = !cir.struct +!ty_22U322 = !cir.struct +module { + // Should lower union to struct with only the largest member. + cir.global external @u1 = #cir.zero : !ty_22U122 + // CHECK: llvm.mlir.global external @u1() {addr_space = 0 : i32} : !llvm.struct<"union.U1", (i32)> + + // Should recursively find the largest member if there are nested unions. + cir.global external @u2 = #cir.zero : !ty_22U222 + cir.global external @u3 = #cir.zero : !ty_22U322 + // CHECK: llvm.mlir.global external @u2() {addr_space = 0 : i32} : !llvm.struct<"union.U2", (f64)> + // CHECK: llvm.mlir.global external @u3() {addr_space = 0 : i32} : !llvm.struct<"union.U3", (i32)> + + // CHECK: llvm.func @test + cir.func @test(%arg0: !cir.ptr) { + + // Should store directly to the union's base address. + %5 = cir.const(#true) : !cir.bool + %6 = cir.get_member %arg0[0] {name = "b"} : !cir.ptr -> !cir.ptr + cir.store %5, %6 : !cir.bool, cir.ptr + // CHECK: %[[#VAL:]] = llvm.mlir.constant(1 : i8) : i8 + // The bitcast it just to bypass the type checker. It will be replaced by an opaque pointer. + // CHECK: %[[#ADDR:]] = llvm.bitcast %{{.+}} : !llvm.ptr to !llvm.ptr + // CHECK: llvm.store %[[#VAL]], %[[#ADDR]] : i8, !llvm.ptr + + // Should load direclty from the union's base address. + %7 = cir.get_member %arg0[0] {name = "b"} : !cir.ptr -> !cir.ptr + %8 = cir.load %7 : cir.ptr , !cir.bool + // The bitcast it just to bypass the type checker. It will be replaced by an opaque pointer. + // CHECK: %[[#BASE:]] = llvm.bitcast %{{.+}} : !llvm.ptr to !llvm.ptr + // CHECK: %{{.+}} = llvm.load %[[#BASE]] : !llvm.ptr + + cir.return + } +}