Skip to content

Commit

Permalink
[CIR] Add FuncAttrs to cir.calls (llvm#637)
Browse files Browse the repository at this point in the history
Some function attributes are also callsite attributes, for instance,
nothrow. This means they are going to show up in both. We don't support
that just yet, hence the PR.

CIR has an attribute `ExtraFuncAttr` that we current use as part of
`FuncOp`, see CIROps.td. This attribute also needs to be added to
`CallOp` and `TryCalOp`.

Right now, In `CIRGenCall.cpp: AddAttributesFromFunctionProtoType` fills
in `FuncAttrs`, but doesn't use it for anything. We should use the
`FuncAttrs` result to populate constructing a `ExtraFuncAttr` and add it
to the aforementioned call operations.
  • Loading branch information
roro47 authored and lanza committed Oct 12, 2024
1 parent 9cd8efb commit 3ec2afd
Show file tree
Hide file tree
Showing 10 changed files with 174 additions and 60 deletions.
51 changes: 51 additions & 0 deletions clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,57 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
mlir::ValueRange value = {}) {
return create<mlir::cir::YieldOp>(loc, value);
}

mlir::cir::CallOp
createCallOp(mlir::Location loc,
mlir::SymbolRefAttr callee = mlir::SymbolRefAttr(),
mlir::Type returnType = mlir::cir::VoidType(),
mlir::ValueRange operands = mlir::ValueRange(),
mlir::cir::ExtraFuncAttributesAttr extraFnAttr = {}) {

mlir::cir::CallOp callOp =
create<mlir::cir::CallOp>(loc, callee, returnType, operands);

if (extraFnAttr) {
callOp->setAttr("extra_attrs", extraFnAttr);
} else {
mlir::NamedAttrList empty;
callOp->setAttr("extra_attrs",
mlir::cir::ExtraFuncAttributesAttr::get(
getContext(), empty.getDictionary(getContext())));
}
return callOp;
}

mlir::cir::CallOp
createCallOp(mlir::Location loc, mlir::cir::FuncOp callee,
mlir::ValueRange operands = mlir::ValueRange(),
mlir::cir::ExtraFuncAttributesAttr extraFnAttr = {}) {
return createCallOp(loc, mlir::SymbolRefAttr::get(callee),
callee.getFunctionType().getReturnType(), operands,
extraFnAttr);
}

mlir::cir::CallOp
createIndirectCallOp(mlir::Location loc, mlir::Value ind_target,
mlir::cir::FuncType fn_type,
mlir::ValueRange operands = mlir::ValueRange(),
mlir::cir::ExtraFuncAttributesAttr extraFnAttr = {}) {

llvm::SmallVector<mlir::Value, 4> resOperands({ind_target});
resOperands.append(operands.begin(), operands.end());

return createCallOp(loc, mlir::SymbolRefAttr(), fn_type.getReturnType(),
resOperands, extraFnAttr);
}

mlir::cir::CallOp
createCallOp(mlir::Location loc, mlir::SymbolRefAttr callee,
mlir::ValueRange operands = mlir::ValueRange(),
mlir::cir::ExtraFuncAttributesAttr extraFnAttr = {}) {
return createCallOp(loc, callee, mlir::cir::VoidType(), operands,
extraFnAttr);
}
};

} // namespace cir
Expand Down
25 changes: 9 additions & 16 deletions clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2787,6 +2787,7 @@ class CIR_CallOp<string mnemonic, list<Trait> extra_traits = []> :
dag commonArgs = (ins
OptionalAttr<FlatSymbolRefAttr>:$callee,
Variadic<CIR_AnyType>:$arg_ops,
ExtraFuncAttr:$extra_attrs,
OptionalAttr<ASTCallExprInterface>:$ast
);
}
Expand Down Expand Up @@ -2822,12 +2823,16 @@ def CallOp : CIR_CallOp<"call"> {
let arguments = commonArgs;
let results = (outs Optional<CIR_AnyType>:$result);

let skipDefaultBuilders = 1;

let builders = [
OpBuilder<(ins "FuncOp":$callee, CArg<"ValueRange", "{}">:$operands), [{
OpBuilder<(ins "SymbolRefAttr":$callee, "mlir::Type":$resType,
CArg<"ValueRange", "{}">:$operands), [{
$_state.addOperands(operands);
$_state.addAttribute("callee", SymbolRefAttr::get(callee));
if (!callee.getFunctionType().isVoid())
$_state.addTypes(callee.getFunctionType().getReturnType());
if (callee)
$_state.addAttribute("callee", callee);
if (resType && !resType.isa<VoidType>())
$_state.addTypes(resType);
}]>,
OpBuilder<(ins "Value":$ind_target,
"FuncType":$fn_type,
Expand All @@ -2836,18 +2841,6 @@ def CallOp : CIR_CallOp<"call"> {
$_state.addOperands(operands);
if (!fn_type.isVoid())
$_state.addTypes(fn_type.getReturnType());
}]>,
OpBuilder<(ins "SymbolRefAttr":$callee, "mlir::Type":$resType,
CArg<"ValueRange", "{}">:$operands), [{
$_state.addOperands(operands);
$_state.addAttribute("callee", callee);
if (resType && !resType.isa<VoidType>())
$_state.addTypes(resType);
}]>,
OpBuilder<(ins "SymbolRefAttr":$callee,
CArg<"ValueRange", "{}">:$operands), [{
$_state.addOperands(operands);
$_state.addAttribute("callee", callee);
}]>
];
}
Expand Down
32 changes: 20 additions & 12 deletions clang/lib/CIR/CodeGen/CIRGenCall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -447,22 +447,29 @@ static mlir::cir::CIRCallOpInterface
buildCallLikeOp(CIRGenFunction &CGF, mlir::Location callLoc,
mlir::cir::FuncType indirectFuncTy, mlir::Value indirectFuncVal,
mlir::cir::FuncOp directFuncOp,
SmallVectorImpl<mlir::Value> &CIRCallArgs, bool InvokeDest) {
SmallVectorImpl<mlir::Value> &CIRCallArgs, bool InvokeDest,
mlir::cir::ExtraFuncAttributesAttr extraFnAttrs) {
auto &builder = CGF.getBuilder();

if (InvokeDest) {
auto addr = CGF.currLexScope->getExceptionInfo().addr;
if (indirectFuncTy)
return builder.create<mlir::cir::TryCallOp>(

mlir::cir::TryCallOp tryCallOp;
if (indirectFuncTy) {
tryCallOp = builder.create<mlir::cir::TryCallOp>(
callLoc, addr, indirectFuncVal, indirectFuncTy, CIRCallArgs);
return builder.create<mlir::cir::TryCallOp>(callLoc, directFuncOp, addr,
CIRCallArgs);
} else {
tryCallOp = builder.create<mlir::cir::TryCallOp>(callLoc, directFuncOp,
addr, CIRCallArgs);
}
tryCallOp->setAttr("extra_attrs", extraFnAttrs);
return tryCallOp;
}

if (indirectFuncTy)
return builder.create<mlir::cir::CallOp>(callLoc, indirectFuncVal,
indirectFuncTy, CIRCallArgs);
return builder.create<mlir::cir::CallOp>(callLoc, directFuncOp, CIRCallArgs);
return builder.createIndirectCallOp(
callLoc, indirectFuncVal, indirectFuncTy, CIRCallArgs, extraFnAttrs);
return builder.createCallOp(callLoc, directFuncOp, CIRCallArgs, extraFnAttrs);
}

RValue CIRGenFunction::buildCall(const CIRGenFunctionInfo &CallInfo,
Expand Down Expand Up @@ -735,9 +742,10 @@ RValue CIRGenFunction::buildCall(const CIRGenFunctionInfo &CallInfo,
indirectFuncVal = CalleePtr->getResult(0);
}

mlir::cir::CIRCallOpInterface callLikeOp =
buildCallLikeOp(*this, callLoc, indirectFuncTy, indirectFuncVal,
directFuncOp, CIRCallArgs, InvokeDest);
mlir::cir::CIRCallOpInterface callLikeOp = buildCallLikeOp(
*this, callLoc, indirectFuncTy, indirectFuncVal, directFuncOp,
CIRCallArgs, InvokeDest,
mlir::cir::ExtraFuncAttributesAttr::get(builder.getContext(), Attrs));

if (E)
callLikeOp->setAttr(
Expand Down Expand Up @@ -844,7 +852,7 @@ mlir::Value CIRGenFunction::buildRuntimeCall(mlir::Location loc,
// TODO(cir): set the calling convention to this runtime call.
assert(!MissingFeatures::setCallingConv());

auto call = builder.create<mlir::cir::CallOp>(loc, callee, args);
auto call = builder.createCallOp(loc, callee, args);
assert(call->getNumResults() <= 1 &&
"runtime functions have at most 1 result");

Expand Down
13 changes: 6 additions & 7 deletions clang/lib/CIR/CodeGen/CIRGenCoroutine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,10 +179,9 @@ mlir::cir::CallOp CIRGenFunction::buildCoroIDBuiltinCall(mlir::Location loc,
} else
fnOp = cast<mlir::cir::FuncOp>(builtin);

return builder.create<mlir::cir::CallOp>(
loc, fnOp,
mlir::ValueRange{builder.getUInt32(NewAlign, loc), nullPtr, nullPtr,
nullPtr});
return builder.createCallOp(loc, fnOp,
mlir::ValueRange{builder.getUInt32(NewAlign, loc),
nullPtr, nullPtr, nullPtr});
}

mlir::cir::CallOp
Expand All @@ -202,7 +201,7 @@ CIRGenFunction::buildCoroAllocBuiltinCall(mlir::Location loc) {
} else
fnOp = cast<mlir::cir::FuncOp>(builtin);

return builder.create<mlir::cir::CallOp>(
return builder.createCallOp(
loc, fnOp, mlir::ValueRange{CurCoro.Data->CoroId.getResult()});
}

Expand All @@ -223,7 +222,7 @@ CIRGenFunction::buildCoroBeginBuiltinCall(mlir::Location loc,
} else
fnOp = cast<mlir::cir::FuncOp>(builtin);

return builder.create<mlir::cir::CallOp>(
return builder.createCallOp(
loc, fnOp,
mlir::ValueRange{CurCoro.Data->CoroId.getResult(), coroframeAddr});
}
Expand All @@ -244,7 +243,7 @@ mlir::cir::CallOp CIRGenFunction::buildCoroEndBuiltinCall(mlir::Location loc,
} else
fnOp = cast<mlir::cir::FuncOp>(builtin);

return builder.create<mlir::cir::CallOp>(
return builder.createCallOp(
loc, fnOp, mlir::ValueRange{nullPtr, builder.getBool(false, loc)});
}

Expand Down
4 changes: 2 additions & 2 deletions clang/lib/CIR/CodeGen/CIRGenModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1718,8 +1718,8 @@ void CIRGenModule::ReplaceUsesOfNonProtoTypeWithRealFunction(
builder.setInsertionPoint(noProtoCallOp);

// Patch call type with the real function type.
auto realCallOp = builder.create<mlir::cir::CallOp>(
noProtoCallOp.getLoc(), NewFn, noProtoCallOp.getOperands());
auto realCallOp = builder.createCallOp(noProtoCallOp.getLoc(), NewFn,
noProtoCallOp.getOperands());

// Replace old no proto call with fixed call.
noProtoCallOp.replaceAllUsesWith(realCallOp);
Expand Down
39 changes: 34 additions & 5 deletions clang/lib/CIR/Dialect/IR/CIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2346,6 +2346,7 @@ verifyCallCommInSymbolUses(Operation *op, SymbolTableCollection &symbolTable) {

static ::mlir::ParseResult parseCallCommon(
::mlir::OpAsmParser &parser, ::mlir::OperationState &result,
llvm::StringRef extraAttrsAttrName,
llvm::function_ref<::mlir::ParseResult(::mlir::OpAsmParser &,
::mlir::OperationState &)>
customOpHandler =
Expand Down Expand Up @@ -2380,6 +2381,23 @@ static ::mlir::ParseResult parseCallCommon(
return ::mlir::failure();
if (parser.parseRParen())
return ::mlir::failure();

auto &builder = parser.getBuilder();
Attribute extraAttrs;
if (::mlir::succeeded(parser.parseOptionalKeyword("extra"))) {
if (parser.parseLParen().failed())
return failure();
if (parser.parseAttribute(extraAttrs).failed())
return failure();
if (parser.parseRParen().failed())
return failure();
} else {
NamedAttrList empty;
extraAttrs = mlir::cir::ExtraFuncAttributesAttr::get(
builder.getContext(), empty.getDictionary(builder.getContext()));
}
result.addAttribute(extraAttrsAttrName, extraAttrs);

if (parser.parseOptionalAttrDict(result.attributes))
return ::mlir::failure();
if (parser.parseColon())
Expand All @@ -2400,6 +2418,7 @@ static ::mlir::ParseResult parseCallCommon(
void printCallCommon(
Operation *op, mlir::Value indirectCallee, mlir::FlatSymbolRefAttr flatSym,
::mlir::OpAsmPrinter &state,
::mlir::cir::ExtraFuncAttributesAttr extraAttrs,
llvm::function_ref<void()> customOpHandler = []() {}) {
state << ' ';

Expand All @@ -2415,13 +2434,20 @@ void printCallCommon(
state << "(";
state << ops;
state << ")";
llvm::SmallVector<::llvm::StringRef, 2> elidedAttrs;

llvm::SmallVector<::llvm::StringRef, 4> elidedAttrs;
elidedAttrs.push_back("callee");
elidedAttrs.push_back("ast");
elidedAttrs.push_back("extra_attrs");
state.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
state << ' ' << ":";
state << ' ';
state.printFunctionalType(op->getOperands().getTypes(), op->getResultTypes());
if (!extraAttrs.getElements().empty()) {
state << " extra(";
state.printAttributeWithoutType(extraAttrs);
state << ")";
}
}

LogicalResult
Expand All @@ -2431,12 +2457,14 @@ cir::CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {

::mlir::ParseResult CallOp::parse(::mlir::OpAsmParser &parser,
::mlir::OperationState &result) {
return parseCallCommon(parser, result);

return parseCallCommon(parser, result, getExtraAttrsAttrName(result.name));
}

void CallOp::print(::mlir::OpAsmPrinter &state) {
mlir::Value indirectCallee = isIndirect() ? getIndirectCall() : nullptr;
printCallCommon(*this, indirectCallee, getCalleeAttr(), state);
printCallCommon(*this, indirectCallee, getCalleeAttr(), state,
getExtraAttrs());
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -2493,7 +2521,7 @@ LogicalResult cir::TryCallOp::verify() { return mlir::success(); }
::mlir::ParseResult TryCallOp::parse(::mlir::OpAsmParser &parser,
::mlir::OperationState &result) {
return parseCallCommon(
parser, result,
parser, result, getExtraAttrsAttrName(result.name),
[](::mlir::OpAsmParser &parser,
::mlir::OperationState &result) -> ::mlir::ParseResult {
::mlir::OpAsmParser::UnresolvedOperand exceptionRawOperands[1];
Expand Down Expand Up @@ -2535,7 +2563,8 @@ void TryCallOp::print(::mlir::OpAsmPrinter &state) {
state << getExceptionInfo();
state << ")";
mlir::Value indirectCallee = isIndirect() ? getIndirectCall() : nullptr;
printCallCommon(*this, indirectCallee, getCalleeAttr(), state);
printCallCommon(*this, indirectCallee, getCalleeAttr(), state,
getExtraAttrs());
}

//===----------------------------------------------------------------------===//
Expand Down
25 changes: 12 additions & 13 deletions clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ FuncOp LoweringPreparePass::buildCXXGlobalVarDeclInitFunc(GlobalOp op) {
}

// Create a variable initialization function.
mlir::OpBuilder builder(&getContext());
CIRBaseBuilderTy builder(getContext());
builder.setInsertionPointAfter(op);
auto voidTy = ::mlir::cir::VoidType::get(builder.getContext());
auto fnType = mlir::cir::FuncType::get({}, voidTy);
Expand Down Expand Up @@ -264,7 +264,7 @@ FuncOp LoweringPreparePass::buildCXXGlobalVarDeclInitFunc(GlobalOp op) {
dtorCall.getArgOperand(0));
args[2] = builder.create<mlir::cir::GetGlobalOp>(
Handle.getLoc(), HandlePtrTy, Handle.getSymName());
builder.create<mlir::cir::CallOp>(dtorCall.getLoc(), fnAtExit, args);
builder.createCallOp(dtorCall.getLoc(), fnAtExit, args);
dtorCall->erase();
entryBB->getOperations().splice(entryBB->end(), dtorBlock.getOperations(),
dtorBlock.begin(),
Expand Down Expand Up @@ -481,7 +481,7 @@ void LoweringPreparePass::buildCXXGlobalInitFunc() {
fnName += getTransformedFileName(theModule);
}

mlir::OpBuilder builder(&getContext());
CIRBaseBuilderTy builder(getContext());
builder.setInsertionPointToEnd(&theModule.getBodyRegion().back());
auto fnType = mlir::cir::FuncType::get(
{}, mlir::cir::VoidType::get(builder.getContext()));
Expand All @@ -490,7 +490,7 @@ void LoweringPreparePass::buildCXXGlobalInitFunc() {
mlir::cir::GlobalLinkageKind::ExternalLinkage);
builder.setInsertionPointToStart(f.addEntryBlock());
for (auto &f : dynamicInitializers) {
builder.create<mlir::cir::CallOp>(f.getLoc(), f);
builder.createCallOp(f.getLoc(), f);
}

builder.create<ReturnOp>(f.getLoc());
Expand Down Expand Up @@ -597,7 +597,7 @@ void LoweringPreparePass::lowerArrayCtor(ArrayCtor op) {
void LoweringPreparePass::lowerStdFindOp(StdFindOp op) {
CIRBaseBuilderTy builder(getContext());
builder.setInsertionPointAfter(op.getOperation());
auto call = builder.create<mlir::cir::CallOp>(
auto call = builder.createCallOp(
op.getLoc(), op.getOriginalFnAttr(), op.getResult().getType(),
mlir::ValueRange{op.getOperand(0), op.getOperand(1), op.getOperand(2)});

Expand All @@ -608,9 +608,9 @@ void LoweringPreparePass::lowerStdFindOp(StdFindOp op) {
void LoweringPreparePass::lowerIterBeginOp(IterBeginOp op) {
CIRBaseBuilderTy builder(getContext());
builder.setInsertionPointAfter(op.getOperation());
auto call = builder.create<mlir::cir::CallOp>(
op.getLoc(), op.getOriginalFnAttr(), op.getResult().getType(),
mlir::ValueRange{op.getOperand()});
auto call = builder.createCallOp(op.getLoc(), op.getOriginalFnAttr(),
op.getResult().getType(),
mlir::ValueRange{op.getOperand()});

op.replaceAllUsesWith(call);
op.erase();
Expand All @@ -619,9 +619,9 @@ void LoweringPreparePass::lowerIterBeginOp(IterBeginOp op) {
void LoweringPreparePass::lowerIterEndOp(IterEndOp op) {
CIRBaseBuilderTy builder(getContext());
builder.setInsertionPointAfter(op.getOperation());
auto call = builder.create<mlir::cir::CallOp>(
op.getLoc(), op.getOriginalFnAttr(), op.getResult().getType(),
mlir::ValueRange{op.getOperand()});
auto call = builder.createCallOp(op.getLoc(), op.getOriginalFnAttr(),
op.getResult().getType(),
mlir::ValueRange{op.getOperand()});

op.replaceAllUsesWith(call);
op.erase();
Expand Down Expand Up @@ -712,8 +712,7 @@ void LoweringPreparePass::runOnMathOp(Operation *op) {
buildRuntimeFunction(builder, rtFuncName, op->getLoc(), rtFuncTy);

builder.setInsertionPointAfter(op);
auto call = builder.create<mlir::cir::CallOp>(op->getLoc(), rtFunc,
op->getOperands());
auto call = builder.createCallOp(op->getLoc(), rtFunc, op->getOperands());

op->replaceAllUsesWith(call);
op->erase();
Expand Down
Loading

0 comments on commit 3ec2afd

Please sign in to comment.