Skip to content

[MLIR][Transform] Allow ApplyRegisteredPassOp to take options as a param #142683

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 6, 2025

Conversation

rolfmorel
Copy link
Contributor

Makes it possible to pass around the options to a pass inside a schedule.

The refactoring also makes it so that the pass manager and pass are only
constructed once per apply() of the transform op versus for each target
payload given to the op's apply().

Makes it possible to pass around the options to a pass inside a schedule.

The refactoring also makes it so that the pass manager and pass are only
constructed once per apply of the transform op versus for each target
payload given to the op.
@llvmbot
Copy link
Member

llvmbot commented Jun 3, 2025

@llvm/pr-subscribers-mlir

Author: Rolf Morel (rolfmorel)

Changes

Makes it possible to pass around the options to a pass inside a schedule.

The refactoring also makes it so that the pass manager and pass are only
constructed once per apply() of the transform op versus for each target
payload given to the op's apply().


Full diff: https://github.com/llvm/llvm-project/pull/142683.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Transform/IR/TransformOps.td (+10-15)
  • (modified) mlir/lib/Dialect/Transform/IR/TransformOps.cpp (+99-18)
  • (modified) mlir/test/Dialect/Transform/test-pass-application.mlir (+51-2)
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index e4eb67c8e14ce..b042f5e436185 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -399,15 +399,15 @@ def ApplyLoopInvariantCodeMotionOp : TransformDialectOp<"apply_licm",
 }
 
 def ApplyRegisteredPassOp : TransformDialectOp<"apply_registered_pass",
-    [TransformOpInterface, TransformEachOpTrait,
-     FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface]> {
+    [DeclareOpInterfaceMethods<TransformOpInterface>,
+     DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
   let summary = "Applies the specified registered pass or pass pipeline";
   let description = [{
     This transform applies the specified pass or pass pipeline to the targeted
     ops. The name of the pass/pipeline is specified as a string attribute, as
     set during pass/pipeline registration. Optionally, pass options may be
-    specified as a string attribute. The pass options syntax is identical to the
-    one used with "mlir-opt".
+    specified as a string attribute with the option to pass the attribute as a
+    param. The pass options syntax is identical to the one used with "mlir-opt".
 
     This op first looks for a pass pipeline with the specified name. If no such
     pipeline exists, it looks for a pass with the specified name. If no such
@@ -420,20 +420,15 @@ def ApplyRegisteredPassOp : TransformDialectOp<"apply_registered_pass",
     of targeted ops.
   }];
 
-  let arguments = (ins TransformHandleTypeInterface:$target,
+  let arguments = (ins Optional<TransformParamTypeInterface>:$dynamic_options,
+                       TransformHandleTypeInterface:$target,
                        StrAttr:$pass_name,
-                       DefaultValuedAttr<StrAttr, "\"\"">:$options);
+                       DefaultValuedAttr<StrAttr, "\"\"">:$static_options);
   let results = (outs TransformHandleTypeInterface:$result);
   let assemblyFormat = [{
-    $pass_name `to` $target attr-dict `:` functional-type(operands, results)
-  }];
-
-  let extraClassDeclaration = [{
-    ::mlir::DiagnosedSilenceableFailure applyToOne(
-      ::mlir::transform::TransformRewriter &rewriter,
-      ::mlir::Operation *target,
-      ::mlir::transform::ApplyToEachResultList &results,
-      ::mlir::transform::TransformState &state);
+    $pass_name (`with` `options` `=`
+      custom<ApplyRegisteredPassOptions>($dynamic_options, $static_options)^)?
+      `to` $target attr-dict `:` functional-type(operands, results)
   }];
 }
 
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 673743f22249a..536c3e14fe5c0 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -53,6 +53,13 @@
 
 using namespace mlir;
 
+static ParseResult parseApplyRegisteredPassOptions(
+    OpAsmParser &parser,
+    std::optional<OpAsmParser::UnresolvedOperand> &dynamicOptions,
+    StringAttr &staticOptions);
+static void printApplyRegisteredPassOptions(OpAsmPrinter &printer,
+                                            Operation *op, Value dynamicOptions,
+                                            StringAttr staticOptions);
 static ParseResult parseSequenceOpOperands(
     OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
     Type &rootType,
@@ -766,17 +773,38 @@ void transform::ApplyLoopInvariantCodeMotionOp::getEffects(
 // ApplyRegisteredPassOp
 //===----------------------------------------------------------------------===//
 
-DiagnosedSilenceableFailure transform::ApplyRegisteredPassOp::applyToOne(
-    transform::TransformRewriter &rewriter, Operation *target,
-    ApplyToEachResultList &results, transform::TransformState &state) {
-  // Make sure that this transform is not applied to itself. Modifying the
-  // transform IR while it is being interpreted is generally dangerous. Even
-  // more so when applying passes because they may perform a wide range of IR
-  // modifications.
-  DiagnosedSilenceableFailure payloadCheck =
-      ensurePayloadIsSeparateFromTransform(*this, target);
-  if (!payloadCheck.succeeded())
-    return payloadCheck;
+void transform::ApplyRegisteredPassOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  consumesHandle(getTargetMutable(), effects);
+  onlyReadsHandle(getDynamicOptionsMutable(), effects);
+  producesHandle(getOperation()->getOpResults(), effects);
+  modifiesPayload(effects);
+}
+
+DiagnosedSilenceableFailure
+transform::ApplyRegisteredPassOp::apply(transform::TransformRewriter &rewriter,
+                                        transform::TransformResults &results,
+                                        transform::TransformState &state) {
+  // Check whether pass options are specified, either as a dynamic param or
+  // a static attribute. In either case, options are passed as a single string.
+  StringRef options;
+  if (auto dynamicOptions = getDynamicOptions()) {
+    ArrayRef<Attribute> dynamicOptionsParam = state.getParams(dynamicOptions);
+    if (dynamicOptionsParam.size() != 1) {
+      return emitSilenceableError()
+             << "options passed as a param must be a single value, got "
+             << dynamicOptionsParam.size();
+    }
+    if (auto optionsStrAttr = dyn_cast<StringAttr>(dynamicOptionsParam[0])) {
+      options = optionsStrAttr.getValue();
+    } else {
+      return emitSilenceableError()
+             << "options passed as a param must be a string, got "
+             << dynamicOptionsParam[0];
+    }
+  } else {
+    options = getStaticOptions();
+  }
 
   // Get pass or pass pipeline from registry.
   const PassRegistryEntry *info = PassPipelineInfo::lookup(getPassName());
@@ -786,9 +814,9 @@ DiagnosedSilenceableFailure transform::ApplyRegisteredPassOp::applyToOne(
     return emitDefiniteFailure()
            << "unknown pass or pass pipeline: " << getPassName();
 
-  // Create pass manager and run the pass or pass pipeline.
+  // Create pass manager and add the pass or pass pipeline.
   PassManager pm(getContext());
-  if (failed(info->addToPipeline(pm, getOptions(), [&](const Twine &msg) {
+  if (failed(info->addToPipeline(pm, options, [&](const Twine &msg) {
         emitError(msg);
         return failure();
       }))) {
@@ -796,16 +824,69 @@ DiagnosedSilenceableFailure transform::ApplyRegisteredPassOp::applyToOne(
            << "failed to add pass or pass pipeline to pipeline: "
            << getPassName();
   }
-  if (failed(pm.run(target))) {
-    auto diag = emitSilenceableError() << "pass pipeline failed";
-    diag.attachNote(target->getLoc()) << "target op";
-    return diag;
+
+  auto targets = SmallVector<Operation *>(state.getPayloadOps(getTarget()));
+  for (Operation *target : targets) {
+    // Make sure that this transform is not applied to itself. Modifying the
+    // transform IR while it is being interpreted is generally dangerous. Even
+    // more so when applying passes because they may perform a wide range of IR
+    // modifications.
+    DiagnosedSilenceableFailure payloadCheck =
+        ensurePayloadIsSeparateFromTransform(*this, target);
+    if (!payloadCheck.succeeded())
+      return payloadCheck;
+
+    // Run the pass or pass pipeline on the current target operation.
+    if (failed(pm.run(target))) {
+      auto diag = emitSilenceableError() << "pass pipeline failed";
+      diag.attachNote(target->getLoc()) << "target op";
+      return diag;
+    }
   }
 
-  results.push_back(target);
+  // The applied pass will have directly modified the payload IR(s).
+  results.set(llvm::cast<OpResult>(getResult()), targets);
   return DiagnosedSilenceableFailure::success();
 }
 
+static ParseResult parseApplyRegisteredPassOptions(
+    OpAsmParser &parser,
+    std::optional<OpAsmParser::UnresolvedOperand> &dynamicOptions,
+    StringAttr &staticOptions) {
+  dynamicOptions = std::nullopt;
+  OpAsmParser::UnresolvedOperand dynamicOptionsOperand;
+  OptionalParseResult hasDynamicOptions =
+      parser.parseOptionalOperand(dynamicOptionsOperand);
+
+  if (hasDynamicOptions.has_value()) {
+    if (failed(hasDynamicOptions.value()))
+      return failure();
+
+    dynamicOptions = dynamicOptionsOperand;
+    return success();
+  }
+
+  OptionalParseResult hasStaticOptions =
+      parser.parseOptionalAttribute(staticOptions);
+  if (hasStaticOptions.has_value()) {
+    if (failed(hasStaticOptions.value()))
+      return failure();
+    return success();
+  }
+
+  return success();
+}
+
+static void printApplyRegisteredPassOptions(OpAsmPrinter &printer,
+                                            Operation *op, Value dynamicOptions,
+                                            StringAttr staticOptions) {
+  if (dynamicOptions) {
+    printer.printOperand(dynamicOptions);
+  } else if (!staticOptions.getValue().empty()) {
+    printer.printAttribute(staticOptions);
+  }
+}
+
 //===----------------------------------------------------------------------===//
 // CastOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Transform/test-pass-application.mlir b/mlir/test/Dialect/Transform/test-pass-application.mlir
index 3a40b462b8270..e8e0f63b28096 100644
--- a/mlir/test/Dialect/Transform/test-pass-application.mlir
+++ b/mlir/test/Dialect/Transform/test-pass-application.mlir
@@ -79,7 +79,7 @@ module attributes {transform.with_named_sequence} {
     %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
     // expected-error @below {{failed to add pass or pass pipeline to pipeline: canonicalize}}
     // expected-error @below {{<Pass-Options-Parser>: no such option invalid-option}}
-    transform.apply_registered_pass "canonicalize" to %1 {options = "invalid-option=1"} : (!transform.any_op) -> !transform.any_op
+    transform.apply_registered_pass "canonicalize" with options = "invalid-option=1" to %1 : (!transform.any_op) -> !transform.any_op
     transform.yield
   }
 }
@@ -94,7 +94,56 @@ func.func @valid_pass_option() {
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
     %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
-    transform.apply_registered_pass "canonicalize" to %1 {options = "top-down=false"} : (!transform.any_op) -> !transform.any_op
+    transform.apply_registered_pass "canonicalize" with options = "top-down=false" to %1 : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+// CHECK-LABEL: func @valid_dynamic_pass_option()
+func.func @valid_dynamic_pass_option() {
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
+    %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %pass_options = transform.param.constant "top-down=false" -> !transform.any_param
+    transform.apply_registered_pass "canonicalize" with options = %pass_options to %1 : (!transform.any_param, !transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+// -----
+
+func.func @invalid_pass_option_param() {
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
+    %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %pass_options = transform.param.constant 42 -> !transform.any_param
+    // expected-error @below {{options passed as a param must be a string, got 42}}
+    transform.apply_registered_pass "canonicalize" with options = %pass_options to %1 : (!transform.any_param, !transform.any_op) -> !transform.any_op
+    transform.apply_registered_pass "canonicalize" with options = "invalid-option=1" to %1 : (!transform.any_op) -> !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @too_many_pass_option_params() {
+  return
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
+    %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    %x = transform.param.constant "x" -> !transform.any_param
+    %pass_options = transform.merge_handles %x, %x : !transform.any_param
+    // expected-error @below {{options passed as a param must be a single value, got 2}}
+    transform.apply_registered_pass "canonicalize" with options = %pass_options to %1 : (!transform.any_param, !transform.any_op) -> !transform.any_op
     transform.yield
   }
 }

Copy link
Contributor

@adam-smnk adam-smnk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 for the concept
Overall LGTM but let's wait for another opinion

@fschlimb
Copy link
Contributor

fschlimb commented Jun 4, 2025

I like the possibility to have be able to provide options as SSA values.

It seems like one might want to mix static and dynamic options. Would it make sense to provide the options as a list rather than a string?

%option1 =...
transform.apply_registered_pass "canonicalize" with options = [%option1, "option2"] to %1 : (!transform.any_param, !transform.any_op) -> !transform.any_op

@rolfmorel
Copy link
Contributor Author

rolfmorel commented Jun 4, 2025

Indeed, being able to mix-and-match static arguments with those passed in dynamically - or being able to combine multiple orthogonal dynamic arguments - would be nice!

The suggested syntax of a list/ArrayAttr where elements are either strings, i.e. "option=value" pairs, or can be transform params makes sense to me. With the interpretation that these elements need to be joined by commas spaces to have a single options string to pass to the pass.

Switching to this syntax does break the (documented) property that the options argument is just the string one would pass to the pass on the commandline. On the other hand, as this string had to be statically provided anyway, you could always do the transformation to an array manually. We could keep this option available though: either the options argument is a StringAttr (maybe even coming in as a param) or it is an array (of StringAttr or SSA-values) which will be commaspace-joined.

I will have a go at updating the PR. Thanks @fschlimb!

@rolfmorel
Copy link
Contributor Author

Updated the PR so that the following syntax is accepted (no brackets as it matches the cmdline options more closely - as suggested by @fschlimb offline):

    %max_iter = transform.param.constant "max-iterations=10" -> !transform.any_param
    %max_rewrites = transform.param.constant "max-num-rewrites=1" -> !transform.any_param
    %2 = transform.apply_registered_pass "canonicalize" with options = "top-down=false" %max_iter "test-convergence=true" %max_rewrites to %1 : (!transform.any_param, !transform.any_param, !transform.any_op) -> !transform.any_op

The PR is ready to be re-reviewed.

@rolfmorel rolfmorel force-pushed the transform-pass-param branch from 9acadec to 9529ea4 Compare June 6, 2025 09:44
@rolfmorel rolfmorel merged commit 4eeee41 into llvm:main Jun 6, 2025
6 of 7 checks passed
rorth pushed a commit to rorth/llvm-project that referenced this pull request Jun 11, 2025
…ram (llvm#142683)

Makes it possible to pass around the options to a pass inside a schedule.

The refactoring also makes it so that the pass manager and pass are only
constructed once per `apply()` of the transform op versus for each target
payload given to the op's `apply()`.
DhruvSrivastavaX pushed a commit to DhruvSrivastavaX/lldb-for-aix that referenced this pull request Jun 12, 2025
…ram (llvm#142683)

Makes it possible to pass around the options to a pass inside a schedule.

The refactoring also makes it so that the pass manager and pass are only
constructed once per `apply()` of the transform op versus for each target
payload given to the op's `apply()`.
rolfmorel added a commit to libxsmm/tpp-mlir that referenced this pull request Jun 12, 2025
* llvm/llvm-project#139340
```
sed -i "s/isZeroIndex/isZeroInteger/g" **/*.h
sed -i "s/isZeroIndex/isZeroInteger/g" **/*.cpp
```

* llvm/llvm-project#141466 &
llvm/llvm-project#141019
  * Add `BufferizationState &state` to `bufferize` and `getBuffer` 

* llvm/llvm-project#143159 &
llvm/llvm-project#142683 &
llvm/llvm-project#143779
  * Updates to `transform.apply_registered_pass` and its Python-bindings

* llvm/llvm-project#143217
* `tilingResult->mergeResult.replacements` ->
`tilingResult->replacements`

* llvm/llvm-project#140559 &
llvm/llvm-project#143871
* Change CHECK lines which expected `amx.` ops to `llvm.call_intrinsic`s
& fix which enables conversion again.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants