Skip to content

Commit e6e9beb

Browse files
authored
[mlir][tools] Introduce tblgen-to-irdl tool (#66865)
RFC: https://discourse.llvm.org/t/rfc-tblgen-to-irdl-tool/73578
1 parent be383de commit e6e9beb

File tree

6 files changed

+260
-0
lines changed

6 files changed

+260
-0
lines changed

mlir/test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ set(MLIR_TEST_DEPENDS
108108
mlir-tblgen
109109
mlir-translate
110110
tblgen-lsp-server
111+
tblgen-to-irdl
111112
)
112113

113114
# The native target may not be enabled, in this case we won't
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
// RUN: tblgen-to-irdl %s -I=%S/../../include --gen-dialect-irdl-defs --dialect=cmath | FileCheck %s
2+
3+
include "mlir/IR/OpBase.td"
4+
include "mlir/IR/AttrTypeBase.td"
5+
6+
// CHECK-LABEL: irdl.dialect @cmath {
7+
def CMath_Dialect : Dialect {
8+
let name = "cmath";
9+
}
10+
11+
class CMath_Type<string name, string typeMnemonic, list<Trait> traits = []>
12+
: TypeDef<CMath_Dialect, name, traits> {
13+
let mnemonic = typeMnemonic;
14+
}
15+
16+
class CMath_Op<string mnemonic, list<Trait> traits = []>
17+
: Op<CMath_Dialect, mnemonic, traits>;
18+
19+
def f32Orf64Type : Or<[CPred<"::llvm::isa<::mlir::F32>">,
20+
CPred<"::llvm::isa<::mlir::F64>">]>;
21+
22+
def CMath_ComplexType : CMath_Type<"ComplexType", "complex"> {
23+
let parameters = (ins f32Orf64Type:$elementType);
24+
}
25+
26+
// CHECK: irdl.operation @identity {
27+
// CHECK-NEXT: %0 = irdl.c_pred "(::llvm::isa<cmath::ComplexTypeType>($_self))"
28+
// CHECK-NEXT: irdl.operands()
29+
// CHECK-NEXT: irdl.results(%0)
30+
// CHECK-NEXT: }
31+
def CMath_IdentityOp : CMath_Op<"identity"> {
32+
let results = (outs CMath_ComplexType:$out);
33+
}
34+
35+
// CHECK: irdl.operation @mul {
36+
// CHECK-NEXT: %0 = irdl.c_pred "(::llvm::isa<cmath::ComplexTypeType>($_self))"
37+
// CHECK-NEXT: %1 = irdl.c_pred "(::llvm::isa<cmath::ComplexTypeType>($_self))"
38+
// CHECK-NEXT: %2 = irdl.c_pred "(::llvm::isa<cmath::ComplexTypeType>($_self))"
39+
// CHECK-NEXT: irdl.operands(%0, %1)
40+
// CHECK-NEXT: irdl.results(%2)
41+
// CHECK-NEXT: }
42+
def CMath_MulOp : CMath_Op<"mul"> {
43+
let arguments = (ins CMath_ComplexType:$in1, CMath_ComplexType:$in2);
44+
let results = (outs CMath_ComplexType:$out);
45+
}
46+
47+
// CHECK: irdl.operation @norm {
48+
// CHECK-NEXT: %0 = irdl.c_pred "(true)"
49+
// CHECK-NEXT: %1 = irdl.c_pred "(::llvm::isa<cmath::ComplexTypeType>($_self))"
50+
// CHECK-NEXT: irdl.operands(%0)
51+
// CHECK-NEXT: irdl.results(%1)
52+
// CHECK-NEXT: }
53+
def CMath_NormOp : CMath_Op<"norm"> {
54+
let arguments = (ins AnyType:$in);
55+
let results = (outs CMath_ComplexType:$out);
56+
}

mlir/tools/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ add_subdirectory(mlir-spirv-cpu-runner)
88
add_subdirectory(mlir-translate)
99
add_subdirectory(mlir-vulkan-runner)
1010
add_subdirectory(tblgen-lsp-server)
11+
add_subdirectory(tblgen-to-irdl)
1112

1213
# mlir-cpu-runner requires ExecutionEngine.
1314
if(MLIR_ENABLE_EXECUTION_ENGINE)
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
set(LLVM_LINK_COMPONENTS
2+
TableGen
3+
)
4+
5+
add_tablegen(tblgen-to-irdl MLIR
6+
DESTINATION "${MLIR_TOOLS_INSTALL_DIR}"
7+
EXPORT MLIR
8+
tblgen-to-irdl.cpp
9+
OpDefinitionsGen.cpp
10+
)
11+
12+
target_link_libraries(tblgen-to-irdl
13+
PRIVATE
14+
MLIRIR
15+
MLIRIRDL
16+
MLIRTblgenLib
17+
MLIRSupport
18+
)
19+
20+
mlir_check_all_link_libraries(tblgen-to-irdl)
Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
//===- OpDefinitionsGen.cpp - IRDL op definitions generator ---------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// OpDefinitionsGen uses the description of operations to generate IRDL
10+
// definitions for ops.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include "mlir/Dialect/IRDL/IR/IRDL.h"
15+
#include "mlir/IR/Attributes.h"
16+
#include "mlir/IR/Builders.h"
17+
#include "mlir/IR/BuiltinOps.h"
18+
#include "mlir/IR/Diagnostics.h"
19+
#include "mlir/IR/Dialect.h"
20+
#include "mlir/IR/MLIRContext.h"
21+
#include "mlir/TableGen/AttrOrTypeDef.h"
22+
#include "mlir/TableGen/GenInfo.h"
23+
#include "mlir/TableGen/GenNameParser.h"
24+
#include "mlir/TableGen/Interfaces.h"
25+
#include "mlir/TableGen/Operator.h"
26+
#include "llvm/Support/CommandLine.h"
27+
#include "llvm/Support/InitLLVM.h"
28+
#include "llvm/Support/raw_ostream.h"
29+
#include "llvm/TableGen/Main.h"
30+
#include "llvm/TableGen/Record.h"
31+
#include "llvm/TableGen/TableGenBackend.h"
32+
33+
using namespace llvm;
34+
using namespace mlir;
35+
using tblgen::NamedTypeConstraint;
36+
37+
static llvm::cl::OptionCategory dialectGenCat("Options for -gen-irdl-dialect");
38+
llvm::cl::opt<std::string>
39+
selectedDialect("dialect", llvm::cl::desc("The dialect to gen for"),
40+
llvm::cl::cat(dialectGenCat), llvm::cl::Required);
41+
42+
irdl::CPredOp createConstraint(OpBuilder &builder,
43+
NamedTypeConstraint namedConstraint) {
44+
MLIRContext *ctx = builder.getContext();
45+
// Build the constraint as a string.
46+
std::string constraint =
47+
namedConstraint.constraint.getPredicate().getCondition();
48+
// Build a CPredOp to match the C constraint built.
49+
irdl::CPredOp op = builder.create<irdl::CPredOp>(
50+
UnknownLoc::get(ctx), StringAttr::get(ctx, constraint));
51+
return op;
52+
}
53+
54+
/// Returns the name of the operation without the dialect prefix.
55+
static StringRef getOperatorName(tblgen::Operator &tblgenOp) {
56+
StringRef opName = tblgenOp.getDef().getValueAsString("opName");
57+
return opName;
58+
}
59+
60+
/// Extract an operation to IRDL.
61+
irdl::OperationOp createIRDLOperation(OpBuilder &builder,
62+
tblgen::Operator &tblgenOp) {
63+
MLIRContext *ctx = builder.getContext();
64+
StringRef opName = getOperatorName(tblgenOp);
65+
66+
irdl::OperationOp op = builder.create<irdl::OperationOp>(
67+
UnknownLoc::get(ctx), StringAttr::get(ctx, opName));
68+
69+
// Add the block in the region.
70+
Block &opBlock = op.getBody().emplaceBlock();
71+
OpBuilder consBuilder = OpBuilder::atBlockBegin(&opBlock);
72+
73+
auto getValues = [&](tblgen::Operator::const_value_range namedCons) {
74+
SmallVector<Value> operands;
75+
SmallVector<irdl::VariadicityAttr> variadicity;
76+
for (const NamedTypeConstraint &namedCons : namedCons) {
77+
auto operand = createConstraint(consBuilder, namedCons);
78+
operands.push_back(operand);
79+
80+
irdl::VariadicityAttr var;
81+
if (namedCons.isOptional())
82+
var = consBuilder.getAttr<irdl::VariadicityAttr>(
83+
irdl::Variadicity::optional);
84+
else if (namedCons.isVariadic())
85+
var = consBuilder.getAttr<irdl::VariadicityAttr>(
86+
irdl::Variadicity::variadic);
87+
else
88+
var = consBuilder.getAttr<irdl::VariadicityAttr>(
89+
irdl::Variadicity::single);
90+
91+
variadicity.push_back(var);
92+
}
93+
return std::make_tuple(operands, variadicity);
94+
};
95+
96+
auto [operands, operandVariadicity] = getValues(tblgenOp.getOperands());
97+
auto [results, resultVariadicity] = getValues(tblgenOp.getResults());
98+
99+
// Create the operands and results operations.
100+
consBuilder.create<irdl::OperandsOp>(UnknownLoc::get(ctx), operands,
101+
operandVariadicity);
102+
consBuilder.create<irdl::ResultsOp>(UnknownLoc::get(ctx), results,
103+
resultVariadicity);
104+
105+
return op;
106+
}
107+
108+
static irdl::DialectOp createIRDLDialect(OpBuilder &builder) {
109+
MLIRContext *ctx = builder.getContext();
110+
return builder.create<irdl::DialectOp>(UnknownLoc::get(ctx),
111+
StringAttr::get(ctx, selectedDialect));
112+
}
113+
114+
static std::vector<llvm::Record *>
115+
getOpDefinitions(const RecordKeeper &recordKeeper) {
116+
if (!recordKeeper.getClass("Op"))
117+
return {};
118+
return recordKeeper.getAllDerivedDefinitions("Op");
119+
}
120+
121+
static bool emitDialectIRDLDefs(const RecordKeeper &recordKeeper,
122+
raw_ostream &os) {
123+
// Initialize.
124+
MLIRContext ctx;
125+
ctx.getOrLoadDialect<irdl::IRDLDialect>();
126+
OpBuilder builder(&ctx);
127+
128+
// Create a module op and set it as the insertion point.
129+
ModuleOp module = builder.create<ModuleOp>(UnknownLoc::get(&ctx));
130+
builder = builder.atBlockBegin(module.getBody());
131+
// Create the dialect and insert it.
132+
irdl::DialectOp dialect = createIRDLDialect(builder);
133+
// Set insertion point to start of DialectOp.
134+
builder = builder.atBlockBegin(&dialect.getBody().emplaceBlock());
135+
136+
std::vector<Record *> defs = getOpDefinitions(recordKeeper);
137+
for (auto *def : defs) {
138+
tblgen::Operator tblgenOp(def);
139+
if (tblgenOp.getDialectName() != selectedDialect)
140+
continue;
141+
142+
createIRDLOperation(builder, tblgenOp);
143+
}
144+
145+
// Print the module.
146+
module.print(os);
147+
148+
return false;
149+
}
150+
151+
static mlir::GenRegistration
152+
genOpDefs("gen-dialect-irdl-defs", "Generate IRDL dialect definitions",
153+
[](const RecordKeeper &records, raw_ostream &os) {
154+
return emitDialectIRDLDefs(records, os);
155+
});
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
//===- mlir-tblgen.cpp - Top-Level TableGen implementation for MLIR -------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file contains the main function for MLIR's TableGen IRDL backend.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "mlir/TableGen/GenInfo.h"
14+
#include "mlir/Tools/mlir-tblgen/MlirTblgenMain.h"
15+
#include "llvm/TableGen/Record.h"
16+
17+
using namespace llvm;
18+
using namespace mlir;
19+
20+
// Generator that prints records.
21+
GenRegistration printRecords("print-records", "Print all records to stdout",
22+
[](const RecordKeeper &records, raw_ostream &os) {
23+
os << records;
24+
return false;
25+
});
26+
27+
int main(int argc, char **argv) { return MlirTblgenMain(argc, argv); }

0 commit comments

Comments
 (0)