Skip to content

Commit

Permalink
[CIR] Pass in ExtraFuncAttrs to CallOp and TryCallOp
Browse files Browse the repository at this point in the history
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
roro47 committed May 29, 2024
1 parent 6fa3821 commit c79abfe
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 5 deletions.
3 changes: 3 additions & 0 deletions clang/lib/CIR/CodeGen/CIRGenCall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -738,6 +738,9 @@ RValue CIRGenFunction::buildCall(const CIRGenFunctionInfo &CallInfo,
buildCallLikeOp(*this, callLoc, indirectFuncTy, indirectFuncVal,
directFuncOp, CIRCallArgs, InvokeDest);

callLikeOp->setAttr("extra_attrs", mlir::cir::ExtraFuncAttributesAttr::get(
builder.getContext(), Attrs));

if (E)
callLikeOp->setAttr(
"ast", mlir::cir::ASTCallExprAttr::get(builder.getContext(), *E));
Expand Down
47 changes: 42 additions & 5 deletions clang/lib/CIR/Dialect/IR/CIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2293,6 +2293,7 @@ verifyCallCommInSymbolUses(Operation *op, SymbolTableCollection &symbolTable) {

static ::mlir::ParseResult parseCallCommon(
::mlir::OpAsmParser &parser, ::mlir::OperationState &result,
llvm::function_ref<::mlir::StringAttr(::mlir::OperationName)> getExtraAttrsAttrName,
llvm::function_ref<::mlir::ParseResult(::mlir::OpAsmParser &,
::mlir::OperationState &)>
customOpHandler =
Expand Down Expand Up @@ -2327,6 +2328,23 @@ static ::mlir::ParseResult parseCallCommon(
return ::mlir::failure();
if (parser.parseRParen())
return ::mlir::failure();

auto &builder = parser.getBuilder();
Attribute extraAttrs;
if (::mlir::succeeded(parser.parseOptionalKeyword("extra"))) {
if (parser.parseLParen().failed())
return failure();
if (parser.parseAttribute(extraAttrs).failed())
return failure();
if (parser.parseRParen().failed())
return failure();
} else {
NamedAttrList empty;
extraAttrs = mlir::cir::ExtraFuncAttributesAttr::get(
builder.getContext(), empty.getDictionary(builder.getContext()));
}
result.addAttribute(getExtraAttrsAttrName(result.name), extraAttrs);

if (parser.parseOptionalAttrDict(result.attributes))
return ::mlir::failure();
if (parser.parseColon())
Expand All @@ -2347,6 +2365,7 @@ static ::mlir::ParseResult parseCallCommon(
void printCallCommon(
Operation *op, mlir::Value indirectCallee, mlir::FlatSymbolRefAttr flatSym,
::mlir::OpAsmPrinter &state,
llvm::function_ref<::mlir::cir::ExtraFuncAttributesAttr()> getExtraAttrs,
llvm::function_ref<void()> customOpHandler = []() {}) {
state << ' ';

Expand All @@ -2362,13 +2381,20 @@ void printCallCommon(
state << "(";
state << ops;
state << ")";
llvm::SmallVector<::llvm::StringRef, 2> elidedAttrs;

llvm::SmallVector<::llvm::StringRef, 3> elidedAttrs;
elidedAttrs.push_back("callee");
elidedAttrs.push_back("ast");
elidedAttrs.push_back("extra_attrs");
state.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
state << ' ' << ":";
state << ' ';
state.printFunctionalType(op->getOperands().getTypes(), op->getResultTypes());
if (!getExtraAttrs().getElements().empty()) {
state << " extra(";
state.printAttributeWithoutType(getExtraAttrs());
state << ")";
}
}

LogicalResult
Expand All @@ -2378,12 +2404,18 @@ cir::CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {

::mlir::ParseResult CallOp::parse(::mlir::OpAsmParser &parser,
::mlir::OperationState &result) {
return parseCallCommon(parser, result);

return parseCallCommon(parser, result, [](::mlir::OperationName name) -> ::mlir::StringAttr {
return getExtraAttrsAttrName(name);
});
}

void CallOp::print(::mlir::OpAsmPrinter &state) {
mlir::Value indirectCallee = isIndirect() ? getIndirectCall() : nullptr;
printCallCommon(*this, indirectCallee, getCalleeAttr(), state);
printCallCommon(*this, indirectCallee, getCalleeAttr(), state,
[this]() -> ::mlir::cir::ExtraFuncAttributesAttr {
return getExtraAttrs();
});
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -2440,7 +2472,9 @@ LogicalResult cir::TryCallOp::verify() { return mlir::success(); }
::mlir::ParseResult TryCallOp::parse(::mlir::OpAsmParser &parser,
::mlir::OperationState &result) {
return parseCallCommon(
parser, result,
parser, result, [](::mlir::OperationName name) -> ::mlir::StringAttr {
return getExtraAttrsAttrName(name);
},
[](::mlir::OpAsmParser &parser,
::mlir::OperationState &result) -> ::mlir::ParseResult {
::mlir::OpAsmParser::UnresolvedOperand exceptionRawOperands[1];
Expand Down Expand Up @@ -2482,7 +2516,10 @@ void TryCallOp::print(::mlir::OpAsmPrinter &state) {
state << getExceptionInfo();
state << ")";
mlir::Value indirectCallee = isIndirect() ? getIndirectCall() : nullptr;
printCallCommon(*this, indirectCallee, getCalleeAttr(), state);
printCallCommon(*this, indirectCallee, getCalleeAttr(), state,
[this]() -> ::mlir::cir::ExtraFuncAttributesAttr {
return getExtraAttrs();
});
}

//===----------------------------------------------------------------------===//
Expand Down
23 changes: 23 additions & 0 deletions clang/test/CIR/CodeGen/call-extra-attrs.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// RUN: %clang_cc1 -O2 -triple x86_64-unknown-linux-gnu -fclangir -emit-cir %s -o %t.cir
// RUN: FileCheck --input-file=%t.cir %s -check-prefix=CIR

__attribute__((nothrow))
int s0(int a, int b) {
int x = a + b;
return x;
}

__attribute__((noinline))
int s1(int a, int b) {
return s0(a,b);
}

int s2(int a, int b) {
return s1(a, b);
}

// CIR: #fn_attr = #cir<extra({nothrow = #cir.nothrow})>
// CIR: #fn_attr1 = #cir<extra({inline = #cir.inline<no>, nothrow = #cir.nothrow})>

// CIR: cir.call @_Z2s0ii(%{{.*}}, %{{.*}}) : {{.*}}, {{.*}}) -> {{.*}} extra(#fn_attr)
// CIR: cir.call @_Z2s1ii(%{{.*}}, %{{.*}}) : ({{.*}}, {{.*}}) -> {{.*}} loc(#loc{{.*}})

0 comments on commit c79abfe

Please sign in to comment.