Skip to content

Commit

Permalink
[CIR][Lowering] Lower unions
Browse files Browse the repository at this point in the history
Converts a union to a struct containing only its largest element.
GetMemberOp for unions is lowered as bitcasts instead of GEPs, since
union members share the same address space.

ghstack-source-id: 744ac312675b8f3225ccc459fcd09474bcfcfe81
Pull Request resolved: #230
  • Loading branch information
sitio-couto authored and lanza committed Apr 29, 2024
1 parent ef2a9fe commit 57e081f
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 22 deletions.
6 changes: 5 additions & 1 deletion clang/include/clang/CIR/Dialect/IR/CIRTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def CIR_StructType : CIR_Type<"Struct", "struct",
// for the struct size and alignment.
mutable std::optional<unsigned> size{}, align{};
mutable std::optional<bool> padded{};
mutable mlir::Type largestMember{};
void computeSizeAndAlignment(const ::mlir::DataLayout &dataLayout) const;
public:
void dropAst();
Expand All @@ -141,12 +142,15 @@ def CIR_StructType : CIR_Type<"Struct", "struct",
case RecordKind::Class:
return "class." + name;
case RecordKind::Union:
return "union "+ name;
return "union." + name;
case RecordKind::Struct:
return "struct." + name;
}
}

/// Return the member with the largest bit-length.
mlir::Type getLargestMember(const ::mlir::DataLayout &dataLayout) const;

/// Return whether this is a class declaration.
bool isClass() const { return getKind() == RecordKind::Class; }

Expand Down
34 changes: 33 additions & 1 deletion clang/lib/CIR/Dialect/IR/CIRTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/Interfaces/DataLayoutInterfaces.h"
#include "mlir/Support/LogicalResult.h"

#include "llvm/ADT/STLExtras.h"
Expand Down Expand Up @@ -94,6 +95,15 @@ void BoolType::print(mlir::AsmPrinter &printer) const {}
// StructType Definitions
//===----------------------------------------------------------------------===//

/// Return the largest member of in the type.
///
/// Recurses into union members never returning a union as the largest member.
Type StructType::getLargestMember(const ::mlir::DataLayout &dataLayout) const {
if (!largestMember)
computeSizeAndAlignment(dataLayout);
return largestMember;
}

Type StructType::parse(mlir::AsmParser &parser) {
const auto loc = parser.getCurrentLocation();
llvm::SmallVector<mlir::Type> members;
Expand Down Expand Up @@ -278,7 +288,7 @@ void StructType::computeSizeAndAlignment(
const ::mlir::DataLayout &dataLayout) const {
assert(!isOpaque() && "Cannot get layout of opaque structs");
// Do not recompute.
if (size || align || padded)
if (size || align || padded || largestMember)
return;

// This is a similar algorithm to LLVM's StructLayout.
Expand All @@ -287,11 +297,25 @@ void StructType::computeSizeAndAlignment(
[[maybe_unused]] bool isPadded = false;
unsigned numElements = getNumElements();
auto members = getMembers();
unsigned largestMemberSize = 0;

// Loop over each of the elements, placing them in memory.
for (unsigned i = 0, e = numElements; i != e; ++i) {
auto ty = members[i];

// Found a nested union: recurse into it to fetch its largest member.
auto structMember = ty.dyn_cast<StructType>();
if (structMember && structMember.isUnion()) {
auto candidate = structMember.getLargestMember(dataLayout);
if (dataLayout.getTypeSize(candidate) > largestMemberSize) {
largestMember = candidate;
largestMemberSize = dataLayout.getTypeSize(largestMember);
}
} else if (dataLayout.getTypeSize(ty) > largestMemberSize) {
largestMember = ty;
largestMemberSize = dataLayout.getTypeSize(largestMember);
}

// This matches LLVM since it uses the ABI instead of preferred alignment.
const llvm::Align tyAlign =
llvm::Align(getPacked() ? 1 : dataLayout.getTypeABIAlignment(ty));
Expand All @@ -312,6 +336,14 @@ void StructType::computeSizeAndAlignment(
structSize += dataLayout.getTypeSize(ty);
}

// For unions, the size and aligment is that of the largest element.
if (isUnion()) {
size = largestMemberSize;
align = structAlignment.value();
padded = false;
return;
}

// Add padding to the end of the struct so that it could be put in an array
// and all array elements would be aligned correctly.
if (!llvm::isAligned(structAlignment, structSize)) {
Expand Down
74 changes: 54 additions & 20 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,12 @@
#include "clang/CIR/Dialect/IR/CIRTypes.h"
#include "clang/CIR/Passes.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/DebugInfoMetadata.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/Support/Casting.h"
Expand Down Expand Up @@ -1306,8 +1308,8 @@ class CIRGlobalOpLowering
sourceSymbol.getSymName());
llvm::SmallVector<mlir::LLVM::GEPArg> offset{0};
auto gepOp = rewriter.create<mlir::LLVM::GEPOp>(
loc, llvmType, sourceSymbol.getType(),
addressOfOp.getResult(), offset);
loc, llvmType, sourceSymbol.getType(), addressOfOp.getResult(),
offset);
rewriter.create<mlir::LLVM::ReturnOp>(loc, gepOp.getResult());
return mlir::success();
} else if (isa<mlir::cir::ZeroAttr, mlir::cir::NullAttr>(init.value())) {
Expand Down Expand Up @@ -1721,14 +1723,30 @@ class CIRGetMemberOpLowering
matchAndRewrite(mlir::cir::GetMemberOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
auto llResTy = getTypeConverter()->convertType(op.getType());
// Since the base address is a pointer to structs, the first offset is
// always zero. The second offset tell us which member it will access.
llvm::SmallVector<mlir::LLVM::GEPArg> offset{0, op.getIndex()};
const auto elementTy = getTypeConverter()->convertType(
op.getAddr().getType().getPointee());
rewriter.replaceOpWithNewOp<mlir::LLVM::GEPOp>(
op, llResTy, elementTy, adaptor.getAddr(), offset);
return mlir::success();
const auto structTy =
op.getAddrTy().getPointee().cast<mlir::cir::StructType>();
assert(structTy && "expected struct type");

switch (structTy.getKind()) {
case mlir::cir::StructType::Struct: {
// Since the base address is a pointer to an aggregate, the first offset
// is always zero. The second offset tell us which member it will access.
llvm::SmallVector<mlir::LLVM::GEPArg, 2> offset{0, op.getIndex()};
const auto elementTy = getTypeConverter()->convertType(structTy);
rewriter.replaceOpWithNewOp<mlir::LLVM::GEPOp>(op, llResTy, elementTy,
adaptor.getAddr(), offset);
return mlir::success();
}
case mlir::cir::StructType::Union:
// Union members share the address space, so we just need a bitcast to
// conform to type-checking.
rewriter.replaceOpWithNewOp<mlir::LLVM::BitcastOp>(op, llResTy,
adaptor.getAddr());
return mlir::success();
default:
return op.emitError()
<< "struct kind '" << structTy.getKind() << "' is NYI";
}
}
};

Expand Down Expand Up @@ -1789,7 +1807,8 @@ void populateCIRToLLVMConversionPatterns(mlir::RewritePatternSet &patterns,
}

namespace {
void prepareTypeConverter(mlir::LLVMTypeConverter &converter) {
void prepareTypeConverter(mlir::LLVMTypeConverter &converter,
mlir::DataLayout &dataLayout) {
converter.addConversion([&](mlir::cir::PointerType type) -> mlir::Type {
return mlir::LLVM::LLVMPointerType::get(&converter.getContext());
});
Expand All @@ -1814,9 +1833,24 @@ void prepareTypeConverter(mlir::LLVMTypeConverter &converter) {
return mlir::LLVM::LLVMFunctionType::get(result, arguments, varArg);
});
converter.addConversion([&](mlir::cir::StructType type) -> mlir::Type {
// FIXME(cir): create separate unions, struct, and classes types.
// Convert struct members.
llvm::SmallVector<mlir::Type> llvmMembers;
for (auto ty : type.getMembers())
llvmMembers.push_back(converter.convertType(ty));
switch (type.getKind()) {
case mlir::cir::StructType::Class:
// TODO(cir): This should be properly validated.
case mlir::cir::StructType::Struct:
for (auto ty : type.getMembers())
llvmMembers.push_back(converter.convertType(ty));
break;
// Unions are lowered as only the largest member.
case mlir::cir::StructType::Union: {
auto largestMember = type.getLargestMember(dataLayout);
if (largestMember)
llvmMembers.push_back(converter.convertType(largestMember));
break;
}
}

// Struct has a name: lower as an identified struct.
mlir::LLVM::LLVMStructType llvmStruct;
Expand Down Expand Up @@ -1847,7 +1881,7 @@ static void buildCtorList(mlir::ModuleOp module) {
assert(attr.isa<mlir::cir::GlobalCtorAttr>() &&
"must be a GlobalCtorAttr");
if (auto ctorAttr = attr.cast<mlir::cir::GlobalCtorAttr>()) {
// default priority is 65536
// default priority is 65536
int priority = 65536;
if (ctorAttr.getPriority())
priority = *ctorAttr.getPriority();
Expand Down Expand Up @@ -1885,15 +1919,15 @@ static void buildCtorList(mlir::ModuleOp module) {
newGlobalOp.getRegion().push_back(new mlir::Block());
builder.setInsertionPointToEnd(newGlobalOp.getInitializerBlock());

mlir::Value result = builder.create<mlir::LLVM::UndefOp>(
loc, CtorStructArrayTy);
mlir::Value result =
builder.create<mlir::LLVM::UndefOp>(loc, CtorStructArrayTy);

for (uint64_t I = 0; I < globalCtors.size(); I++) {
auto fn = globalCtors[I];
mlir::Value structInit =
builder.create<mlir::LLVM::UndefOp>(loc, CtorStructTy);
mlir::Value initPriority =
builder.create<mlir::LLVM::ConstantOp>(loc, CtorStructFields[0], fn.second);
mlir::Value initPriority = builder.create<mlir::LLVM::ConstantOp>(
loc, CtorStructFields[0], fn.second);
mlir::Value initFuncAddr = builder.create<mlir::LLVM::AddressOfOp>(
loc, CtorStructFields[1], fn.first);
mlir::Value initAssociate =
Expand All @@ -1914,9 +1948,9 @@ static void buildCtorList(mlir::ModuleOp module) {

void ConvertCIRToLLVMPass::runOnOperation() {
auto module = getOperation();

mlir::DataLayout dataLayout(module);
mlir::LLVMTypeConverter converter(&getContext());
prepareTypeConverter(converter);
prepareTypeConverter(converter, dataLayout);

mlir::RewritePatternSet patterns(&getContext());

Expand Down
42 changes: 42 additions & 0 deletions clang/test/CIR/Lowering/unions.cir
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// RUN: cir-opt %s -cir-to-llvm -o %t.mlir
// RUN: FileCheck --input-file=%t.mlir %s

!s16i = !cir.int<s, 16>
!s32i = !cir.int<s, 32>
#true = #cir.bool<true> : !cir.bool
!ty_22U122 = !cir.struct<union "U1" {!cir.bool, !s16i, !s32i} #cir.recdecl.ast>
!ty_22U222 = !cir.struct<union "U2" {f64, !ty_22U122} #cir.recdecl.ast>
!ty_22U322 = !cir.struct<union "U3" {!s16i, !ty_22U122} #cir.recdecl.ast>
module {
// Should lower union to struct with only the largest member.
cir.global external @u1 = #cir.zero : !ty_22U122
// CHECK: llvm.mlir.global external @u1() {addr_space = 0 : i32} : !llvm.struct<"union.U1", (i32)>

// Should recursively find the largest member if there are nested unions.
cir.global external @u2 = #cir.zero : !ty_22U222
cir.global external @u3 = #cir.zero : !ty_22U322
// CHECK: llvm.mlir.global external @u2() {addr_space = 0 : i32} : !llvm.struct<"union.U2", (f64)>
// CHECK: llvm.mlir.global external @u3() {addr_space = 0 : i32} : !llvm.struct<"union.U3", (i32)>

// CHECK: llvm.func @test
cir.func @test(%arg0: !cir.ptr<!ty_22U122>) {

// Should store directly to the union's base address.
%5 = cir.const(#true) : !cir.bool
%6 = cir.get_member %arg0[0] {name = "b"} : !cir.ptr<!ty_22U122> -> !cir.ptr<!cir.bool>
cir.store %5, %6 : !cir.bool, cir.ptr <!cir.bool>
// CHECK: %[[#VAL:]] = llvm.mlir.constant(1 : i8) : i8
// The bitcast it just to bypass the type checker. It will be replaced by an opaque pointer.
// CHECK: %[[#ADDR:]] = llvm.bitcast %{{.+}} : !llvm.ptr to !llvm.ptr
// CHECK: llvm.store %[[#VAL]], %[[#ADDR]] : i8, !llvm.ptr

// Should load direclty from the union's base address.
%7 = cir.get_member %arg0[0] {name = "b"} : !cir.ptr<!ty_22U122> -> !cir.ptr<!cir.bool>
%8 = cir.load %7 : cir.ptr <!cir.bool>, !cir.bool
// The bitcast it just to bypass the type checker. It will be replaced by an opaque pointer.
// CHECK: %[[#BASE:]] = llvm.bitcast %{{.+}} : !llvm.ptr to !llvm.ptr
// CHECK: %{{.+}} = llvm.load %[[#BASE]] : !llvm.ptr

cir.return
}
}

0 comments on commit 57e081f

Please sign in to comment.