-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR][Transform] Allow ApplyRegisteredPassOp to take options as a param #142683
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Makes it possible to pass around the options to a pass inside a schedule. The refactoring also makes it so that the pass manager and pass are only constructed once per apply of the transform op versus for each target payload given to the op.
@llvm/pr-subscribers-mlir Author: Rolf Morel (rolfmorel) ChangesMakes it possible to pass around the options to a pass inside a schedule. The refactoring also makes it so that the pass manager and pass are only Full diff: https://github.com/llvm/llvm-project/pull/142683.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index e4eb67c8e14ce..b042f5e436185 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -399,15 +399,15 @@ def ApplyLoopInvariantCodeMotionOp : TransformDialectOp<"apply_licm",
}
def ApplyRegisteredPassOp : TransformDialectOp<"apply_registered_pass",
- [TransformOpInterface, TransformEachOpTrait,
- FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface]> {
+ [DeclareOpInterfaceMethods<TransformOpInterface>,
+ DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
let summary = "Applies the specified registered pass or pass pipeline";
let description = [{
This transform applies the specified pass or pass pipeline to the targeted
ops. The name of the pass/pipeline is specified as a string attribute, as
set during pass/pipeline registration. Optionally, pass options may be
- specified as a string attribute. The pass options syntax is identical to the
- one used with "mlir-opt".
+ specified as a string attribute with the option to pass the attribute as a
+ param. The pass options syntax is identical to the one used with "mlir-opt".
This op first looks for a pass pipeline with the specified name. If no such
pipeline exists, it looks for a pass with the specified name. If no such
@@ -420,20 +420,15 @@ def ApplyRegisteredPassOp : TransformDialectOp<"apply_registered_pass",
of targeted ops.
}];
- let arguments = (ins TransformHandleTypeInterface:$target,
+ let arguments = (ins Optional<TransformParamTypeInterface>:$dynamic_options,
+ TransformHandleTypeInterface:$target,
StrAttr:$pass_name,
- DefaultValuedAttr<StrAttr, "\"\"">:$options);
+ DefaultValuedAttr<StrAttr, "\"\"">:$static_options);
let results = (outs TransformHandleTypeInterface:$result);
let assemblyFormat = [{
- $pass_name `to` $target attr-dict `:` functional-type(operands, results)
- }];
-
- let extraClassDeclaration = [{
- ::mlir::DiagnosedSilenceableFailure applyToOne(
- ::mlir::transform::TransformRewriter &rewriter,
- ::mlir::Operation *target,
- ::mlir::transform::ApplyToEachResultList &results,
- ::mlir::transform::TransformState &state);
+ $pass_name (`with` `options` `=`
+ custom<ApplyRegisteredPassOptions>($dynamic_options, $static_options)^)?
+ `to` $target attr-dict `:` functional-type(operands, results)
}];
}
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 673743f22249a..536c3e14fe5c0 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -53,6 +53,13 @@
using namespace mlir;
+static ParseResult parseApplyRegisteredPassOptions(
+ OpAsmParser &parser,
+ std::optional<OpAsmParser::UnresolvedOperand> &dynamicOptions,
+ StringAttr &staticOptions);
+static void printApplyRegisteredPassOptions(OpAsmPrinter &printer,
+ Operation *op, Value dynamicOptions,
+ StringAttr staticOptions);
static ParseResult parseSequenceOpOperands(
OpAsmParser &parser, std::optional<OpAsmParser::UnresolvedOperand> &root,
Type &rootType,
@@ -766,17 +773,38 @@ void transform::ApplyLoopInvariantCodeMotionOp::getEffects(
// ApplyRegisteredPassOp
//===----------------------------------------------------------------------===//
-DiagnosedSilenceableFailure transform::ApplyRegisteredPassOp::applyToOne(
- transform::TransformRewriter &rewriter, Operation *target,
- ApplyToEachResultList &results, transform::TransformState &state) {
- // Make sure that this transform is not applied to itself. Modifying the
- // transform IR while it is being interpreted is generally dangerous. Even
- // more so when applying passes because they may perform a wide range of IR
- // modifications.
- DiagnosedSilenceableFailure payloadCheck =
- ensurePayloadIsSeparateFromTransform(*this, target);
- if (!payloadCheck.succeeded())
- return payloadCheck;
+void transform::ApplyRegisteredPassOp::getEffects(
+ SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+ consumesHandle(getTargetMutable(), effects);
+ onlyReadsHandle(getDynamicOptionsMutable(), effects);
+ producesHandle(getOperation()->getOpResults(), effects);
+ modifiesPayload(effects);
+}
+
+DiagnosedSilenceableFailure
+transform::ApplyRegisteredPassOp::apply(transform::TransformRewriter &rewriter,
+ transform::TransformResults &results,
+ transform::TransformState &state) {
+ // Check whether pass options are specified, either as a dynamic param or
+ // a static attribute. In either case, options are passed as a single string.
+ StringRef options;
+ if (auto dynamicOptions = getDynamicOptions()) {
+ ArrayRef<Attribute> dynamicOptionsParam = state.getParams(dynamicOptions);
+ if (dynamicOptionsParam.size() != 1) {
+ return emitSilenceableError()
+ << "options passed as a param must be a single value, got "
+ << dynamicOptionsParam.size();
+ }
+ if (auto optionsStrAttr = dyn_cast<StringAttr>(dynamicOptionsParam[0])) {
+ options = optionsStrAttr.getValue();
+ } else {
+ return emitSilenceableError()
+ << "options passed as a param must be a string, got "
+ << dynamicOptionsParam[0];
+ }
+ } else {
+ options = getStaticOptions();
+ }
// Get pass or pass pipeline from registry.
const PassRegistryEntry *info = PassPipelineInfo::lookup(getPassName());
@@ -786,9 +814,9 @@ DiagnosedSilenceableFailure transform::ApplyRegisteredPassOp::applyToOne(
return emitDefiniteFailure()
<< "unknown pass or pass pipeline: " << getPassName();
- // Create pass manager and run the pass or pass pipeline.
+ // Create pass manager and add the pass or pass pipeline.
PassManager pm(getContext());
- if (failed(info->addToPipeline(pm, getOptions(), [&](const Twine &msg) {
+ if (failed(info->addToPipeline(pm, options, [&](const Twine &msg) {
emitError(msg);
return failure();
}))) {
@@ -796,16 +824,69 @@ DiagnosedSilenceableFailure transform::ApplyRegisteredPassOp::applyToOne(
<< "failed to add pass or pass pipeline to pipeline: "
<< getPassName();
}
- if (failed(pm.run(target))) {
- auto diag = emitSilenceableError() << "pass pipeline failed";
- diag.attachNote(target->getLoc()) << "target op";
- return diag;
+
+ auto targets = SmallVector<Operation *>(state.getPayloadOps(getTarget()));
+ for (Operation *target : targets) {
+ // Make sure that this transform is not applied to itself. Modifying the
+ // transform IR while it is being interpreted is generally dangerous. Even
+ // more so when applying passes because they may perform a wide range of IR
+ // modifications.
+ DiagnosedSilenceableFailure payloadCheck =
+ ensurePayloadIsSeparateFromTransform(*this, target);
+ if (!payloadCheck.succeeded())
+ return payloadCheck;
+
+ // Run the pass or pass pipeline on the current target operation.
+ if (failed(pm.run(target))) {
+ auto diag = emitSilenceableError() << "pass pipeline failed";
+ diag.attachNote(target->getLoc()) << "target op";
+ return diag;
+ }
}
- results.push_back(target);
+ // The applied pass will have directly modified the payload IR(s).
+ results.set(llvm::cast<OpResult>(getResult()), targets);
return DiagnosedSilenceableFailure::success();
}
+static ParseResult parseApplyRegisteredPassOptions(
+ OpAsmParser &parser,
+ std::optional<OpAsmParser::UnresolvedOperand> &dynamicOptions,
+ StringAttr &staticOptions) {
+ dynamicOptions = std::nullopt;
+ OpAsmParser::UnresolvedOperand dynamicOptionsOperand;
+ OptionalParseResult hasDynamicOptions =
+ parser.parseOptionalOperand(dynamicOptionsOperand);
+
+ if (hasDynamicOptions.has_value()) {
+ if (failed(hasDynamicOptions.value()))
+ return failure();
+
+ dynamicOptions = dynamicOptionsOperand;
+ return success();
+ }
+
+ OptionalParseResult hasStaticOptions =
+ parser.parseOptionalAttribute(staticOptions);
+ if (hasStaticOptions.has_value()) {
+ if (failed(hasStaticOptions.value()))
+ return failure();
+ return success();
+ }
+
+ return success();
+}
+
+static void printApplyRegisteredPassOptions(OpAsmPrinter &printer,
+ Operation *op, Value dynamicOptions,
+ StringAttr staticOptions) {
+ if (dynamicOptions) {
+ printer.printOperand(dynamicOptions);
+ } else if (!staticOptions.getValue().empty()) {
+ printer.printAttribute(staticOptions);
+ }
+}
+
//===----------------------------------------------------------------------===//
// CastOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Transform/test-pass-application.mlir b/mlir/test/Dialect/Transform/test-pass-application.mlir
index 3a40b462b8270..e8e0f63b28096 100644
--- a/mlir/test/Dialect/Transform/test-pass-application.mlir
+++ b/mlir/test/Dialect/Transform/test-pass-application.mlir
@@ -79,7 +79,7 @@ module attributes {transform.with_named_sequence} {
%1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
// expected-error @below {{failed to add pass or pass pipeline to pipeline: canonicalize}}
// expected-error @below {{<Pass-Options-Parser>: no such option invalid-option}}
- transform.apply_registered_pass "canonicalize" to %1 {options = "invalid-option=1"} : (!transform.any_op) -> !transform.any_op
+ transform.apply_registered_pass "canonicalize" with options = "invalid-option=1" to %1 : (!transform.any_op) -> !transform.any_op
transform.yield
}
}
@@ -94,7 +94,56 @@ func.func @valid_pass_option() {
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
%1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
- transform.apply_registered_pass "canonicalize" to %1 {options = "top-down=false"} : (!transform.any_op) -> !transform.any_op
+ transform.apply_registered_pass "canonicalize" with options = "top-down=false" to %1 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: func @valid_dynamic_pass_option()
+func.func @valid_dynamic_pass_option() {
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
+ %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %pass_options = transform.param.constant "top-down=false" -> !transform.any_param
+ transform.apply_registered_pass "canonicalize" with options = %pass_options to %1 : (!transform.any_param, !transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+// -----
+
+func.func @invalid_pass_option_param() {
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
+ %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %pass_options = transform.param.constant 42 -> !transform.any_param
+ // expected-error @below {{options passed as a param must be a string, got 42}}
+ transform.apply_registered_pass "canonicalize" with options = %pass_options to %1 : (!transform.any_param, !transform.any_op) -> !transform.any_op
+ transform.apply_registered_pass "canonicalize" with options = "invalid-option=1" to %1 : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @too_many_pass_option_params() {
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
+ %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %x = transform.param.constant "x" -> !transform.any_param
+ %pass_options = transform.merge_handles %x, %x : !transform.any_param
+ // expected-error @below {{options passed as a param must be a single value, got 2}}
+ transform.apply_registered_pass "canonicalize" with options = %pass_options to %1 : (!transform.any_param, !transform.any_op) -> !transform.any_op
transform.yield
}
}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 for the concept
Overall LGTM but let's wait for another opinion
I like the possibility to have be able to provide options as SSA values. It seems like one might want to mix static and dynamic options. Would it make sense to provide the options as a list rather than a string?
|
Indeed, being able to mix-and-match static arguments with those passed in dynamically - or being able to combine multiple orthogonal dynamic arguments - would be nice! The suggested syntax of a list/ Switching to this syntax does break the (documented) property that the I will have a go at updating the PR. Thanks @fschlimb! |
Updated the PR so that the following syntax is accepted (no brackets as it matches the cmdline options more closely - as suggested by @fschlimb offline): %max_iter = transform.param.constant "max-iterations=10" -> !transform.any_param
%max_rewrites = transform.param.constant "max-num-rewrites=1" -> !transform.any_param
%2 = transform.apply_registered_pass "canonicalize" with options = "top-down=false" %max_iter "test-convergence=true" %max_rewrites to %1 : (!transform.any_param, !transform.any_param, !transform.any_op) -> !transform.any_op The PR is ready to be re-reviewed. |
9acadec
to
9529ea4
Compare
…ram (llvm#142683) Makes it possible to pass around the options to a pass inside a schedule. The refactoring also makes it so that the pass manager and pass are only constructed once per `apply()` of the transform op versus for each target payload given to the op's `apply()`.
…ram (llvm#142683) Makes it possible to pass around the options to a pass inside a schedule. The refactoring also makes it so that the pass manager and pass are only constructed once per `apply()` of the transform op versus for each target payload given to the op's `apply()`.
* llvm/llvm-project#139340 ``` sed -i "s/isZeroIndex/isZeroInteger/g" **/*.h sed -i "s/isZeroIndex/isZeroInteger/g" **/*.cpp ``` * llvm/llvm-project#141466 & llvm/llvm-project#141019 * Add `BufferizationState &state` to `bufferize` and `getBuffer` * llvm/llvm-project#143159 & llvm/llvm-project#142683 & llvm/llvm-project#143779 * Updates to `transform.apply_registered_pass` and its Python-bindings * llvm/llvm-project#143217 * `tilingResult->mergeResult.replacements` -> `tilingResult->replacements` * llvm/llvm-project#140559 & llvm/llvm-project#143871 * Change CHECK lines which expected `amx.` ops to `llvm.call_intrinsic`s & fix which enables conversion again.
Makes it possible to pass around the options to a pass inside a schedule.
The refactoring also makes it so that the pass manager and pass are only
constructed once per
apply()
of the transform op versus for each targetpayload given to the op's
apply()
.