Skip to content

[mlir][transform] Don't modify the target in interpreter when loading library. #67686

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

ingomueller-net
Copy link
Contributor

@ingomueller-net ingomueller-net commented Sep 28, 2023

Until now, if the transform script was embedded into the input IR, the transform dialect interpreter injected the externally resolved symbols into that IR, which then became part of the output. This is not always desirable.

This PR is a first step to separate the logic of loading/resolution/injection from the interpreter. The modification consists of cloning the IR that contains the main transform script if necessary (i.e., if we actually need to load it and it is part of the input op of the pass). The next step will be to introduce a dedicated pass for loading and injecting transform script and or library.

The PR also improves some variable names and related if-conditions, which are currently an independent NFC commit and could be factored out into a dedicated PR.

@ingomueller-net ingomueller-net changed the title Transform loading separation [mlir][transform] Don't modify the target in interpreter when loading library. Sep 28, 2023
@ingomueller-net ingomueller-net force-pushed the transform-loading-separation branch from 54f083a to 0afdb77 Compare September 28, 2023 14:29
@ingomueller-net
Copy link
Contributor Author

Two questions about this PR:

  • Should the loading/injection pass be part of this PR or can I do it in an independent one?

  • Let us also reconsider the alternative of keeping loading and interpretation together. The issue that led to this PR could also be solved by not expecting the possibility to run the interpreter twice and supply the same library in the second invocation. If the semantic is "load and interpreter", then doing "load and interpret" twice implies doing "load" twice, for which it is obvious that it leads to doubly defined symbols. Similarly, is the semantic is "interpret but don't load", then running the interpreter once with a library name and then once without (as one of the tests did) will break with the current PR.

    All in all, as I see it, the decision is between having more precise control (with two different passes) or having fewer passes. What's more desirable?

@ingomueller-net ingomueller-net marked this pull request as ready for review September 28, 2023 14:39
@llvmbot llvmbot added the mlir label Sep 28, 2023
@llvmbot
Copy link
Member

llvmbot commented Sep 28, 2023

@llvm/pr-subscribers-mlir

Changes

Until now, if the transform script was embedded into the input IR, the transform dialect interpreter injected the externally resolved symbols into that IR, which then became part of the output. This is not always desirable.

This PR is a first step to separate the logic of loading/resolution/injection from the interpreter. The modification consists of cloning the IR that contains the main transform script if necessary (i.e., if we actually need to load it and it is part of the input op of the pass). The next step will be to introduce a dedicated pass for loading and injecting transform script and or library.

The PR also improves some variable names and related if-conditions, which are currently an independent NFC commit and could be factored out into a dedicated PR.


Full diff: https://github.com/llvm/llvm-project/pull/67686.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h (+2-2)
  • (modified) mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp (+51-24)
  • (modified) mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir (+7-9)
  • (modified) mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp (+2-2)
diff --git a/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h b/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h
index 91903e254b0d5b3..9c67f3af61cc12d 100644
--- a/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h
+++ b/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h
@@ -64,7 +64,7 @@ LogicalResult interpreterBaseRunOnOperationImpl(
 ///     will be interpreted.
 ///   - transformLibraryFileName: if non-empty, the name of the file containing
 ///     definitions of external symbols referenced in the transform script.
-///     These definitions will be used to replace declarations.
+///     These definitions will be used to resolve declarations.
 ///   - debugPayloadRootTag: if non-empty, the value of the attribute named
 ///     `kTransformDialectTagAttrName` indicating the single op that is
 ///     considered the payload root of the transform interpreter; otherwise, the
@@ -85,7 +85,7 @@ LogicalResult interpreterBaseRunOnOperationImpl(
 /// as template arguments. They are *not* expected to to implement `initialize`
 /// or `runOnOperation`. They *are* expected to call the copy constructor of
 /// this class in their copy constructors, short of which the file-based
-/// transform dialect script injection facility will become nonoperational.
+/// transform dialect script resolution facility will become non-operational.
 ///
 /// Concrete passes may implement the `runBeforeInterpreter` and
 /// `runAfterInterpreter` to customize the behavior of the pass.
diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
index d5c65b23e3a2134..9c245fda7f567e7 100644
--- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
+++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
@@ -379,7 +379,7 @@ static LogicalResult defineDeclaredSymbols(Block &block, ModuleOp definitions) {
 LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
     Operation *target, StringRef passName,
     const std::shared_ptr<OwningOpRef<ModuleOp>> &sharedTransformModule,
-    const std::shared_ptr<OwningOpRef<ModuleOp>> &libraryModule,
+    const std::shared_ptr<OwningOpRef<ModuleOp>> &transformLibraryModule,
     const RaggedArray<MappedValue> &extraMappings,
     const TransformOptions &options,
     const Pass::Option<std::string> &transformFileName,
@@ -387,6 +387,16 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
     const Pass::Option<std::string> &debugPayloadRootTag,
     const Pass::Option<std::string> &debugTransformRootTag,
     StringRef binaryName) {
+  bool hasSharedTransformModule =
+      sharedTransformModule && *sharedTransformModule;
+  bool hasTransformLibraryModule =
+      transformLibraryModule && *transformLibraryModule;
+  assert((!hasSharedTransformModule || !hasTransformLibraryModule) &&
+         "at most one of shared or library transform module can be set");
+
+  // Step 0
+  // ------
+  // If debugPayloadRootTag or debugTransformRootTag was passed, then we are in)
 
   // Step 1
   // ------
@@ -407,9 +417,24 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
   // transform is embedded in the payload IR. If debugTransformRootTag was
   // passed, then we are in user-specified selection of the transforming IR.
   // This corresponds to REPL debug mode.
-  bool sharedTransform = (sharedTransformModule && *sharedTransformModule);
-  Operation *transformContainer =
-      sharedTransform ? sharedTransformModule->get() : target;
+
+  OwningOpRef<Operation *> transformContainerClone;
+  Operation *transformContainer;
+  if (hasTransformLibraryModule) {
+    // If we have a library module, then the transform script is embedded in the
+    // target, which we don't want to modify when loading the library. We thus
+    // clone the target and use that as transform container.
+    assert(!hasSharedTransformModule);
+    transformContainerClone = target->clone();
+    transformContainer = transformContainerClone.get();
+  } else {
+    // If we have a shared library, which is private to us, we can modify it
+    // when loading the library, so we use that. Otherwise, we don't have any
+    // library to load, so we can use the target and won't modify it.
+    transformContainer =
+        hasSharedTransformModule ? sharedTransformModule->get() : target;
+  }
+
   Operation *transformRoot =
       debugTransformRootTag.empty()
           ? findTopLevelTransform(transformContainer,
@@ -430,8 +455,8 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
   // Copy external defintions for symbols if provided. Be aware of potential
   // concurrent execution (normally, the error shouldn't be triggered unless the
   // transform IR modifies itself in a pass, which is also forbidden elsewhere).
-  if (!sharedTransform && libraryModule && *libraryModule) {
-    if (!target->isProperAncestor(transformRoot)) {
+  if (hasTransformLibraryModule) {
+    if (!transformContainer->isProperAncestor(transformRoot)) {
       InFlightDiagnostic diag =
           transformRoot->emitError()
           << "cannot inject transform definitions next to pass anchor op";
@@ -439,7 +464,7 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
       return diag;
     }
     if (failed(defineDeclaredSymbols(*transformRoot->getBlock(),
-                                     libraryModule->get())))
+                                     transformLibraryModule->get())))
       return failure();
   }
 
@@ -461,25 +486,27 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
 LogicalResult transform::detail::interpreterBaseInitializeImpl(
     MLIRContext *context, StringRef transformFileName,
     StringRef transformLibraryFileName,
-    std::shared_ptr<OwningOpRef<ModuleOp>> &module,
-    std::shared_ptr<OwningOpRef<ModuleOp>> &libraryModule,
+    std::shared_ptr<OwningOpRef<ModuleOp>> &sharedTransformModule,
+    std::shared_ptr<OwningOpRef<ModuleOp>> &transformLibraryModule,
     function_ref<std::optional<LogicalResult>(OpBuilder &, Location)>
         moduleBuilder) {
-  OwningOpRef<ModuleOp> parsed;
-  if (failed(parseTransformModuleFromFile(context, transformFileName, parsed)))
+  OwningOpRef<ModuleOp> parsedTransformModule;
+  if (failed(parseTransformModuleFromFile(context, transformFileName,
+                                          parsedTransformModule)))
     return failure();
-  if (parsed && failed(mlir::verify(*parsed)))
+  if (parsedTransformModule && failed(mlir::verify(*parsedTransformModule)))
     return failure();
 
-  OwningOpRef<ModuleOp> parsedLibrary;
+  OwningOpRef<ModuleOp> parsedLibraryModule;
   if (failed(parseTransformModuleFromFile(context, transformLibraryFileName,
-                                          parsedLibrary)))
+                                          parsedLibraryModule)))
     return failure();
-  if (parsedLibrary && failed(mlir::verify(*parsedLibrary)))
+  if (parsedLibraryModule && failed(mlir::verify(*parsedLibraryModule)))
     return failure();
 
-  if (parsed) {
-    module = std::make_shared<OwningOpRef<ModuleOp>>(std::move(parsed));
+  if (parsedTransformModule) {
+    sharedTransformModule = std::make_shared<OwningOpRef<ModuleOp>>(
+        std::move(parsedTransformModule));
   } else if (moduleBuilder) {
     // TODO: better location story.
     auto location = UnknownLoc::get(context);
@@ -491,20 +518,20 @@ LogicalResult transform::detail::interpreterBaseInitializeImpl(
     if (std::optional<LogicalResult> result = moduleBuilder(b, location)) {
       if (failed(*result))
         return failure();
-      module = std::move(localModule);
+      sharedTransformModule = std::move(localModule);
     }
   }
 
-  if (!parsedLibrary || !*parsedLibrary)
+  if (!parsedLibraryModule || !*parsedLibraryModule)
     return success();
 
-  if (module && *module) {
-    if (failed(defineDeclaredSymbols(*module->get().getBody(),
-                                     parsedLibrary.get())))
+  if (sharedTransformModule && *sharedTransformModule) {
+    if (failed(defineDeclaredSymbols(*sharedTransformModule->get().getBody(),
+                                     parsedLibraryModule.get())))
       return failure();
   } else {
-    libraryModule =
-        std::make_shared<OwningOpRef<ModuleOp>>(std::move(parsedLibrary));
+    transformLibraryModule =
+        std::make_shared<OwningOpRef<ModuleOp>>(std::move(parsedLibraryModule));
   }
   return success();
 }
diff --git a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir
index 04b6c5a02e0adf1..076a2171094808a 100644
--- a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir
@@ -1,27 +1,25 @@
 // RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir})" \
 // RUN:             --verify-diagnostics --split-input-file | FileCheck %s
 
-// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir}, test-transform-dialect-interpreter)" \
-// RUN:             --verify-diagnostics --split-input-file | FileCheck %s
-
 // RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir}, test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir})" \
 // RUN:             --verify-diagnostics --split-input-file | FileCheck %s
 
 // The definition of the @foo named sequence is provided in another file. It
-// will be included because of the pass option. Repeated application of the
-// same pass, with or without the library option, should not be a problem.
+// will be available because of the pass option but not included in the output.
+// Repeated application of the same pass works, but only if the library is
+// provided in both.
 // Note that the same diagnostic produced twice at the same location only
 // needs to be matched once.
 
 // expected-remark @below {{message}}
 // expected-remark @below {{unannotated}}
 module attributes {transform.with_named_sequence} {
-  // CHECK: transform.named_sequence @foo
-  // CHECK: test_print_remark_at_operand %{{.*}}, "message"
+  // CHECK: transform.named_sequence private @foo
+  // CHECK-NOT: test_print_remark_at_operand
   transform.named_sequence private @foo(!transform.any_op {transform.readonly})
 
-  // CHECK: transform.named_sequence @unannotated
-  // CHECK: test_print_remark_at_operand %{{.*}}, "unannotated"
+  // CHECK: transform.named_sequence private @unannotated
+  // CHECK-NOT: test_print_remark_at_operand
   transform.named_sequence private @unannotated(!transform.any_op {transform.readonly})
 
   transform.sequence failures(propagate) {
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp
index f73deef9d5fd48c..675b5ecd50346fa 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp
@@ -219,8 +219,8 @@ class TestTransformDialectInterpreterPass
   Option<std::string> transformLibraryFileName{
       *this, "transform-library-file-name", llvm::cl::init(""),
       llvm::cl::desc(
-          "Optional name of the file containing transform dialect symbol "
-          "definitions to be injected into the transform module.")};
+          "Optional name of the file providing transform dialect definitions "
+          "from which declarations in the transform module can be resolved.")};
 
   Option<bool> testModuleGeneration{
       *this, "test-module-generation", llvm::cl::init(false),

@ingomueller-net
Copy link
Contributor Author

I am currently inclined to put this PR on hold until we need this functionality. (But I factored out the clean-up into #67800, which I think would be good to have anyways.)

@ftynse ftynse marked this pull request as draft October 3, 2023 07:36
Until now, if the transform script was embedded into the input IR, the
transform dialect interpreter injected the externally resolved symbols
into that IR, which then became part of the output. This is not always
desirable.

This commit is a first step to separate the logic of loading/resolution/
injection from the interpreter. The modification consists of cloning the
IR that contains the main transform script if necessary (i.e., if we
actually need to load it and it is part of the input op of the pass).
The next step will be to introduce a dedicated pass for loading and
injecting transform script and or library.
@ingomueller-net ingomueller-net force-pushed the transform-loading-separation branch from 2040139 to 3c814b0 Compare October 9, 2023 08:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants