|
54 | 54 | using namespace mlir;
|
55 | 55 |
|
56 | 56 | static ParseResult parseApplyRegisteredPassOptions(
|
57 |
| - OpAsmParser &parser, ArrayAttr &options, |
| 57 | + OpAsmParser &parser, DictionaryAttr &options, |
58 | 58 | SmallVectorImpl<OpAsmParser::UnresolvedOperand> &dynamicOptions);
|
59 | 59 | static void printApplyRegisteredPassOptions(OpAsmPrinter &printer,
|
60 |
| - Operation *op, ArrayAttr options, |
| 60 | + Operation *op, |
| 61 | + DictionaryAttr options, |
61 | 62 | ValueRange dynamicOptions);
|
62 | 63 | static ParseResult parseSequenceOpOperands(
|
63 | 64 | OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
|
@@ -784,41 +785,50 @@ DiagnosedSilenceableFailure
|
784 | 785 | transform::ApplyRegisteredPassOp::apply(transform::TransformRewriter &rewriter,
|
785 | 786 | transform::TransformResults &results,
|
786 | 787 | transform::TransformState &state) {
|
787 |
| - // Obtain a single options-string from options passed statically as |
788 |
| - // string attributes as well as "dynamically" through params. |
| 788 | + // Obtain a single options-string to pass to the pass(-pipeline) from options |
| 789 | + // passed in as a dictionary of keys mapping to values which are either |
| 790 | + // attributes or param-operands pointing to attributes. |
| 791 | + |
789 | 792 | std::string options;
|
| 793 | + llvm::raw_string_ostream optionsStream(options); // For "printing" attrs. |
| 794 | + |
790 | 795 | OperandRange dynamicOptions = getDynamicOptions();
|
791 |
| - size_t dynamicOptionsIdx = 0; |
792 |
| - for (auto [idx, optionAttr] : llvm::enumerate(getOptions())) { |
| 796 | + for (auto [idx, namedAttribute] : llvm::enumerate(getOptions())) { |
793 | 797 | if (idx > 0)
|
794 |
| - options += " "; // Interleave options seperator. |
795 |
| - |
796 |
| - if (auto strAttr = dyn_cast<StringAttr>(optionAttr)) { |
797 |
| - options += strAttr.getValue(); |
798 |
| - } else if (isa<UnitAttr>(optionAttr)) { |
799 |
| - assert(dynamicOptionsIdx < dynamicOptions.size() && |
| 798 | + optionsStream << " "; // Interleave options separator. |
| 799 | + optionsStream << namedAttribute.getName().str(); // Append the key. |
| 800 | + optionsStream << "="; // And the key-value separator. |
| 801 | + |
| 802 | + Attribute valueAttrToAppend; |
| 803 | + if (auto paramOperandIndex = dyn_cast<transform::ParamOperandIndexAttr>( |
| 804 | + namedAttribute.getValue())) { |
| 805 | + // The corresponding value attribute is passed in via a param. |
| 806 | + // Obtain the param-operand via its specified index. |
| 807 | + size_t dynamicOptionIdx = paramOperandIndex.getIndex().getInt(); |
| 808 | + assert(dynamicOptionIdx < dynamicOptions.size() && |
800 | 809 | "number of dynamic option markers (UnitAttr) in options ArrayAttr "
|
801 | 810 | "should be the same as the number of options passed as params");
|
802 | 811 | ArrayRef<Attribute> dynamicOption =
|
803 |
| - state.getParams(dynamicOptions[dynamicOptionsIdx++]); |
| 812 | + state.getParams(dynamicOptions[dynamicOptionIdx]); |
804 | 813 | if (dynamicOption.size() != 1)
|
805 |
| - return emitSilenceableError() << "options passed as a param must have " |
806 |
| - "a single value associated, param " |
807 |
| - << dynamicOptionsIdx - 1 << " associates " |
808 |
| - << dynamicOption.size(); |
809 |
| - |
810 |
| - if (auto dynamicOptionStr = dyn_cast<StringAttr>(dynamicOption[0])) { |
811 |
| - options += dynamicOptionStr.getValue(); |
812 |
| - } else { |
813 | 814 | return emitSilenceableError()
|
814 |
| - << "options passed as a param must be a string, got " |
815 |
| - << dynamicOption[0]; |
816 |
| - } |
| 815 | + << "options passed as a param must have " |
| 816 | + "a single value associated, param " |
| 817 | + << dynamicOptionIdx << " associates " << dynamicOption.size(); |
| 818 | + valueAttrToAppend = dynamicOption[0]; |
| 819 | + } else { |
| 820 | + // Value is a static attribute. |
| 821 | + valueAttrToAppend = namedAttribute.getValue(); |
| 822 | + } |
| 823 | + |
| 824 | + // Append string representation of value attribute. |
| 825 | + if (auto strAttr = dyn_cast<StringAttr>(valueAttrToAppend)) { |
| 826 | + optionsStream << strAttr.getValue().str(); |
817 | 827 | } else {
|
818 |
| - llvm_unreachable( |
819 |
| - "expected options element to be either StringAttr or UnitAttr"); |
| 828 | + valueAttrToAppend.print(optionsStream, /*elideType=*/true); |
820 | 829 | }
|
821 | 830 | }
|
| 831 | + optionsStream.flush(); |
822 | 832 |
|
823 | 833 | // Get pass or pass pipeline from registry.
|
824 | 834 | const PassRegistryEntry *info = PassPipelineInfo::lookup(getPassName());
|
@@ -864,84 +874,116 @@ transform::ApplyRegisteredPassOp::apply(transform::TransformRewriter &rewriter,
|
864 | 874 | }
|
865 | 875 |
|
866 | 876 | static ParseResult parseApplyRegisteredPassOptions(
|
867 |
| - OpAsmParser &parser, ArrayAttr &options, |
| 877 | + OpAsmParser &parser, DictionaryAttr &options, |
868 | 878 | SmallVectorImpl<OpAsmParser::UnresolvedOperand> &dynamicOptions) {
|
869 |
| - auto dynamicOptionMarker = UnitAttr::get(parser.getContext()); |
870 |
| - SmallVector<Attribute> optionsArray; |
871 |
| - |
872 |
| - auto parseOperandOrString = [&]() -> OptionalParseResult { |
873 |
| - OpAsmParser::UnresolvedOperand operand; |
874 |
| - OptionalParseResult parsedOperand = parser.parseOptionalOperand(operand); |
875 |
| - if (parsedOperand.has_value()) { |
876 |
| - if (failed(parsedOperand.value())) |
877 |
| - return failure(); |
878 |
| - |
879 |
| - dynamicOptions.push_back(operand); |
880 |
| - optionsArray.push_back( |
881 |
| - dynamicOptionMarker); // Placeholder for knowing where to |
882 |
| - // inject the dynamic option-as-param. |
883 |
| - return success(); |
884 |
| - } |
| 879 | + // Construct the options DictionaryAttr per a `{ key = value, ... }` syntax. |
| 880 | + SmallVector<NamedAttribute> keyValuePairs; |
885 | 881 |
|
886 |
| - StringAttr stringAttr; |
887 |
| - OptionalParseResult parsedStringAttr = |
888 |
| - parser.parseOptionalAttribute(stringAttr); |
889 |
| - if (parsedStringAttr.has_value()) { |
890 |
| - if (failed(parsedStringAttr.value())) |
891 |
| - return failure(); |
892 |
| - optionsArray.push_back(stringAttr); |
893 |
| - return success(); |
894 |
| - } |
| 882 | + size_t dynamicOptionsIdx = 0; |
| 883 | + auto parseKeyValuePair = [&]() -> ParseResult { |
| 884 | + // Parse items of the form `key = value` where `key` is a bare identifier or |
| 885 | + // a string and `value` is either an attribute or an operand. |
| 886 | + |
| 887 | + std::string key; |
| 888 | + Attribute valueAttr; |
| 889 | + if (parser.parseOptionalKeywordOrString(&key)) |
| 890 | + return parser.emitError(parser.getCurrentLocation()) |
| 891 | + << "expected key to either be an identifier or a string"; |
| 892 | + if (key.empty()) |
| 893 | + return failure(); |
895 | 894 |
|
896 |
| - return std::nullopt; |
| 895 | + if (parser.parseEqual()) |
| 896 | + return parser.emitError(parser.getCurrentLocation()) |
| 897 | + << "expected '=' after key in key-value pair"; |
| 898 | + |
| 899 | + // Parse the value, which can be either an attribute or an operand. |
| 900 | + OptionalParseResult parsedValueAttr = |
| 901 | + parser.parseOptionalAttribute(valueAttr); |
| 902 | + if (!parsedValueAttr.has_value()) { |
| 903 | + OpAsmParser::UnresolvedOperand operand; |
| 904 | + ParseResult parsedOperand = parser.parseOperand(operand); |
| 905 | + if (failed(parsedOperand)) |
| 906 | + return parser.emitError(parser.getCurrentLocation()) |
| 907 | + << "expected a valid attribute or operand as value associated " |
| 908 | + << "to key '" << key << "'"; |
| 909 | + dynamicOptions.push_back(operand); |
| 910 | + auto wrappedIndex = IntegerAttr::get( |
| 911 | + IntegerType::get(parser.getContext(), 64), dynamicOptionsIdx++); |
| 912 | + valueAttr = transform::ParamOperandIndexAttr::get(parser.getContext(), |
| 913 | + wrappedIndex); |
| 914 | + } else if (failed(parsedValueAttr.value())) { |
| 915 | + return failure(); // NB: Attempted parse should have output error message. |
| 916 | + } else if (isa<transform::ParamOperandIndexAttr>(valueAttr)) { |
| 917 | + return parser.emitError(parser.getCurrentLocation()) |
| 918 | + << "the param_operand_index attribute is a marker reserved for " |
| 919 | + << "indicating a value will be passed via params and is only used " |
| 920 | + << "in the generic print format"; |
| 921 | + } |
| 922 | + |
| 923 | + keyValuePairs.push_back(NamedAttribute(key, valueAttr)); |
| 924 | + return success(); |
897 | 925 | };
|
898 | 926 |
|
899 |
| - OptionalParseResult parsedOptionsElement = parseOperandOrString(); |
900 |
| - while (parsedOptionsElement.has_value()) { |
901 |
| - if (failed(parsedOptionsElement.value())) |
902 |
| - return failure(); |
903 |
| - parsedOptionsElement = parseOperandOrString(); |
904 |
| - } |
| 927 | + if (parser.parseCommaSeparatedList(AsmParser::Delimiter::Braces, |
| 928 | + parseKeyValuePair, |
| 929 | + " in options dictionary")) |
| 930 | + return failure(); // NB: Attempted parse should have output error message. |
905 | 931 |
|
906 |
| - if (optionsArray.empty()) { |
| 932 | + if (DictionaryAttr::findDuplicate( |
| 933 | + keyValuePairs, /*isSorted=*/false) // Also sorts the keyValuePairs. |
| 934 | + .has_value()) |
907 | 935 | return parser.emitError(parser.getCurrentLocation())
|
908 |
| - << "expected at least one option (either a string or a param)"; |
909 |
| - } |
910 |
| - options = parser.getBuilder().getArrayAttr(optionsArray); |
| 936 | + << "duplicate keys found in options dictionary"; |
| 937 | + |
| 938 | + options = DictionaryAttr::getWithSorted(parser.getContext(), keyValuePairs); |
| 939 | + |
911 | 940 | return success();
|
912 | 941 | }
|
913 | 942 |
|
914 | 943 | static void printApplyRegisteredPassOptions(OpAsmPrinter &printer,
|
915 |
| - Operation *op, ArrayAttr options, |
| 944 | + Operation *op, |
| 945 | + DictionaryAttr options, |
916 | 946 | ValueRange dynamicOptions) {
|
917 |
| - size_t currentDynamicOptionIdx = 0; |
918 |
| - for (auto [idx, optionAttr] : llvm::enumerate(options)) { |
919 |
| - if (idx > 0) |
920 |
| - printer << " "; // Interleave options separator. |
| 947 | + if (options.empty()) |
| 948 | + return; |
921 | 949 |
|
922 |
| - if (isa<UnitAttr>(optionAttr)) |
923 |
| - printer.printOperand(dynamicOptions[currentDynamicOptionIdx++]); |
924 |
| - else if (auto strAttr = dyn_cast<StringAttr>(optionAttr)) |
925 |
| - printer.printAttribute(strAttr); |
926 |
| - else |
927 |
| - llvm_unreachable("each option should be either a StringAttr or UnitAttr"); |
928 |
| - } |
| 950 | + printer << "{"; |
| 951 | + llvm::interleaveComma(options, printer, [&](NamedAttribute namedAttribute) { |
| 952 | + printer << namedAttribute.getName() << " = "; |
| 953 | + Attribute value = namedAttribute.getValue(); |
| 954 | + if (auto indexAttr = dyn_cast<transform::ParamOperandIndexAttr>(value)) { |
| 955 | + printer.printOperand(dynamicOptions[indexAttr.getIndex().getInt()]); |
| 956 | + } else { |
| 957 | + printer.printAttribute(value); |
| 958 | + } |
| 959 | + }); |
| 960 | + printer << "}"; |
929 | 961 | }
|
930 | 962 |
|
931 | 963 | LogicalResult transform::ApplyRegisteredPassOp::verify() {
|
932 |
| - size_t numUnitsInOptions = 0; |
933 |
| - for (Attribute optionsElement : getOptions()) { |
934 |
| - if (isa<UnitAttr>(optionsElement)) |
935 |
| - numUnitsInOptions++; |
936 |
| - else if (!isa<StringAttr>(optionsElement)) |
937 |
| - return emitOpError() << "expected each option to be either a StringAttr " |
938 |
| - << "or a UnitAttr, got " << optionsElement; |
939 |
| - } |
940 |
| - |
941 |
| - if (getDynamicOptions().size() != numUnitsInOptions) |
942 |
| - return emitOpError() |
943 |
| - << "expected the same number of options passed as params as " |
944 |
| - << "UnitAttr elements in options ArrayAttr"; |
| 964 | + // Check that there is a one-to-one correspondence between param operands |
| 965 | + // and references to dynamic options in the options dictionary. |
| 966 | + |
| 967 | + auto dynamicOptions = SmallVector<Value>(getDynamicOptions()); |
| 968 | + for (NamedAttribute namedAttr : getOptions()) |
| 969 | + if (auto paramOperandIndex = |
| 970 | + dyn_cast<transform::ParamOperandIndexAttr>(namedAttr.getValue())) { |
| 971 | + size_t dynamicOptionIdx = paramOperandIndex.getIndex().getInt(); |
| 972 | + if (dynamicOptionIdx < 0 || dynamicOptionIdx >= dynamicOptions.size()) |
| 973 | + return emitOpError() |
| 974 | + << "dynamic option index " << dynamicOptionIdx |
| 975 | + << " is out of bounds for the number of dynamic options: " |
| 976 | + << dynamicOptions.size(); |
| 977 | + if (dynamicOptions[dynamicOptionIdx] == nullptr) |
| 978 | + return emitOpError() << "dynamic option index " << dynamicOptionIdx |
| 979 | + << " is already used in options"; |
| 980 | + dynamicOptions[dynamicOptionIdx] = nullptr; // Mark this option as used. |
| 981 | + } |
| 982 | + |
| 983 | + for (Value dynamicOption : dynamicOptions) |
| 984 | + if (dynamicOption) |
| 985 | + return emitOpError() << "a param operand does not have a corresponding " |
| 986 | + << "param_operand_index attr in the options dict"; |
945 | 987 |
|
946 | 988 | return success();
|
947 | 989 | }
|
|
0 commit comments