From 57fa182f422fa6630abb758a037dfcc3200c8b15 Mon Sep 17 00:00:00 2001 From: Vinicius Couto Espindola Date: Tue, 22 Aug 2023 16:54:14 -0300 Subject: [PATCH] [CIR][Lowering] Lower unions Converts a union to a struct containing only its largest element. GetMemberOp for unions is lowered as bitcasts instead of GEPs, since union members share the same address space. ghstack-source-id: 791a944c5df3103c5671ba061fe4c7c3a56317fa Pull Request resolved: https://github.com/llvm/clangir/pull/230 --- .../include/clang/CIR/Dialect/IR/CIRTypes.td | 5 +- clang/lib/CIR/Dialect/IR/CIRTypes.cpp | 36 ++++++++++++ .../CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp | 57 +++++++++++++++---- clang/test/CIR/Lowering/unions.cir | 42 ++++++++++++++ 4 files changed, 128 insertions(+), 12 deletions(-) create mode 100644 clang/test/CIR/Lowering/unions.cir diff --git a/clang/include/clang/CIR/Dialect/IR/CIRTypes.td b/clang/include/clang/CIR/Dialect/IR/CIRTypes.td index abbee1419613..7c4a15f2430e 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIRTypes.td +++ b/clang/include/clang/CIR/Dialect/IR/CIRTypes.td @@ -141,12 +141,15 @@ def CIR_StructType : CIR_Type<"Struct", "struct", case RecordKind::Class: return "class." + name; case RecordKind::Union: - return "union "+ name; + return "union." + name; case RecordKind::Struct: return "struct." + name; } } + /// Return the member with the largest bit-length. + mlir::Type getLargestMember(const ::mlir::DataLayout &dataLayout) const; + /// Return whether this is a class declaration. bool isClass() const { return getKind() == RecordKind::Class; } diff --git a/clang/lib/CIR/Dialect/IR/CIRTypes.cpp b/clang/lib/CIR/Dialect/IR/CIRTypes.cpp index 4acb3bc38de2..55fa7e46e0f4 100644 --- a/clang/lib/CIR/Dialect/IR/CIRTypes.cpp +++ b/clang/lib/CIR/Dialect/IR/CIRTypes.cpp @@ -18,6 +18,7 @@ #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/DialectImplementation.h" +#include "mlir/Interfaces/DataLayoutInterfaces.h" #include "mlir/Support/LogicalResult.h" #include "llvm/ADT/STLExtras.h" @@ -94,6 +95,29 @@ void BoolType::print(mlir::AsmPrinter &printer) const {} // StructType Definitions //===----------------------------------------------------------------------===// +/// Return the largest member of in the type. +/// +/// Recurses into union members never returning a union as the largest member. +Type StructType::getLargestMember(const ::mlir::DataLayout &dataLayout) const { + unsigned largestTySize = 0; + mlir::Type largestTy = nullptr; + + for (auto type : getMembers()) { + // Found a nested union: fetch its largest member. + if (auto structMember = type.dyn_cast()) + if (structMember.isUnion()) + type = structMember.getLargestMember(dataLayout); + + // Update largest member. + if (dataLayout.getTypeSize(type) > largestTySize) { + largestTy = type; + largestTySize = dataLayout.getTypeSize(largestTy); + } + } + + return largestTy; +} + Type StructType::parse(mlir::AsmParser &parser) { const auto loc = parser.getCurrentLocation(); llvm::SmallVector members; @@ -288,6 +312,18 @@ void StructType::computeSizeAndAlignment( unsigned numElements = getNumElements(); auto members = getMembers(); + // For unions, the size and aligment is that of the largest element. + if (isUnion()) { + for (auto ty : members) { + if (dataLayout.getTypeSize(ty) > structSize) { + size = dataLayout.getTypeSize(ty); + align = structAlignment.value(); + padded = false; + } + } + return; + } + // Loop over each of the elements, placing them in memory. for (unsigned i = 0, e = numElements; i != e; ++i) { auto ty = members[i]; diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp index 64641e7cf0e1..9d903ed2c892 100644 --- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp +++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp @@ -54,9 +54,11 @@ #include "clang/CIR/Dialect/IR/CIRTypes.h" #include "clang/CIR/Passes.h" #include "llvm/ADT/APInt.h" +#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Sequence.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/IR/DataLayout.h" #include "llvm/IR/DebugInfoMetadata.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/Support/Casting.h" @@ -1706,12 +1708,29 @@ class CIRGetMemberOpLowering matchAndRewrite(mlir::cir::GetMemberOp op, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { auto llResTy = getTypeConverter()->convertType(op.getType()); - // Since the base address is a pointer to structs, the first offset is - // always zero. The second offset tell us which member it will access. - llvm::SmallVector offset{0, op.getIndex()}; - rewriter.replaceOpWithNewOp(op, llResTy, - adaptor.getAddr(), offset); - return mlir::success(); + const auto structTy = + op.getAddrTy().getPointee().cast(); + assert(structTy && "expected struct type"); + + switch (structTy.getKind()) { + case mlir::cir::StructType::Struct: { + // Since the base address is a pointer to an aggregate, the first offset + // is always zero. The second offset tell us which member it will access. + llvm::SmallVector offset{0, op.getIndex()}; + rewriter.replaceOpWithNewOp(op, llResTy, + adaptor.getAddr(), offset); + return mlir::success(); + } + case mlir::cir::StructType::Union: + // Union members share the address space, so we just need a bitcast to + // conform to type-checking. + rewriter.replaceOpWithNewOp(op, llResTy, + adaptor.getAddr()); + return mlir::success(); + default: + return op.emitError() + << "struct kind '" << structTy.getKind() << "' is NYI"; + } } }; @@ -1772,7 +1791,8 @@ void populateCIRToLLVMConversionPatterns(mlir::RewritePatternSet &patterns, } namespace { -mlir::LLVMTypeConverter prepareTypeConverter(mlir::MLIRContext *ctx) { +mlir::LLVMTypeConverter prepareTypeConverter(mlir::MLIRContext *ctx, + mlir::DataLayout &dataLayout) { mlir::LLVMTypeConverter converter(ctx); converter.addConversion([&](mlir::cir::PointerType type) -> mlir::Type { auto ty = converter.convertType(type.getPointee()); @@ -1802,9 +1822,24 @@ mlir::LLVMTypeConverter prepareTypeConverter(mlir::MLIRContext *ctx) { return mlir::LLVM::LLVMFunctionType::get(result, arguments, varArg); }); converter.addConversion([&](mlir::cir::StructType type) -> mlir::Type { + // FIXME(cir): create separate unions, struct, and classes types. + // Convert struct members. llvm::SmallVector llvmMembers; - for (auto ty : type.getMembers()) - llvmMembers.push_back(converter.convertType(ty)); + switch (type.getKind()) { + case mlir::cir::StructType::Class: + // TODO(cir): This should be properly validated. + case mlir::cir::StructType::Struct: + for (auto ty : type.getMembers()) + llvmMembers.push_back(converter.convertType(ty)); + break; + // Unions are lowered as only the largest member. + case mlir::cir::StructType::Union: { + auto largestMember = type.getLargestMember(dataLayout); + if (largestMember) + llvmMembers.push_back(converter.convertType(largestMember)); + break; + } + } // Struct has a name: lower as an identified struct. mlir::LLVM::LLVMStructType llvmStruct; @@ -1831,8 +1866,8 @@ mlir::LLVMTypeConverter prepareTypeConverter(mlir::MLIRContext *ctx) { void ConvertCIRToLLVMPass::runOnOperation() { auto module = getOperation(); - - auto converter = prepareTypeConverter(&getContext()); + mlir::DataLayout dataLayout(module); + auto converter = prepareTypeConverter(&getContext(), dataLayout); mlir::RewritePatternSet patterns(&getContext()); diff --git a/clang/test/CIR/Lowering/unions.cir b/clang/test/CIR/Lowering/unions.cir new file mode 100644 index 000000000000..3e303c377f7e --- /dev/null +++ b/clang/test/CIR/Lowering/unions.cir @@ -0,0 +1,42 @@ +// RUN: cir-opt %s -cir-to-llvm -o %t.mlir +// RUN: FileCheck --input-file=%t.mlir %s + +!s16i = !cir.int +!s32i = !cir.int +#true = #cir.bool : !cir.bool +!ty_22U122 = !cir.struct +!ty_22U222 = !cir.struct +!ty_22U322 = !cir.struct +module { + // Should lower union to struct with only the largest member. + cir.global external @u1 = #cir.zero : !ty_22U122 + // CHECK: llvm.mlir.global external @u1() {addr_space = 0 : i32} : !llvm.struct<"union.U1", (i32)> + + // Should recursively find the largest member if there are nested unions. + cir.global external @u2 = #cir.zero : !ty_22U222 + cir.global external @u3 = #cir.zero : !ty_22U322 + // CHECK: llvm.mlir.global external @u2() {addr_space = 0 : i32} : !llvm.struct<"union.U2", (f64)> + // CHECK: llvm.mlir.global external @u3() {addr_space = 0 : i32} : !llvm.struct<"union.U3", (i32)> + + // CHECK: llvm.func @test + cir.func @test(%arg0: !cir.ptr) { + + // Should store directly to the union's base address. + %5 = cir.const(#true) : !cir.bool + %6 = cir.get_member %arg0[0] {name = "b"} : !cir.ptr -> !cir.ptr + cir.store %5, %6 : !cir.bool, cir.ptr + // CHECK: %[[#VAL:]] = llvm.mlir.constant(1 : i8) : i8 + // The bitcast it just to bypass the type checker. It will be replaced by an opaque pointer. + // CHECK: %[[#ADDR:]] = llvm.bitcast %{{.+}} : !llvm.ptr> to !llvm.ptr + // CHECK: llvm.store %[[#VAL]], %[[#ADDR]] : !llvm.ptr + + // Should load direclty from the union's base address. + %7 = cir.get_member %arg0[0] {name = "b"} : !cir.ptr -> !cir.ptr + %8 = cir.load %7 : cir.ptr , !cir.bool + // The bitcast it just to bypass the type checker. It will be replaced by an opaque pointer. + // CHECK: %[[#BASE:]] = llvm.bitcast %{{.+}} : !llvm.ptr> to !llvm.ptr + // CHECK: %{{.+}} = llvm.load %[[#BASE]] : !llvm.ptr + + cir.return + } +}