Skip to content

Commit b409b36

Browse files
committed
Allow passing a dictionary of options, including params as values
1 parent a96f40d commit b409b36

File tree

10 files changed

+384
-136
lines changed

10 files changed

+384
-136
lines changed

mlir/include/mlir/Dialect/Transform/IR/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@ mlir_tablegen(TransformDialectEnums.h.inc -gen-enum-decls)
2020
mlir_tablegen(TransformDialectEnums.cpp.inc -gen-enum-defs)
2121
add_public_tablegen_target(MLIRTransformDialectEnumIncGen)
2222
add_dependencies(mlir-headers MLIRTransformDialectEnumIncGen)
23+
mlir_tablegen(TransformAttrs.h.inc -gen-attrdef-decls)
24+
mlir_tablegen(TransformAttrs.cpp.inc -gen-attrdef-defs)
25+
add_public_tablegen_target(MLIRTransformDialectAttributesIncGen)
26+
add_dependencies(mlir-headers MLIRTransformDialectAttributesIncGen)
2327

2428
add_mlir_dialect(TransformOps transform)
2529
add_mlir_doc(TransformOps TransformOps Dialects/ -gen-op-doc -dialect=transform)

mlir/include/mlir/Dialect/Transform/IR/TransformAttrs.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,7 @@
1717

1818
#include "mlir/Dialect/Transform/IR/TransformDialectEnums.h.inc"
1919

20+
#define GET_ATTRDEF_CLASSES
21+
#include "mlir/Dialect/Transform/IR/TransformAttrs.h.inc"
22+
2023
#endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORMATTRS_H

mlir/include/mlir/Dialect/Transform/IR/TransformAttrs.td

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,14 @@
1010
#define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMATTRS
1111

1212
include "mlir/IR/EnumAttr.td"
13+
include "mlir/Dialect/Transform/IR/TransformDialect.td"
14+
15+
class Transform_Attr<string name, string attrMnemonic,
16+
list<Trait> traits = [],
17+
string baseCppClass = "::mlir::Attribute">
18+
: AttrDef<Transform_Dialect, name, traits, baseCppClass> {
19+
let mnemonic = attrMnemonic;
20+
}
1321

1422
def PropagateFailuresCase : I32EnumAttrCase<"Propagate", 1, "propagate">;
1523
def SuppressFailuresCase : I32EnumAttrCase<"Suppress", 2, "suppress">;
@@ -33,4 +41,17 @@ def MatchCmpIPredicateAttr : I32EnumAttr<
3341
let cppNamespace = "::mlir::transform";
3442
}
3543

44+
def ParamOperandIndexAttr : Transform_Attr<"ParamOperandIndex",
45+
"param_operand_index" > {
46+
let mnemonic = "param_operand_index";
47+
let description = [{
48+
Used to refer to a specific param-operand (via its index) from within an
49+
attribute on a transform operation.
50+
}];
51+
let parameters = (ins
52+
"IntegerAttr":$index
53+
);
54+
let assemblyFormat = "`<` $index `>`";
55+
}
56+
3657
#endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORMATTRS

mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def Transform_Dialect : Dialect {
1919
let cppNamespace = "::mlir::transform";
2020

2121
let hasOperationAttrVerify = 1;
22+
let useDefaultAttributePrinterParser = 1;
2223
let extraClassDeclaration = [{
2324
/// Symbol name for the default entry point "named sequence".
2425
constexpr const static ::llvm::StringLiteral

mlir/include/mlir/Dialect/Transform/IR/TransformOps.td

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -405,10 +405,23 @@ def ApplyRegisteredPassOp : TransformDialectOp<"apply_registered_pass",
405405
let description = [{
406406
This transform applies the specified pass or pass pipeline to the targeted
407407
ops. The name of the pass/pipeline is specified as a string attribute, as
408-
set during pass/pipeline registration. Optionally, pass options may be
409-
specified as (space-separated) string attributes with the option to pass
410-
these attributes via params. The pass options syntax is identical to the one
411-
used with "mlir-opt".
408+
set during pass/pipeline registration.
409+
410+
Optionally, pass options may be specified via a DictionaryAttr. This
411+
dictionary is converted to a string -- formatted `key=value ...` -- which
412+
is expected to be in the exact format used by the pass on the commandline.
413+
Values are either attributes or (SSA-values of) Transform Dialect params.
414+
For example:
415+
416+
```mlir
417+
transform.apply_registered_pass "canonicalize"
418+
with options = { "top-down" = false,
419+
"max-iterations" = %max_iter,
420+
"test-convergence" = true,
421+
"max-num-rewrites" = %max_rewrites }
422+
to %module
423+
: (!transform.any_param, !transform.any_param, !transform.any_op) -> !transform.any_op
424+
```
412425

413426
This op first looks for a pass pipeline with the specified name. If no such
414427
pipeline exists, it looks for a pass with the specified name. If no such
@@ -422,7 +435,7 @@ def ApplyRegisteredPassOp : TransformDialectOp<"apply_registered_pass",
422435
}];
423436

424437
let arguments = (ins StrAttr:$pass_name,
425-
DefaultValuedAttr<ArrayAttr, "{}">:$options,
438+
DefaultValuedAttr<DictionaryAttr, "{}">:$options,
426439
Variadic<TransformParamTypeInterface>:$dynamic_options,
427440
TransformHandleTypeInterface:$target);
428441
let results = (outs TransformHandleTypeInterface:$result);

mlir/lib/Dialect/Transform/IR/TransformDialect.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,22 @@
88

99
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
1010
#include "mlir/Analysis/CallGraph.h"
11+
#include "mlir/Dialect/Transform/IR/TransformAttrs.h"
1112
#include "mlir/Dialect/Transform/IR/TransformOps.h"
1213
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
1314
#include "mlir/Dialect/Transform/IR/Utils.h"
1415
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
1516
#include "mlir/IR/DialectImplementation.h"
1617
#include "llvm/ADT/SCCIterator.h"
18+
#include "llvm/ADT/TypeSwitch.h"
1719

1820
using namespace mlir;
1921

2022
#include "mlir/Dialect/Transform/IR/TransformDialect.cpp.inc"
2123

24+
#define GET_ATTRDEF_CLASSES
25+
#include "mlir/Dialect/Transform/IR/TransformAttrs.cpp.inc"
26+
2227
#ifndef NDEBUG
2328
void transform::detail::checkImplementsTransformOpInterface(
2429
StringRef name, MLIRContext *context) {
@@ -66,6 +71,10 @@ void transform::TransformDialect::initialize() {
6671
#include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
6772
>();
6873
initializeTypes();
74+
addAttributes<
75+
#define GET_ATTRDEF_LIST
76+
#include "mlir/Dialect/Transform/IR/TransformAttrs.cpp.inc"
77+
>();
6978
initializeLibraryModule();
7079
}
7180

mlir/lib/Dialect/Transform/IR/TransformOps.cpp

Lines changed: 130 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,11 @@
5454
using namespace mlir;
5555

5656
static ParseResult parseApplyRegisteredPassOptions(
57-
OpAsmParser &parser, ArrayAttr &options,
57+
OpAsmParser &parser, DictionaryAttr &options,
5858
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &dynamicOptions);
5959
static void printApplyRegisteredPassOptions(OpAsmPrinter &printer,
60-
Operation *op, ArrayAttr options,
60+
Operation *op,
61+
DictionaryAttr options,
6162
ValueRange dynamicOptions);
6263
static ParseResult parseSequenceOpOperands(
6364
OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
@@ -784,41 +785,50 @@ DiagnosedSilenceableFailure
784785
transform::ApplyRegisteredPassOp::apply(transform::TransformRewriter &rewriter,
785786
transform::TransformResults &results,
786787
transform::TransformState &state) {
787-
// Obtain a single options-string from options passed statically as
788-
// string attributes as well as "dynamically" through params.
788+
// Obtain a single options-string to pass to the pass(-pipeline) from options
789+
// passed in as a dictionary of keys mapping to values which are either
790+
// attributes or param-operands pointing to attributes.
791+
789792
std::string options;
793+
llvm::raw_string_ostream optionsStream(options); // For "printing" attrs.
794+
790795
OperandRange dynamicOptions = getDynamicOptions();
791-
size_t dynamicOptionsIdx = 0;
792-
for (auto [idx, optionAttr] : llvm::enumerate(getOptions())) {
796+
for (auto [idx, namedAttribute] : llvm::enumerate(getOptions())) {
793797
if (idx > 0)
794-
options += " "; // Interleave options seperator.
795-
796-
if (auto strAttr = dyn_cast<StringAttr>(optionAttr)) {
797-
options += strAttr.getValue();
798-
} else if (isa<UnitAttr>(optionAttr)) {
799-
assert(dynamicOptionsIdx < dynamicOptions.size() &&
798+
optionsStream << " "; // Interleave options separator.
799+
optionsStream << namedAttribute.getName().str(); // Append the key.
800+
optionsStream << "="; // And the key-value separator.
801+
802+
Attribute valueAttrToAppend;
803+
if (auto paramOperandIndex = dyn_cast<transform::ParamOperandIndexAttr>(
804+
namedAttribute.getValue())) {
805+
// The corresponding value attribute is passed in via a param.
806+
// Obtain the param-operand via its specified index.
807+
size_t dynamicOptionIdx = paramOperandIndex.getIndex().getInt();
808+
assert(dynamicOptionIdx < dynamicOptions.size() &&
800809
"number of dynamic option markers (UnitAttr) in options ArrayAttr "
801810
"should be the same as the number of options passed as params");
802811
ArrayRef<Attribute> dynamicOption =
803-
state.getParams(dynamicOptions[dynamicOptionsIdx++]);
812+
state.getParams(dynamicOptions[dynamicOptionIdx]);
804813
if (dynamicOption.size() != 1)
805-
return emitSilenceableError() << "options passed as a param must have "
806-
"a single value associated, param "
807-
<< dynamicOptionsIdx - 1 << " associates "
808-
<< dynamicOption.size();
809-
810-
if (auto dynamicOptionStr = dyn_cast<StringAttr>(dynamicOption[0])) {
811-
options += dynamicOptionStr.getValue();
812-
} else {
813814
return emitSilenceableError()
814-
<< "options passed as a param must be a string, got "
815-
<< dynamicOption[0];
816-
}
815+
<< "options passed as a param must have "
816+
"a single value associated, param "
817+
<< dynamicOptionIdx << " associates " << dynamicOption.size();
818+
valueAttrToAppend = dynamicOption[0];
819+
} else {
820+
// Value is a static attribute.
821+
valueAttrToAppend = namedAttribute.getValue();
822+
}
823+
824+
// Append string representation of value attribute.
825+
if (auto strAttr = dyn_cast<StringAttr>(valueAttrToAppend)) {
826+
optionsStream << strAttr.getValue().str();
817827
} else {
818-
llvm_unreachable(
819-
"expected options element to be either StringAttr or UnitAttr");
828+
valueAttrToAppend.print(optionsStream, /*elideType=*/true);
820829
}
821830
}
831+
optionsStream.flush();
822832

823833
// Get pass or pass pipeline from registry.
824834
const PassRegistryEntry *info = PassPipelineInfo::lookup(getPassName());
@@ -864,84 +874,116 @@ transform::ApplyRegisteredPassOp::apply(transform::TransformRewriter &rewriter,
864874
}
865875

866876
static ParseResult parseApplyRegisteredPassOptions(
867-
OpAsmParser &parser, ArrayAttr &options,
877+
OpAsmParser &parser, DictionaryAttr &options,
868878
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &dynamicOptions) {
869-
auto dynamicOptionMarker = UnitAttr::get(parser.getContext());
870-
SmallVector<Attribute> optionsArray;
871-
872-
auto parseOperandOrString = [&]() -> OptionalParseResult {
873-
OpAsmParser::UnresolvedOperand operand;
874-
OptionalParseResult parsedOperand = parser.parseOptionalOperand(operand);
875-
if (parsedOperand.has_value()) {
876-
if (failed(parsedOperand.value()))
877-
return failure();
878-
879-
dynamicOptions.push_back(operand);
880-
optionsArray.push_back(
881-
dynamicOptionMarker); // Placeholder for knowing where to
882-
// inject the dynamic option-as-param.
883-
return success();
884-
}
879+
// Construct the options DictionaryAttr per a `{ key = value, ... }` syntax.
880+
SmallVector<NamedAttribute> keyValuePairs;
885881

886-
StringAttr stringAttr;
887-
OptionalParseResult parsedStringAttr =
888-
parser.parseOptionalAttribute(stringAttr);
889-
if (parsedStringAttr.has_value()) {
890-
if (failed(parsedStringAttr.value()))
891-
return failure();
892-
optionsArray.push_back(stringAttr);
893-
return success();
894-
}
882+
size_t dynamicOptionsIdx = 0;
883+
auto parseKeyValuePair = [&]() -> ParseResult {
884+
// Parse items of the form `key = value` where `key` is a bare identifier or
885+
// a string and `value` is either an attribute or an operand.
886+
887+
std::string key;
888+
Attribute valueAttr;
889+
if (parser.parseOptionalKeywordOrString(&key))
890+
return parser.emitError(parser.getCurrentLocation())
891+
<< "expected key to either be an identifier or a string";
892+
if (key.empty())
893+
return failure();
895894

896-
return std::nullopt;
895+
if (parser.parseEqual())
896+
return parser.emitError(parser.getCurrentLocation())
897+
<< "expected '=' after key in key-value pair";
898+
899+
// Parse the value, which can be either an attribute or an operand.
900+
OptionalParseResult parsedValueAttr =
901+
parser.parseOptionalAttribute(valueAttr);
902+
if (!parsedValueAttr.has_value()) {
903+
OpAsmParser::UnresolvedOperand operand;
904+
ParseResult parsedOperand = parser.parseOperand(operand);
905+
if (failed(parsedOperand))
906+
return parser.emitError(parser.getCurrentLocation())
907+
<< "expected a valid attribute or operand as value associated "
908+
<< "to key '" << key << "'";
909+
dynamicOptions.push_back(operand);
910+
auto wrappedIndex = IntegerAttr::get(
911+
IntegerType::get(parser.getContext(), 64), dynamicOptionsIdx++);
912+
valueAttr = transform::ParamOperandIndexAttr::get(parser.getContext(),
913+
wrappedIndex);
914+
} else if (failed(parsedValueAttr.value())) {
915+
return failure(); // NB: Attempted parse should have output error message.
916+
} else if (isa<transform::ParamOperandIndexAttr>(valueAttr)) {
917+
return parser.emitError(parser.getCurrentLocation())
918+
<< "the param_operand_index attribute is a marker reserved for "
919+
<< "indicating a value will be passed via params and is only used "
920+
<< "in the generic print format";
921+
}
922+
923+
keyValuePairs.push_back(NamedAttribute(key, valueAttr));
924+
return success();
897925
};
898926

899-
OptionalParseResult parsedOptionsElement = parseOperandOrString();
900-
while (parsedOptionsElement.has_value()) {
901-
if (failed(parsedOptionsElement.value()))
902-
return failure();
903-
parsedOptionsElement = parseOperandOrString();
904-
}
927+
if (parser.parseCommaSeparatedList(AsmParser::Delimiter::Braces,
928+
parseKeyValuePair,
929+
" in options dictionary"))
930+
return failure(); // NB: Attempted parse should have output error message.
905931

906-
if (optionsArray.empty()) {
932+
if (DictionaryAttr::findDuplicate(
933+
keyValuePairs, /*isSorted=*/false) // Also sorts the keyValuePairs.
934+
.has_value())
907935
return parser.emitError(parser.getCurrentLocation())
908-
<< "expected at least one option (either a string or a param)";
909-
}
910-
options = parser.getBuilder().getArrayAttr(optionsArray);
936+
<< "duplicate keys found in options dictionary";
937+
938+
options = DictionaryAttr::getWithSorted(parser.getContext(), keyValuePairs);
939+
911940
return success();
912941
}
913942

914943
static void printApplyRegisteredPassOptions(OpAsmPrinter &printer,
915-
Operation *op, ArrayAttr options,
944+
Operation *op,
945+
DictionaryAttr options,
916946
ValueRange dynamicOptions) {
917-
size_t currentDynamicOptionIdx = 0;
918-
for (auto [idx, optionAttr] : llvm::enumerate(options)) {
919-
if (idx > 0)
920-
printer << " "; // Interleave options separator.
947+
if (options.empty())
948+
return;
921949

922-
if (isa<UnitAttr>(optionAttr))
923-
printer.printOperand(dynamicOptions[currentDynamicOptionIdx++]);
924-
else if (auto strAttr = dyn_cast<StringAttr>(optionAttr))
925-
printer.printAttribute(strAttr);
926-
else
927-
llvm_unreachable("each option should be either a StringAttr or UnitAttr");
928-
}
950+
printer << "{";
951+
llvm::interleaveComma(options, printer, [&](NamedAttribute namedAttribute) {
952+
printer << namedAttribute.getName() << " = ";
953+
Attribute value = namedAttribute.getValue();
954+
if (auto indexAttr = dyn_cast<transform::ParamOperandIndexAttr>(value)) {
955+
printer.printOperand(dynamicOptions[indexAttr.getIndex().getInt()]);
956+
} else {
957+
printer.printAttribute(value);
958+
}
959+
});
960+
printer << "}";
929961
}
930962

931963
LogicalResult transform::ApplyRegisteredPassOp::verify() {
932-
size_t numUnitsInOptions = 0;
933-
for (Attribute optionsElement : getOptions()) {
934-
if (isa<UnitAttr>(optionsElement))
935-
numUnitsInOptions++;
936-
else if (!isa<StringAttr>(optionsElement))
937-
return emitOpError() << "expected each option to be either a StringAttr "
938-
<< "or a UnitAttr, got " << optionsElement;
939-
}
940-
941-
if (getDynamicOptions().size() != numUnitsInOptions)
942-
return emitOpError()
943-
<< "expected the same number of options passed as params as "
944-
<< "UnitAttr elements in options ArrayAttr";
964+
// Check that there is a one-to-one correspondence between param operands
965+
// and references to dynamic options in the options dictionary.
966+
967+
auto dynamicOptions = SmallVector<Value>(getDynamicOptions());
968+
for (NamedAttribute namedAttr : getOptions())
969+
if (auto paramOperandIndex =
970+
dyn_cast<transform::ParamOperandIndexAttr>(namedAttr.getValue())) {
971+
size_t dynamicOptionIdx = paramOperandIndex.getIndex().getInt();
972+
if (dynamicOptionIdx < 0 || dynamicOptionIdx >= dynamicOptions.size())
973+
return emitOpError()
974+
<< "dynamic option index " << dynamicOptionIdx
975+
<< " is out of bounds for the number of dynamic options: "
976+
<< dynamicOptions.size();
977+
if (dynamicOptions[dynamicOptionIdx] == nullptr)
978+
return emitOpError() << "dynamic option index " << dynamicOptionIdx
979+
<< " is already used in options";
980+
dynamicOptions[dynamicOptionIdx] = nullptr; // Mark this option as used.
981+
}
982+
983+
for (Value dynamicOption : dynamicOptions)
984+
if (dynamicOption)
985+
return emitOpError() << "a param operand does not have a corresponding "
986+
<< "param_operand_index attr in the options dict";
945987

946988
return success();
947989
}

0 commit comments

Comments
 (0)