Skip to content

[CIR][CodeGen][LowerToLLVM] Set calling convention for call ops #836

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions clang/lib/CIR/CodeGen/CIRGenCall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ buildCallLikeOp(CIRGenFunction &CGF, mlir::Location callLoc,
mlir::cir::FuncType indirectFuncTy, mlir::Value indirectFuncVal,
mlir::cir::FuncOp directFuncOp,
SmallVectorImpl<mlir::Value> &CIRCallArgs,
mlir::Operation *InvokeDest,
mlir::Operation *InvokeDest, mlir::cir::CallingConv callingConv,
mlir::cir::ExtraFuncAttributesAttr extraFnAttrs) {
auto &builder = CGF.getBuilder();

Expand All @@ -468,6 +468,8 @@ buildCallLikeOp(CIRGenFunction &CGF, mlir::Location callLoc,
}

mlir::cir::CallOp tryCallOp;
// TODO(cir): Set calling convention for `cir.try_call`.
assert(callingConv == mlir::cir::CallingConv::C && "NYI");
if (indirectFuncTy) {
tryCallOp = builder.createIndirectTryCallOp(callLoc, indirectFuncVal,
indirectFuncTy, CIRCallArgs);
Expand All @@ -484,12 +486,15 @@ buildCallLikeOp(CIRGenFunction &CGF, mlir::Location callLoc,
}

assert(builder.getInsertionBlock() && "expected valid basic block");
if (indirectFuncTy)
if (indirectFuncTy) {
// TODO(cir): Set calling convention for indirect calls.
assert(callingConv == mlir::cir::CallingConv::C && "NYI");
return builder.createIndirectCallOp(
callLoc, indirectFuncVal, indirectFuncTy, CIRCallArgs,
mlir::cir::CallingConv::C, extraFnAttrs);
return builder.createCallOp(callLoc, directFuncOp, CIRCallArgs,
mlir::cir::CallingConv::C, extraFnAttrs);
}
return builder.createCallOp(callLoc, directFuncOp, CIRCallArgs, callingConv,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tested by updated spir-calling-conv.cl.

extraFnAttrs);
}

RValue CIRGenFunction::buildCall(const CIRGenFunctionInfo &CallInfo,
Expand Down Expand Up @@ -765,9 +770,9 @@ RValue CIRGenFunction::buildCall(const CIRGenFunctionInfo &CallInfo,
auto extraFnAttrs = mlir::cir::ExtraFuncAttributesAttr::get(
builder.getContext(), Attrs.getDictionary(builder.getContext()));

mlir::cir::CIRCallOpInterface callLikeOp =
buildCallLikeOp(*this, callLoc, indirectFuncTy, indirectFuncVal,
directFuncOp, CIRCallArgs, InvokeDest, extraFnAttrs);
mlir::cir::CIRCallOpInterface callLikeOp = buildCallLikeOp(
*this, callLoc, indirectFuncTy, indirectFuncVal, directFuncOp,
CIRCallArgs, InvokeDest, callingConv, extraFnAttrs);

if (E)
callLikeOp->setAttr(
Expand Down
6 changes: 6 additions & 0 deletions clang/lib/CIR/Dialect/IR/CIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2704,6 +2704,12 @@ verifyCallCommInSymbolUses(Operation *op, SymbolTableCollection &symbolTable) {
<< op->getOperand(i).getType() << " for operand number " << i;
}

// Calling convention must match.
if (callIf.getCallingConv() != fn.getCallingConv())
return op->emitOpError("calling convention mismatch: expected ")
<< stringifyCallingConv(fn.getCallingConv()) << ", but provided "
<< stringifyCallingConv(callIf.getCallingConv());

// Void function must not return any results.
if (fnType.isVoid() && op->getNumResults() != 0)
return op->emitOpError("callee returns void but call has results");
Expand Down
29 changes: 21 additions & 8 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -875,18 +875,24 @@ rewriteToCallOrInvoke(mlir::Operation *op, mlir::ValueRange callOperands,
mlir::Block *landingPadBlock = nullptr) {
llvm::SmallVector<mlir::Type, 8> llvmResults;
auto cirResults = op->getResultTypes();
auto callIf = cast<mlir::cir::CIRCallOpInterface>(op);

if (converter->convertTypes(cirResults, llvmResults).failed())
return mlir::failure();

auto cconv = convertCallingConv(callIf.getCallingConv());

if (calleeAttr) { // direct call
if (landingPadBlock)
rewriter.replaceOpWithNewOp<mlir::LLVM::InvokeOp>(
if (landingPadBlock) {
auto newOp = rewriter.replaceOpWithNewOp<mlir::LLVM::InvokeOp>(
op, llvmResults, calleeAttr, callOperands, continueBlock,
mlir::ValueRange{}, landingPadBlock, mlir::ValueRange{});
else
rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>(op, llvmResults,
calleeAttr, callOperands);
newOp.setCConv(cconv);
Comment on lines +887 to +890
Copy link
Collaborator Author

@seven-mile seven-mile Sep 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have an idea on how to avoid 4 dups of saving newOp and setting CConv. Different from CIRCallOpInterface, LLVM uses CallOpInterface from MLIR std, which is not aware of calling convention.

If we don't consider hacky op->setAttr('calling_conv', ...), it seems we have to do a dynamic dispatch of at least size 2 (invoke + call) after rewriting. I chose to just leave it clear. It might be more appropriate to refactor this part after we have more similar logic like setCConv.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wouldn't worry too much about this. Since we don't need the result for anything else (yet), my suggestion would be:

rewriter.replaceOpWithNewOp<mlir::LLVM::InvokeOp>(
          op, llvmResults, calleeAttr, callOperands, continueBlock,
          mlir::ValueRange{}, landingPadBlock, mlir::ValueRange{}).setCConv(cconv);

} else {
auto newOp = rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>(
op, llvmResults, calleeAttr, callOperands);
newOp.setCConv(cconv);
}
} else { // indirect call
assert(op->getOperands().size() &&
"operands list must no be empty for the indirect call");
Expand All @@ -899,14 +905,17 @@ rewriteToCallOrInvoke(mlir::Operation *op, mlir::ValueRange callOperands,
if (landingPadBlock) {
auto llvmFnTy =
dyn_cast<mlir::LLVM::LLVMFunctionType>(converter->convertType(ftyp));
rewriter.replaceOpWithNewOp<mlir::LLVM::InvokeOp>(
auto newOp = rewriter.replaceOpWithNewOp<mlir::LLVM::InvokeOp>(
op, llvmFnTy, mlir::FlatSymbolRefAttr{}, callOperands, continueBlock,
mlir::ValueRange{}, landingPadBlock, mlir::ValueRange{});
} else
rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>(
newOp.setCConv(cconv);
} else {
auto newOp = rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>(
op,
dyn_cast<mlir::LLVM::LLVMFunctionType>(converter->convertType(ftyp)),
callOperands);
newOp.setCConv(cconv);
}
}
return mlir::success();
}
Expand All @@ -932,6 +941,10 @@ class CIRTryCallLowering
mlir::LogicalResult
matchAndRewrite(mlir::cir::TryCallOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
if (op.getCallingConv() != mlir::cir::CallingConv::C) {
return op.emitError(
"non-C calling convention is not implemented for try_call");
}
return rewriteToCallOrInvoke(
op.getOperation(), adaptor.getOperands(), rewriter, getTypeConverter(),
op.getCalleeAttr(), op.getCont(), op.getLandingPad());
Expand Down
4 changes: 4 additions & 0 deletions clang/test/CIR/CodeGen/OpenCL/spir-calling-conv.cl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ kernel void bar(global int *A);
// LLVM-DAG: define{{.*}} spir_kernel void @foo(
kernel void foo(global int *A) {
int id = get_dummy_id(0);
// CIR: %{{[0-9]+}} = cir.call @get_dummy_id(%2) : (!s32i) -> !s32i cc(spir_function)
// LLVM: %{{[a-z0-9_]+}} = call spir_func i32 @get_dummy_id(
A[id] = id;
bar(A);
// CIR: cir.call @bar(%8) : (!cir.ptr<!s32i, addrspace(offload_global)>) -> () cc(spir_kernel)
// LLVM: call spir_kernel void @bar(ptr addrspace(1)
}
15 changes: 15 additions & 0 deletions clang/test/CIR/IR/invalid.cir
Original file line number Diff line number Diff line change
Expand Up @@ -1304,7 +1304,22 @@ module {
!s32i = !cir.int<s, 32>

module {
cir.func @subroutine() cc(spir_function) {
cir.return
}

cir.func @call_conv_match() {
// expected-error@+1 {{'cir.call' op calling convention mismatch: expected spir_function, but provided spir_kernel}}
cir.call @subroutine(): () -> !cir.void cc(spir_kernel)
cir.return
}
}

// -----

!s32i = !cir.int<s, 32>

module {
cir.func @test_bitcast_addrspace() {
%0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["tmp"] {alignment = 4 : i64}
// expected-error@+1 {{'cir.cast' op result type address space does not match the address space of the operand}}
Expand Down
22 changes: 22 additions & 0 deletions clang/test/CIR/Lowering/call-op-call-conv.cir
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// RUN: cir-translate -cir-to-llvmir %s -o %t.ll
// RUN: FileCheck --input-file=%t.ll %s --check-prefix=LLVM

!s32i = !cir.int<s, 32>
!fnptr = !cir.ptr<!cir.func<!s32i (!s32i)>>

module {
cir.func private @my_add(%a: !s32i, %b: !s32i) -> !s32i cc(spir_function)

cir.func @ind(%fnptr: !fnptr, %a : !s32i) {
%1 = cir.call %fnptr(%a) : (!fnptr, !s32i) -> !s32i cc(spir_kernel)
// LLVM: %{{[0-9]+}} = call spir_kernel i32 %{{[0-9]+}}(i32 %{{[0-9]+}})

%2 = cir.call %fnptr(%a) : (!fnptr, !s32i) -> !s32i cc(spir_function)
// LLVM: %{{[0-9]+}} = call spir_func i32 %{{[0-9]+}}(i32 %{{[0-9]+}})

%3 = cir.call @my_add(%1, %2) : (!s32i, !s32i) -> !s32i cc(spir_function)
// LLVM: %{{[0-9]+}} = call spir_func i32 @my_add(i32 %{{[0-9]+}}, i32 %{{[0-9]+}})

cir.return
}
}