Skip to content

Commit 832d40f

Browse files
committed
[CIR] Refactor AST record interface and make it more type safe.
1 parent e6d0f0e commit 832d40f

File tree

22 files changed

+58
-144
lines changed

22 files changed

+58
-144
lines changed

clang/include/clang/CIR/Dialect/IR/CIRAttrs.td

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -443,24 +443,9 @@ def ASTTypeDeclAttr: ASTDecl<"TypeDecl", "type.decl",
443443
def ASTTagDeclAttr : ASTDecl<"TagDecl", "tag.decl",
444444
[ASTTagDeclInterface]>;
445445

446-
def ASTEnumDeclAttr : ASTDecl<"EnumDecl", "enum.decl",
447-
[ASTEnumDeclInterface]>;
448-
449446
def ASTRecordDeclAttr : ASTDecl<"RecordDecl", "record.decl",
450447
[ASTRecordDeclInterface]>;
451448

452-
def ASTCXXRecordDeclAttr : ASTDecl<"CXXRecordDecl", "cxxrecord.decl",
453-
[ASTCXXRecordDeclInterface]>;
454-
455-
def ASTClassTemplateSpecializationDeclAttr :
456-
ASTDecl<"ClassTemplateSpecializationDecl", "class.template.spec.decl",
457-
[ASTClassTemplateSpecializationDeclInterface]>;
458-
459-
def ASTClassTemplatePartialSpecializationDeclAttr :
460-
ASTDecl<"ClassTemplatePartialSpecializationDecl",
461-
"class.template.partial.spec.decl",
462-
[ASTClassTemplatePartialSpecializationDeclInterface]>;
463-
464449
//===----------------------------------------------------------------------===//
465450
// ExtraFuncAttr
466451
//===----------------------------------------------------------------------===//

clang/include/clang/CIR/Dialect/IR/CIRTypes.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def CIR_StructType : CIR_Type<"Struct", "struct",
112112
"bool":$body,
113113
"bool":$packed,
114114
"mlir::cir::StructType::RecordKind":$kind,
115-
OptionalParameter<"Attribute">:$ast
115+
OptionalParameter<"std::optional<ASTRecordDeclInterface>">:$ast
116116
);
117117

118118
let hasCustomAssemblyFormat = 1;
@@ -164,7 +164,7 @@ def CIR_StructType : CIR_Type<"Struct", "struct",
164164

165165
let extraClassDefinition = [{
166166
void $cppClass::dropAst() {
167-
getImpl()->ast = Attribute();
167+
getImpl()->ast = std::nullopt;
168168
}
169169
}];
170170
}

clang/include/clang/CIR/Interfaces/ASTAttrInterfaces.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,6 @@
1818
namespace mlir {
1919
namespace cir {
2020

21-
mlir::Attribute makeAstDeclAttr(const clang::Decl *decl,
22-
mlir::MLIRContext *ctx);
23-
2421
mlir::Attribute makeFuncDeclAttr(const clang::Decl *decl,
2522
mlir::MLIRContext *ctx);
2623

clang/include/clang/CIR/Interfaces/ASTAttrInterfaces.td

Lines changed: 13 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,17 @@ let cppNamespace = "::mlir::cir" in {
1616
let methods = [
1717
InterfaceMethod<"", "bool", "hasOwnerAttr", (ins), [{}],
1818
/*defaultImplementation=*/ [{
19-
return $_attr.getAstDecl()->template hasAttr< clang::OwnerAttr >();
19+
return $_attr.getAstDecl()->template hasAttr<clang::OwnerAttr>();
2020
}]
2121
>,
2222
InterfaceMethod<"", "bool", "hasPointerAttr", (ins), [{}],
2323
/*defaultImplementation=*/ [{
24-
return $_attr.getAstDecl()->template hasAttr< clang::PointerAttr >();
24+
return $_attr.getAstDecl()->template hasAttr<clang::PointerAttr>();
2525
}]
2626
>,
2727
InterfaceMethod<"", "bool", "hasInitPriorityAttr", (ins), [{}],
2828
/*defaultImplementation=*/ [{
29-
return $_attr.getAstDecl()->template hasAttr< clang::InitPriorityAttr >();
29+
return $_attr.getAstDecl()->template hasAttr<clang::InitPriorityAttr>();
3030
}]
3131
>
3232
];
@@ -93,14 +93,14 @@ let cppNamespace = "::mlir::cir" in {
9393
let methods = [
9494
InterfaceMethod<"", "bool", "isCopyAssignmentOperator", (ins), [{}],
9595
/*defaultImplementation=*/ [{
96-
if (auto decl = dyn_cast< clang::CXXMethodDecl >($_attr.getAstDecl()))
96+
if (auto decl = dyn_cast<clang::CXXMethodDecl>($_attr.getAstDecl()))
9797
return decl->isCopyAssignmentOperator();
9898
return false;
9999
}]
100100
>,
101101
InterfaceMethod<"", "bool", "isMoveAssignmentOperator", (ins), [{}],
102102
/*defaultImplementation=*/ [{
103-
if (auto decl = dyn_cast< clang::CXXMethodDecl >($_attr.getAstDecl()))
103+
if (auto decl = dyn_cast<clang::CXXMethodDecl>($_attr.getAstDecl()))
104104
return decl->isMoveAssignmentOperator();
105105
return false;
106106
}]
@@ -149,29 +149,20 @@ let cppNamespace = "::mlir::cir" in {
149149
];
150150
}
151151

152-
def ASTEnumDeclInterface : AttrInterface<"ASTEnumDeclInterface",
153-
[ASTTagDeclInterface]>;
154-
155152
def ASTRecordDeclInterface : AttrInterface<"ASTRecordDeclInterface",
156-
[ASTTagDeclInterface]>;
157-
158-
def ASTCXXRecordDeclInterface : AttrInterface<"ASTCXXRecordDeclInterface",
159-
[ASTRecordDeclInterface]> {
153+
[ASTTagDeclInterface]> {
160154
let methods = [
161155
InterfaceMethod<"", "bool", "isLambda", (ins), [{}],
162156
/*defaultImplementation=*/ [{
163-
return $_attr.getAstDecl()->isLambda();
157+
if (auto ast = clang::dyn_cast<clang::CXXRecordDecl>($_attr.getAstDecl()))
158+
return ast->isLambda();
159+
return false;
164160
}]
165-
>
166-
];
167-
}
168-
169-
def ASTClassTemplateSpecializationDeclInterface :
170-
AttrInterface<"ASTClassTemplateSpecializationDeclInterface",
171-
[ASTCXXRecordDeclInterface]> {
172-
let methods = [
173-
InterfaceMethod<"", "bool", "hasPromiseType", (ins), [{}],
161+
>,
162+
InterfaceMethod<"", "bool", "hasPromiseType", (ins), [{}],
174163
/*defaultImplementation=*/ [{
164+
if (!clang::isa<clang::ClassTemplateSpecializationDecl>($_attr.getAstDecl()))
165+
return false;
175166
for (const auto *sub : $_attr.getAstDecl()->decls()) {
176167
if (auto subRec = clang::dyn_cast<clang::CXXRecordDecl>(sub)) {
177168
if (subRec->getDeclName().isIdentifier() &&
@@ -186,11 +177,6 @@ let cppNamespace = "::mlir::cir" in {
186177
];
187178
}
188179

189-
def ASTClassTemplatePartialSpecializationDeclInterface :
190-
AttrInterface<"ASTClassTemplatePartialSpecializationDeclInterface",
191-
[ASTClassTemplateSpecializationDeclInterface]>;
192-
193-
194180
def AnyASTFunctionDeclAttr : Attr<
195181
CPred<"::mlir::isa<::mlir::cir::ASTFunctionDeclInterface>($_self)">,
196182
"AST Function attribute"> {

clang/lib/CIR/CodeGen/CIRGenBuilder.h

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ class CIRGenBuilderTy : public mlir::OpBuilder {
171171
structTy = getType<mlir::cir::StructType>(
172172
members, mlir::StringAttr::get(getContext()),
173173
/*body=*/true, packed, mlir::cir::StructType::Struct,
174-
/*ast=*/mlir::Attribute());
174+
/*ast=*/std::nullopt);
175175

176176
// Return zero or anonymous constant struct.
177177
if (isZero)
@@ -408,23 +408,16 @@ class CIRGenBuilderTy : public mlir::OpBuilder {
408408
/// Get a CIR named struct type.
409409
mlir::cir::StructType getStructTy(llvm::ArrayRef<mlir::Type> members,
410410
llvm::StringRef name, bool body,
411-
bool packed, mlir::Attribute ast) {
411+
bool packed, const clang::RecordDecl *ast) {
412412
const auto nameAttr = getStringAttr(name);
413+
std::optional<mlir::cir::ASTRecordDeclAttr> astAttr = std::nullopt;
413414
auto kind = mlir::cir::StructType::RecordKind::Struct;
414-
if (ast)
415-
if (auto tagDecl = mlir::dyn_cast<mlir::cir::ASTTagDeclInterface>(ast))
416-
kind = getRecordKind(tagDecl.getTagKind());
415+
if (ast) {
416+
astAttr = getAttr<mlir::cir::ASTRecordDeclAttr>(ast);
417+
kind = getRecordKind(ast->getTagKind());
418+
}
417419
return mlir::cir::StructType::get(getContext(), members, nameAttr, body,
418-
packed, kind, ast);
419-
}
420-
421-
mlir::cir::StructType getStructTy(llvm::ArrayRef<mlir::Type> members,
422-
llvm::StringRef name, bool body,
423-
bool packed, const clang::RecordDecl *ast) {
424-
mlir::Attribute astAttr;
425-
if (ast)
426-
astAttr = mlir::cir::makeAstDeclAttr(ast, getContext());
427-
return getStructTy(members, name, body, packed, astAttr);
420+
packed, kind, astAttr);
428421
}
429422

430423
//

clang/lib/CIR/CodeGen/CIRGenTypes.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -167,9 +167,7 @@ mlir::Type CIRGenTypes::convertRecordDeclType(const clang::RecordDecl *RD) {
167167
// Handle forward decl / incomplete types.
168168
if (!entry) {
169169
auto name = getRecordTypeName(RD, "");
170-
entry =
171-
Builder.getStructTy({}, name, /*body=*/false, /*packed=*/false,
172-
mlir::cir::makeAstDeclAttr(RD, &getMLIRContext()));
170+
entry = Builder.getStructTy({}, name, /*body=*/false, /*packed=*/false, RD);
173171
recordDeclTypes[key] = entry;
174172
}
175173

clang/lib/CIR/CodeGen/CIRRecordLayoutBuilder.cpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -584,10 +584,8 @@ CIRGenTypes::computeRecordLayout(const RecordDecl *D,
584584
builder.astRecordLayout.getSize()) {
585585
CIRRecordLowering baseBuilder(*this, D, /*Packed=*/builder.isPacked);
586586
auto baseIdentifier = getRecordTypeName(D, ".base");
587-
*BaseTy =
588-
Builder.getStructTy(baseBuilder.fieldTypes, baseIdentifier,
589-
/*body=*/true, /*packed=*/false,
590-
mlir::cir::makeAstDeclAttr(D, &getMLIRContext()));
587+
*BaseTy = Builder.getStructTy(baseBuilder.fieldTypes, baseIdentifier,
588+
/*body=*/true, /*packed=*/false, D);
591589
// TODO(cir): add something like addRecordTypeName
592590

593591
// BaseTy and Ty must agree on their packedness for getCIRFieldNo to work
@@ -601,8 +599,7 @@ CIRGenTypes::computeRecordLayout(const RecordDecl *D,
601599
// signifies that the type is no longer opaque and record layout is complete,
602600
// but we may need to recursively layout D while laying D out as a base type.
603601
*Ty = Builder.getStructTy(builder.fieldTypes, getRecordTypeName(D, ""),
604-
/*body=*/true, /*packed=*/false,
605-
mlir::cir::makeAstDeclAttr(D, &getMLIRContext()));
602+
/*body=*/true, /*packed=*/false, D);
606603

607604
auto RL = std::make_unique<CIRGenRecordLayout>(
608605
Ty ? *Ty : mlir::cir::StructType{},

clang/lib/CIR/Dialect/IR/CIRAttrs.cpp

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -49,38 +49,6 @@ using namespace mlir::cir;
4949
namespace mlir {
5050
namespace cir {
5151

52-
mlir::Attribute makeAstDeclAttr(const clang::Decl *decl,
53-
mlir::MLIRContext *ctx) {
54-
if (auto ast = clang::dyn_cast<clang::CXXConstructorDecl>(decl))
55-
return ASTCXXConstructorDeclAttr::get(ctx, ast);
56-
if (auto ast = clang::dyn_cast<clang::CXXConversionDecl>(decl))
57-
return ASTCXXConversionDeclAttr::get(ctx, ast);
58-
if (auto ast = clang::dyn_cast<clang::CXXDestructorDecl>(decl))
59-
return ASTCXXDestructorDeclAttr::get(ctx, ast);
60-
if (auto ast = clang::dyn_cast<clang::CXXMethodDecl>(decl))
61-
return ASTCXXMethodDeclAttr::get(ctx, ast);
62-
if (auto ast = clang::dyn_cast<clang::FunctionDecl>(decl))
63-
return ASTFunctionDeclAttr::get(ctx, ast);
64-
if (auto ast =
65-
clang::dyn_cast<clang::ClassTemplatePartialSpecializationDecl>(decl))
66-
return ASTClassTemplatePartialSpecializationDeclAttr::get(ctx, ast);
67-
if (auto ast = clang::dyn_cast<clang::ClassTemplateSpecializationDecl>(decl))
68-
return ASTClassTemplateSpecializationDeclAttr::get(ctx, ast);
69-
if (auto ast = clang::dyn_cast<clang::CXXRecordDecl>(decl))
70-
return ASTCXXRecordDeclAttr::get(ctx, ast);
71-
if (auto ast = clang::dyn_cast<clang::RecordDecl>(decl))
72-
return ASTRecordDeclAttr::get(ctx, ast);
73-
if (auto ast = clang::dyn_cast<clang::EnumDecl>(decl))
74-
return ASTEnumDeclAttr::get(ctx, ast);
75-
if (auto ast = clang::dyn_cast<clang::TagDecl>(decl))
76-
return ASTTagDeclAttr::get(ctx, ast);
77-
if (auto ast = clang::dyn_cast<clang::TypeDecl>(decl))
78-
return ASTTypeDeclAttr::get(ctx, ast);
79-
if (auto ast = clang::dyn_cast<clang::VarDecl>(decl))
80-
return ASTVarDeclAttr::get(ctx, ast);
81-
return ASTDeclAttr::get(ctx, decl);
82-
};
83-
8452
mlir::Attribute makeFuncDeclAttr(const clang::Decl *decl,
8553
mlir::MLIRContext *ctx) {
8654
return llvm::TypeSwitch<const clang::Decl *, mlir::Attribute>(decl)

clang/lib/CIR/Dialect/IR/CIRTypes.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ Type StructType::parse(mlir::AsmParser &parser) {
156156
return {};
157157

158158
return StructType::get(parser.getContext(), members, id, body, packed, kind,
159-
mlir::Attribute());
159+
std::nullopt);
160160
}
161161

162162
void StructType::print(mlir::AsmPrinter &printer) const {
@@ -187,9 +187,9 @@ void StructType::print(mlir::AsmPrinter &printer) const {
187187
printer << "}";
188188
}
189189

190-
if (getAst()) {
190+
if (getAst().has_value()) {
191191
printer << " ";
192-
printer.printAttribute(getAst());
192+
printer.printAttribute(getAst().value());
193193
}
194194

195195
printer << '>';

clang/lib/CIR/Dialect/Transforms/LifetimeCheck.cpp

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -890,12 +890,7 @@ void LifetimeCheckPass::checkIf(IfOp ifOp) {
890890
template <class T> bool isStructAndHasAttr(mlir::Type ty) {
891891
if (!ty.isa<mlir::cir::StructType>())
892892
return false;
893-
auto sTy = ty.cast<mlir::cir::StructType>();
894-
auto recordDecl = sTy.getAst();
895-
if (auto interface = dyn_cast<ASTDeclInterface>(recordDecl))
896-
if (hasAttr<T>(interface))
897-
return true;
898-
return false;
893+
return hasAttr<T>(*mlir::cast<mlir::cir::StructType>(ty).getAst());
899894
}
900895

901896
static bool isOwnerType(mlir::Type ty) {
@@ -1761,8 +1756,7 @@ bool LifetimeCheckPass::isLambdaType(mlir::Type ty) {
17611756
auto taskTy = ty.dyn_cast<mlir::cir::StructType>();
17621757
if (!taskTy)
17631758
return false;
1764-
if (auto recordDecl = dyn_cast<ASTCXXRecordDeclInterface>(taskTy.getAst()))
1765-
if (recordDecl.isLambda())
1759+
if (taskTy.getAst()->isLambda())
17661760
IsLambdaTyCache[ty] = true;
17671761

17681762
return IsLambdaTyCache[ty];
@@ -1777,11 +1771,7 @@ bool LifetimeCheckPass::isTaskType(mlir::Value taskVal) {
17771771
auto taskTy = taskVal.getType().dyn_cast<mlir::cir::StructType>();
17781772
if (!taskTy)
17791773
return false;
1780-
auto recordDecl = taskTy.getAst();
1781-
auto spec = dyn_cast<ASTClassTemplateSpecializationDeclInterface>(recordDecl);
1782-
if (!spec)
1783-
return false;
1784-
return spec.hasPromiseType();
1774+
return taskTy.getAst()->hasPromiseType();
17851775
} ();
17861776

17871777
IsTaskTyCache[ty] = result;

0 commit comments

Comments
 (0)