diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td index e4eb67c8e14ce..e864a65f8ceac 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -399,15 +399,16 @@ def ApplyLoopInvariantCodeMotionOp : TransformDialectOp<"apply_licm", } def ApplyRegisteredPassOp : TransformDialectOp<"apply_registered_pass", - [TransformOpInterface, TransformEachOpTrait, - FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface]> { + [DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]> { let summary = "Applies the specified registered pass or pass pipeline"; let description = [{ This transform applies the specified pass or pass pipeline to the targeted ops. The name of the pass/pipeline is specified as a string attribute, as set during pass/pipeline registration. Optionally, pass options may be - specified as a string attribute. The pass options syntax is identical to the - one used with "mlir-opt". + specified as (space-separated) string attributes with the option to pass + these attributes via params. The pass options syntax is identical to the one + used with "mlir-opt". This op first looks for a pass pipeline with the specified name. If no such pipeline exists, it looks for a pass with the specified name. If no such @@ -420,21 +421,17 @@ def ApplyRegisteredPassOp : TransformDialectOp<"apply_registered_pass", of targeted ops. }]; - let arguments = (ins TransformHandleTypeInterface:$target, - StrAttr:$pass_name, - DefaultValuedAttr:$options); + let arguments = (ins StrAttr:$pass_name, + DefaultValuedAttr:$options, + Variadic:$dynamic_options, + TransformHandleTypeInterface:$target); let results = (outs TransformHandleTypeInterface:$result); let assemblyFormat = [{ - $pass_name `to` $target attr-dict `:` functional-type(operands, results) - }]; - - let extraClassDeclaration = [{ - ::mlir::DiagnosedSilenceableFailure applyToOne( - ::mlir::transform::TransformRewriter &rewriter, - ::mlir::Operation *target, - ::mlir::transform::ApplyToEachResultList &results, - ::mlir::transform::TransformState &state); + $pass_name (`with` `options` `=` + custom($options, $dynamic_options)^)? + `to` $target attr-dict `:` functional-type(operands, results) }]; + let hasVerifier = 1; } def CastOp : TransformDialectOp<"cast", diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp index 673743f22249a..a0f9518e3d12f 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -53,6 +53,12 @@ using namespace mlir; +static ParseResult parseApplyRegisteredPassOptions( + OpAsmParser &parser, ArrayAttr &options, + SmallVectorImpl &dynamicOptions); +static void printApplyRegisteredPassOptions(OpAsmPrinter &printer, + Operation *op, ArrayAttr options, + ValueRange dynamicOptions); static ParseResult parseSequenceOpOperands( OpAsmParser &parser, std::optional &root, Type &rootType, @@ -766,17 +772,53 @@ void transform::ApplyLoopInvariantCodeMotionOp::getEffects( // ApplyRegisteredPassOp //===----------------------------------------------------------------------===// -DiagnosedSilenceableFailure transform::ApplyRegisteredPassOp::applyToOne( - transform::TransformRewriter &rewriter, Operation *target, - ApplyToEachResultList &results, transform::TransformState &state) { - // Make sure that this transform is not applied to itself. Modifying the - // transform IR while it is being interpreted is generally dangerous. Even - // more so when applying passes because they may perform a wide range of IR - // modifications. - DiagnosedSilenceableFailure payloadCheck = - ensurePayloadIsSeparateFromTransform(*this, target); - if (!payloadCheck.succeeded()) - return payloadCheck; +void transform::ApplyRegisteredPassOp::getEffects( + SmallVectorImpl &effects) { + consumesHandle(getTargetMutable(), effects); + onlyReadsHandle(getDynamicOptionsMutable(), effects); + producesHandle(getOperation()->getOpResults(), effects); + modifiesPayload(effects); +} + +DiagnosedSilenceableFailure +transform::ApplyRegisteredPassOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, + transform::TransformState &state) { + // Obtain a single options-string from options passed statically as + // string attributes as well as "dynamically" through params. + std::string options; + OperandRange dynamicOptions = getDynamicOptions(); + size_t dynamicOptionsIdx = 0; + for (auto [idx, optionAttr] : llvm::enumerate(getOptions())) { + if (idx > 0) + options += " "; // Interleave options seperator. + + if (auto strAttr = dyn_cast(optionAttr)) { + options += strAttr.getValue(); + } else if (isa(optionAttr)) { + assert(dynamicOptionsIdx < dynamicOptions.size() && + "number of dynamic option markers (UnitAttr) in options ArrayAttr " + "should be the same as the number of options passed as params"); + ArrayRef dynamicOption = + state.getParams(dynamicOptions[dynamicOptionsIdx++]); + if (dynamicOption.size() != 1) + return emitSilenceableError() << "options passed as a param must have " + "a single value associated, param " + << dynamicOptionsIdx - 1 << " associates " + << dynamicOption.size(); + + if (auto dynamicOptionStr = dyn_cast(dynamicOption[0])) { + options += dynamicOptionStr.getValue(); + } else { + return emitSilenceableError() + << "options passed as a param must be a string, got " + << dynamicOption[0]; + } + } else { + llvm_unreachable( + "expected options element to be either StringAttr or UnitAttr"); + } + } // Get pass or pass pipeline from registry. const PassRegistryEntry *info = PassPipelineInfo::lookup(getPassName()); @@ -786,9 +828,9 @@ DiagnosedSilenceableFailure transform::ApplyRegisteredPassOp::applyToOne( return emitDefiniteFailure() << "unknown pass or pass pipeline: " << getPassName(); - // Create pass manager and run the pass or pass pipeline. + // Create pass manager and add the pass or pass pipeline. PassManager pm(getContext()); - if (failed(info->addToPipeline(pm, getOptions(), [&](const Twine &msg) { + if (failed(info->addToPipeline(pm, options, [&](const Twine &msg) { emitError(msg); return failure(); }))) { @@ -796,16 +838,114 @@ DiagnosedSilenceableFailure transform::ApplyRegisteredPassOp::applyToOne( << "failed to add pass or pass pipeline to pipeline: " << getPassName(); } - if (failed(pm.run(target))) { - auto diag = emitSilenceableError() << "pass pipeline failed"; - diag.attachNote(target->getLoc()) << "target op"; - return diag; + + auto targets = SmallVector(state.getPayloadOps(getTarget())); + for (Operation *target : targets) { + // Make sure that this transform is not applied to itself. Modifying the + // transform IR while it is being interpreted is generally dangerous. Even + // more so when applying passes because they may perform a wide range of IR + // modifications. + DiagnosedSilenceableFailure payloadCheck = + ensurePayloadIsSeparateFromTransform(*this, target); + if (!payloadCheck.succeeded()) + return payloadCheck; + + // Run the pass or pass pipeline on the current target operation. + if (failed(pm.run(target))) { + auto diag = emitSilenceableError() << "pass pipeline failed"; + diag.attachNote(target->getLoc()) << "target op"; + return diag; + } } - results.push_back(target); + // The applied pass will have directly modified the payload IR(s). + results.set(llvm::cast(getResult()), targets); return DiagnosedSilenceableFailure::success(); } +static ParseResult parseApplyRegisteredPassOptions( + OpAsmParser &parser, ArrayAttr &options, + SmallVectorImpl &dynamicOptions) { + auto dynamicOptionMarker = UnitAttr::get(parser.getContext()); + SmallVector optionsArray; + + auto parseOperandOrString = [&]() -> OptionalParseResult { + OpAsmParser::UnresolvedOperand operand; + OptionalParseResult parsedOperand = parser.parseOptionalOperand(operand); + if (parsedOperand.has_value()) { + if (failed(parsedOperand.value())) + return failure(); + + dynamicOptions.push_back(operand); + optionsArray.push_back( + dynamicOptionMarker); // Placeholder for knowing where to + // inject the dynamic option-as-param. + return success(); + } + + StringAttr stringAttr; + OptionalParseResult parsedStringAttr = + parser.parseOptionalAttribute(stringAttr); + if (parsedStringAttr.has_value()) { + if (failed(parsedStringAttr.value())) + return failure(); + optionsArray.push_back(stringAttr); + return success(); + } + + return std::nullopt; + }; + + OptionalParseResult parsedOptionsElement = parseOperandOrString(); + while (parsedOptionsElement.has_value()) { + if (failed(parsedOptionsElement.value())) + return failure(); + parsedOptionsElement = parseOperandOrString(); + } + + if (optionsArray.empty()) { + return parser.emitError(parser.getCurrentLocation()) + << "expected at least one option (either a string or a param)"; + } + options = parser.getBuilder().getArrayAttr(optionsArray); + return success(); +} + +static void printApplyRegisteredPassOptions(OpAsmPrinter &printer, + Operation *op, ArrayAttr options, + ValueRange dynamicOptions) { + size_t currentDynamicOptionIdx = 0; + for (auto [idx, optionAttr] : llvm::enumerate(options)) { + if (idx > 0) + printer << " "; // Interleave options separator. + + if (isa(optionAttr)) + printer.printOperand(dynamicOptions[currentDynamicOptionIdx++]); + else if (auto strAttr = dyn_cast(optionAttr)) + printer.printAttribute(strAttr); + else + llvm_unreachable("each option should be either a StringAttr or UnitAttr"); + } +} + +LogicalResult transform::ApplyRegisteredPassOp::verify() { + size_t numUnitsInOptions = 0; + for (Attribute optionsElement : getOptions()) { + if (isa(optionsElement)) + numUnitsInOptions++; + else if (!isa(optionsElement)) + return emitOpError() << "expected each option to be either a StringAttr " + << "or a UnitAttr, got " << optionsElement; + } + + if (getDynamicOptions().size() != numUnitsInOptions) + return emitOpError() + << "expected the same number of options passed as params as " + << "UnitAttr elements in options ArrayAttr"; + + return success(); +} + //===----------------------------------------------------------------------===// // CastOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Transform/test-pass-application.mlir b/mlir/test/Dialect/Transform/test-pass-application.mlir index 3a40b462b8270..463fd98afa65c 100644 --- a/mlir/test/Dialect/Transform/test-pass-application.mlir +++ b/mlir/test/Dialect/Transform/test-pass-application.mlir @@ -79,7 +79,9 @@ module attributes {transform.with_named_sequence} { %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op // expected-error @below {{failed to add pass or pass pipeline to pipeline: canonicalize}} // expected-error @below {{: no such option invalid-option}} - transform.apply_registered_pass "canonicalize" to %1 {options = "invalid-option=1"} : (!transform.any_op) -> !transform.any_op + transform.apply_registered_pass "canonicalize" + with options = "invalid-option=1" to %1 + : (!transform.any_op) -> !transform.any_op transform.yield } } @@ -94,7 +96,136 @@ func.func @valid_pass_option() { module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg1: !transform.any_op) { %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - transform.apply_registered_pass "canonicalize" to %1 {options = "top-down=false"} : (!transform.any_op) -> !transform.any_op + transform.apply_registered_pass "canonicalize" + with options = "top-down=false" to %1 + : (!transform.any_op) -> !transform.any_op + transform.yield + } +} + +// ----- + +// CHECK-LABEL: func @valid_pass_options() +func.func @valid_pass_options() { + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op) { + %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + //transform.apply_registered_pass "canonicalize" with options = "top-down=false,max-iterations=10" to %1 : (!transform.any_op) -> !transform.any_op + transform.apply_registered_pass "canonicalize" + with options = "top-down=false test-convergence=true" to %1 + : (!transform.any_op) -> !transform.any_op + transform.yield + } +} + +// ----- + +// CHECK-LABEL: func @valid_pass_options_as_list() +func.func @valid_pass_options_as_list() { + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op) { + %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.apply_registered_pass "canonicalize" + with options = "top-down=false" "max-iterations=0" to %1 + : (!transform.any_op) -> !transform.any_op + transform.yield + } +} + +// ----- + +// CHECK-LABEL: func @valid_dynamic_pass_options() +func.func @valid_dynamic_pass_options() { + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op) { + %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %max_iter = transform.param.constant "max-iterations=10" -> !transform.any_param + %max_rewrites = transform.param.constant "max-num-rewrites=1" -> !transform.any_param + %2 = transform.apply_registered_pass "canonicalize" + with options = "top-down=false" %max_iter "test-convergence=true" %max_rewrites to %1 + : (!transform.any_param, !transform.any_param, !transform.any_op) -> !transform.any_op + transform.yield + } +} + +// ----- + +func.func @invalid_dynamic_options_as_array() { + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op) { + %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %max_iter = transform.param.constant "max-iterations=10" -> !transform.any_param + // expected-error @+2 {{expected at least one option (either a string or a param)}} + %2 = transform.apply_registered_pass "canonicalize" + with options = ["top-down=false" %max_iter] to %1 + : (!transform.any_param, !transform.any_param, !transform.any_op) -> !transform.any_op + transform.yield + } +} + +// ----- + +func.func @invalid_options_as_pairs() { + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op) { + %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + // expected-error @+2 {{expected 'to'}} + %2 = transform.apply_registered_pass "canonicalize" + with options = "top-down=" false to %1 + : (!transform.any_param, !transform.any_param, !transform.any_op) -> !transform.any_op + transform.yield + } +} + +// ----- + +func.func @invalid_pass_option_param() { + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op) { + %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %pass_options = transform.param.constant 42 -> !transform.any_param + // expected-error @below {{options passed as a param must be a string, got 42}} + transform.apply_registered_pass "canonicalize" + with options = %pass_options to %1 + : (!transform.any_param, !transform.any_op) -> !transform.any_op + transform.yield + } +} + +// ----- + +func.func @too_many_pass_option_params() { + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op) { + %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %x = transform.param.constant "x" -> !transform.any_param + %y = transform.param.constant "y" -> !transform.any_param + %pass_options = transform.merge_handles %x, %y : !transform.any_param + // expected-error @below {{options passed as a param must have a single value associated, param 0 associates 2}} + transform.apply_registered_pass "canonicalize" + with options = %pass_options to %1 + : (!transform.any_param, !transform.any_op) -> !transform.any_op transform.yield } }