Skip to content

Commit

Permalink
[CIR][Lowering] Lower unions
Browse files Browse the repository at this point in the history
Converts a union to a struct containing only its largest element.
GetMemberOp for unions is lowered as bitcasts instead of GEPs, since
union members share the same address space.

ghstack-source-id: 791a944c5df3103c5671ba061fe4c7c3a56317fa
Pull Request resolved: #230
  • Loading branch information
sitio-couto committed Aug 22, 2023
1 parent 1eeb982 commit 57fa182
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 12 deletions.
5 changes: 4 additions & 1 deletion clang/include/clang/CIR/Dialect/IR/CIRTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -141,12 +141,15 @@ def CIR_StructType : CIR_Type<"Struct", "struct",
case RecordKind::Class:
return "class." + name;
case RecordKind::Union:
return "union "+ name;
return "union." + name;
case RecordKind::Struct:
return "struct." + name;
}
}

/// Return the member with the largest bit-length.
mlir::Type getLargestMember(const ::mlir::DataLayout &dataLayout) const;

/// Return whether this is a class declaration.
bool isClass() const { return getKind() == RecordKind::Class; }

Expand Down
36 changes: 36 additions & 0 deletions clang/lib/CIR/Dialect/IR/CIRTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/Interfaces/DataLayoutInterfaces.h"
#include "mlir/Support/LogicalResult.h"

#include "llvm/ADT/STLExtras.h"
Expand Down Expand Up @@ -94,6 +95,29 @@ void BoolType::print(mlir::AsmPrinter &printer) const {}
// StructType Definitions
//===----------------------------------------------------------------------===//

/// Return the largest member of in the type.
///
/// Recurses into union members never returning a union as the largest member.
Type StructType::getLargestMember(const ::mlir::DataLayout &dataLayout) const {
unsigned largestTySize = 0;
mlir::Type largestTy = nullptr;

for (auto type : getMembers()) {
// Found a nested union: fetch its largest member.
if (auto structMember = type.dyn_cast<StructType>())
if (structMember.isUnion())
type = structMember.getLargestMember(dataLayout);

// Update largest member.
if (dataLayout.getTypeSize(type) > largestTySize) {
largestTy = type;
largestTySize = dataLayout.getTypeSize(largestTy);
}
}

return largestTy;
}

Type StructType::parse(mlir::AsmParser &parser) {
const auto loc = parser.getCurrentLocation();
llvm::SmallVector<mlir::Type> members;
Expand Down Expand Up @@ -288,6 +312,18 @@ void StructType::computeSizeAndAlignment(
unsigned numElements = getNumElements();
auto members = getMembers();

// For unions, the size and aligment is that of the largest element.
if (isUnion()) {
for (auto ty : members) {
if (dataLayout.getTypeSize(ty) > structSize) {
size = dataLayout.getTypeSize(ty);
align = structAlignment.value();
padded = false;
}
}
return;
}

// Loop over each of the elements, placing them in memory.
for (unsigned i = 0, e = numElements; i != e; ++i) {
auto ty = members[i];
Expand Down
57 changes: 46 additions & 11 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,11 @@
#include "clang/CIR/Dialect/IR/CIRTypes.h"
#include "clang/CIR/Passes.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/DebugInfoMetadata.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/Support/Casting.h"
Expand Down Expand Up @@ -1706,12 +1708,29 @@ class CIRGetMemberOpLowering
matchAndRewrite(mlir::cir::GetMemberOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
auto llResTy = getTypeConverter()->convertType(op.getType());
// Since the base address is a pointer to structs, the first offset is
// always zero. The second offset tell us which member it will access.
llvm::SmallVector<mlir::LLVM::GEPArg> offset{0, op.getIndex()};
rewriter.replaceOpWithNewOp<mlir::LLVM::GEPOp>(op, llResTy,
adaptor.getAddr(), offset);
return mlir::success();
const auto structTy =
op.getAddrTy().getPointee().cast<mlir::cir::StructType>();
assert(structTy && "expected struct type");

switch (structTy.getKind()) {
case mlir::cir::StructType::Struct: {
// Since the base address is a pointer to an aggregate, the first offset
// is always zero. The second offset tell us which member it will access.
llvm::SmallVector<mlir::LLVM::GEPArg, 2> offset{0, op.getIndex()};
rewriter.replaceOpWithNewOp<mlir::LLVM::GEPOp>(op, llResTy,
adaptor.getAddr(), offset);
return mlir::success();
}
case mlir::cir::StructType::Union:
// Union members share the address space, so we just need a bitcast to
// conform to type-checking.
rewriter.replaceOpWithNewOp<mlir::LLVM::BitcastOp>(op, llResTy,
adaptor.getAddr());
return mlir::success();
default:
return op.emitError()
<< "struct kind '" << structTy.getKind() << "' is NYI";
}
}
};

Expand Down Expand Up @@ -1772,7 +1791,8 @@ void populateCIRToLLVMConversionPatterns(mlir::RewritePatternSet &patterns,
}

namespace {
mlir::LLVMTypeConverter prepareTypeConverter(mlir::MLIRContext *ctx) {
mlir::LLVMTypeConverter prepareTypeConverter(mlir::MLIRContext *ctx,
mlir::DataLayout &dataLayout) {
mlir::LLVMTypeConverter converter(ctx);
converter.addConversion([&](mlir::cir::PointerType type) -> mlir::Type {
auto ty = converter.convertType(type.getPointee());
Expand Down Expand Up @@ -1802,9 +1822,24 @@ mlir::LLVMTypeConverter prepareTypeConverter(mlir::MLIRContext *ctx) {
return mlir::LLVM::LLVMFunctionType::get(result, arguments, varArg);
});
converter.addConversion([&](mlir::cir::StructType type) -> mlir::Type {
// FIXME(cir): create separate unions, struct, and classes types.
// Convert struct members.
llvm::SmallVector<mlir::Type> llvmMembers;
for (auto ty : type.getMembers())
llvmMembers.push_back(converter.convertType(ty));
switch (type.getKind()) {
case mlir::cir::StructType::Class:
// TODO(cir): This should be properly validated.
case mlir::cir::StructType::Struct:
for (auto ty : type.getMembers())
llvmMembers.push_back(converter.convertType(ty));
break;
// Unions are lowered as only the largest member.
case mlir::cir::StructType::Union: {
auto largestMember = type.getLargestMember(dataLayout);
if (largestMember)
llvmMembers.push_back(converter.convertType(largestMember));
break;
}
}

// Struct has a name: lower as an identified struct.
mlir::LLVM::LLVMStructType llvmStruct;
Expand All @@ -1831,8 +1866,8 @@ mlir::LLVMTypeConverter prepareTypeConverter(mlir::MLIRContext *ctx) {

void ConvertCIRToLLVMPass::runOnOperation() {
auto module = getOperation();

auto converter = prepareTypeConverter(&getContext());
mlir::DataLayout dataLayout(module);
auto converter = prepareTypeConverter(&getContext(), dataLayout);

mlir::RewritePatternSet patterns(&getContext());

Expand Down
42 changes: 42 additions & 0 deletions clang/test/CIR/Lowering/unions.cir
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// RUN: cir-opt %s -cir-to-llvm -o %t.mlir
// RUN: FileCheck --input-file=%t.mlir %s

!s16i = !cir.int<s, 16>
!s32i = !cir.int<s, 32>
#true = #cir.bool<true> : !cir.bool
!ty_22U122 = !cir.struct<union "U1" {!cir.bool, !s16i, !s32i} #cir.recdecl.ast>
!ty_22U222 = !cir.struct<union "U2" {f64, !ty_22U122} #cir.recdecl.ast>
!ty_22U322 = !cir.struct<union "U3" {!s16i, !ty_22U122} #cir.recdecl.ast>
module {
// Should lower union to struct with only the largest member.
cir.global external @u1 = #cir.zero : !ty_22U122
// CHECK: llvm.mlir.global external @u1() {addr_space = 0 : i32} : !llvm.struct<"union.U1", (i32)>

// Should recursively find the largest member if there are nested unions.
cir.global external @u2 = #cir.zero : !ty_22U222
cir.global external @u3 = #cir.zero : !ty_22U322
// CHECK: llvm.mlir.global external @u2() {addr_space = 0 : i32} : !llvm.struct<"union.U2", (f64)>
// CHECK: llvm.mlir.global external @u3() {addr_space = 0 : i32} : !llvm.struct<"union.U3", (i32)>

// CHECK: llvm.func @test
cir.func @test(%arg0: !cir.ptr<!ty_22U122>) {

// Should store directly to the union's base address.
%5 = cir.const(#true) : !cir.bool
%6 = cir.get_member %arg0[0] {name = "b"} : !cir.ptr<!ty_22U122> -> !cir.ptr<!cir.bool>
cir.store %5, %6 : !cir.bool, cir.ptr <!cir.bool>
// CHECK: %[[#VAL:]] = llvm.mlir.constant(1 : i8) : i8
// The bitcast it just to bypass the type checker. It will be replaced by an opaque pointer.
// CHECK: %[[#ADDR:]] = llvm.bitcast %{{.+}} : !llvm.ptr<struct<"union.U1", (i32)>> to !llvm.ptr<i8>
// CHECK: llvm.store %[[#VAL]], %[[#ADDR]] : !llvm.ptr<i8>

// Should load direclty from the union's base address.
%7 = cir.get_member %arg0[0] {name = "b"} : !cir.ptr<!ty_22U122> -> !cir.ptr<!cir.bool>
%8 = cir.load %7 : cir.ptr <!cir.bool>, !cir.bool
// The bitcast it just to bypass the type checker. It will be replaced by an opaque pointer.
// CHECK: %[[#BASE:]] = llvm.bitcast %{{.+}} : !llvm.ptr<struct<"union.U1", (i32)>> to !llvm.ptr<i8>
// CHECK: %{{.+}} = llvm.load %[[#BASE]] : !llvm.ptr<i8>

cir.return
}
}

0 comments on commit 57fa182

Please sign in to comment.