Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CIR][ABI][AArch64] Support struct passing with coercion through memory #1111

Merged
merged 5 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,7 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
// Block handling helpers
// ----------------------
//
OpBuilder::InsertPoint getBestAllocaInsertPoint(mlir::Block *block) {
static OpBuilder::InsertPoint getBestAllocaInsertPoint(mlir::Block *block) {
auto last =
std::find_if(block->rbegin(), block->rend(), [](mlir::Operation &op) {
return mlir::isa<cir::AllocaOp, cir::LabelOp>(&op);
Expand Down
119 changes: 76 additions & 43 deletions clang/lib/CIR/Dialect/Transforms/TargetLowering/LowerFunction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LogicalResult.h"
#include "clang/CIR/ABIArgInfo.h"
#include "clang/CIR/Dialect/Builder/CIRBaseBuilder.h"
#include "clang/CIR/Dialect/IR/CIRAttrs.h"
#include "clang/CIR/Dialect/IR/CIRDialect.h"
#include "clang/CIR/Dialect/IR/CIRTypes.h"
Expand Down Expand Up @@ -140,6 +141,76 @@ static mlir::Value coerceIntOrPtrToIntOrPtr(mlir::Value val, mlir::Type typ,
return val;
}

// FIXME(cir): Create a custom rewriter class to abstract this away.
mlir::Value createBitcast(mlir::Value Src, mlir::Type Ty, LowerFunction &LF) {
return LF.getRewriter().create<CastOp>(Src.getLoc(), Ty, CastKind::bitcast,
Src);
}
bcardosolopes marked this conversation as resolved.
Show resolved Hide resolved

AllocaOp createTmpAlloca(LowerFunction &LF, mlir::Location loc, mlir::Type ty) {
auto &rw = LF.getRewriter();
auto *ctxt = rw.getContext();
mlir::PatternRewriter::InsertionGuard guard(rw);

// find function's entry block and use it to find a best place for alloca
auto *blk = rw.getBlock();
auto *op = blk->getParentOp();
FuncOp fun = mlir::dyn_cast<FuncOp>(op);
if (!fun)
fun = op->getParentOfType<FuncOp>();
auto &entry = fun.getBody().front();

auto ip = CIRBaseBuilderTy::getBestAllocaInsertPoint(&entry);
rw.restoreInsertionPoint(ip);

auto align = LF.LM.getDataLayout().getABITypeAlign(ty);
auto alignAttr = rw.getI64IntegerAttr(align.value());
auto ptrTy = PointerType::get(ctxt, ty);
return rw.create<AllocaOp>(loc, ptrTy, ty, "tmp", alignAttr);
}

bool isVoidPtr(mlir::Value v) {
if (auto p = mlir::dyn_cast<PointerType>(v.getType()))
return mlir::isa<VoidType>(p.getPointee());
return false;
}

MemCpyOp createMemCpy(LowerFunction &LF, mlir::Value dst, mlir::Value src,
uint64_t len) {
cir_cconv_assert(mlir::isa<PointerType>(src.getType()));
cir_cconv_assert(mlir::isa<PointerType>(dst.getType()));

auto *ctxt = LF.getRewriter().getContext();
auto &rw = LF.getRewriter();
auto voidPtr = PointerType::get(ctxt, cir::VoidType::get(ctxt));

if (!isVoidPtr(src))
src = createBitcast(src, voidPtr, LF);
if (!isVoidPtr(dst))
dst = createBitcast(dst, voidPtr, LF);

auto i64Ty = IntType::get(ctxt, 64, false);
auto length = rw.create<ConstantOp>(src.getLoc(), IntAttr::get(i64Ty, len));
return rw.create<MemCpyOp>(src.getLoc(), dst, src, length);
}

cir::AllocaOp findAlloca(mlir::Operation *op) {
if (!op)
return {};

if (auto al = mlir::dyn_cast<cir::AllocaOp>(op)) {
return al;
} else if (auto ret = mlir::dyn_cast<cir::ReturnOp>(op)) {
auto vals = ret.getInput();
if (vals.size() == 1)
return findAlloca(vals[0].getDefiningOp());
} else if (auto load = mlir::dyn_cast<cir::LoadOp>(op)) {
return findAlloca(load.getAddr().getDefiningOp());
}

return {};
}

bcardosolopes marked this conversation as resolved.
Show resolved Hide resolved
/// Create a store to \param Dst from \param Src where the source and
/// destination may have different types.
///
Expand Down Expand Up @@ -187,16 +258,12 @@ void createCoercedStore(mlir::Value Src, mlir::Value Dst, bool DstIsVolatile,
auto addr = bld.create<CastOp>(Dst.getLoc(), ptrTy, CastKind::bitcast, Dst);
bld.create<StoreOp>(Dst.getLoc(), Src, addr);
} else {
cir_cconv_unreachable("NYI");
auto tmp = createTmpAlloca(CGF, Src.getLoc(), SrcTy);
CGF.getRewriter().create<StoreOp>(Src.getLoc(), Src, tmp);
createMemCpy(CGF, Dst, tmp, DstSize.getFixedValue());
}
}

// FIXME(cir): Create a custom rewriter class to abstract this away.
mlir::Value createBitcast(mlir::Value Src, mlir::Type Ty, LowerFunction &LF) {
return LF.getRewriter().create<CastOp>(Src.getLoc(), Ty, CastKind::bitcast,
Src);
}

/// Coerces a \param Src value to a value of type \param Ty.
///
/// This safely handles the case when the src type is smaller than the
Expand Down Expand Up @@ -261,23 +328,6 @@ mlir::Value emitAddressAtOffset(LowerFunction &LF, mlir::Value addr,
return addr;
}

cir::AllocaOp findAlloca(mlir::Operation *op) {
if (!op)
return {};

if (auto al = mlir::dyn_cast<cir::AllocaOp>(op)) {
return al;
} else if (auto ret = mlir::dyn_cast<cir::ReturnOp>(op)) {
auto vals = ret.getInput();
if (vals.size() == 1)
return findAlloca(vals[0].getDefiningOp());
} else if (auto load = mlir::dyn_cast<cir::LoadOp>(op)) {
return findAlloca(load.getAddr().getDefiningOp());
}

return {};
}

/// After the calling convention is lowered, an ABI-agnostic type might have to
/// be loaded back to its ABI-aware couterpart so it may be returned. If they
/// differ, we have to do a coerced load. A coerced load, which means to load a
Expand Down Expand Up @@ -329,25 +379,8 @@ mlir::Value castReturnValue(mlir::Value Src, mlir::Type Ty, LowerFunction &LF) {
// Otherwise do coercion through memory.
if (auto addr = findAlloca(Src.getDefiningOp())) {
auto &rewriter = LF.getRewriter();
auto *ctxt = LF.LM.getMLIRContext();
auto ptrTy = PointerType::get(ctxt, Ty);
auto voidPtr = PointerType::get(ctxt, cir::VoidType::get(ctxt));

// insert alloca near the previuos one
auto point = rewriter.saveInsertionPoint();
rewriter.setInsertionPointAfter(addr);
auto align = LF.LM.getDataLayout().getABITypeAlign(Ty);
auto alignAttr = rewriter.getI64IntegerAttr(align.value());
auto tmp =
rewriter.create<AllocaOp>(Src.getLoc(), ptrTy, Ty, "tmp", alignAttr);
rewriter.restoreInsertionPoint(point);

auto srcVoidPtr = createBitcast(addr, voidPtr, LF);
auto dstVoidPtr = createBitcast(tmp, voidPtr, LF);
auto i64Ty = IntType::get(ctxt, 64, false);
auto len = rewriter.create<ConstantOp>(
Src.getLoc(), IntAttr::get(i64Ty, SrcSize.getFixedValue()));
rewriter.create<MemCpyOp>(Src.getLoc(), dstVoidPtr, srcVoidPtr, len);
auto tmp = createTmpAlloca(LF, Src.getLoc(), Ty);
createMemCpy(LF, tmp, addr, SrcSize.getFixedValue());
bcardosolopes marked this conversation as resolved.
Show resolved Hide resolved
return rewriter.create<LoadOp>(Src.getLoc(), tmp.getResult());
}

Expand Down
16 changes: 16 additions & 0 deletions clang/test/CIR/CallConvLowering/AArch64/aarch64-cc-structs.c
Original file line number Diff line number Diff line change
Expand Up @@ -152,3 +152,19 @@ void pass_eq_128(EQ_128 s) {}
// LLVM: store ptr %0, ptr %[[#V1]], align 8
// LLVM: %[[#V2:]] = load ptr, ptr %[[#V1]], align 8
void pass_gt_128(GT_128 s) {}

// CHECK: cir.func @passS(%arg0: !cir.array<!u64i x 2>
// CHECK: %[[#V0:]] = cir.alloca !ty_S, !cir.ptr<!ty_S>, [""] {alignment = 4 : i64}
// CHECK: %[[#V1:]] = cir.alloca !cir.array<!u64i x 2>, !cir.ptr<!cir.array<!u64i x 2>>, ["tmp"] {alignment = 8 : i64}
// CHECK: cir.store %arg0, %[[#V1]] : !cir.array<!u64i x 2>, !cir.ptr<!cir.array<!u64i x 2>>
// CHECK: %[[#V2:]] = cir.cast(bitcast, %[[#V1]] : !cir.ptr<!cir.array<!u64i x 2>>), !cir.ptr<!void>
// CHECK: %[[#V3:]] = cir.cast(bitcast, %[[#V0]] : !cir.ptr<!ty_S>), !cir.ptr<!void>
// CHECK: %[[#V4:]] = cir.const #cir.int<12> : !u64i
// CHECK: cir.libc.memcpy %[[#V4]] bytes from %[[#V2]] to %[[#V3]] : !u64i, !cir.ptr<!void> -> !cir.ptr<!void>

// LLVM: void @passS([2 x i64] %[[#ARG:]])
// LLVM: %[[#V1:]] = alloca %struct.S, i64 1, align 4
// LLVM: %[[#V2:]] = alloca [2 x i64], i64 1, align 8
// LLVM: store [2 x i64] %[[#ARG]], ptr %[[#V2]], align 8
// LLVM: call void @llvm.memcpy.p0.p0.i64(ptr %[[#V1]], ptr %[[#V2]], i64 12, i1 false)
void passS(S s) {}