-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR][Transform] apply_registered_pass op's options as a dict #143159
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir Author: Rolf Morel (rolfmorel) ChangesIn particular, use similar syntax for providing options as in the (pretty-)printed IR. Full diff: https://github.com/llvm/llvm-project/pull/143159.diff 2 Files Affected:
diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py
index 5b158ec6b65fd..cdcdeadd54cd3 100644
--- a/mlir/python/mlir/dialects/transform/__init__.py
+++ b/mlir/python/mlir/dialects/transform/__init__.py
@@ -214,6 +214,41 @@ def __init__(
super().__init__(_get_op_results_or_values(operands), loc=loc, ip=ip)
+@_ods_cext.register_operation(_Dialect, replace=True)
+class ApplyRegisteredPassOp(ApplyRegisteredPassOp):
+ def __init__(
+ self,
+ result: Type,
+ pass_name: Union[str, StringAttr],
+ target: Value,
+ *,
+ options: Sequence[Union[str, StringAttr, Value, Operation]] = [],
+ loc=None,
+ ip=None,
+ ):
+ static_options = []
+ dynamic_options = []
+ for opt in options:
+ if isinstance(opt, str):
+ static_options.append(StringAttr.get(opt))
+ elif isinstance(opt, StringAttr):
+ static_options.append(opt)
+ elif isinstance(opt, Value):
+ static_options.append(UnitAttr.get())
+ dynamic_options.append(_get_op_result_or_value(opt))
+ else:
+ raise TypeError(f"Unsupported option type: {type(opt)}")
+ super().__init__(
+ result,
+ pass_name,
+ dynamic_options,
+ target=_get_op_result_or_value(target),
+ options=static_options,
+ loc=loc,
+ ip=ip,
+ )
+
+
AnyOpTypeT = NewType("AnyOpType", AnyOpType)
diff --git a/mlir/test/python/dialects/transform.py b/mlir/test/python/dialects/transform.py
index 6ed4818fc9d2f..dc0987e769a09 100644
--- a/mlir/test/python/dialects/transform.py
+++ b/mlir/test/python/dialects/transform.py
@@ -254,3 +254,39 @@ def testReplicateOp(module: Module):
# CHECK: %[[FIRST:.+]] = pdl_match
# CHECK: %[[SECOND:.+]] = pdl_match
# CHECK: %{{.*}} = replicate num(%[[FIRST]]) %[[SECOND]]
+
+
+@run
+def testApplyRegisteredPassOp(module: Module):
+ sequence = transform.SequenceOp(
+ transform.FailurePropagationMode.Propagate, [], transform.AnyOpType.get()
+ )
+ with InsertionPoint(sequence.body):
+ mod = transform.ApplyRegisteredPassOp(
+ transform.AnyOpType.get(), "canonicalize", sequence.bodyTarget
+ )
+ mod = transform.ApplyRegisteredPassOp(
+ transform.AnyOpType.get(), "canonicalize", mod, options=("top-down=false",)
+ )
+ max_iter = transform.param_constant(
+ transform.AnyParamType.get(), StringAttr.get("max-iterations=10")
+ )
+ max_rewrites = transform.param_constant(
+ transform.AnyParamType.get(), StringAttr.get("max-num-rewrites=1")
+ )
+ transform.ApplyRegisteredPassOp(
+ transform.AnyOpType.get(),
+ "canonicalize",
+ mod,
+ options=("top-down=false", max_iter, "test-convergence=true", max_rewrites),
+ )
+ transform.YieldOp()
+ # CHECK-LABEL: TEST: testApplyRegisteredPassOp
+ # CHECK: transform.sequence
+ # CHECK: %{{.*}} = apply_registered_pass "canonicalize" to {{.*}} : (!transform.any_op) -> !transform.any_op
+ # CHECK: %{{.*}} = apply_registered_pass "canonicalize" with options = "top-down=false" to {{.*}} : (!transform.any_op) -> !transform.any_op
+ # CHECK: %[[MAX_ITER:.+]] = transform.param.constant
+ # CHECK: %[[MAX_REWRITE:.+]] = transform.param.constant
+ # CHECK: %{{.*}} = apply_registered_pass "canonicalize"
+ # CHECK-SAME: with options = "top-down=false" %[[MAX_ITER]]
+ # CHECK-SAME: "test-convergence=true" %[[MAX_REWRITE]] to %{{.*}} : (!transform.any_param, !transform.any_param, !transform.any_op) -> !transform.any_op
|
✅ With the latest revision this PR passed the Python code formatter. |
Context: #142683 |
transform.AnyOpType.get(), | ||
"canonicalize", | ||
mod, | ||
options=("top-down=false", max_iter, "test-convergence=true", max_rewrites), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd rather works toward a dictionary here that would make it Python-friendly, but I see the actual ops allows for "max-iterations=10" style of parameter... Though even for the op itself, it may be wise separating the pass parameter name (which is a literal) from the value it takes (which may be a constant/value or also a literal).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i have this "widget":
def add_pass(self, pass_name, **kwargs):
kwargs = {
k.replace("_", "-"): int(v) if isinstance(v, bool) else v
for k, v in kwargs.items()
if v is not None
}
if kwargs:
args_str = " ".join(f"{k}={v}" for k, v in kwargs.items())
string interpolation of python values does the right thing for the kinds of args seen in passes (ints, strings, lists, etc) except for bools True/False
, which is handled by int(True)/int(False) -> 0/1
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did a major overhaul. Most notably values can now be passed via params (without needing to be strings nor needing to include key=
).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The overhaul addresses this point.
Will now merge with the current mechanism for allowing a reference to a param from the options dict. Can iterate in-tree on that design if someone has a better suggestion.
Hi @ftynse, @makslevental, partially based on your comments, I decided to go whole hog on this and improve the op on the C++ side as well. That is, the options are now provided in a dictionary and it is possible to pass the values for the options via params. I re-purposed this PR (prev. was solely fixing up Python bindings for this op) for this more substantial change. Looking forward to your re-review! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 to dict, looks neat
Minor comments
In particular, use similar syntax for providing options as in the (pretty-)printed IR.
* llvm/llvm-project#139340 ``` sed -i "s/isZeroIndex/isZeroInteger/g" **/*.h sed -i "s/isZeroIndex/isZeroInteger/g" **/*.cpp ``` * llvm/llvm-project#141466 & llvm/llvm-project#141019 * Add `BufferizationState &state` to `bufferize` and `getBuffer` * llvm/llvm-project#143159 & llvm/llvm-project#142683 & llvm/llvm-project#143779 * Updates to `transform.apply_registered_pass` and its Python-bindings * llvm/llvm-project#143217 * `tilingResult->mergeResult.replacements` -> `tilingResult->replacements` * llvm/llvm-project#140559 & llvm/llvm-project#143871 * Change CHECK lines which expected `amx.` ops to `llvm.call_intrinsic`s & fix which enables conversion again.
…143159) Improve ApplyRegisteredPassOp's support for taking options by taking them as a dict (vs a list of string-valued key-value pairs). Values of options are provided as either static attributes or as params (which pass in attributes at interpreter runtime). In either case, the keys and value attributes are converted to strings and a single options-string, in the format used on the commandline, is constructed to pass to the `addToPipeline`-pass API.
Improve ApplyRegisteredPassOp's support for taking options by taking them as a dict (vs a list of string-valued key-value pairs).
Values of options are provided as either static attributes or as params (which pass in attributes at interpreter runtime). In either case, the keys and value attributes are converted to strings and a single options-string, in the format used on the commandline, is constructed to pass to the
addToPipeline
-pass API.