Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CIR] Add FuncAttrs to cir.calls #637

Merged
merged 2 commits into from
Jun 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,57 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
mlir::ValueRange value = {}) {
return create<mlir::cir::YieldOp>(loc, value);
}

mlir::cir::CallOp
createCallOp(mlir::Location loc,
mlir::SymbolRefAttr callee = mlir::SymbolRefAttr(),
mlir::Type returnType = mlir::cir::VoidType(),
mlir::ValueRange operands = mlir::ValueRange(),
mlir::cir::ExtraFuncAttributesAttr extraFnAttr = {}) {

mlir::cir::CallOp callOp =
create<mlir::cir::CallOp>(loc, callee, returnType, operands);

if (extraFnAttr) {
callOp->setAttr("extra_attrs", extraFnAttr);
} else {
mlir::NamedAttrList empty;
callOp->setAttr("extra_attrs",
mlir::cir::ExtraFuncAttributesAttr::get(
getContext(), empty.getDictionary(getContext())));
}
return callOp;
}

mlir::cir::CallOp
createCallOp(mlir::Location loc, mlir::cir::FuncOp callee,
mlir::ValueRange operands = mlir::ValueRange(),
mlir::cir::ExtraFuncAttributesAttr extraFnAttr = {}) {
return createCallOp(loc, mlir::SymbolRefAttr::get(callee),
callee.getFunctionType().getReturnType(), operands,
extraFnAttr);
}

mlir::cir::CallOp
createIndirectCallOp(mlir::Location loc, mlir::Value ind_target,
mlir::cir::FuncType fn_type,
mlir::ValueRange operands = mlir::ValueRange(),
mlir::cir::ExtraFuncAttributesAttr extraFnAttr = {}) {

llvm::SmallVector<mlir::Value, 4> resOperands({ind_target});
resOperands.append(operands.begin(), operands.end());

return createCallOp(loc, mlir::SymbolRefAttr(), fn_type.getReturnType(),
resOperands, extraFnAttr);
}

mlir::cir::CallOp
createCallOp(mlir::Location loc, mlir::SymbolRefAttr callee,
mlir::ValueRange operands = mlir::ValueRange(),
mlir::cir::ExtraFuncAttributesAttr extraFnAttr = {}) {
return createCallOp(loc, callee, mlir::cir::VoidType(), operands,
extraFnAttr);
}
};

} // namespace cir
Expand Down
25 changes: 9 additions & 16 deletions clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2785,6 +2785,7 @@ class CIR_CallOp<string mnemonic, list<Trait> extra_traits = []> :
dag commonArgs = (ins
OptionalAttr<FlatSymbolRefAttr>:$callee,
Variadic<CIR_AnyType>:$arg_ops,
ExtraFuncAttr:$extra_attrs,
OptionalAttr<ASTCallExprInterface>:$ast
);
}
Expand Down Expand Up @@ -2820,12 +2821,16 @@ def CallOp : CIR_CallOp<"call"> {
let arguments = commonArgs;
let results = (outs Optional<CIR_AnyType>:$result);

let skipDefaultBuilders = 1;

let builders = [
OpBuilder<(ins "FuncOp":$callee, CArg<"ValueRange", "{}">:$operands), [{
bcardosolopes marked this conversation as resolved.
Show resolved Hide resolved
OpBuilder<(ins "SymbolRefAttr":$callee, "mlir::Type":$resType,
CArg<"ValueRange", "{}">:$operands), [{
$_state.addOperands(operands);
$_state.addAttribute("callee", SymbolRefAttr::get(callee));
if (!callee.getFunctionType().isVoid())
$_state.addTypes(callee.getFunctionType().getReturnType());
if (callee)
$_state.addAttribute("callee", callee);
if (resType && !resType.isa<VoidType>())
$_state.addTypes(resType);
}]>,
OpBuilder<(ins "Value":$ind_target,
"FuncType":$fn_type,
Expand All @@ -2834,18 +2839,6 @@ def CallOp : CIR_CallOp<"call"> {
$_state.addOperands(operands);
if (!fn_type.isVoid())
$_state.addTypes(fn_type.getReturnType());
}]>,
OpBuilder<(ins "SymbolRefAttr":$callee, "mlir::Type":$resType,
CArg<"ValueRange", "{}">:$operands), [{
$_state.addOperands(operands);
$_state.addAttribute("callee", callee);
if (resType && !resType.isa<VoidType>())
$_state.addTypes(resType);
}]>,
OpBuilder<(ins "SymbolRefAttr":$callee,
CArg<"ValueRange", "{}">:$operands), [{
$_state.addOperands(operands);
$_state.addAttribute("callee", callee);
}]>
];
}
Expand Down
32 changes: 20 additions & 12 deletions clang/lib/CIR/CodeGen/CIRGenCall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -447,22 +447,29 @@ static mlir::cir::CIRCallOpInterface
buildCallLikeOp(CIRGenFunction &CGF, mlir::Location callLoc,
mlir::cir::FuncType indirectFuncTy, mlir::Value indirectFuncVal,
mlir::cir::FuncOp directFuncOp,
SmallVectorImpl<mlir::Value> &CIRCallArgs, bool InvokeDest) {
SmallVectorImpl<mlir::Value> &CIRCallArgs, bool InvokeDest,
mlir::cir::ExtraFuncAttributesAttr extraFnAttrs) {
auto &builder = CGF.getBuilder();

if (InvokeDest) {
auto addr = CGF.currLexScope->getExceptionInfo().addr;
if (indirectFuncTy)
return builder.create<mlir::cir::TryCallOp>(

mlir::cir::TryCallOp tryCallOp;
if (indirectFuncTy) {
tryCallOp = builder.create<mlir::cir::TryCallOp>(
callLoc, addr, indirectFuncVal, indirectFuncTy, CIRCallArgs);
return builder.create<mlir::cir::TryCallOp>(callLoc, directFuncOp, addr,
CIRCallArgs);
} else {
tryCallOp = builder.create<mlir::cir::TryCallOp>(callLoc, directFuncOp,
addr, CIRCallArgs);
}
tryCallOp->setAttr("extra_attrs", extraFnAttrs);
return tryCallOp;
}

if (indirectFuncTy)
return builder.create<mlir::cir::CallOp>(callLoc, indirectFuncVal,
indirectFuncTy, CIRCallArgs);
return builder.create<mlir::cir::CallOp>(callLoc, directFuncOp, CIRCallArgs);
return builder.createIndirectCallOp(
callLoc, indirectFuncVal, indirectFuncTy, CIRCallArgs, extraFnAttrs);
return builder.createCallOp(callLoc, directFuncOp, CIRCallArgs, extraFnAttrs);
}

RValue CIRGenFunction::buildCall(const CIRGenFunctionInfo &CallInfo,
Expand Down Expand Up @@ -735,9 +742,10 @@ RValue CIRGenFunction::buildCall(const CIRGenFunctionInfo &CallInfo,
indirectFuncVal = CalleePtr->getResult(0);
}

mlir::cir::CIRCallOpInterface callLikeOp =
buildCallLikeOp(*this, callLoc, indirectFuncTy, indirectFuncVal,
directFuncOp, CIRCallArgs, InvokeDest);
mlir::cir::CIRCallOpInterface callLikeOp = buildCallLikeOp(
*this, callLoc, indirectFuncTy, indirectFuncVal, directFuncOp,
CIRCallArgs, InvokeDest,
mlir::cir::ExtraFuncAttributesAttr::get(builder.getContext(), Attrs));

if (E)
callLikeOp->setAttr(
Expand Down Expand Up @@ -844,7 +852,7 @@ mlir::Value CIRGenFunction::buildRuntimeCall(mlir::Location loc,
// TODO(cir): set the calling convention to this runtime call.
assert(!MissingFeatures::setCallingConv());

auto call = builder.create<mlir::cir::CallOp>(loc, callee, args);
auto call = builder.createCallOp(loc, callee, args);
assert(call->getNumResults() <= 1 &&
"runtime functions have at most 1 result");

Expand Down
13 changes: 6 additions & 7 deletions clang/lib/CIR/CodeGen/CIRGenCoroutine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,10 +179,9 @@ mlir::cir::CallOp CIRGenFunction::buildCoroIDBuiltinCall(mlir::Location loc,
} else
fnOp = cast<mlir::cir::FuncOp>(builtin);

return builder.create<mlir::cir::CallOp>(
loc, fnOp,
mlir::ValueRange{builder.getUInt32(NewAlign, loc), nullPtr, nullPtr,
nullPtr});
return builder.createCallOp(loc, fnOp,
mlir::ValueRange{builder.getUInt32(NewAlign, loc),
nullPtr, nullPtr, nullPtr});
}

mlir::cir::CallOp
Expand All @@ -202,7 +201,7 @@ CIRGenFunction::buildCoroAllocBuiltinCall(mlir::Location loc) {
} else
fnOp = cast<mlir::cir::FuncOp>(builtin);

return builder.create<mlir::cir::CallOp>(
return builder.createCallOp(
loc, fnOp, mlir::ValueRange{CurCoro.Data->CoroId.getResult()});
}

Expand All @@ -223,7 +222,7 @@ CIRGenFunction::buildCoroBeginBuiltinCall(mlir::Location loc,
} else
fnOp = cast<mlir::cir::FuncOp>(builtin);

return builder.create<mlir::cir::CallOp>(
return builder.createCallOp(
loc, fnOp,
mlir::ValueRange{CurCoro.Data->CoroId.getResult(), coroframeAddr});
}
Expand All @@ -244,7 +243,7 @@ mlir::cir::CallOp CIRGenFunction::buildCoroEndBuiltinCall(mlir::Location loc,
} else
fnOp = cast<mlir::cir::FuncOp>(builtin);

return builder.create<mlir::cir::CallOp>(
return builder.createCallOp(
loc, fnOp, mlir::ValueRange{nullPtr, builder.getBool(false, loc)});
}

Expand Down
4 changes: 2 additions & 2 deletions clang/lib/CIR/CodeGen/CIRGenModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1717,8 +1717,8 @@ void CIRGenModule::ReplaceUsesOfNonProtoTypeWithRealFunction(
builder.setInsertionPoint(noProtoCallOp);

// Patch call type with the real function type.
auto realCallOp = builder.create<mlir::cir::CallOp>(
noProtoCallOp.getLoc(), NewFn, noProtoCallOp.getOperands());
auto realCallOp = builder.createCallOp(noProtoCallOp.getLoc(), NewFn,
noProtoCallOp.getOperands());

// Replace old no proto call with fixed call.
noProtoCallOp.replaceAllUsesWith(realCallOp);
Expand Down
39 changes: 34 additions & 5 deletions clang/lib/CIR/Dialect/IR/CIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2310,6 +2310,7 @@ verifyCallCommInSymbolUses(Operation *op, SymbolTableCollection &symbolTable) {

static ::mlir::ParseResult parseCallCommon(
::mlir::OpAsmParser &parser, ::mlir::OperationState &result,
llvm::StringRef extraAttrsAttrName,
llvm::function_ref<::mlir::ParseResult(::mlir::OpAsmParser &,
::mlir::OperationState &)>
customOpHandler =
Expand Down Expand Up @@ -2344,6 +2345,23 @@ static ::mlir::ParseResult parseCallCommon(
return ::mlir::failure();
if (parser.parseRParen())
return ::mlir::failure();

auto &builder = parser.getBuilder();
Attribute extraAttrs;
if (::mlir::succeeded(parser.parseOptionalKeyword("extra"))) {
if (parser.parseLParen().failed())
return failure();
if (parser.parseAttribute(extraAttrs).failed())
return failure();
if (parser.parseRParen().failed())
return failure();
} else {
NamedAttrList empty;
extraAttrs = mlir::cir::ExtraFuncAttributesAttr::get(
builder.getContext(), empty.getDictionary(builder.getContext()));
}
result.addAttribute(extraAttrsAttrName, extraAttrs);

if (parser.parseOptionalAttrDict(result.attributes))
return ::mlir::failure();
if (parser.parseColon())
Expand All @@ -2364,6 +2382,7 @@ static ::mlir::ParseResult parseCallCommon(
void printCallCommon(
Operation *op, mlir::Value indirectCallee, mlir::FlatSymbolRefAttr flatSym,
::mlir::OpAsmPrinter &state,
::mlir::cir::ExtraFuncAttributesAttr extraAttrs,
llvm::function_ref<void()> customOpHandler = []() {}) {
state << ' ';

Expand All @@ -2379,13 +2398,20 @@ void printCallCommon(
state << "(";
state << ops;
state << ")";
llvm::SmallVector<::llvm::StringRef, 2> elidedAttrs;

llvm::SmallVector<::llvm::StringRef, 4> elidedAttrs;
elidedAttrs.push_back("callee");
elidedAttrs.push_back("ast");
elidedAttrs.push_back("extra_attrs");
state.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
state << ' ' << ":";
state << ' ';
state.printFunctionalType(op->getOperands().getTypes(), op->getResultTypes());
if (!extraAttrs.getElements().empty()) {
state << " extra(";
state.printAttributeWithoutType(extraAttrs);
state << ")";
}
}

LogicalResult
Expand All @@ -2395,12 +2421,14 @@ cir::CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {

::mlir::ParseResult CallOp::parse(::mlir::OpAsmParser &parser,
::mlir::OperationState &result) {
return parseCallCommon(parser, result);

return parseCallCommon(parser, result, getExtraAttrsAttrName(result.name));
}

void CallOp::print(::mlir::OpAsmPrinter &state) {
mlir::Value indirectCallee = isIndirect() ? getIndirectCall() : nullptr;
printCallCommon(*this, indirectCallee, getCalleeAttr(), state);
printCallCommon(*this, indirectCallee, getCalleeAttr(), state,
getExtraAttrs());
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -2457,7 +2485,7 @@ LogicalResult cir::TryCallOp::verify() { return mlir::success(); }
::mlir::ParseResult TryCallOp::parse(::mlir::OpAsmParser &parser,
::mlir::OperationState &result) {
return parseCallCommon(
parser, result,
parser, result, getExtraAttrsAttrName(result.name),
[](::mlir::OpAsmParser &parser,
::mlir::OperationState &result) -> ::mlir::ParseResult {
::mlir::OpAsmParser::UnresolvedOperand exceptionRawOperands[1];
Expand Down Expand Up @@ -2499,7 +2527,8 @@ void TryCallOp::print(::mlir::OpAsmPrinter &state) {
state << getExceptionInfo();
state << ")";
mlir::Value indirectCallee = isIndirect() ? getIndirectCall() : nullptr;
printCallCommon(*this, indirectCallee, getCalleeAttr(), state);
printCallCommon(*this, indirectCallee, getCalleeAttr(), state,
getExtraAttrs());
}

//===----------------------------------------------------------------------===//
Expand Down
25 changes: 12 additions & 13 deletions clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ FuncOp LoweringPreparePass::buildCXXGlobalVarDeclInitFunc(GlobalOp op) {
}

// Create a variable initialization function.
mlir::OpBuilder builder(&getContext());
CIRBaseBuilderTy builder(getContext());
builder.setInsertionPointAfter(op);
auto voidTy = ::mlir::cir::VoidType::get(builder.getContext());
auto fnType = mlir::cir::FuncType::get({}, voidTy);
Expand Down Expand Up @@ -264,7 +264,7 @@ FuncOp LoweringPreparePass::buildCXXGlobalVarDeclInitFunc(GlobalOp op) {
dtorCall.getArgOperand(0));
args[2] = builder.create<mlir::cir::GetGlobalOp>(
Handle.getLoc(), HandlePtrTy, Handle.getSymName());
builder.create<mlir::cir::CallOp>(dtorCall.getLoc(), fnAtExit, args);
builder.createCallOp(dtorCall.getLoc(), fnAtExit, args);
dtorCall->erase();
entryBB->getOperations().splice(entryBB->end(), dtorBlock.getOperations(),
dtorBlock.begin(),
Expand Down Expand Up @@ -481,7 +481,7 @@ void LoweringPreparePass::buildCXXGlobalInitFunc() {
fnName += getTransformedFileName(theModule);
}

mlir::OpBuilder builder(&getContext());
CIRBaseBuilderTy builder(getContext());
builder.setInsertionPointToEnd(&theModule.getBodyRegion().back());
auto fnType = mlir::cir::FuncType::get(
{}, mlir::cir::VoidType::get(builder.getContext()));
Expand All @@ -490,7 +490,7 @@ void LoweringPreparePass::buildCXXGlobalInitFunc() {
mlir::cir::GlobalLinkageKind::ExternalLinkage);
builder.setInsertionPointToStart(f.addEntryBlock());
for (auto &f : dynamicInitializers) {
builder.create<mlir::cir::CallOp>(f.getLoc(), f);
builder.createCallOp(f.getLoc(), f);
}

builder.create<ReturnOp>(f.getLoc());
Expand Down Expand Up @@ -597,7 +597,7 @@ void LoweringPreparePass::lowerArrayCtor(ArrayCtor op) {
void LoweringPreparePass::lowerStdFindOp(StdFindOp op) {
CIRBaseBuilderTy builder(getContext());
builder.setInsertionPointAfter(op.getOperation());
auto call = builder.create<mlir::cir::CallOp>(
auto call = builder.createCallOp(
op.getLoc(), op.getOriginalFnAttr(), op.getResult().getType(),
mlir::ValueRange{op.getOperand(0), op.getOperand(1), op.getOperand(2)});

Expand All @@ -608,9 +608,9 @@ void LoweringPreparePass::lowerStdFindOp(StdFindOp op) {
void LoweringPreparePass::lowerIterBeginOp(IterBeginOp op) {
CIRBaseBuilderTy builder(getContext());
builder.setInsertionPointAfter(op.getOperation());
auto call = builder.create<mlir::cir::CallOp>(
op.getLoc(), op.getOriginalFnAttr(), op.getResult().getType(),
mlir::ValueRange{op.getOperand()});
auto call = builder.createCallOp(op.getLoc(), op.getOriginalFnAttr(),
op.getResult().getType(),
mlir::ValueRange{op.getOperand()});

op.replaceAllUsesWith(call);
op.erase();
Expand All @@ -619,9 +619,9 @@ void LoweringPreparePass::lowerIterBeginOp(IterBeginOp op) {
void LoweringPreparePass::lowerIterEndOp(IterEndOp op) {
CIRBaseBuilderTy builder(getContext());
builder.setInsertionPointAfter(op.getOperation());
auto call = builder.create<mlir::cir::CallOp>(
op.getLoc(), op.getOriginalFnAttr(), op.getResult().getType(),
mlir::ValueRange{op.getOperand()});
auto call = builder.createCallOp(op.getLoc(), op.getOriginalFnAttr(),
op.getResult().getType(),
mlir::ValueRange{op.getOperand()});

op.replaceAllUsesWith(call);
op.erase();
Expand Down Expand Up @@ -712,8 +712,7 @@ void LoweringPreparePass::runOnMathOp(Operation *op) {
buildRuntimeFunction(builder, rtFuncName, op->getLoc(), rtFuncTy);

builder.setInsertionPointAfter(op);
auto call = builder.create<mlir::cir::CallOp>(op->getLoc(), rtFunc,
op->getOperands());
auto call = builder.createCallOp(op->getLoc(), rtFunc, op->getOperands());

op->replaceAllUsesWith(call);
op->erase();
Expand Down
Loading
Loading