Skip to content

Commit

Permalink
[CIR] Refactor AST record interface and make it more type safe.
Browse files Browse the repository at this point in the history
  • Loading branch information
xlauko committed Sep 18, 2023
1 parent 562e716 commit 6ede214
Show file tree
Hide file tree
Showing 21 changed files with 56 additions and 142 deletions.
15 changes: 0 additions & 15 deletions clang/include/clang/CIR/Dialect/IR/CIRAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -444,24 +444,9 @@ def ASTTypeDeclAttr: ASTDecl<"TypeDecl", "type.decl",
def ASTTagDeclAttr : ASTDecl<"TagDecl", "tag.decl",
[ASTTagDeclInterface]>;

def ASTEnumDeclAttr : ASTDecl<"EnumDecl", "enum.decl",
[ASTEnumDeclInterface]>;

def ASTRecordDeclAttr : ASTDecl<"RecordDecl", "record.decl",
[ASTRecordDeclInterface]>;

def ASTCXXRecordDeclAttr : ASTDecl<"CXXRecordDecl", "cxxrecord.decl",
[ASTCXXRecordDeclInterface]>;

def ASTClassTemplateSpecializationDeclAttr :
ASTDecl<"ClassTemplateSpecializationDecl", "class.template.spec.decl",
[ASTClassTemplateSpecializationDeclInterface]>;

def ASTClassTemplatePartialSpecializationDeclAttr :
ASTDecl<"ClassTemplatePartialSpecializationDecl",
"class.template.partial.spec.decl",
[ASTClassTemplatePartialSpecializationDeclInterface]>;

//===----------------------------------------------------------------------===//
// ExtraFuncAttr
//===----------------------------------------------------------------------===//
Expand Down
4 changes: 2 additions & 2 deletions clang/include/clang/CIR/Dialect/IR/CIRTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def CIR_StructType : CIR_Type<"Struct", "struct",
"bool":$body,
"bool":$packed,
"mlir::cir::StructType::RecordKind":$kind,
OptionalParameter<"Attribute">:$ast
OptionalParameter<"std::optional<ASTRecordDeclInterface>">:$ast
);

let hasCustomAssemblyFormat = 1;
Expand Down Expand Up @@ -164,7 +164,7 @@ def CIR_StructType : CIR_Type<"Struct", "struct",

let extraClassDefinition = [{
void $cppClass::dropAst() {
getImpl()->ast = Attribute();
getImpl()->ast = std::nullopt;
}
}];
}
Expand Down
3 changes: 0 additions & 3 deletions clang/include/clang/CIR/Interfaces/ASTAttrInterfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@
namespace mlir {
namespace cir {

mlir::Attribute makeAstDeclAttr(const clang::Decl *decl,
mlir::MLIRContext *ctx);

mlir::Attribute makeFuncDeclAttr(const clang::Decl *decl,
mlir::MLIRContext *ctx);

Expand Down
40 changes: 13 additions & 27 deletions clang/include/clang/CIR/Interfaces/ASTAttrInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,17 @@ let cppNamespace = "::mlir::cir" in {
let methods = [
InterfaceMethod<"", "bool", "hasOwnerAttr", (ins), [{}],
/*defaultImplementation=*/ [{
return $_attr.getAstDecl()->template hasAttr< clang::OwnerAttr >();
return $_attr.getAstDecl()->template hasAttr<clang::OwnerAttr>();
}]
>,
InterfaceMethod<"", "bool", "hasPointerAttr", (ins), [{}],
/*defaultImplementation=*/ [{
return $_attr.getAstDecl()->template hasAttr< clang::PointerAttr >();
return $_attr.getAstDecl()->template hasAttr<clang::PointerAttr>();
}]
>,
InterfaceMethod<"", "bool", "hasInitPriorityAttr", (ins), [{}],
/*defaultImplementation=*/ [{
return $_attr.getAstDecl()->template hasAttr< clang::InitPriorityAttr >();
return $_attr.getAstDecl()->template hasAttr<clang::InitPriorityAttr>();
}]
>
];
Expand Down Expand Up @@ -93,14 +93,14 @@ let cppNamespace = "::mlir::cir" in {
let methods = [
InterfaceMethod<"", "bool", "isCopyAssignmentOperator", (ins), [{}],
/*defaultImplementation=*/ [{
if (auto decl = dyn_cast< clang::CXXMethodDecl >($_attr.getAstDecl()))
if (auto decl = dyn_cast<clang::CXXMethodDecl>($_attr.getAstDecl()))
return decl->isCopyAssignmentOperator();
return false;
}]
>,
InterfaceMethod<"", "bool", "isMoveAssignmentOperator", (ins), [{}],
/*defaultImplementation=*/ [{
if (auto decl = dyn_cast< clang::CXXMethodDecl >($_attr.getAstDecl()))
if (auto decl = dyn_cast<clang::CXXMethodDecl>($_attr.getAstDecl()))
return decl->isMoveAssignmentOperator();
return false;
}]
Expand Down Expand Up @@ -149,29 +149,20 @@ let cppNamespace = "::mlir::cir" in {
];
}

def ASTEnumDeclInterface : AttrInterface<"ASTEnumDeclInterface",
[ASTTagDeclInterface]>;

def ASTRecordDeclInterface : AttrInterface<"ASTRecordDeclInterface",
[ASTTagDeclInterface]>;

def ASTCXXRecordDeclInterface : AttrInterface<"ASTCXXRecordDeclInterface",
[ASTRecordDeclInterface]> {
[ASTTagDeclInterface]> {
let methods = [
InterfaceMethod<"", "bool", "isLambda", (ins), [{}],
/*defaultImplementation=*/ [{
return $_attr.getAstDecl()->isLambda();
if (auto ast = clang::dyn_cast<clang::CXXRecordDecl>($_attr.getAstDecl()))
return ast->isLambda();
return false;
}]
>
];
}

def ASTClassTemplateSpecializationDeclInterface :
AttrInterface<"ASTClassTemplateSpecializationDeclInterface",
[ASTCXXRecordDeclInterface]> {
let methods = [
InterfaceMethod<"", "bool", "hasPromiseType", (ins), [{}],
>,
InterfaceMethod<"", "bool", "hasPromiseType", (ins), [{}],
/*defaultImplementation=*/ [{
if (!clang::isa<clang::ClassTemplateSpecializationDecl>($_attr.getAstDecl()))
return false;
for (const auto *sub : $_attr.getAstDecl()->decls()) {
if (auto subRec = clang::dyn_cast<clang::CXXRecordDecl>(sub)) {
if (subRec->getDeclName().isIdentifier() &&
Expand All @@ -186,11 +177,6 @@ let cppNamespace = "::mlir::cir" in {
];
}

def ASTClassTemplatePartialSpecializationDeclInterface :
AttrInterface<"ASTClassTemplatePartialSpecializationDeclInterface",
[ASTClassTemplateSpecializationDeclInterface]>;


def AnyASTFunctionDeclAttr : Attr<
CPred<"::mlir::isa<::mlir::cir::ASTFunctionDeclInterface>($_self)">,
"AST Function attribute"> {
Expand Down
23 changes: 8 additions & 15 deletions clang/lib/CIR/CodeGen/CIRGenBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ class CIRGenBuilderTy : public mlir::OpBuilder {
structTy = getType<mlir::cir::StructType>(
members, mlir::StringAttr::get(getContext()),
/*body=*/true, packed, mlir::cir::StructType::Struct,
/*ast=*/mlir::Attribute());
/*ast=*/std::nullopt);

// Return zero or anonymous constant struct.
if (isZero)
Expand Down Expand Up @@ -413,23 +413,16 @@ class CIRGenBuilderTy : public mlir::OpBuilder {
/// Get a CIR named struct type.
mlir::cir::StructType getStructTy(llvm::ArrayRef<mlir::Type> members,
llvm::StringRef name, bool body,
bool packed, mlir::Attribute ast) {
bool packed, const clang::RecordDecl *ast) {
const auto nameAttr = getStringAttr(name);
std::optional<mlir::cir::ASTRecordDeclAttr> astAttr = std::nullopt;
auto kind = mlir::cir::StructType::RecordKind::Struct;
if (ast)
if (auto tagDecl = mlir::dyn_cast<mlir::cir::ASTTagDeclInterface>(ast))
kind = getRecordKind(tagDecl.getTagKind());
if (ast) {
astAttr = getAttr<mlir::cir::ASTRecordDeclAttr>(ast);
kind = getRecordKind(ast->getTagKind());
}
return mlir::cir::StructType::get(getContext(), members, nameAttr, body,
packed, kind, ast);
}

mlir::cir::StructType getStructTy(llvm::ArrayRef<mlir::Type> members,
llvm::StringRef name, bool body,
bool packed, const clang::RecordDecl *ast) {
mlir::Attribute astAttr;
if (ast)
astAttr = mlir::cir::makeAstDeclAttr(ast, getContext());
return getStructTy(members, name, body, packed, astAttr);
packed, kind, astAttr);
}

//
Expand Down
4 changes: 1 addition & 3 deletions clang/lib/CIR/CodeGen/CIRGenTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,7 @@ mlir::Type CIRGenTypes::convertRecordDeclType(const clang::RecordDecl *RD) {
// Handle forward decl / incomplete types.
if (!entry) {
auto name = getRecordTypeName(RD, "");
entry =
Builder.getStructTy({}, name, /*body=*/false, /*packed=*/false,
mlir::cir::makeAstDeclAttr(RD, &getMLIRContext()));
entry = Builder.getStructTy({}, name, /*body=*/false, /*packed=*/false, RD);
recordDeclTypes[key] = entry;
}

Expand Down
9 changes: 3 additions & 6 deletions clang/lib/CIR/CodeGen/CIRRecordLayoutBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -605,10 +605,8 @@ CIRGenTypes::computeRecordLayout(const RecordDecl *D,
builder.astRecordLayout.getSize()) {
CIRRecordLowering baseBuilder(*this, D, /*Packed=*/builder.isPacked);
auto baseIdentifier = getRecordTypeName(D, ".base");
*BaseTy =
Builder.getStructTy(baseBuilder.fieldTypes, baseIdentifier,
/*body=*/true, /*packed=*/false,
mlir::cir::makeAstDeclAttr(D, &getMLIRContext()));
*BaseTy = Builder.getStructTy(baseBuilder.fieldTypes, baseIdentifier,
/*body=*/true, /*packed=*/false, D);
// TODO(cir): add something like addRecordTypeName

// BaseTy and Ty must agree on their packedness for getCIRFieldNo to work
Expand All @@ -622,8 +620,7 @@ CIRGenTypes::computeRecordLayout(const RecordDecl *D,
// signifies that the type is no longer opaque and record layout is complete,
// but we may need to recursively layout D while laying D out as a base type.
*Ty = Builder.getStructTy(builder.fieldTypes, getRecordTypeName(D, ""),
/*body=*/true, /*packed=*/false,
mlir::cir::makeAstDeclAttr(D, &getMLIRContext()));
/*body=*/true, /*packed=*/false, D);

auto RL = std::make_unique<CIRGenRecordLayout>(
Ty ? *Ty : mlir::cir::StructType{},
Expand Down
32 changes: 0 additions & 32 deletions clang/lib/CIR/Dialect/IR/CIRAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,38 +49,6 @@ using namespace mlir::cir;
namespace mlir {
namespace cir {

mlir::Attribute makeAstDeclAttr(const clang::Decl *decl,
mlir::MLIRContext *ctx) {
if (auto ast = clang::dyn_cast<clang::CXXConstructorDecl>(decl))
return ASTCXXConstructorDeclAttr::get(ctx, ast);
if (auto ast = clang::dyn_cast<clang::CXXConversionDecl>(decl))
return ASTCXXConversionDeclAttr::get(ctx, ast);
if (auto ast = clang::dyn_cast<clang::CXXDestructorDecl>(decl))
return ASTCXXDestructorDeclAttr::get(ctx, ast);
if (auto ast = clang::dyn_cast<clang::CXXMethodDecl>(decl))
return ASTCXXMethodDeclAttr::get(ctx, ast);
if (auto ast = clang::dyn_cast<clang::FunctionDecl>(decl))
return ASTFunctionDeclAttr::get(ctx, ast);
if (auto ast =
clang::dyn_cast<clang::ClassTemplatePartialSpecializationDecl>(decl))
return ASTClassTemplatePartialSpecializationDeclAttr::get(ctx, ast);
if (auto ast = clang::dyn_cast<clang::ClassTemplateSpecializationDecl>(decl))
return ASTClassTemplateSpecializationDeclAttr::get(ctx, ast);
if (auto ast = clang::dyn_cast<clang::CXXRecordDecl>(decl))
return ASTCXXRecordDeclAttr::get(ctx, ast);
if (auto ast = clang::dyn_cast<clang::RecordDecl>(decl))
return ASTRecordDeclAttr::get(ctx, ast);
if (auto ast = clang::dyn_cast<clang::EnumDecl>(decl))
return ASTEnumDeclAttr::get(ctx, ast);
if (auto ast = clang::dyn_cast<clang::TagDecl>(decl))
return ASTTagDeclAttr::get(ctx, ast);
if (auto ast = clang::dyn_cast<clang::TypeDecl>(decl))
return ASTTypeDeclAttr::get(ctx, ast);
if (auto ast = clang::dyn_cast<clang::VarDecl>(decl))
return ASTVarDeclAttr::get(ctx, ast);
return ASTDeclAttr::get(ctx, decl);
};

mlir::Attribute makeFuncDeclAttr(const clang::Decl *decl,
mlir::MLIRContext *ctx) {
return llvm::TypeSwitch<const clang::Decl *, mlir::Attribute>(decl)
Expand Down
6 changes: 3 additions & 3 deletions clang/lib/CIR/Dialect/IR/CIRTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ Type StructType::parse(mlir::AsmParser &parser) {
return {};

return StructType::get(parser.getContext(), members, id, body, packed, kind,
mlir::Attribute());
std::nullopt);
}

void StructType::print(mlir::AsmPrinter &printer) const {
Expand Down Expand Up @@ -187,9 +187,9 @@ void StructType::print(mlir::AsmPrinter &printer) const {
printer << "}";
}

if (getAst()) {
if (getAst().has_value()) {
printer << " ";
printer.printAttribute(getAst());
printer.printAttribute(getAst().value());
}

printer << '>';
Expand Down
18 changes: 4 additions & 14 deletions clang/lib/CIR/Dialect/Transforms/LifetimeCheck.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -891,12 +891,7 @@ void LifetimeCheckPass::checkIf(IfOp ifOp) {
template <class T> bool isStructAndHasAttr(mlir::Type ty) {
if (!ty.isa<mlir::cir::StructType>())
return false;
auto sTy = ty.cast<mlir::cir::StructType>();
auto recordDecl = sTy.getAst();
if (auto interface = dyn_cast<ASTDeclInterface>(recordDecl))
if (hasAttr<T>(interface))
return true;
return false;
return hasAttr<T>(*mlir::cast<mlir::cir::StructType>(ty).getAst());
}

static bool isOwnerType(mlir::Type ty) {
Expand Down Expand Up @@ -1762,8 +1757,7 @@ bool LifetimeCheckPass::isLambdaType(mlir::Type ty) {
auto taskTy = ty.dyn_cast<mlir::cir::StructType>();
if (!taskTy)
return false;
if (auto recordDecl = dyn_cast<ASTCXXRecordDeclInterface>(taskTy.getAst()))
if (recordDecl.isLambda())
if (taskTy.getAst()->isLambda())
IsLambdaTyCache[ty] = true;

return IsLambdaTyCache[ty];
Expand All @@ -1778,12 +1772,8 @@ bool LifetimeCheckPass::isTaskType(mlir::Value taskVal) {
auto taskTy = taskVal.getType().dyn_cast<mlir::cir::StructType>();
if (!taskTy)
return false;
auto recordDecl = taskTy.getAst();
auto spec = dyn_cast<ASTClassTemplateSpecializationDeclInterface>(recordDecl);
if (!spec)
return false;
return spec.hasPromiseType();
} ();
return taskTy.getAst()->hasPromiseType();
}();

IsTaskTyCache[ty] = result;
return result;
Expand Down
2 changes: 1 addition & 1 deletion clang/test/CIR/CodeGen/bitfields.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,5 @@ void m() {
__long l;
}

// CHECK: !ty_22anon22 = !cir.struct<struct "anon" {!u32i} #cir.cxxrecord.decl.ast>
// CHECK: !ty_22anon22 = !cir.struct<struct "anon" {!u32i} #cir.record.decl.ast>
// CHECK: !ty_22__long22 = !cir.struct<struct "__long" {!ty_22anon22, !u32i, !cir.ptr<!u32i>}>
2 changes: 1 addition & 1 deletion clang/test/CIR/CodeGen/dtors.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class B : public A
};

// Class A
// CHECK: ![[ClassA:ty_.*]] = !cir.struct<class "A" {!cir.ptr<!cir.ptr<!cir.func<!u32i ()>>>} #cir.cxxrecord.decl.ast>
// CHECK: ![[ClassA:ty_.*]] = !cir.struct<class "A" {!cir.ptr<!cir.ptr<!cir.func<!u32i ()>>>} #cir.record.decl.ast>

// Class B
// CHECK: ![[ClassB:ty_.*]] = !cir.struct<class "B" {![[ClassA]]}>
Expand Down
2 changes: 1 addition & 1 deletion clang/test/CIR/CodeGen/struct.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ void yoyo(incomplete *i) {}
// CHECK-DAG: !ty_22Bar22 = !cir.struct<struct "Bar" {!s32i, !s8i}>

// CHECK-DAG: !ty_22Foo22 = !cir.struct<struct "Foo" {!s32i, !s8i, !ty_22Bar22}>
// CHECK-DAG: !ty_22Mandalore22 = !cir.struct<struct "Mandalore" {!u32i, !cir.ptr<!void>, !s32i} #cir.cxxrecord.decl.ast>
// CHECK-DAG: !ty_22Mandalore22 = !cir.struct<struct "Mandalore" {!u32i, !cir.ptr<!void>, !s32i} #cir.record.decl.ast>
// CHECK-DAG: !ty_22Adv22 = !cir.struct<class "Adv" {!ty_22Mandalore22}>
// CHECK-DAG: !ty_22Entry22 = !cir.struct<struct "Entry" {!cir.ptr<!cir.func<!u32i (!s32i, !cir.ptr<!s8i>, !cir.ptr<!void>)>>}>

Expand Down
12 changes: 6 additions & 6 deletions clang/test/CIR/CodeGen/union.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ typedef union { yolo y; struct { int lifecnt; }; } yolm;
typedef union { yolo y; struct { int *lifecnt; int genpad; }; } yolm2;
typedef union { yolo y; struct { bool life; int genpad; }; } yolm3;

// CHECK-DAG: !ty_22U23A3ADummy22 = !cir.struct<struct "U2::Dummy" {!s16i, f32} #cir.cxxrecord.decl.ast>
// CHECK-DAG: !ty_22anon221 = !cir.struct<struct "anon" {!cir.bool, !s32i} #cir.cxxrecord.decl.ast>
// CHECK-DAG: !ty_22yolo22 = !cir.struct<struct "yolo" {!s32i} #cir.cxxrecord.decl.ast>
// CHECK-DAG: !ty_22anon222 = !cir.struct<struct "anon" {!cir.ptr<!s32i>, !s32i} #cir.cxxrecord.decl.ast>
// CHECK-DAG: !ty_22U23A3ADummy22 = !cir.struct<struct "U2::Dummy" {!s16i, f32} #cir.record.decl.ast>
// CHECK-DAG: !ty_22anon221 = !cir.struct<struct "anon" {!cir.bool, !s32i} #cir.record.decl.ast>
// CHECK-DAG: !ty_22yolo22 = !cir.struct<struct "yolo" {!s32i} #cir.record.decl.ast>
// CHECK-DAG: !ty_22anon222 = !cir.struct<struct "anon" {!cir.ptr<!s32i>, !s32i} #cir.record.decl.ast>

// CHECK-DAG: !ty_22yolm22 = !cir.struct<union "yolm" {!ty_22yolo22, !ty_22anon22}>
// CHECK-DAG: !ty_22yolm322 = !cir.struct<union "yolm3" {!ty_22yolo22, !ty_22anon221}>
Expand All @@ -33,14 +33,14 @@ union U2 {
float f;
} s;
} u2;
// CHECK-DAG: !cir.struct<union "U2" {!cir.bool, !ty_22U23A3ADummy22} #cir.cxxrecord.decl.ast>
// CHECK-DAG: !cir.struct<union "U2" {!cir.bool, !ty_22U23A3ADummy22} #cir.record.decl.ast>

// Should genereate unions without padding.
union U3 {
short b;
U u;
} u3;
// CHECK-DAG: !ty_22U322 = !cir.struct<union "U3" {!s16i, !ty_22U22} #cir.cxxrecord.decl.ast>
// CHECK-DAG: !ty_22U322 = !cir.struct<union "U3" {!s16i, !ty_22U22} #cir.record.decl.ast>

void m() {
yolm q;
Expand Down
2 changes: 1 addition & 1 deletion clang/test/CIR/CodeGen/vtable-rtti.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class B : public A
// CHECK: ![[VTableTypeA:ty_.*]] = !cir.struct<struct "" {!cir.array<!cir.ptr<!u8i> x 5>}>

// Class A
// CHECK: ![[ClassA:ty_.*]] = !cir.struct<class "A" {!cir.ptr<!cir.ptr<!cir.func<!u32i ()>>>} #cir.cxxrecord.decl.ast>
// CHECK: ![[ClassA:ty_.*]] = !cir.struct<class "A" {!cir.ptr<!cir.ptr<!cir.func<!u32i ()>>>} #cir.record.decl.ast>

// Class B
// CHECK: ![[ClassB:ty_.*]] = !cir.struct<class "B" {![[ClassA]]}>
Expand Down
2 changes: 1 addition & 1 deletion clang/test/CIR/IR/global.cir
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
!s8i = !cir.int<s, 8>
!s32i = !cir.int<s, 32>
!s64i = !cir.int<s, 64>
!ty_22Init22 = !cir.struct<class "Init" {!s8i} #cir.cxxrecord.decl.ast>
!ty_22Init22 = !cir.struct<class "Init" {!s8i} #cir.record.decl.ast>
module {
cir.global external @a = #cir.int<3> : !s32i
cir.global external @rgb = #cir.const_array<[#cir.int<0> : !s8i, #cir.int<-23> : !s8i, #cir.int<33> : !s8i] : !cir.array<!s8i x 3>>
Expand Down
Loading

0 comments on commit 6ede214

Please sign in to comment.