Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][Transform] Add a transform.match.operation_empty op to allow s… #68319

Merged

Conversation

nicolasvasilache
Copy link
Contributor

…pecifying negative conditions

In the process, get_parent_op gains an attribute to allow it to return empty handles explicitly and still succeed.

@llvmbot llvmbot added the mlir label Oct 5, 2023
@llvmbot
Copy link
Member

llvmbot commented Oct 5, 2023

@llvm/pr-subscribers-mlir

Changes

…pecifying negative conditions

In the process, get_parent_op gains an attribute to allow it to return empty handles explicitly and still succeed.


Full diff: https://github.com/llvm/llvm-project/pull/68319.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h (+99-39)
  • (modified) mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.td (+19-2)
  • (modified) mlir/include/mlir/Dialect/Transform/IR/TransformOps.td (+24-3)
  • (modified) mlir/lib/Dialect/Transform/IR/TransformOps.cpp (+20)
  • (modified) mlir/test/Dialect/Transform/test-interpreter.mlir (+21)
diff --git a/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h
index c8888f294f6ca1d..2cf008a911bd644 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.h
@@ -11,14 +11,46 @@
 
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
 #include "mlir/IR/OpDefinition.h"
+#include "llvm/ADT/STLExtras.h"
+#include <functional>
+#include <optional>
+#include <type_traits>
 
 namespace mlir {
 namespace transform {
 class MatchOpInterface;
 
+namespace detail {
 template <typename OpTy>
-class SingleOpMatcherOpTrait
-    : public OpTrait::TraitBase<OpTy, SingleOpMatcherOpTrait> {
+DiagnosedSilenceableFailure matchOptionalOperationImpl(
+    OpTy op, TransformResults &results, TransformState &state, std::false_type) {
+  return op.matchOperation(std::nullopt, results, state);
+}
+
+template <typename OpTy>
+DiagnosedSilenceableFailure
+matchOptionalOperationImpl(OpTy op, TransformResults &results,
+                           TransformState &state, std::true_type) {
+  return op.matchOperation(nullptr, results, state);
+}
+
+/// Dispatch `matchOperation` based on Operation* or std::optional<Operation*>
+/// first operand.
+template <typename OpTy, typename... Args>
+DiagnosedSilenceableFailure
+matchOptionalOperation(OpTy op, TransformResults &results,
+                           TransformState &state) {
+  using uses_operation_ptr_t =
+      typename std::is_same <
+        typename llvm::function_traits<decltype(&OpTy::matchOperation)>::template arg_t<0>,
+        Operation*>;
+  return matchOptionalOperationImpl(op, results, state, uses_operation_ptr_t{});
+}
+} // namespace detail
+
+template <typename OpTy>
+class AtMostOneOpMatcherOpTrait
+    : public OpTrait::TraitBase<OpTy, AtMostOneOpMatcherOpTrait> {
   template <typename T>
   using has_get_operand_handle =
       decltype(std::declval<T &>().getOperandHandle());
@@ -30,20 +62,22 @@ class SingleOpMatcherOpTrait
 public:
   static LogicalResult verifyTrait(Operation *op) {
     static_assert(llvm::is_detected<has_get_operand_handle, OpTy>::value,
-                  "SingleOpMatcherOpTrait expects operation type to have the "
+                  "AtMostOneOpMatcherOpTrait/SingleOpMatcherOpTrait expects operation type to have the "
                   "getOperandHandle() method");
     static_assert(llvm::is_detected<has_match_operation, OpTy>::value,
-                  "SingleOpMatcherOpTrait expected operation type to have the "
+                  "AtMostOneOpMatcherOpTrait/SingleOpMatcherOpTrait expected "
+                  "operation type to have the "
                   "matchOperation(Operation *, TransformResults &, "
                   "TransformState &) method");
 
     // This must be a dynamic assert because interface registration is dynamic.
     assert(isa<MatchOpInterface>(op) &&
-           "SingleOpMatchOpTrait is only available on operations with "
+           "AtMostOneOpMatcherOpTrait/SingleOpMatchOpTrait is only available on "
+           "operations with "
            "MatchOpInterface");
     Value operandHandle = cast<OpTy>(op).getOperandHandle();
     if (!isa<TransformHandleTypeInterface>(operandHandle.getType())) {
-      return op->emitError() << "SingleOpMatchOpTrait requires the op handle "
+      return op->emitError() << "AtMostOneOpMatcherOpTrait/SingleOpMatchOpTrait requires the op handle "
                                 "to be of TransformHandleTypeInterface";
     }
 
@@ -55,12 +89,16 @@ class SingleOpMatcherOpTrait
                                     TransformState &state) {
     Value operandHandle = cast<OpTy>(this->getOperation()).getOperandHandle();
     auto payload = state.getPayloadOps(operandHandle);
-    if (!llvm::hasSingleElement(payload)) {
+    if (!payload.empty() && !llvm::hasSingleElement(payload)) {
       return emitDefiniteFailure(this->getOperation()->getLoc())
-             << "SingleOpMatchOpTrait requires the operand handle to point to "
-                "a single payload op";
+             << "AtMostOneOpMatcherOpTrait requires the operand handle to "
+                "point to "
+                "at most one payload op";
+    }
+    if (payload.empty()) {
+      return detail::matchOptionalOperation(cast<OpTy>(this->getOperation()), results,
+                                    state);
     }
-
     return cast<OpTy>(this->getOperation())
         .matchOperation(*payload.begin(), results, state);
   }
@@ -73,46 +111,68 @@ class SingleOpMatcherOpTrait
 };
 
 template <typename OpTy>
-class SingleValueMatcherOpTrait
-    : public OpTrait::TraitBase<OpTy, SingleValueMatcherOpTrait> {
-public:
-  static LogicalResult verifyTrait(Operation *op) {
-    // This must be a dynamic assert because interface registration is dynamic.
-    assert(isa<MatchOpInterface>(op) &&
-           "SingleValueMatchOpTrait is only available on operations with "
-           "MatchOpInterface");
-
-    Value operandHandle = cast<OpTy>(op).getOperandHandle();
-    if (!isa<TransformValueHandleTypeInterface>(operandHandle.getType())) {
-      return op->emitError() << "SingleValueMatchOpTrait requires an operand "
-                                "of TransformValueHandleTypeInterface";
-    }
-
-    return success();
-  }
+class SingleOpMatcherOpTrait
+    : public AtMostOneOpMatcherOpTrait<OpTy> {
 
+  public:
   DiagnosedSilenceableFailure apply(TransformRewriter &rewriter,
                                     TransformResults &results,
                                     TransformState &state) {
     Value operandHandle = cast<OpTy>(this->getOperation()).getOperandHandle();
-    auto payload = state.getPayloadValues(operandHandle);
+    auto payload = state.getPayloadOps(operandHandle);
     if (!llvm::hasSingleElement(payload)) {
       return emitDefiniteFailure(this->getOperation()->getLoc())
-             << "SingleValueMatchOpTrait requires the value handle to point to "
-                "a single payload value";
+             << "SingleOpMatchOpTrait requires the operand handle to point to "
+                "a single payload op";
     }
-
-    return cast<OpTy>(this->getOperation())
-        .matchValue(*payload.begin(), results, state);
-  }
-
-  void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
-    onlyReadsHandle(this->getOperation()->getOperands(), effects);
-    producesHandle(this->getOperation()->getResults(), effects);
-    onlyReadsPayload(effects);
+    return static_cast<AtMostOneOpMatcherOpTrait<OpTy> *>(this)->apply(
+        rewriter, results, state);
   }
 };
 
+template <typename OpTy>
+  class SingleValueMatcherOpTrait
+      : public OpTrait::TraitBase<OpTy, SingleValueMatcherOpTrait> {
+  public:
+    static LogicalResult verifyTrait(Operation *op) {
+      // This must be a dynamic assert because interface registration is
+      // dynamic.
+      assert(isa<MatchOpInterface>(op) &&
+             "SingleValueMatchOpTrait is only available on operations with "
+             "MatchOpInterface");
+
+      Value operandHandle = cast<OpTy>(op).getOperandHandle();
+      if (!isa<TransformValueHandleTypeInterface>(operandHandle.getType())) {
+        return op->emitError() << "SingleValueMatchOpTrait requires an operand "
+                                  "of TransformValueHandleTypeInterface";
+      }
+
+      return success();
+    }
+
+    DiagnosedSilenceableFailure apply(TransformRewriter &rewriter,
+                                      TransformResults &results,
+                                      TransformState &state) {
+      Value operandHandle = cast<OpTy>(this->getOperation()).getOperandHandle();
+      auto payload = state.getPayloadValues(operandHandle);
+      if (!llvm::hasSingleElement(payload)) {
+        return emitDefiniteFailure(this->getOperation()->getLoc())
+               << "SingleValueMatchOpTrait requires the value handle to point "
+                  "to "
+                  "a single payload value";
+      }
+
+      return cast<OpTy>(this->getOperation())
+          .matchValue(*payload.begin(), results, state);
+    }
+
+    void getEffects(SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+      onlyReadsHandle(this->getOperation()->getOperands(), effects);
+      producesHandle(this->getOperation()->getResults(), effects);
+      onlyReadsPayload(effects);
+    }
+  };
+
 } // namespace transform
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.td b/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.td
index 1f81fd5252eb45b..be92e4d91b42b32 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/MatchInterfaces.td
@@ -14,11 +14,28 @@ def MatchOpInterface
   let cppNamespace = "::mlir::transform";
 }
 
+// Trait for "matcher" transform operations that apply to an operation handle
+// associated with at most one payload operation. Checks that it is indeed
+// the case and produces a definite failure when it is not. The matching logic
+// is implemented in the `matchOperation` function instead of `apply`. The op
+// with this trait must provide a `Value getOperandHandle()` function that
+// returns the handle to be used for matching.
+def AtMostOneOpMatcher : NativeOpTrait<"AtMostOneOpMatcherOpTrait"> {
+  let cppNamespace = "::mlir::transform";
+
+  string extraDeclaration = [{
+    ::mlir::DiagnosedSilenceableFailure matchOperation(
+          ::std::optional<::mlir::Operation *> maybeCurrent,
+          ::mlir::transform::TransformResults &results,
+          ::mlir::transform::TransformState &state);
+  }];
+}
+
 // Trait for "matcher" transform operations that apply to an operation handle
 // associated with exactly one payload operation. Checks that it is indeed
 // the case and produces a definite failure when it is not. The matching logic
 // is implemented in the `matchOperation` function instead of `apply`. The op
-// with this trait must provide a `Value getOperandHandle()` function that 
+// with this trait must provide a `Value getOperandHandle()` function that
 // returns the handle to be used for matching.
 def SingleOpMatcher : NativeOpTrait<"SingleOpMatcherOpTrait"> {
   let cppNamespace = "::mlir::transform";
@@ -35,7 +52,7 @@ def SingleOpMatcher : NativeOpTrait<"SingleOpMatcherOpTrait"> {
 // associated with exactly one payload value. Checks that it is indeed
 // the case and produces a definite failure when it is not. The matching logic
 // is implemented in the `matchValue` function instead of `apply`. The op
-// with this trait must provide a `Value getOperandHandle()` function that 
+// with this trait must provide a `Value getOperandHandle()` function that
 // returns the handle to be used for matching.
 def SingleValueMatcher : NativeOpTrait<"SingleValueMatcherOpTrait"> {
   let cppNamespace = "::mlir::transform";
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index ca5c915ef8c2caa..2c6917236d34ddf 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -595,8 +595,9 @@ def GetDefiningOp : TransformDialectOp<"get_defining_op",
 
 def GetParentOp : TransformDialectOp<"get_parent_op",
     [DeclareOpInterfaceMethods<TransformOpInterface>,
+     MatchOpInterface,
      NavigationTransformOpTrait, MemoryEffectsOpInterface]> {
-  let summary = "Gets handles to the closest isolated-from-above parents";
+  let summary = "Gets handles to the closest parent ops";
   let description = [{
     The handle defined by this Transform op corresponds to the parents of the
     targeted payload ops (in the same order).
@@ -605,6 +606,9 @@ def GetParentOp : TransformDialectOp<"get_parent_op",
     that case for each target op, the closest parent op that fulfills all
     requirements, is returned.
     - `isolated_from_above`: the parent op must be isolated from above
+    - `allow_empty_results`: get_parent_op is allowed to return an empty list and
+      still succeeds. In such a case, if get_parent_op fails for any operation
+      in the list, the entire transform returns an empty handle.
     - `op_name`: the parent op must have the specified name
 
     If `deduplicate` is set, the result handle does not contain any duplicate
@@ -614,12 +618,14 @@ def GetParentOp : TransformDialectOp<"get_parent_op",
     is applied, e.g., "B" may itself be a parent of "A". This may have an impact
     on the further transformation applied to the handle produced here.
 
-    If any of the given Payload IR ops has no such suitable parent, the
-    transformation fails silently.
+    If any of the given Payload IR ops has no such suitable parent, then:
+      - if `allow_empty_results` is set, the result handle is empty
+      - otherwise, the transformation fails silently.
   }];
 
   let arguments = (ins TransformHandleTypeInterface:$target,
                        UnitAttr:$isolated_from_above,
+                       UnitAttr:$allow_empty_results,
                        OptionalAttr<StrAttr>:$op_name,
                        UnitAttr:$deduplicate);
   let results = (outs TransformHandleTypeInterface:$parent);
@@ -739,6 +745,21 @@ def IncludeOp : TransformDialectOp<"include",
   }];
 }
 
+def MatchOperationEmptyOp : Op<Transform_Dialect, "match.operation_empty", [
+    AtMostOneOpMatcher,
+    MatchOpInterface,
+    MemoryEffectsOpInterface]> {
+  let summary =
+    "Matches if the handle is not associated to any op";
+  let description = [{
+    Succeeds if the handle is not associated to any op.
+  }];
+  let arguments = (ins TransformHandleTypeInterface:$operand_handle);
+  let assemblyFormat =
+      "$operand_handle attr-dict `:` type($operand_handle)";
+  let extraClassDeclaration = AtMostOneOpMatcher.extraDeclaration;
+}
+
 def MatchOperationNameOp : TransformDialectOp<"match.operation_name",
     [SingleOpMatcher,
      MatchOpInterface,
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 44626260e2f9ef3..2dff1bf3d0a80ef 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -20,6 +20,7 @@
 #include "mlir/IR/Dominance.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/Verifier.h"
+#include "mlir/IR/Visitors.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/FunctionImplementation.h"
 #include "mlir/Pass/Pass.h"
@@ -1244,6 +1245,10 @@ transform::GetParentOp::apply(transform::TransformRewriter &rewriter,
       parent = parent->getParentOp();
     }
     if (!parent) {
+      if (getAllowEmptyResults()) {
+        results.set(llvm::cast<OpResult>(getResult()), parents);
+        return DiagnosedSilenceableFailure::success();
+      }
       DiagnosedSilenceableFailure diag =
           emitSilenceableError()
           << "could not find a parent op that matches all requirements";
@@ -1545,6 +1550,21 @@ transform::IncludeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
       .checkAndReport();
 }
 
+//===----------------------------------------------------------------------===//
+// MatchOperationEmptyOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure transform::MatchOperationEmptyOp::matchOperation(
+    ::std::optional<::mlir::Operation *> maybeCurrent, 
+    transform::TransformResults &results, transform::TransformState &state) {
+  if (!maybeCurrent.has_value()) {
+    DBGS_MATCHER() << "MatchOperationEmptyOp success\n";
+    return DiagnosedSilenceableFailure::success();
+  }
+  DBGS_MATCHER() << "MatchOperationEmptyOp failure\n";
+  return emitSilenceableError() << "operation is not empty";
+}
+
 //===----------------------------------------------------------------------===//
 // MatchOperationNameOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index daa179cb15408b4..b641b21e876cc42 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -2037,3 +2037,24 @@ transform.sequence failures(propagate) {
   // expected-remark @below{{0}}
   test_print_number_of_associated_payload_ir_ops %empty_op : !transform.any_op
 }
+
+
+// -----
+
+func.func @no_constant_under_loop(%lb: index, %ub: index, %step: index) {
+  scf.for %i= %lb to %ub step %step {
+    // expected-remark @below{{found arith.constant}}
+    arith.constant 0 : index
+  }
+  return
+}
+
+// Match `func.func`s that are not nested under a `func.func` and ensure there are none in the program
+transform.named_sequence @match_func_for_dispatch(%root: !transform.any_op {transform.readonly}) 
+  -> !transform.any_op {
+  transform.match.operation_name %root ["arith.constant"] : !transform.any_op
+  %variant = transform.get_parent_op %root { op_name = "func.func", allow_empty_results }
+    : (!transform.any_op) -> (!transform.any_op)
+  transform.match.operation_empty %variant : !transform.any_op
+  transform.yield %root : !transform.any_op
+}

@github-actions
Copy link

github-actions bot commented Oct 5, 2023

✅ With the latest revision this PR passed the C/C++ code formatter.

@nicolasvasilache nicolasvasilache force-pushed the flush-simplify-transforms branch 4 times, most recently from 213537f to 2f01dd8 Compare October 5, 2023 15:29
…pecifying negative conditions

In the process, get_parent_op gains an attribute to allow it to return empty handles explicitly and still succeed.
@nicolasvasilache nicolasvasilache force-pushed the flush-simplify-transforms branch from 2f01dd8 to 711d360 Compare October 6, 2023 07:15
@nicolasvasilache nicolasvasilache merged commit 98341df into llvm:main Oct 6, 2023
2 checks passed
@nicolasvasilache nicolasvasilache deleted the flush-simplify-transforms branch October 6, 2023 07:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants