diff --git a/clang/include/clang/CIR/MissingFeatures.h b/clang/include/clang/CIR/MissingFeatures.h index b25c5e07f9f0..fbcc3cadb855 100644 --- a/clang/include/clang/CIR/MissingFeatures.h +++ b/clang/include/clang/CIR/MissingFeatures.h @@ -261,7 +261,6 @@ struct MissingFeatures { static bool X86TypeClassification() { return false; } static bool ABIClangTypeKind() { return false; } - static bool ABIEnterStructForCoercedAccess() { return false; } static bool ABIFuncPtr() { return false; } static bool ABIInRegAttribute() { return false; } static bool ABINestedRecordLayout() { return false; } diff --git a/clang/lib/CIR/Dialect/IR/CIRTypes.cpp b/clang/lib/CIR/Dialect/IR/CIRTypes.cpp index e579fe4c2f0c..2e262478a733 100644 --- a/clang/lib/CIR/Dialect/IR/CIRTypes.cpp +++ b/clang/lib/CIR/Dialect/IR/CIRTypes.cpp @@ -519,13 +519,12 @@ void StructType::computeSizeAndAlignment( // Found a nested union: recurse into it to fetch its largest member. auto structMember = mlir::dyn_cast(ty); - if (structMember && structMember.isUnion()) { - auto candidate = structMember.getLargestMember(dataLayout); - if (dataLayout.getTypeSize(candidate) > largestMemberSize) { - largestMember = candidate; - largestMemberSize = dataLayout.getTypeSize(largestMember); - } - } else if (dataLayout.getTypeSize(ty) > largestMemberSize) { + if (!largestMember || + dataLayout.getTypeABIAlignment(ty) > + dataLayout.getTypeABIAlignment(largestMember) || + (dataLayout.getTypeABIAlignment(ty) == + dataLayout.getTypeABIAlignment(largestMember) && + dataLayout.getTypeSize(ty) > largestMemberSize)) { largestMember = ty; largestMemberSize = dataLayout.getTypeSize(largestMember); } diff --git a/clang/lib/CIR/Dialect/Transforms/TargetLowering/LowerFunction.cpp b/clang/lib/CIR/Dialect/Transforms/TargetLowering/LowerFunction.cpp index cf2fdda5b483..704242a73b8c 100644 --- a/clang/lib/CIR/Dialect/Transforms/TargetLowering/LowerFunction.cpp +++ b/clang/lib/CIR/Dialect/Transforms/TargetLowering/LowerFunction.cpp @@ -53,6 +53,12 @@ mlir::Value createCoercedBitcast(mlir::Value Src, mlir::Type DestTy, CastKind::bitcast, Src); } +// FIXME(cir): Create a custom rewriter class to abstract this away. +mlir::Value createBitcast(mlir::Value Src, mlir::Type Ty, LowerFunction &LF) { + return LF.getRewriter().create(Src.getLoc(), Ty, CastKind::bitcast, + Src); +} + /// Given a struct pointer that we are accessing some number of bytes out of it, /// try to gep into the struct to get at its inner goodness. Dive as deep as /// possible without entering an element with an in-memory size smaller than @@ -67,6 +73,9 @@ mlir::Value enterStructPointerForCoercedAccess(mlir::Value SrcPtr, mlir::Type FirstElt = SrcSTy.getMembers()[0]; + if (SrcSTy.isUnion()) + FirstElt = SrcSTy.getLargestMember(CGF.LM.getDataLayout().layout); + // If the first elt is at least as large as what we're looking for, or if the // first element is the same size as the whole struct, we can enter it. The // comparison must be made on the store size and not the alloca size. Using @@ -76,10 +85,26 @@ mlir::Value enterStructPointerForCoercedAccess(mlir::Value SrcPtr, FirstEltSize < CGF.LM.getDataLayout().getTypeStoreSize(SrcSTy)) return SrcPtr; - cir_cconv_assert_or_abort( - !cir::MissingFeatures::ABIEnterStructForCoercedAccess(), "NYI"); - return SrcPtr; // FIXME: This is a temporary workaround for the assertion - // above. + auto &rw = CGF.getRewriter(); + auto *ctxt = rw.getContext(); + auto ptrTy = PointerType::get(ctxt, FirstElt); + if (mlir::isa(SrcPtr.getType())) { + auto addr = SrcPtr; + if (auto load = mlir::dyn_cast(SrcPtr.getDefiningOp())) + addr = load.getAddr(); + cir_cconv_assert(mlir::isa(addr.getType())); + // we can not use getMemberOp here since we need a pointer to the first + // element. And in the case of unions we pick a type of the largest elt, + // that may or may not be the first one. Thus, getMemberOp verification + // may fail. + auto cast = createBitcast(addr, ptrTy, CGF); + SrcPtr = rw.create(SrcPtr.getLoc(), cast); + } + + if (auto sty = mlir::dyn_cast(SrcPtr.getType())) + return enterStructPointerForCoercedAccess(SrcPtr, sty, DstSize, CGF); + + return SrcPtr; } /// Convert a value Val to the specific Ty where both @@ -141,12 +166,6 @@ static mlir::Value coerceIntOrPtrToIntOrPtr(mlir::Value val, mlir::Type typ, return val; } -// FIXME(cir): Create a custom rewriter class to abstract this away. -mlir::Value createBitcast(mlir::Value Src, mlir::Type Ty, LowerFunction &LF) { - return LF.getRewriter().create(Src.getLoc(), Ty, CastKind::bitcast, - Src); -} - AllocaOp createTmpAlloca(LowerFunction &LF, mlir::Location loc, mlir::Type ty) { auto &rw = LF.getRewriter(); auto *ctxt = rw.getContext(); @@ -302,7 +321,7 @@ mlir::Value createCoercedValue(mlir::Value Src, mlir::Type Ty, // extension or truncation to the desired type. if ((mlir::isa(Ty) || mlir::isa(Ty)) && (mlir::isa(SrcTy) || mlir::isa(SrcTy))) { - cir_cconv_unreachable("NYI"); + return coerceIntOrPtrToIntOrPtr(Src, Ty, CGF); } // If load is legal, just bitcast the src pointer. diff --git a/clang/test/CIR/CallConvLowering/AArch64/union.c b/clang/test/CIR/CallConvLowering/AArch64/union.c index 4f622f0215c3..ed02e9aded7a 100644 --- a/clang/test/CIR/CallConvLowering/AArch64/union.c +++ b/clang/test/CIR/CallConvLowering/AArch64/union.c @@ -38,4 +38,34 @@ void foo(U u) {} U init() { U u; return u; -} \ No newline at end of file +} + +typedef union { + + struct { + short a; + char b; + char c; + }; + + int x; +} A; + +void passA(A x) {} + +// CIR: cir.func {{.*@callA}}() +// CIR: %[[#V0:]] = cir.alloca !ty_A, !cir.ptr, ["x"] {alignment = 4 : i64} +// CIR: %[[#V1:]] = cir.cast(bitcast, %[[#V0:]] : !cir.ptr), !cir.ptr +// CIR: %[[#V2:]] = cir.load %[[#V1]] : !cir.ptr, !s32i +// CIR: %[[#V3:]] = cir.cast(integral, %[[#V2]] : !s32i), !u64i +// CIR: cir.call @passA(%[[#V3]]) : (!u64i) -> () + +// LLVM: void @callA() +// LLVM: %[[#V0:]] = alloca %union.A, i64 1, align 4 +// LLVM: %[[#V1:]] = load i32, ptr %[[#V0]], align 4 +// LLVM: %[[#V2:]] = sext i32 %[[#V1]] to i64 +// LLVM: call void @passA(i64 %[[#V2]]) +void callA() { + A x; + passA(x); +} diff --git a/clang/test/CIR/Lowering/unions.cir b/clang/test/CIR/Lowering/unions.cir index 0cc9d1d15749..fe56e2af7527 100644 --- a/clang/test/CIR/Lowering/unions.cir +++ b/clang/test/CIR/Lowering/unions.cir @@ -16,7 +16,7 @@ module { cir.global external @u2 = #cir.zero : !ty_U2_ cir.global external @u3 = #cir.zero : !ty_U3_ // CHECK: llvm.mlir.global external @u2() {addr_space = 0 : i32} : !llvm.struct<"union.U2", (f64)> - // CHECK: llvm.mlir.global external @u3() {addr_space = 0 : i32} : !llvm.struct<"union.U3", (i32)> + // CHECK: llvm.mlir.global external @u3() {addr_space = 0 : i32} : !llvm.struct<"union.U3", (struct<"union.U1", (i32)>)> // CHECK: llvm.func @test cir.func @test(%arg0: !cir.ptr) {