|
| 1 | +//===- OpDefinitionsGen.cpp - IRDL op definitions generator ---------------===// |
| 2 | +// |
| 3 | +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | +// See https://llvm.org/LICENSE.txt for license information. |
| 5 | +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | +// |
| 7 | +//===----------------------------------------------------------------------===// |
| 8 | +// |
| 9 | +// OpDefinitionsGen uses the description of operations to generate IRDL |
| 10 | +// definitions for ops. |
| 11 | +// |
| 12 | +//===----------------------------------------------------------------------===// |
| 13 | + |
| 14 | +#include "mlir/Dialect/IRDL/IR/IRDL.h" |
| 15 | +#include "mlir/IR/Attributes.h" |
| 16 | +#include "mlir/IR/Builders.h" |
| 17 | +#include "mlir/IR/BuiltinOps.h" |
| 18 | +#include "mlir/IR/Diagnostics.h" |
| 19 | +#include "mlir/IR/Dialect.h" |
| 20 | +#include "mlir/IR/MLIRContext.h" |
| 21 | +#include "mlir/TableGen/AttrOrTypeDef.h" |
| 22 | +#include "mlir/TableGen/GenInfo.h" |
| 23 | +#include "mlir/TableGen/GenNameParser.h" |
| 24 | +#include "mlir/TableGen/Interfaces.h" |
| 25 | +#include "mlir/TableGen/Operator.h" |
| 26 | +#include "llvm/Support/CommandLine.h" |
| 27 | +#include "llvm/Support/InitLLVM.h" |
| 28 | +#include "llvm/Support/raw_ostream.h" |
| 29 | +#include "llvm/TableGen/Main.h" |
| 30 | +#include "llvm/TableGen/Record.h" |
| 31 | +#include "llvm/TableGen/TableGenBackend.h" |
| 32 | + |
| 33 | +using namespace llvm; |
| 34 | +using namespace mlir; |
| 35 | +using tblgen::NamedTypeConstraint; |
| 36 | + |
| 37 | +static llvm::cl::OptionCategory dialectGenCat("Options for -gen-irdl-dialect"); |
| 38 | +llvm::cl::opt<std::string> |
| 39 | + selectedDialect("dialect", llvm::cl::desc("The dialect to gen for"), |
| 40 | + llvm::cl::cat(dialectGenCat), llvm::cl::Required); |
| 41 | + |
| 42 | +irdl::CPredOp createConstraint(OpBuilder &builder, |
| 43 | + NamedTypeConstraint namedConstraint) { |
| 44 | + MLIRContext *ctx = builder.getContext(); |
| 45 | + // Build the constraint as a string. |
| 46 | + std::string constraint = |
| 47 | + namedConstraint.constraint.getPredicate().getCondition(); |
| 48 | + // Build a CPredOp to match the C constraint built. |
| 49 | + irdl::CPredOp op = builder.create<irdl::CPredOp>( |
| 50 | + UnknownLoc::get(ctx), StringAttr::get(ctx, constraint)); |
| 51 | + return op; |
| 52 | +} |
| 53 | + |
| 54 | +/// Returns the name of the operation without the dialect prefix. |
| 55 | +static StringRef getOperatorName(tblgen::Operator &tblgenOp) { |
| 56 | + StringRef opName = tblgenOp.getDef().getValueAsString("opName"); |
| 57 | + return opName; |
| 58 | +} |
| 59 | + |
| 60 | +/// Extract an operation to IRDL. |
| 61 | +irdl::OperationOp createIRDLOperation(OpBuilder &builder, |
| 62 | + tblgen::Operator &tblgenOp) { |
| 63 | + MLIRContext *ctx = builder.getContext(); |
| 64 | + StringRef opName = getOperatorName(tblgenOp); |
| 65 | + |
| 66 | + irdl::OperationOp op = builder.create<irdl::OperationOp>( |
| 67 | + UnknownLoc::get(ctx), StringAttr::get(ctx, opName)); |
| 68 | + |
| 69 | + // Add the block in the region. |
| 70 | + Block &opBlock = op.getBody().emplaceBlock(); |
| 71 | + OpBuilder consBuilder = OpBuilder::atBlockBegin(&opBlock); |
| 72 | + |
| 73 | + auto getValues = [&](tblgen::Operator::const_value_range namedCons) { |
| 74 | + SmallVector<Value> operands; |
| 75 | + SmallVector<irdl::VariadicityAttr> variadicity; |
| 76 | + for (const NamedTypeConstraint &namedCons : namedCons) { |
| 77 | + auto operand = createConstraint(consBuilder, namedCons); |
| 78 | + operands.push_back(operand); |
| 79 | + |
| 80 | + irdl::VariadicityAttr var; |
| 81 | + if (namedCons.isOptional()) |
| 82 | + var = consBuilder.getAttr<irdl::VariadicityAttr>( |
| 83 | + irdl::Variadicity::optional); |
| 84 | + else if (namedCons.isVariadic()) |
| 85 | + var = consBuilder.getAttr<irdl::VariadicityAttr>( |
| 86 | + irdl::Variadicity::variadic); |
| 87 | + else |
| 88 | + var = consBuilder.getAttr<irdl::VariadicityAttr>( |
| 89 | + irdl::Variadicity::single); |
| 90 | + |
| 91 | + variadicity.push_back(var); |
| 92 | + } |
| 93 | + return std::make_tuple(operands, variadicity); |
| 94 | + }; |
| 95 | + |
| 96 | + auto [operands, operandVariadicity] = getValues(tblgenOp.getOperands()); |
| 97 | + auto [results, resultVariadicity] = getValues(tblgenOp.getResults()); |
| 98 | + |
| 99 | + // Create the operands and results operations. |
| 100 | + consBuilder.create<irdl::OperandsOp>(UnknownLoc::get(ctx), operands, |
| 101 | + operandVariadicity); |
| 102 | + consBuilder.create<irdl::ResultsOp>(UnknownLoc::get(ctx), results, |
| 103 | + resultVariadicity); |
| 104 | + |
| 105 | + return op; |
| 106 | +} |
| 107 | + |
| 108 | +static irdl::DialectOp createIRDLDialect(OpBuilder &builder) { |
| 109 | + MLIRContext *ctx = builder.getContext(); |
| 110 | + return builder.create<irdl::DialectOp>(UnknownLoc::get(ctx), |
| 111 | + StringAttr::get(ctx, selectedDialect)); |
| 112 | +} |
| 113 | + |
| 114 | +static std::vector<llvm::Record *> |
| 115 | +getOpDefinitions(const RecordKeeper &recordKeeper) { |
| 116 | + if (!recordKeeper.getClass("Op")) |
| 117 | + return {}; |
| 118 | + return recordKeeper.getAllDerivedDefinitions("Op"); |
| 119 | +} |
| 120 | + |
| 121 | +static bool emitDialectIRDLDefs(const RecordKeeper &recordKeeper, |
| 122 | + raw_ostream &os) { |
| 123 | + // Initialize. |
| 124 | + MLIRContext ctx; |
| 125 | + ctx.getOrLoadDialect<irdl::IRDLDialect>(); |
| 126 | + OpBuilder builder(&ctx); |
| 127 | + |
| 128 | + // Create a module op and set it as the insertion point. |
| 129 | + ModuleOp module = builder.create<ModuleOp>(UnknownLoc::get(&ctx)); |
| 130 | + builder = builder.atBlockBegin(module.getBody()); |
| 131 | + // Create the dialect and insert it. |
| 132 | + irdl::DialectOp dialect = createIRDLDialect(builder); |
| 133 | + // Set insertion point to start of DialectOp. |
| 134 | + builder = builder.atBlockBegin(&dialect.getBody().emplaceBlock()); |
| 135 | + |
| 136 | + std::vector<Record *> defs = getOpDefinitions(recordKeeper); |
| 137 | + for (auto *def : defs) { |
| 138 | + tblgen::Operator tblgenOp(def); |
| 139 | + if (tblgenOp.getDialectName() != selectedDialect) |
| 140 | + continue; |
| 141 | + |
| 142 | + createIRDLOperation(builder, tblgenOp); |
| 143 | + } |
| 144 | + |
| 145 | + // Print the module. |
| 146 | + module.print(os); |
| 147 | + |
| 148 | + return false; |
| 149 | +} |
| 150 | + |
| 151 | +static mlir::GenRegistration |
| 152 | + genOpDefs("gen-dialect-irdl-defs", "Generate IRDL dialect definitions", |
| 153 | + [](const RecordKeeper &records, raw_ostream &os) { |
| 154 | + return emitDialectIRDLDefs(records, os); |
| 155 | + }); |
0 commit comments