Skip to content

[mlir][transform] Fix handling of transitive include in interpreter. #67560

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

ingomueller-net
Copy link
Contributor

@ingomueller-net ingomueller-net commented Sep 27, 2023

This is a new attempt at fixing transitive includes in the transform dialect interpreter (next to #67241) and a preparation for being able to load multiple transform library files (#67120).

Until now, the interpreter would only load those symbols from the provided library files that were declared in the main transform module. However, sequences in the library may include other sequences on their own. Until now, if such sequences were not also declared in the main transform module, the interpreter would fail to resolve them. Forward declaring all of them is undesirable as it defeats the purpose of encapsulation into library modules.

This PR implements a kind of linker for transform scripts to solve this problem. The linker merges all symbols of the library module into the main module before interpreting the latter. Symbols whose names collide are handled as follows: (1) if they are both functions (in the sense of FunctionOpInterface) with compatible signatures, one is external, and the other one is public, then they are merged; (2) of one of them is private, that one is renamed; and (3) an error is raised otherwise.

TODO:

  • Enable running interpreter several times (see discussion below).
  • Fix several smaller issues marked with XXX.
  • Fix reproducer: since injection of the library is now not idempotent anymore, we should not print the filename argument in the repro call.
  • Revise CLI and function parameter documentation and description.

@ingomueller-net ingomueller-net force-pushed the transform-interpreter-transitive-include branch from 8156631 to 6e83ecd Compare September 27, 2023 14:57
@ingomueller-net
Copy link
Contributor Author

@ftynse: This is a new attempt at #67241 along the lines of your suggestions. Please take a look. It is worth mentioning that (i) I merged the stages (1) and (2) from your suggestion, and (ii) did not implement the optimization of only including needed symbols.

I have identified one more problem that may require a major change: it is currently not always possible to run the interpreter pass several times. The problem is the following: if the script contains a declaration @foo and the library contains a matching definition, then than definition is merge into the transform module upon its first invocation. In the second invocation, however, @foo is now a definition in the transform module, so merging in the library module leads to a doubly defined symbol error. This did not happen before because only the (directly) missing symbols were merged, and @foo is not missing after the first invocation anymore.

I see two ways out: (A) only merge in missing symbols (plus their dependencies) as previously and make all imported symbols private, or (B) separate the merging from the interpreter into a separate pass. I have the feeling that Approach (A) would make the merging again more special to the interpreter and possibly somewhat brittle and Approach (B) would be more generic and possibly more performant. (If the interpreter pass is really used several times, linking has to be done each time, while with a dedicated linking pass, it would only be done once.)

I imagine the linking pass from Approach (B) to be called something like load-transform-dialect-interpreter-module or similar, take the transform-library-file-name and transform-library-file-name arguments (as well as their extensions planned in #67120), and then only do the loading and merging/linking (but not interpret the result). The main interpreter pass would then simply run the interpreter on the input module and not bother with any loading or merging. Also, I think I'd implement that separation first and then rebase this change on top of it.

What do you think?

@ingomueller-net ingomueller-net force-pushed the transform-interpreter-transitive-include branch from 6e83ecd to 984af0f Compare September 27, 2023 15:19
@ftynse
Copy link
Member

ftynse commented Sep 28, 2023

Approach B makes sense to me. Generally, anything that keeps the interpreter simpler and more focused is likely a good idea. I suppose it should be sufficient to just move the symbols from other files,normally, it is allowed to have multiple declarations as long as there is a single definition + handle name conflicts between private symbols that may have repeated names in different modules. Then we can ran symbol-dce separately to remove unused symbols.

@ftynse
Copy link
Member

ftynse commented Sep 28, 2023

Implementation wise, we can consider having a "mergeModules" function that takes a list of modules and moves everything into the first one. It can then be exercised from the pass that loads transform library files, but is also reusable.

@ingomueller-net
Copy link
Contributor Author

Implementation wise, we can consider having a "mergeModules" function that takes a list of modules and moves everything into the first one. It can then be exercised from the pass that loads transform library files, but is also reusable.

Yeah, good point! Being able to merge several modules is on my TODO list. That would also be required if we wanted to write a generic merger that could replace most of SPIR-V's ModuleCombiner. But I'll leaver that to that after the remaining pending changes...

@ftynse
Copy link
Member

ftynse commented Sep 28, 2023

Well, merging several modules is a for loop around a function that merges two ;)

@ingomueller-net
Copy link
Contributor Author

Yes, but you can do (slightly) better by being careful when to (re)construct which symbol table(s)...

@ingomueller-net
Copy link
Contributor Author

I think I changed my mind about the question of whether or not loading should be part of the interpreter. At the very least, I see that it is absolutely consistent to think of the current pass as "inject dependencies and interpret result." In this case, it is clear that running that pass twice in a way that tries to inject the dependencies twice does not work. The PR thus currently removes that test. If we decide that injection should be separated from the interpreter or that injection should be done only if necessary (such that it is idempotent), I believe we can do that in some future PR.

Other than that, I have overhauled the initial commit and think that it is ready for review now.

@ingomueller-net ingomueller-net marked this pull request as ready for review September 29, 2023 10:09
@llvmbot llvmbot added the mlir label Sep 29, 2023
@llvmbot
Copy link
Member

llvmbot commented Sep 29, 2023

@llvm/pr-subscribers-mlir

Changes

This is a new attempt at fixing transitive includes in the transform dialect interpreter (next to #67241) and a preparation for being able to load multiple transform library files (#67120).

Until now, the interpreter would only load those symbols from the provided library files that were declared in the main transform module. However, sequences in the library may include other sequences on their own. Until now, if such sequences were not also declared in the main transform module, the interpreter would fail to resolve them. Forward declaring all of them is undesirable as it defeats the purpose of encapsulation into library modules.

This PR implements a kind of linker for transform scripts to solve this problem. The linker merges all symbols of the library module into the main module before interpreting the latter. Symbols whose names collide are handled as follows: (1) if they are both functions (in the sense of FunctionOpInterface) with compatible signatures, one is external, and the other one is public, then they are merged; (2) of one of them is private, that one is renamed; and (3) an error is raised otherwise.

TODO:

  • Enable running interpreter several times (see discussion below).
  • Fix several smaller issues marked with XXX.

Patch is 27.39 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/67560.diff

6 Files Affected:

  • (modified) mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp (+252-60)
  • (modified) mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-and-schedule.mlir (+6)
  • (modified) mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-invalid.mlir (+19-4)
  • (modified) mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir (+49-12)
  • (added) mlir/test/Dialect/Transform/test-interpreter-external-symbol-def-invalid.mlir (+14)
  • (modified) mlir/test/Dialect/Transform/test-interpreter-external-symbol-def.mlir (+42-1)
diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
index d5c65b23e3a2134..3dd172335160fb3 100644
--- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
+++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
@@ -50,6 +50,8 @@ constexpr static llvm::StringLiteral kTransformDialectTagPayloadRootValue =
 constexpr static llvm::StringLiteral
     kTransformDialectTagTransformContainerValue = "transform_container";
 
+namespace {
+
 /// Utility to parse the content of a `transformFileName` MLIR file containing
 /// a transform dialect specification.
 static LogicalResult
@@ -302,80 +304,270 @@ static void performOptionalDebugActions(
     transform->removeAttr(kTransformDialectTagAttrName);
 }
 
-/// Replaces external symbols in `block` with their (non-external) definitions
-/// from the given module.
-static LogicalResult defineDeclaredSymbols(Block &block, ModuleOp definitions) {
-  MLIRContext &ctx = *definitions->getContext();
-  auto consumedName =
-      StringAttr::get(&ctx, transform::TransformDialect::kArgConsumedAttrName);
-  auto readOnlyName =
-      StringAttr::get(&ctx, transform::TransformDialect::kArgReadOnlyAttrName);
-
-  for (Operation &op : llvm::make_early_inc_range(block)) {
-    LLVM_DEBUG(DBGS() << op << "\n");
-    auto symbol = dyn_cast<SymbolOpInterface>(op);
-    if (!symbol)
-      continue;
-    if (symbol->getNumRegions() == 1 && !symbol->getRegion(0).empty())
-      continue;
-
-    LLVM_DEBUG(DBGS() << "looking for definition of symbol "
-                      << symbol.getNameAttr() << ":");
-    SymbolTable symbolTable(definitions);
-    Operation *externalSymbol = symbolTable.lookup(symbol.getNameAttr());
-    if (!externalSymbol || externalSymbol->getNumRegions() != 1 ||
-        externalSymbol->getRegion(0).empty()) {
-      LLVM_DEBUG(llvm::dbgs() << "not found\n");
-      continue;
+/// Rename `op` to avoid a collision with `otherOp`. `symbolTable` and
+/// `otherSymbolTable` are the symbol tables of the two ops, respectively.
+/// `uniqueId` is used to generate a unique name in the context of the caller.
+LogicalResult renameToUnique(SymbolOpInterface op, SymbolOpInterface otherOp,
+                             SymbolTable &symbolTable,
+                             SymbolTable &otherSymbolTable, int &uniqueId) {
+  assert(symbolTable.lookup(op.getNameAttr()) == op &&
+         "symbol table does not contain op");
+  assert(otherSymbolTable.lookup(otherOp.getNameAttr()) == otherOp &&
+         "other symbol table does not contain other op");
+
+  // Determine new name that is unique in both symbol tables.
+  StringAttr oldName = op.getNameAttr();
+  StringAttr newName;
+  {
+    MLIRContext *context = op->getContext();
+    SmallString<64> prefix = oldName.getValue();
+    prefix.push_back('_');
+    while (true) {
+      newName = StringAttr::get(context, prefix + Twine(uniqueId++));
+      if (!symbolTable.lookup(newName) && !otherSymbolTable.lookup(newName)) {
+        break;
+      }
     }
+  }
+
+  // Apply renaming.
+  LLVM_DEBUG(llvm::dbgs() << ", renaming to @" << newName.getValue() << "\n");
+  Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(op);
+  if (failed(SymbolTable::replaceAllSymbolUses(op, newName, symbolTableOp))) {
+    InFlightDiagnostic diag =
+        emitError(op->getLoc(),
+                  Twine("failed to rename symbol to @") + newName.getValue());
+    diag.attachNote(otherOp->getLoc())
+        << "attempted renaming due to collision with this op";
+    return diag;
+  }
+
+  // Change the symbol in the op itself and update the symbol table.
+  symbolTable.remove(op);
+  SymbolTable::setSymbolName(op, newName);
+  symbolTable.insert(op);
+
+  assert(symbolTable.lookup(newName) == op &&
+         "symbol table does not resolve to renamed op");
+  assert(symbolTable.lookup(oldName) == nullptr &&
+         "symbol table still resolves old name");
+
+  return success();
+}
 
-    auto symbolFunc = dyn_cast<FunctionOpInterface>(op);
-    auto externalSymbolFunc = dyn_cast<FunctionOpInterface>(externalSymbol);
-    if (!symbolFunc || !externalSymbolFunc) {
-      LLVM_DEBUG(llvm::dbgs() << "cannot compare types\n");
+/// Return whether `func1` can be merged into `func2`.
+bool canMergeInto(FunctionOpInterface func1, FunctionOpInterface func2) {
+  return func1.isExternal() && (func2.isPublic() || func2.isExternal());
+}
+
+/// Merge `func1` into `func2`. The two ops must be inside the same parent op
+/// and mergable according to `canMergeInto`. The function erases `func1` such
+/// that only `func2` exists when the function returns.
+LogicalResult mergeInto(FunctionOpInterface func1, FunctionOpInterface func2) {
+  assert(canMergeInto(func1, func2));
+  assert(func1->getParentOp() == func2->getParentOp() &&
+         "expected func1 and func2 to be in the same parent op");
+
+  MLIRContext *context = func1->getContext();
+  auto consumedName = StringAttr::get(
+      context, transform::TransformDialect::kArgConsumedAttrName);
+  auto readOnlyName = StringAttr::get(
+      context, transform::TransformDialect::kArgReadOnlyAttrName);
+
+  // Check that function signatures match.
+  if (func1.getFunctionType() != func2.getFunctionType()) {
+    return func1.emitError()
+           << "external definition has a mismatching signature ("
+           << func2.getFunctionType() << ")";
+  }
+
+  // Check and merge argument attributes.
+  for (unsigned i = 0, e = func1.getNumArguments(); i < e; ++i) {
+    bool isExternalConsumed = func2.getArgAttr(i, consumedName) != nullptr;
+    bool isExternalReadonly = func2.getArgAttr(i, readOnlyName) != nullptr;
+    bool isConsumed = func1.getArgAttr(i, consumedName) != nullptr;
+    bool isReadonly = func1.getArgAttr(i, readOnlyName) != nullptr;
+    if (!isExternalConsumed && !isExternalReadonly) {
+      if (isConsumed)
+        func2.setArgAttr(i, consumedName, UnitAttr::get(context));
+      else if (isReadonly)
+        func2.setArgAttr(i, readOnlyName, UnitAttr::get(context));
       continue;
     }
 
-    LLVM_DEBUG(llvm::dbgs() << "found @" << externalSymbol << "\n");
-    if (symbolFunc.getFunctionType() != externalSymbolFunc.getFunctionType()) {
-      return symbolFunc.emitError()
-             << "external definition has a mismatching signature ("
-             << externalSymbolFunc.getFunctionType() << ")";
+    if ((isExternalConsumed && !isConsumed) ||
+        (isExternalReadonly && !isReadonly)) {
+      return func1.emitError()
+             << "external definition has mismatching consumption "
+                "annotations for argument #"
+             << i;
     }
+  }
+
+  // `func1` is the external one, so we can remove it.
+  assert(func1.isExternal());
+  func1->erase();
+
+  return success();
+}
+
+/// Merge all symbols from `other` into `target`. Both ops need to implement the
+/// `SymbolTable` trait. Operations are moved from `other`, i.e., `other` may be
+/// modified by this function and might not verify after the function returns.
+/// Upon merging, private symbols may be renamed in order to avoid collisions in
+/// the result. Public symbols may not collide, with the exception of
+/// instances of `SymbolOpInterface`, where collisions are allowed if at least
+/// one of the two is external, in which case the other op preserved (or any one
+/// of the two if both are external).
+static LogicalResult mergeSymbolsInto(Operation *target,
+                                      OwningOpRef<Operation *> other) {
+  assert(target->hasTrait<OpTrait::SymbolTable>() &&
+         "requires target to implement the 'SymbolTable' trait");
+  assert(other->hasTrait<OpTrait::SymbolTable>() &&
+         "requires target to implement the 'SymbolTable' trait");
+
+  SymbolTable targetSymbolTable(target);
+  SymbolTable otherSymbolTable(*other);
+
+  int uniqueId = 0;
+
+  // Step 1:
+  //
+  // Rename private symbols in both ops in order to resolve conflicts that can
+  // be resolved that way.
+  LLVM_DEBUG(DBGS() << "renaming private symbols to resolve conflicts:\n");
+  for (auto [symbolTable, otherSymbolTable] : llvm::zip(
+           SmallVector<SymbolTable *>{&targetSymbolTable, &otherSymbolTable},
+           SmallVector<SymbolTable *>{&otherSymbolTable, &targetSymbolTable})) {
+    Operation *symbolTableOp = symbolTable->getOp();
+    for (Operation &op : symbolTableOp->getRegion(0).front()) {
+      auto symbolOp = dyn_cast<SymbolOpInterface>(op);
+      if (!symbolOp)
+        continue;
+      StringAttr name = symbolOp.getNameAttr();
+      LLVM_DEBUG(DBGS() << "  found @" << name.getValue() << "\n");
 
-    for (unsigned i = 0, e = symbolFunc.getNumArguments(); i < e; ++i) {
-      bool isExternalConsumed =
-          externalSymbolFunc.getArgAttr(i, consumedName) != nullptr;
-      bool isExternalReadonly =
-          externalSymbolFunc.getArgAttr(i, readOnlyName) != nullptr;
-      bool isConsumed = symbolFunc.getArgAttr(i, consumedName) != nullptr;
-      bool isReadonly = symbolFunc.getArgAttr(i, readOnlyName) != nullptr;
-      if (!isExternalConsumed && !isExternalReadonly) {
-        if (isConsumed)
-          externalSymbolFunc.setArgAttr(i, consumedName, UnitAttr::get(&ctx));
-        else if (isReadonly)
-          externalSymbolFunc.setArgAttr(i, readOnlyName, UnitAttr::get(&ctx));
+      // Check if there is a colliding op in the other module.
+      auto collidingOp =
+          cast_or_null<SymbolOpInterface>(otherSymbolTable->lookup(name));
+      if (!collidingOp)
         continue;
+
+      LLVM_DEBUG(DBGS() << "    collision found for @" << name.getValue());
+
+      // Collisions are fine if both opt are functions and can be merged.
+      if (auto funcOp = dyn_cast<FunctionOpInterface>(op),
+          collidingFuncOp =
+              dyn_cast<FunctionOpInterface>(collidingOp.getOperation());
+          funcOp && collidingFuncOp) {
+        if (canMergeInto(funcOp, collidingFuncOp) ||
+            canMergeInto(collidingFuncOp, funcOp)) {
+          LLVM_DEBUG(llvm::dbgs() << " but both ops are functions and "
+                                     "will be merged\n");
+          continue;
+        }
+
+        // If they can't be merged, proceed like any other collision.
+        LLVM_DEBUG(llvm::dbgs() << " and both ops are function definitions");
       }
 
-      if ((isExternalConsumed && !isConsumed) ||
-          (isExternalReadonly && !isReadonly)) {
-        return symbolFunc.emitError()
-               << "external definition has mismatching consumption annotations "
-                  "for argument #"
-               << i;
+      // Collision can be resolved if one of the ops is private.
+      if (symbolOp.isPrivate()) {
+        if (failed(renameToUnique(symbolOp, collidingOp, *symbolTable,
+                                  *otherSymbolTable, uniqueId)))
+          return failure();
+        continue;
+      }
+      if (collidingOp.isPrivate()) {
+        if (failed(renameToUnique(collidingOp, symbolOp, *otherSymbolTable,
+                                  *symbolTable, uniqueId)))
+          return failure();
+        continue;
       }
+
+      LLVM_DEBUG(llvm::dbgs() << ", emitting error\n");
+      InFlightDiagnostic diag =
+          emitError(symbolOp->getLoc(),
+                    Twine("doubly defined symbol @") + name.getValue());
+      diag.attachNote(collidingOp->getLoc()) << "previously defined here";
+      return diag;
     }
+  }
 
-    OpBuilder builder(&op);
-    builder.setInsertionPoint(&op);
-    builder.clone(*externalSymbol);
-    symbol->erase();
+  for (auto *op : SmallVector<Operation *>{target, *other}) {
+    if (failed(mlir::verify(op)))
+      return emitError(op->getLoc(),
+                       "failed to verify input op after renaming");
   }
 
+  // Step 2:
+  //
+  // Move all ops from `other` into target and merge public symbols.
+  LLVM_DEBUG(DBGS() << "moving all symbols into target\n");
+  {
+    SmallVector<SymbolOpInterface> opsToMove;
+    for (Operation &op : other->getRegion(0).front()) {
+      if (auto symbol = dyn_cast<SymbolOpInterface>(op))
+        opsToMove.push_back(symbol);
+    }
+
+    for (SymbolOpInterface op : opsToMove) {
+      // Remember potentially colliding op in the target module.
+      auto collidingOp = cast_or_null<SymbolOpInterface>(
+          targetSymbolTable.lookup(op.getNameAttr()));
+
+      // Move op even if we get a collision.
+      LLVM_DEBUG(DBGS() << "  moving @" << op.getName());
+      op->moveAfter(&target->getRegion(0).front(),
+                    target->getRegion(0).front().begin());
+
+      // If there is no collision, we are done.
+      if (!collidingOp) {
+        LLVM_DEBUG(llvm::dbgs() << " without collision\n");
+        continue;
+      }
+
+      // The two colliding ops must both be functions because we have already
+      // emitted errors otherwise earlier.
+      auto funcOp = cast<FunctionOpInterface>(op.getOperation());
+      auto collidingFuncOp =
+          cast<FunctionOpInterface>(collidingOp.getOperation());
+
+      // Both ops are in the target module now and can be treated symmetrically,
+      // so w.l.o.g. we can reduce to merging `funcOp` into `collidingFuncOp`.
+      if (!canMergeInto(funcOp, collidingFuncOp)) {
+        std::swap(funcOp, collidingFuncOp);
+      }
+      assert(canMergeInto(funcOp, collidingFuncOp));
+
+      LLVM_DEBUG(llvm::dbgs() << " with collision, trying to keep op at "
+                              << collidingFuncOp.getLoc() << ":\n"
+                              << collidingFuncOp << "\n");
+
+      // Update symbol table. This works with or without the previous `swap`.
+      targetSymbolTable.remove(funcOp);
+      targetSymbolTable.insert(collidingFuncOp);
+      assert(targetSymbolTable.lookup(funcOp.getName()) == collidingFuncOp);
+
+      // Do the actual merging.
+      if (failed(mergeInto(funcOp, collidingFuncOp))) {
+        return failure();
+      }
+
+      assert(succeeded(mlir::verify(target)));
+    }
+  }
+
+  if (failed(mlir::verify(target)))
+    return emitError(target->getLoc(),
+                     "failed to verify target op after merging symbols");
+
+  LLVM_DEBUG(DBGS() << "done merging ops\n");
   return success();
 }
 
+} // namespace
+
 LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
     Operation *target, StringRef passName,
     const std::shared_ptr<OwningOpRef<ModuleOp>> &sharedTransformModule,
@@ -438,8 +630,9 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
       diag.attachNote(target->getLoc()) << "pass anchor op";
       return diag;
     }
-    if (failed(defineDeclaredSymbols(*transformRoot->getBlock(),
-                                     libraryModule->get())))
+    if (failed(
+            mergeSymbolsInto(SymbolTable::getNearestSymbolTable(transformRoot),
+                             libraryModule->get()->clone())))
       return failure();
   }
 
@@ -499,8 +692,7 @@ LogicalResult transform::detail::interpreterBaseInitializeImpl(
     return success();
 
   if (module && *module) {
-    if (failed(defineDeclaredSymbols(*module->get().getBody(),
-                                     parsedLibrary.get())))
+    if (failed(mergeSymbolsInto(module->get(), std::move(parsedLibrary))))
       return failure();
   } else {
     libraryModule =
diff --git a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-and-schedule.mlir b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-and-schedule.mlir
index 3d4cb0776982934..dd8d141e994da0e 100644
--- a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-and-schedule.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-and-schedule.mlir
@@ -11,4 +11,10 @@
 
 // expected-remark @below {{message}}
 // expected-remark @below {{unannotated}}
+// expected-remark @below {{internal colliding (without suffix)}}
+// expected-remark @below {{internal colliding_0}}
+// expected-remark @below {{internal colliding_1}}
+// expected-remark @below {{internal colliding_3}}
+// expected-remark @below {{internal colliding_4}}
+// expected-remark @below {{internal colliding_5}}
 module {}
diff --git a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-invalid.mlir b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-invalid.mlir
index b21abbbdfd6d045..7452deb39b6c18d 100644
--- a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-invalid.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-invalid.mlir
@@ -1,16 +1,16 @@
-// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir}, test-transform-dialect-interpreter)" \
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def-invalid.mlir}, test-transform-dialect-interpreter)" \
 // RUN:             --verify-diagnostics --split-input-file
 
-// The definition of the @foo named sequence is provided in another file. It
+// The definition of the @print_message named sequence is provided in another file. It
 // will be included because of the pass option.
 
 module attributes {transform.with_named_sequence} {
   // expected-error @below {{external definition has a mismatching signature}}
-  transform.named_sequence private @foo(!transform.op<"builtin.module"> {transform.readonly})
+  transform.named_sequence private @print_message(!transform.op<"builtin.module"> {transform.readonly})
 
   transform.sequence failures(propagate) {
   ^bb0(%arg0: !transform.op<"builtin.module">):
-    include @foo failures(propagate) (%arg0) : (!transform.op<"builtin.module">) -> ()
+    include @print_message failures(propagate) (%arg0) : (!transform.op<"builtin.module">) -> ()
   }
 }
 
@@ -37,3 +37,18 @@ module attributes {transform.with_named_sequence} {
     include @consuming failures(suppress) (%arg0) : (!transform.any_op) -> ()
   }
 }
+
+// -----
+
+module attributes {transform.with_named_sequence} {
+  // expected-error @below {{doubly defined symbol @print_message}}
+  transform.named_sequence @print_message(%arg0: !transform.any_op {transform.readonly}) {
+    transform.test_print_remark_at_operand %arg0, "message" : !transform.any_op
+    transform.yield
+  }
+
+  transform.sequence failures(suppress) {
+  ^bb0(%arg0: !transform.any_op):
+    include @print_message failures(propagate) (%arg0) : (!transform.any_op) -> ()
+  }
+}
diff --git a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir
index 04b6c5a02e0adf1..a9083fe3e70788a 100644
--- a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir
@@ -4,29 +4,66 @@
 // RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir}, test-transform-dialect-interpreter)" \
 // RUN:             --verify-diagnostics --split-input-file | FileCheck %s
 
-// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir}, test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir})" \
-// RUN:             --verify-diagnostics --split-input-file | FileCheck %s
-
-// The definition of the @foo named sequence is provided in another file. It
-// will be included because of the pass option. Repeated application of the
-// same pass, with or without the library option, should not be a problem.
+// The definition of the @print_message named sequence is provided in another
+// file. It will be included because of the pass option. Subsequent application
+// of the same pass works but only without the library file (since the first
+// application loads external symbols and loading them again woul make them
+// clash).
 // Note that the same diagnostic produced twice at the same location only
 // needs to be matched once.
 
 // expected-remark @below {{message}}
 // expected-remark @below {{unannotated}}
+// ...
[truncated]

}

// Apply renaming.
LLVM_DEBUG(llvm::dbgs() << ", renaming to @" << newName.getValue() << "\n");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: looks like the debug line starts with a comma for no reason.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This terminated a debug line that was started earlier but that wasn't visible in this context. I believe that after the refactoring of renameToUnique, the context is now much smaller such that the intend should be understood. Let me know if you disagree.

LLVM_DEBUG(llvm::dbgs() << "cannot compare types\n");
/// Return whether `func1` can be merged into `func2`.
bool canMergeInto(FunctionOpInterface func1, FunctionOpInterface func2) {
return func1.isExternal() && (func2.isPublic() || func2.isExternal());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This condition is unclear to me. Why can we merge an external function into another external function? Neither has a body... Why should the other function be public? Is this about merging declarations?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this logic allows merging func1 into func2 if the former is a declaration and the latter is a public definition or it is a declaration as well. I added a comment along those lines. Let me know if anything else is needed.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A more fitting name could be beneficial, but I can't come up with one immediately. canMergeDeclarationInto?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that could work. (As you probably also felt, it's not perfect: the name suggest that it has to be a declaration but it actually checks if it is a declaration.)

One argument in favor of keeping it as is is that this is a formulation that generalizes: In the SPIR-V module combiner, they "merge" global variables and constants, and those are "mergeable" if they define the same thing. This generalization could be formalized into a MergeableSymbol interface (which extends Symbol) with two methods canMergeInto and mergeInto. At least for these two examples (SPIR-V and TD), that concept seems to work...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Kind of late to the party, but I agree that the meaning is difficult to grasp currently.
More specifically I believe the comment doesn't describe what mergeable means.

For me the confusing part is that merging implies that the result is bigger. So before reading the code I was wondering if we were doing some sort of function concatenation :).
<digression>LLVM has an optimization pass called merge globals that puts global variables in the same global structure to allow faster symbol resolution and that's the kind of thing this function name brings to mind for me.</digression>

Maybe canCoalesceSymbols?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am open to using a different verb or name. From what I understand from the dictionary, unite, combine, merge, and coalesce all roughly mean "take several things and make a single one of them."

I guess that that is undoubtably what we do with the outer ops (and the result is indeed larger usually). I'd argue that it is close to what we do with the symbols as well but maybe not exactly the same, at least not what we do with functions currently. Another angle to look at what we do with symbols is that we "deduplicate" them. Maybe that can yield a name?

@llvmbot llvmbot added the mlir:core MLIR Core Infrastructure label Oct 4, 2023
Copy link
Contributor Author

@ingomueller-net ingomueller-net left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot, @ftynse, for the detailed feedback! I just pushed a bunch of commits that should address your comments (modulo the questions I am submitting here).

}

// Apply renaming.
LLVM_DEBUG(llvm::dbgs() << ", renaming to @" << newName.getValue() << "\n");
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This terminated a debug line that was started earlier but that wasn't visible in this context. I believe that after the refactoring of renameToUnique, the context is now much smaller such that the intend should be understood. Let me know if you disagree.

LLVM_DEBUG(llvm::dbgs() << "cannot compare types\n");
/// Return whether `func1` can be merged into `func2`.
bool canMergeInto(FunctionOpInterface func1, FunctionOpInterface func2) {
return func1.isExternal() && (func2.isPublic() || func2.isExternal());
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this logic allows merging func1 into func2 if the former is a declaration and the latter is a public definition or it is a declaration as well. I added a comment along those lines. Let me know if anything else is needed.

@ingomueller-net ingomueller-net requested a review from ftynse October 4, 2023 12:05
LLVM_DEBUG(llvm::dbgs() << "cannot compare types\n");
/// Return whether `func1` can be merged into `func2`.
bool canMergeInto(FunctionOpInterface func1, FunctionOpInterface func2) {
return func1.isExternal() && (func2.isPublic() || func2.isExternal());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A more fitting name could be beneficial, but I can't come up with one immediately. canMergeDeclarationInto?

Copy link
Contributor Author

@ingomueller-net ingomueller-net left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review! Now all comments should be addressed. I'll merge once CI has passed. Should some of the minor things still bother you/us, then I can still submit a fix-up.

LLVM_DEBUG(llvm::dbgs() << "cannot compare types\n");
/// Return whether `func1` can be merged into `func2`.
bool canMergeInto(FunctionOpInterface func1, FunctionOpInterface func2) {
return func1.isExternal() && (func2.isPublic() || func2.isExternal());
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that could work. (As you probably also felt, it's not perfect: the name suggest that it has to be a declaration but it actually checks if it is a declaration.)

One argument in favor of keeping it as is is that this is a formulation that generalizes: In the SPIR-V module combiner, they "merge" global variables and constants, and those are "mergeable" if they define the same thing. This generalization could be formalized into a MergeableSymbol interface (which extends Symbol) with two methods canMergeInto and mergeInto. At least for these two examples (SPIR-V and TD), that concept seems to work...

@ingomueller-net ingomueller-net force-pushed the transform-interpreter-transitive-include branch from fd5ff47 to e74aa59 Compare October 5, 2023 14:57
Until now, the interpreter would only load those symbols from the
provided library files that were declared in the main transform module.
However, sequences in the library may include other sequences on their
own. Until now, if such sequences were not *also* declared in the main
transform module, the interpreter would fail to resolve them. Forward
declaring all of them is undesirable as it defeats the purpose of
encapsulation into library modules.

This PR implements a kind of linker for transform scripts to solve this
problem. The linker merges all symbols of the library module into the
main module before interpreting the latter. Symbols whose names collide
are handled as follows: (1) if they are both functions (in the sense of
`FunctionOpInterface`) with compatible signatures, one is external, and
the other one is public, then they are merged; (2) of one of them is
private, that one is renamed; and (3) an error is raised otherwise.
* Move all private functions of the CPP file into anonymous namespace.
* Remove test with second interpreter pass that reloads the library. I
  think that this shouldn't be possible.
* Factor out `renameToUnique`, `canMergeInto`, and `mergeInto` into
  proper functions.
* Use a single symbol table per input op and update it correctly
  whenever symbols or ops change.
* Make `other` arg an `OwningOpRef` and clone the arguments where
  necessary.
* Improve comments.
* Use `moveBefore` instead of `moveAfter` in order to work on empty
  targets as well.
* Do not verify the target after moving each op, since the last op may
  use symbols of ops that still have to be moved.
@ingomueller-net ingomueller-net force-pushed the transform-interpreter-transitive-include branch from e74aa59 to 71b40b1 Compare October 6, 2023 07:39
@ingomueller-net ingomueller-net merged commit 7876899 into llvm:main Oct 6, 2023
@ingomueller-net ingomueller-net deleted the transform-interpreter-transitive-include branch October 6, 2023 08:57
Copy link
Collaborator

@qcolombet qcolombet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I told @ingomueller-net I was going to look at this, so here we are :).

Sorry for the delay. LGTM too, couple of comments inlined.

LLVM_DEBUG(llvm::dbgs() << "cannot compare types\n");
/// Return whether `func1` can be merged into `func2`.
bool canMergeInto(FunctionOpInterface func1, FunctionOpInterface func2) {
return func1.isExternal() && (func2.isPublic() || func2.isExternal());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Kind of late to the party, but I agree that the meaning is difficult to grasp currently.
More specifically I believe the comment doesn't describe what mergeable means.

For me the confusing part is that merging implies that the result is bigger. So before reading the code I was wondering if we were doing some sort of function concatenation :).
<digression>LLVM has an optimization pass called merge globals that puts global variables in the same global structure to allow faster symbol resolution and that's the kind of thing this function name brings to mind for me.</digression>

Maybe canCoalesceSymbols?

}
/// Merge `func1` into `func2`. The two ops must be inside the same parent op
/// and mergable according to `canMergeInto`. The function erases `func1` such
/// that only `func2` exists when the function returns.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll add a comment about "metadata" merging rules (i.e., around consumed and readonly)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You will or you would? I.e., who makes that commit?

/// the result. Public symbols may not collide, with the exception of
/// instances of `SymbolOpInterface`, where collisions are allowed if at least
/// one of the two is external, in which case the other op preserved (or any one
/// of the two if both are external).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this comment would be easier to read if we break it down into several sentences (or bullet points?).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, sounds good. Let's have the person do it that adds the comment you suggested above.

@@ -0,0 +1,14 @@
// RUN: mlir-opt %s
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need a run line here?
If lit doesn't skip this file, maybe we could move all the "include" files in a sub-dir and tell lit to ignore these files.

Edit: I see that we have patterns of files like that already (e.g., mlir/test/Dialect/Transform/test-interpreter-external-symbol-def.mlir). I don't particularly like this kind of "fake" run line, but maybe we are actually testing something here?
@ftynse what do you think we should do?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I copied this from other files and speculated that this is done to (1) make lit not complain and (2) maybe make a test fail if the syntax of this specific file gets out of date rather than relying a test of another file?

Copy link
Contributor Author

@ingomueller-net ingomueller-net left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @qcolombet! The suggestions about the comments makes sense. Who does them?

LLVM_DEBUG(llvm::dbgs() << "cannot compare types\n");
/// Return whether `func1` can be merged into `func2`.
bool canMergeInto(FunctionOpInterface func1, FunctionOpInterface func2) {
return func1.isExternal() && (func2.isPublic() || func2.isExternal());
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am open to using a different verb or name. From what I understand from the dictionary, unite, combine, merge, and coalesce all roughly mean "take several things and make a single one of them."

I guess that that is undoubtably what we do with the outer ops (and the result is indeed larger usually). I'd argue that it is close to what we do with the symbols as well but maybe not exactly the same, at least not what we do with functions currently. Another angle to look at what we do with symbols is that we "deduplicate" them. Maybe that can yield a name?

}
/// Merge `func1` into `func2`. The two ops must be inside the same parent op
/// and mergable according to `canMergeInto`. The function erases `func1` such
/// that only `func2` exists when the function returns.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You will or you would? I.e., who makes that commit?

/// the result. Public symbols may not collide, with the exception of
/// instances of `SymbolOpInterface`, where collisions are allowed if at least
/// one of the two is external, in which case the other op preserved (or any one
/// of the two if both are external).
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, sounds good. Let's have the person do it that adds the comment you suggested above.

@@ -0,0 +1,14 @@
// RUN: mlir-opt %s
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I copied this from other files and speculated that this is done to (1) make lit not complain and (2) maybe make a test fail if the syntax of this specific file gets out of date rather than relying a test of another file?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants