Skip to content

[MLIR][Transform] apply_registered_pass op's options as a dict #143159

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jun 11, 2025

Conversation

rolfmorel
Copy link
Contributor

@rolfmorel rolfmorel commented Jun 6, 2025

Improve ApplyRegisteredPassOp's support for taking options by taking them as a dict (vs a list of string-valued key-value pairs).

Values of options are provided as either static attributes or as params (which pass in attributes at interpreter runtime). In either case, the keys and value attributes are converted to strings and a single options-string, in the format used on the commandline, is constructed to pass to the addToPipeline-pass API.

@llvmbot
Copy link
Member

llvmbot commented Jun 6, 2025

@llvm/pr-subscribers-mlir

Author: Rolf Morel (rolfmorel)

Changes

In particular, use similar syntax for providing options as in the (pretty-)printed IR.


Full diff: https://github.com/llvm/llvm-project/pull/143159.diff

2 Files Affected:

  • (modified) mlir/python/mlir/dialects/transform/init.py (+35)
  • (modified) mlir/test/python/dialects/transform.py (+36)
diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py
index 5b158ec6b65fd..cdcdeadd54cd3 100644
--- a/mlir/python/mlir/dialects/transform/__init__.py
+++ b/mlir/python/mlir/dialects/transform/__init__.py
@@ -214,6 +214,41 @@ def __init__(
         super().__init__(_get_op_results_or_values(operands), loc=loc, ip=ip)
 
 
+@_ods_cext.register_operation(_Dialect, replace=True)
+class ApplyRegisteredPassOp(ApplyRegisteredPassOp):
+    def __init__(
+        self,
+        result: Type,
+        pass_name: Union[str, StringAttr],
+        target: Value,
+        *,
+        options: Sequence[Union[str, StringAttr, Value, Operation]] = [],
+        loc=None,
+        ip=None,
+    ):
+        static_options = []
+        dynamic_options = []
+        for opt in options:
+            if isinstance(opt, str):
+                static_options.append(StringAttr.get(opt))
+            elif isinstance(opt, StringAttr):
+                static_options.append(opt)
+            elif isinstance(opt, Value):
+                static_options.append(UnitAttr.get())
+                dynamic_options.append(_get_op_result_or_value(opt))
+            else:
+                raise TypeError(f"Unsupported option type: {type(opt)}")
+        super().__init__(
+            result,
+            pass_name,
+            dynamic_options,
+            target=_get_op_result_or_value(target),
+            options=static_options,
+            loc=loc,
+            ip=ip,
+        )
+
+
 AnyOpTypeT = NewType("AnyOpType", AnyOpType)
 
 
diff --git a/mlir/test/python/dialects/transform.py b/mlir/test/python/dialects/transform.py
index 6ed4818fc9d2f..dc0987e769a09 100644
--- a/mlir/test/python/dialects/transform.py
+++ b/mlir/test/python/dialects/transform.py
@@ -254,3 +254,39 @@ def testReplicateOp(module: Module):
     # CHECK: %[[FIRST:.+]] = pdl_match
     # CHECK: %[[SECOND:.+]] = pdl_match
     # CHECK: %{{.*}} = replicate num(%[[FIRST]]) %[[SECOND]]
+
+
+@run
+def testApplyRegisteredPassOp(module: Module):
+    sequence = transform.SequenceOp(
+        transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
+    )
+    with InsertionPoint(sequence.body):
+        mod = transform.ApplyRegisteredPassOp(
+            transform.AnyOpType.get(), "canonicalize", sequence.bodyTarget
+        )
+        mod = transform.ApplyRegisteredPassOp(
+            transform.AnyOpType.get(), "canonicalize", mod, options=("top-down=false",)
+        )
+        max_iter = transform.param_constant(
+            transform.AnyParamType.get(), StringAttr.get("max-iterations=10")
+        )
+        max_rewrites = transform.param_constant(
+            transform.AnyParamType.get(), StringAttr.get("max-num-rewrites=1")
+        )
+        transform.ApplyRegisteredPassOp(
+            transform.AnyOpType.get(),
+            "canonicalize",
+            mod,
+            options=("top-down=false", max_iter, "test-convergence=true", max_rewrites),
+        )
+        transform.YieldOp()
+    # CHECK-LABEL: TEST: testApplyRegisteredPassOp
+    # CHECK: transform.sequence
+    # CHECK:   %{{.*}} = apply_registered_pass "canonicalize" to {{.*}} : (!transform.any_op) -> !transform.any_op
+    # CHECK:   %{{.*}} = apply_registered_pass "canonicalize" with options = "top-down=false" to {{.*}} : (!transform.any_op) -> !transform.any_op
+    # CHECK:   %[[MAX_ITER:.+]] = transform.param.constant
+    # CHECK:   %[[MAX_REWRITE:.+]] = transform.param.constant
+    # CHECK:   %{{.*}} = apply_registered_pass "canonicalize"
+    # CHECK-SAME:    with options = "top-down=false" %[[MAX_ITER]]
+    # CHECK-SAME:   "test-convergence=true" %[[MAX_REWRITE]] to %{{.*}} : (!transform.any_param, !transform.any_param, !transform.any_op) -> !transform.any_op

Copy link

github-actions bot commented Jun 6, 2025

✅ With the latest revision this PR passed the Python code formatter.

@rolfmorel
Copy link
Contributor Author

Context: #142683

transform.AnyOpType.get(),
"canonicalize",
mod,
options=("top-down=false", max_iter, "test-convergence=true", max_rewrites),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rather works toward a dictionary here that would make it Python-friendly, but I see the actual ops allows for "max-iterations=10" style of parameter... Though even for the op itself, it may be wise separating the pass parameter name (which is a literal) from the value it takes (which may be a constant/value or also a literal).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i have this "widget":

def add_pass(self, pass_name, **kwargs):
    kwargs = {
        k.replace("_", "-"): int(v) if isinstance(v, bool) else v
        for k, v in kwargs.items()
        if v is not None
    }
    if kwargs:
        args_str = " ".join(f"{k}={v}" for k, v in kwargs.items())

string interpolation of python values does the right thing for the kinds of args seen in passes (ints, strings, lists, etc) except for bools True/False, which is handled by int(True)/int(False) -> 0/1.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did a major overhaul. Most notably values can now be passed via params (without needing to be strings nor needing to include key=).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The overhaul addresses this point.

Will now merge with the current mechanism for allowing a reference to a param from the options dict. Can iterate in-tree on that design if someone has a better suggestion.

@rolfmorel rolfmorel changed the title [MLIR][Transform] friendlier Python-bindings apply_registered_pass op [MLIR][Transform] apply_registered_pass op's options as a dict Jun 7, 2025
@rolfmorel
Copy link
Contributor Author

rolfmorel commented Jun 7, 2025

Hi @ftynse, @makslevental, partially based on your comments, I decided to go whole hog on this and improve the op on the C++ side as well. That is, the options are now provided in a dictionary and it is possible to pass the values for the options via params.

I re-purposed this PR (prev. was solely fixing up Python bindings for this op) for this more substantial change. Looking forward to your re-review!

@rolfmorel rolfmorel requested a review from rengolin June 7, 2025 20:38
Copy link
Contributor

@adam-smnk adam-smnk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 to dict, looks neat
Minor comments

@rolfmorel rolfmorel merged commit fe7bf4b into llvm:main Jun 11, 2025
7 checks passed
rolfmorel added a commit to libxsmm/tpp-mlir that referenced this pull request Jun 12, 2025
* llvm/llvm-project#139340
```
sed -i "s/isZeroIndex/isZeroInteger/g" **/*.h
sed -i "s/isZeroIndex/isZeroInteger/g" **/*.cpp
```

* llvm/llvm-project#141466 &
llvm/llvm-project#141019
  * Add `BufferizationState &state` to `bufferize` and `getBuffer` 

* llvm/llvm-project#143159 &
llvm/llvm-project#142683 &
llvm/llvm-project#143779
  * Updates to `transform.apply_registered_pass` and its Python-bindings

* llvm/llvm-project#143217
* `tilingResult->mergeResult.replacements` ->
`tilingResult->replacements`

* llvm/llvm-project#140559 &
llvm/llvm-project#143871
* Change CHECK lines which expected `amx.` ops to `llvm.call_intrinsic`s
& fix which enables conversion again.
tomtor pushed a commit to tomtor/llvm-project that referenced this pull request Jun 14, 2025
…143159)

Improve ApplyRegisteredPassOp's support for taking options by taking
them as a dict (vs a list of string-valued key-value pairs).

Values of options are provided as either static attributes or as params
(which pass in attributes at interpreter runtime). In either case, the
keys and value attributes are converted to strings and a single
options-string, in the format used on the commandline, is constructed to
pass to the `addToPipeline`-pass API.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:python MLIR Python bindings mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants