-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[mlir][transform] Fix handling of transitive include in interpreter. #67560
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][transform] Fix handling of transitive include in interpreter. #67560
Conversation
8156631
to
6e83ecd
Compare
@ftynse: This is a new attempt at #67241 along the lines of your suggestions. Please take a look. It is worth mentioning that (i) I merged the stages (1) and (2) from your suggestion, and (ii) did not implement the optimization of only including needed symbols. I have identified one more problem that may require a major change: it is currently not always possible to run the interpreter pass several times. The problem is the following: if the script contains a declaration I see two ways out: (A) only merge in missing symbols (plus their dependencies) as previously and make all imported symbols private, or (B) separate the merging from the interpreter into a separate pass. I have the feeling that Approach (A) would make the merging again more special to the interpreter and possibly somewhat brittle and Approach (B) would be more generic and possibly more performant. (If the interpreter pass is really used several times, linking has to be done each time, while with a dedicated linking pass, it would only be done once.) I imagine the linking pass from Approach (B) to be called something like What do you think? |
6e83ecd
to
984af0f
Compare
Approach B makes sense to me. Generally, anything that keeps the interpreter simpler and more focused is likely a good idea. I suppose it should be sufficient to just move the symbols from other files,normally, it is allowed to have multiple declarations as long as there is a single definition + handle name conflicts between private symbols that may have repeated names in different modules. Then we can ran symbol-dce separately to remove unused symbols. |
Implementation wise, we can consider having a "mergeModules" function that takes a list of modules and moves everything into the first one. It can then be exercised from the pass that loads transform library files, but is also reusable. |
Yeah, good point! Being able to merge several modules is on my TODO list. That would also be required if we wanted to write a generic merger that could replace most of SPIR-V's |
Well, merging several modules is a |
Yes, but you can do (slightly) better by being careful when to (re)construct which symbol table(s)... |
984af0f
to
5daa3f7
Compare
I think I changed my mind about the question of whether or not loading should be part of the interpreter. At the very least, I see that it is absolutely consistent to think of the current pass as "inject dependencies and interpret result." In this case, it is clear that running that pass twice in a way that tries to inject the dependencies twice does not work. The PR thus currently removes that test. If we decide that injection should be separated from the interpreter or that injection should be done only if necessary (such that it is idempotent), I believe we can do that in some future PR. Other than that, I have overhauled the initial commit and think that it is ready for review now. |
@llvm/pr-subscribers-mlir ChangesThis is a new attempt at fixing transitive includes in the transform dialect interpreter (next to #67241) and a preparation for being able to load multiple transform library files (#67120). Until now, the interpreter would only load those symbols from the provided library files that were declared in the main transform module. However, sequences in the library may include other sequences on their own. Until now, if such sequences were not also declared in the main transform module, the interpreter would fail to resolve them. Forward declaring all of them is undesirable as it defeats the purpose of encapsulation into library modules. This PR implements a kind of linker for transform scripts to solve this problem. The linker merges all symbols of the library module into the main module before interpreting the latter. Symbols whose names collide are handled as follows: (1) if they are both functions (in the sense of TODO:
Patch is 27.39 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/67560.diff 6 Files Affected:
diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
index d5c65b23e3a2134..3dd172335160fb3 100644
--- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
+++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
@@ -50,6 +50,8 @@ constexpr static llvm::StringLiteral kTransformDialectTagPayloadRootValue =
constexpr static llvm::StringLiteral
kTransformDialectTagTransformContainerValue = "transform_container";
+namespace {
+
/// Utility to parse the content of a `transformFileName` MLIR file containing
/// a transform dialect specification.
static LogicalResult
@@ -302,80 +304,270 @@ static void performOptionalDebugActions(
transform->removeAttr(kTransformDialectTagAttrName);
}
-/// Replaces external symbols in `block` with their (non-external) definitions
-/// from the given module.
-static LogicalResult defineDeclaredSymbols(Block &block, ModuleOp definitions) {
- MLIRContext &ctx = *definitions->getContext();
- auto consumedName =
- StringAttr::get(&ctx, transform::TransformDialect::kArgConsumedAttrName);
- auto readOnlyName =
- StringAttr::get(&ctx, transform::TransformDialect::kArgReadOnlyAttrName);
-
- for (Operation &op : llvm::make_early_inc_range(block)) {
- LLVM_DEBUG(DBGS() << op << "\n");
- auto symbol = dyn_cast<SymbolOpInterface>(op);
- if (!symbol)
- continue;
- if (symbol->getNumRegions() == 1 && !symbol->getRegion(0).empty())
- continue;
-
- LLVM_DEBUG(DBGS() << "looking for definition of symbol "
- << symbol.getNameAttr() << ":");
- SymbolTable symbolTable(definitions);
- Operation *externalSymbol = symbolTable.lookup(symbol.getNameAttr());
- if (!externalSymbol || externalSymbol->getNumRegions() != 1 ||
- externalSymbol->getRegion(0).empty()) {
- LLVM_DEBUG(llvm::dbgs() << "not found\n");
- continue;
+/// Rename `op` to avoid a collision with `otherOp`. `symbolTable` and
+/// `otherSymbolTable` are the symbol tables of the two ops, respectively.
+/// `uniqueId` is used to generate a unique name in the context of the caller.
+LogicalResult renameToUnique(SymbolOpInterface op, SymbolOpInterface otherOp,
+ SymbolTable &symbolTable,
+ SymbolTable &otherSymbolTable, int &uniqueId) {
+ assert(symbolTable.lookup(op.getNameAttr()) == op &&
+ "symbol table does not contain op");
+ assert(otherSymbolTable.lookup(otherOp.getNameAttr()) == otherOp &&
+ "other symbol table does not contain other op");
+
+ // Determine new name that is unique in both symbol tables.
+ StringAttr oldName = op.getNameAttr();
+ StringAttr newName;
+ {
+ MLIRContext *context = op->getContext();
+ SmallString<64> prefix = oldName.getValue();
+ prefix.push_back('_');
+ while (true) {
+ newName = StringAttr::get(context, prefix + Twine(uniqueId++));
+ if (!symbolTable.lookup(newName) && !otherSymbolTable.lookup(newName)) {
+ break;
+ }
}
+ }
+
+ // Apply renaming.
+ LLVM_DEBUG(llvm::dbgs() << ", renaming to @" << newName.getValue() << "\n");
+ Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(op);
+ if (failed(SymbolTable::replaceAllSymbolUses(op, newName, symbolTableOp))) {
+ InFlightDiagnostic diag =
+ emitError(op->getLoc(),
+ Twine("failed to rename symbol to @") + newName.getValue());
+ diag.attachNote(otherOp->getLoc())
+ << "attempted renaming due to collision with this op";
+ return diag;
+ }
+
+ // Change the symbol in the op itself and update the symbol table.
+ symbolTable.remove(op);
+ SymbolTable::setSymbolName(op, newName);
+ symbolTable.insert(op);
+
+ assert(symbolTable.lookup(newName) == op &&
+ "symbol table does not resolve to renamed op");
+ assert(symbolTable.lookup(oldName) == nullptr &&
+ "symbol table still resolves old name");
+
+ return success();
+}
- auto symbolFunc = dyn_cast<FunctionOpInterface>(op);
- auto externalSymbolFunc = dyn_cast<FunctionOpInterface>(externalSymbol);
- if (!symbolFunc || !externalSymbolFunc) {
- LLVM_DEBUG(llvm::dbgs() << "cannot compare types\n");
+/// Return whether `func1` can be merged into `func2`.
+bool canMergeInto(FunctionOpInterface func1, FunctionOpInterface func2) {
+ return func1.isExternal() && (func2.isPublic() || func2.isExternal());
+}
+
+/// Merge `func1` into `func2`. The two ops must be inside the same parent op
+/// and mergable according to `canMergeInto`. The function erases `func1` such
+/// that only `func2` exists when the function returns.
+LogicalResult mergeInto(FunctionOpInterface func1, FunctionOpInterface func2) {
+ assert(canMergeInto(func1, func2));
+ assert(func1->getParentOp() == func2->getParentOp() &&
+ "expected func1 and func2 to be in the same parent op");
+
+ MLIRContext *context = func1->getContext();
+ auto consumedName = StringAttr::get(
+ context, transform::TransformDialect::kArgConsumedAttrName);
+ auto readOnlyName = StringAttr::get(
+ context, transform::TransformDialect::kArgReadOnlyAttrName);
+
+ // Check that function signatures match.
+ if (func1.getFunctionType() != func2.getFunctionType()) {
+ return func1.emitError()
+ << "external definition has a mismatching signature ("
+ << func2.getFunctionType() << ")";
+ }
+
+ // Check and merge argument attributes.
+ for (unsigned i = 0, e = func1.getNumArguments(); i < e; ++i) {
+ bool isExternalConsumed = func2.getArgAttr(i, consumedName) != nullptr;
+ bool isExternalReadonly = func2.getArgAttr(i, readOnlyName) != nullptr;
+ bool isConsumed = func1.getArgAttr(i, consumedName) != nullptr;
+ bool isReadonly = func1.getArgAttr(i, readOnlyName) != nullptr;
+ if (!isExternalConsumed && !isExternalReadonly) {
+ if (isConsumed)
+ func2.setArgAttr(i, consumedName, UnitAttr::get(context));
+ else if (isReadonly)
+ func2.setArgAttr(i, readOnlyName, UnitAttr::get(context));
continue;
}
- LLVM_DEBUG(llvm::dbgs() << "found @" << externalSymbol << "\n");
- if (symbolFunc.getFunctionType() != externalSymbolFunc.getFunctionType()) {
- return symbolFunc.emitError()
- << "external definition has a mismatching signature ("
- << externalSymbolFunc.getFunctionType() << ")";
+ if ((isExternalConsumed && !isConsumed) ||
+ (isExternalReadonly && !isReadonly)) {
+ return func1.emitError()
+ << "external definition has mismatching consumption "
+ "annotations for argument #"
+ << i;
}
+ }
+
+ // `func1` is the external one, so we can remove it.
+ assert(func1.isExternal());
+ func1->erase();
+
+ return success();
+}
+
+/// Merge all symbols from `other` into `target`. Both ops need to implement the
+/// `SymbolTable` trait. Operations are moved from `other`, i.e., `other` may be
+/// modified by this function and might not verify after the function returns.
+/// Upon merging, private symbols may be renamed in order to avoid collisions in
+/// the result. Public symbols may not collide, with the exception of
+/// instances of `SymbolOpInterface`, where collisions are allowed if at least
+/// one of the two is external, in which case the other op preserved (or any one
+/// of the two if both are external).
+static LogicalResult mergeSymbolsInto(Operation *target,
+ OwningOpRef<Operation *> other) {
+ assert(target->hasTrait<OpTrait::SymbolTable>() &&
+ "requires target to implement the 'SymbolTable' trait");
+ assert(other->hasTrait<OpTrait::SymbolTable>() &&
+ "requires target to implement the 'SymbolTable' trait");
+
+ SymbolTable targetSymbolTable(target);
+ SymbolTable otherSymbolTable(*other);
+
+ int uniqueId = 0;
+
+ // Step 1:
+ //
+ // Rename private symbols in both ops in order to resolve conflicts that can
+ // be resolved that way.
+ LLVM_DEBUG(DBGS() << "renaming private symbols to resolve conflicts:\n");
+ for (auto [symbolTable, otherSymbolTable] : llvm::zip(
+ SmallVector<SymbolTable *>{&targetSymbolTable, &otherSymbolTable},
+ SmallVector<SymbolTable *>{&otherSymbolTable, &targetSymbolTable})) {
+ Operation *symbolTableOp = symbolTable->getOp();
+ for (Operation &op : symbolTableOp->getRegion(0).front()) {
+ auto symbolOp = dyn_cast<SymbolOpInterface>(op);
+ if (!symbolOp)
+ continue;
+ StringAttr name = symbolOp.getNameAttr();
+ LLVM_DEBUG(DBGS() << " found @" << name.getValue() << "\n");
- for (unsigned i = 0, e = symbolFunc.getNumArguments(); i < e; ++i) {
- bool isExternalConsumed =
- externalSymbolFunc.getArgAttr(i, consumedName) != nullptr;
- bool isExternalReadonly =
- externalSymbolFunc.getArgAttr(i, readOnlyName) != nullptr;
- bool isConsumed = symbolFunc.getArgAttr(i, consumedName) != nullptr;
- bool isReadonly = symbolFunc.getArgAttr(i, readOnlyName) != nullptr;
- if (!isExternalConsumed && !isExternalReadonly) {
- if (isConsumed)
- externalSymbolFunc.setArgAttr(i, consumedName, UnitAttr::get(&ctx));
- else if (isReadonly)
- externalSymbolFunc.setArgAttr(i, readOnlyName, UnitAttr::get(&ctx));
+ // Check if there is a colliding op in the other module.
+ auto collidingOp =
+ cast_or_null<SymbolOpInterface>(otherSymbolTable->lookup(name));
+ if (!collidingOp)
continue;
+
+ LLVM_DEBUG(DBGS() << " collision found for @" << name.getValue());
+
+ // Collisions are fine if both opt are functions and can be merged.
+ if (auto funcOp = dyn_cast<FunctionOpInterface>(op),
+ collidingFuncOp =
+ dyn_cast<FunctionOpInterface>(collidingOp.getOperation());
+ funcOp && collidingFuncOp) {
+ if (canMergeInto(funcOp, collidingFuncOp) ||
+ canMergeInto(collidingFuncOp, funcOp)) {
+ LLVM_DEBUG(llvm::dbgs() << " but both ops are functions and "
+ "will be merged\n");
+ continue;
+ }
+
+ // If they can't be merged, proceed like any other collision.
+ LLVM_DEBUG(llvm::dbgs() << " and both ops are function definitions");
}
- if ((isExternalConsumed && !isConsumed) ||
- (isExternalReadonly && !isReadonly)) {
- return symbolFunc.emitError()
- << "external definition has mismatching consumption annotations "
- "for argument #"
- << i;
+ // Collision can be resolved if one of the ops is private.
+ if (symbolOp.isPrivate()) {
+ if (failed(renameToUnique(symbolOp, collidingOp, *symbolTable,
+ *otherSymbolTable, uniqueId)))
+ return failure();
+ continue;
+ }
+ if (collidingOp.isPrivate()) {
+ if (failed(renameToUnique(collidingOp, symbolOp, *otherSymbolTable,
+ *symbolTable, uniqueId)))
+ return failure();
+ continue;
}
+
+ LLVM_DEBUG(llvm::dbgs() << ", emitting error\n");
+ InFlightDiagnostic diag =
+ emitError(symbolOp->getLoc(),
+ Twine("doubly defined symbol @") + name.getValue());
+ diag.attachNote(collidingOp->getLoc()) << "previously defined here";
+ return diag;
}
+ }
- OpBuilder builder(&op);
- builder.setInsertionPoint(&op);
- builder.clone(*externalSymbol);
- symbol->erase();
+ for (auto *op : SmallVector<Operation *>{target, *other}) {
+ if (failed(mlir::verify(op)))
+ return emitError(op->getLoc(),
+ "failed to verify input op after renaming");
}
+ // Step 2:
+ //
+ // Move all ops from `other` into target and merge public symbols.
+ LLVM_DEBUG(DBGS() << "moving all symbols into target\n");
+ {
+ SmallVector<SymbolOpInterface> opsToMove;
+ for (Operation &op : other->getRegion(0).front()) {
+ if (auto symbol = dyn_cast<SymbolOpInterface>(op))
+ opsToMove.push_back(symbol);
+ }
+
+ for (SymbolOpInterface op : opsToMove) {
+ // Remember potentially colliding op in the target module.
+ auto collidingOp = cast_or_null<SymbolOpInterface>(
+ targetSymbolTable.lookup(op.getNameAttr()));
+
+ // Move op even if we get a collision.
+ LLVM_DEBUG(DBGS() << " moving @" << op.getName());
+ op->moveAfter(&target->getRegion(0).front(),
+ target->getRegion(0).front().begin());
+
+ // If there is no collision, we are done.
+ if (!collidingOp) {
+ LLVM_DEBUG(llvm::dbgs() << " without collision\n");
+ continue;
+ }
+
+ // The two colliding ops must both be functions because we have already
+ // emitted errors otherwise earlier.
+ auto funcOp = cast<FunctionOpInterface>(op.getOperation());
+ auto collidingFuncOp =
+ cast<FunctionOpInterface>(collidingOp.getOperation());
+
+ // Both ops are in the target module now and can be treated symmetrically,
+ // so w.l.o.g. we can reduce to merging `funcOp` into `collidingFuncOp`.
+ if (!canMergeInto(funcOp, collidingFuncOp)) {
+ std::swap(funcOp, collidingFuncOp);
+ }
+ assert(canMergeInto(funcOp, collidingFuncOp));
+
+ LLVM_DEBUG(llvm::dbgs() << " with collision, trying to keep op at "
+ << collidingFuncOp.getLoc() << ":\n"
+ << collidingFuncOp << "\n");
+
+ // Update symbol table. This works with or without the previous `swap`.
+ targetSymbolTable.remove(funcOp);
+ targetSymbolTable.insert(collidingFuncOp);
+ assert(targetSymbolTable.lookup(funcOp.getName()) == collidingFuncOp);
+
+ // Do the actual merging.
+ if (failed(mergeInto(funcOp, collidingFuncOp))) {
+ return failure();
+ }
+
+ assert(succeeded(mlir::verify(target)));
+ }
+ }
+
+ if (failed(mlir::verify(target)))
+ return emitError(target->getLoc(),
+ "failed to verify target op after merging symbols");
+
+ LLVM_DEBUG(DBGS() << "done merging ops\n");
return success();
}
+} // namespace
+
LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
Operation *target, StringRef passName,
const std::shared_ptr<OwningOpRef<ModuleOp>> &sharedTransformModule,
@@ -438,8 +630,9 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
diag.attachNote(target->getLoc()) << "pass anchor op";
return diag;
}
- if (failed(defineDeclaredSymbols(*transformRoot->getBlock(),
- libraryModule->get())))
+ if (failed(
+ mergeSymbolsInto(SymbolTable::getNearestSymbolTable(transformRoot),
+ libraryModule->get()->clone())))
return failure();
}
@@ -499,8 +692,7 @@ LogicalResult transform::detail::interpreterBaseInitializeImpl(
return success();
if (module && *module) {
- if (failed(defineDeclaredSymbols(*module->get().getBody(),
- parsedLibrary.get())))
+ if (failed(mergeSymbolsInto(module->get(), std::move(parsedLibrary))))
return failure();
} else {
libraryModule =
diff --git a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-and-schedule.mlir b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-and-schedule.mlir
index 3d4cb0776982934..dd8d141e994da0e 100644
--- a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-and-schedule.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-and-schedule.mlir
@@ -11,4 +11,10 @@
// expected-remark @below {{message}}
// expected-remark @below {{unannotated}}
+// expected-remark @below {{internal colliding (without suffix)}}
+// expected-remark @below {{internal colliding_0}}
+// expected-remark @below {{internal colliding_1}}
+// expected-remark @below {{internal colliding_3}}
+// expected-remark @below {{internal colliding_4}}
+// expected-remark @below {{internal colliding_5}}
module {}
diff --git a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-invalid.mlir b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-invalid.mlir
index b21abbbdfd6d045..7452deb39b6c18d 100644
--- a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-invalid.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-invalid.mlir
@@ -1,16 +1,16 @@
-// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir}, test-transform-dialect-interpreter)" \
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def-invalid.mlir}, test-transform-dialect-interpreter)" \
// RUN: --verify-diagnostics --split-input-file
-// The definition of the @foo named sequence is provided in another file. It
+// The definition of the @print_message named sequence is provided in another file. It
// will be included because of the pass option.
module attributes {transform.with_named_sequence} {
// expected-error @below {{external definition has a mismatching signature}}
- transform.named_sequence private @foo(!transform.op<"builtin.module"> {transform.readonly})
+ transform.named_sequence private @print_message(!transform.op<"builtin.module"> {transform.readonly})
transform.sequence failures(propagate) {
^bb0(%arg0: !transform.op<"builtin.module">):
- include @foo failures(propagate) (%arg0) : (!transform.op<"builtin.module">) -> ()
+ include @print_message failures(propagate) (%arg0) : (!transform.op<"builtin.module">) -> ()
}
}
@@ -37,3 +37,18 @@ module attributes {transform.with_named_sequence} {
include @consuming failures(suppress) (%arg0) : (!transform.any_op) -> ()
}
}
+
+// -----
+
+module attributes {transform.with_named_sequence} {
+ // expected-error @below {{doubly defined symbol @print_message}}
+ transform.named_sequence @print_message(%arg0: !transform.any_op {transform.readonly}) {
+ transform.test_print_remark_at_operand %arg0, "message" : !transform.any_op
+ transform.yield
+ }
+
+ transform.sequence failures(suppress) {
+ ^bb0(%arg0: !transform.any_op):
+ include @print_message failures(propagate) (%arg0) : (!transform.any_op) -> ()
+ }
+}
diff --git a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir
index 04b6c5a02e0adf1..a9083fe3e70788a 100644
--- a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir
@@ -4,29 +4,66 @@
// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir}, test-transform-dialect-interpreter)" \
// RUN: --verify-diagnostics --split-input-file | FileCheck %s
-// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir}, test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir})" \
-// RUN: --verify-diagnostics --split-input-file | FileCheck %s
-
-// The definition of the @foo named sequence is provided in another file. It
-// will be included because of the pass option. Repeated application of the
-// same pass, with or without the library option, should not be a problem.
+// The definition of the @print_message named sequence is provided in another
+// file. It will be included because of the pass option. Subsequent application
+// of the same pass works but only without the library file (since the first
+// application loads external symbols and loading them again woul make them
+// clash).
// Note that the same diagnostic produced twice at the same location only
// needs to be matched once.
// expected-remark @below {{message}}
// expected-remark @below {{unannotated}}
+// ...
[truncated]
|
mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
Outdated
Show resolved
Hide resolved
} | ||
|
||
// Apply renaming. | ||
LLVM_DEBUG(llvm::dbgs() << ", renaming to @" << newName.getValue() << "\n"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: looks like the debug line starts with a comma for no reason.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This terminated a debug line that was started earlier but that wasn't visible in this context. I believe that after the refactoring of renameToUnique
, the context is now much smaller such that the intend should be understood. Let me know if you disagree.
mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
Outdated
Show resolved
Hide resolved
LLVM_DEBUG(llvm::dbgs() << "cannot compare types\n"); | ||
/// Return whether `func1` can be merged into `func2`. | ||
bool canMergeInto(FunctionOpInterface func1, FunctionOpInterface func2) { | ||
return func1.isExternal() && (func2.isPublic() || func2.isExternal()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This condition is unclear to me. Why can we merge an external function into another external function? Neither has a body... Why should the other function be public? Is this about merging declarations?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this logic allows merging func1
into func2
if the former is a declaration and the latter is a public definition or it is a declaration as well. I added a comment along those lines. Let me know if anything else is needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A more fitting name could be beneficial, but I can't come up with one immediately. canMergeDeclarationInto
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, that could work. (As you probably also felt, it's not perfect: the name suggest that it has to be a declaration but it actually checks if it is a declaration.)
One argument in favor of keeping it as is is that this is a formulation that generalizes: In the SPIR-V module combiner, they "merge" global variables and constants, and those are "mergeable" if they define the same thing. This generalization could be formalized into a MergeableSymbol
interface (which extends Symbol
) with two methods canMergeInto
and mergeInto
. At least for these two examples (SPIR-V and TD), that concept seems to work...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Kind of late to the party, but I agree that the meaning is difficult to grasp currently.
More specifically I believe the comment doesn't describe what mergeable means.
For me the confusing part is that merging implies that the result is bigger. So before reading the code I was wondering if we were doing some sort of function concatenation :).
<digression>LLVM has an optimization pass called merge globals that puts global variables in the same global structure to allow faster symbol resolution and that's the kind of thing this function name brings to mind for me.</digression>
Maybe canCoalesceSymbols
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am open to using a different verb or name. From what I understand from the dictionary, unite, combine, merge, and coalesce all roughly mean "take several things and make a single one of them."
I guess that that is undoubtably what we do with the outer ops (and the result is indeed larger usually). I'd argue that it is close to what we do with the symbols as well but maybe not exactly the same, at least not what we do with functions currently. Another angle to look at what we do with symbols is that we "deduplicate" them. Maybe that can yield a name?
mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot, @ftynse, for the detailed feedback! I just pushed a bunch of commits that should address your comments (modulo the questions I am submitting here).
mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
Outdated
Show resolved
Hide resolved
} | ||
|
||
// Apply renaming. | ||
LLVM_DEBUG(llvm::dbgs() << ", renaming to @" << newName.getValue() << "\n"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This terminated a debug line that was started earlier but that wasn't visible in this context. I believe that after the refactoring of renameToUnique
, the context is now much smaller such that the intend should be understood. Let me know if you disagree.
LLVM_DEBUG(llvm::dbgs() << "cannot compare types\n"); | ||
/// Return whether `func1` can be merged into `func2`. | ||
bool canMergeInto(FunctionOpInterface func1, FunctionOpInterface func2) { | ||
return func1.isExternal() && (func2.isPublic() || func2.isExternal()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this logic allows merging func1
into func2
if the former is a declaration and the latter is a public definition or it is a declaration as well. I added a comment along those lines. Let me know if anything else is needed.
mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
Outdated
Show resolved
Hide resolved
LLVM_DEBUG(llvm::dbgs() << "cannot compare types\n"); | ||
/// Return whether `func1` can be merged into `func2`. | ||
bool canMergeInto(FunctionOpInterface func1, FunctionOpInterface func2) { | ||
return func1.isExternal() && (func2.isPublic() || func2.isExternal()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A more fitting name could be beneficial, but I can't come up with one immediately. canMergeDeclarationInto
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the review! Now all comments should be addressed. I'll merge once CI has passed. Should some of the minor things still bother you/us, then I can still submit a fix-up.
mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
Outdated
Show resolved
Hide resolved
LLVM_DEBUG(llvm::dbgs() << "cannot compare types\n"); | ||
/// Return whether `func1` can be merged into `func2`. | ||
bool canMergeInto(FunctionOpInterface func1, FunctionOpInterface func2) { | ||
return func1.isExternal() && (func2.isPublic() || func2.isExternal()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, that could work. (As you probably also felt, it's not perfect: the name suggest that it has to be a declaration but it actually checks if it is a declaration.)
One argument in favor of keeping it as is is that this is a formulation that generalizes: In the SPIR-V module combiner, they "merge" global variables and constants, and those are "mergeable" if they define the same thing. This generalization could be formalized into a MergeableSymbol
interface (which extends Symbol
) with two methods canMergeInto
and mergeInto
. At least for these two examples (SPIR-V and TD), that concept seems to work...
fd5ff47
to
e74aa59
Compare
Until now, the interpreter would only load those symbols from the provided library files that were declared in the main transform module. However, sequences in the library may include other sequences on their own. Until now, if such sequences were not *also* declared in the main transform module, the interpreter would fail to resolve them. Forward declaring all of them is undesirable as it defeats the purpose of encapsulation into library modules. This PR implements a kind of linker for transform scripts to solve this problem. The linker merges all symbols of the library module into the main module before interpreting the latter. Symbols whose names collide are handled as follows: (1) if they are both functions (in the sense of `FunctionOpInterface`) with compatible signatures, one is external, and the other one is public, then they are merged; (2) of one of them is private, that one is renamed; and (3) an error is raised otherwise.
* Move all private functions of the CPP file into anonymous namespace. * Remove test with second interpreter pass that reloads the library. I think that this shouldn't be possible. * Factor out `renameToUnique`, `canMergeInto`, and `mergeInto` into proper functions. * Use a single symbol table per input op and update it correctly whenever symbols or ops change. * Make `other` arg an `OwningOpRef` and clone the arguments where necessary. * Improve comments.
* Use `moveBefore` instead of `moveAfter` in order to work on empty targets as well. * Do not verify the target after moving each op, since the last op may use symbols of ops that still have to be moved.
Since the pass injects all definitions, providing the library again isn't needed. Since that injection isn't idempotent, it actually isn't even *possible* anymore, so this commits removes that argument.
e74aa59
to
71b40b1
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I told @ingomueller-net I was going to look at this, so here we are :).
Sorry for the delay. LGTM too, couple of comments inlined.
LLVM_DEBUG(llvm::dbgs() << "cannot compare types\n"); | ||
/// Return whether `func1` can be merged into `func2`. | ||
bool canMergeInto(FunctionOpInterface func1, FunctionOpInterface func2) { | ||
return func1.isExternal() && (func2.isPublic() || func2.isExternal()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Kind of late to the party, but I agree that the meaning is difficult to grasp currently.
More specifically I believe the comment doesn't describe what mergeable means.
For me the confusing part is that merging implies that the result is bigger. So before reading the code I was wondering if we were doing some sort of function concatenation :).
<digression>LLVM has an optimization pass called merge globals that puts global variables in the same global structure to allow faster symbol resolution and that's the kind of thing this function name brings to mind for me.</digression>
Maybe canCoalesceSymbols
?
} | ||
/// Merge `func1` into `func2`. The two ops must be inside the same parent op | ||
/// and mergable according to `canMergeInto`. The function erases `func1` such | ||
/// that only `func2` exists when the function returns. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll add a comment about "metadata" merging rules (i.e., around consumed and readonly)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You will or you would? I.e., who makes that commit?
/// the result. Public symbols may not collide, with the exception of | ||
/// instances of `SymbolOpInterface`, where collisions are allowed if at least | ||
/// one of the two is external, in which case the other op preserved (or any one | ||
/// of the two if both are external). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this comment would be easier to read if we break it down into several sentences (or bullet points?).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, sounds good. Let's have the person do it that adds the comment you suggested above.
@@ -0,0 +1,14 @@ | |||
// RUN: mlir-opt %s |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need a run line here?
If lit
doesn't skip this file, maybe we could move all the "include" files in a sub-dir and tell lit to ignore these files.
Edit: I see that we have patterns of files like that already (e.g., mlir/test/Dialect/Transform/test-interpreter-external-symbol-def.mlir
). I don't particularly like this kind of "fake" run line, but maybe we are actually testing something here?
@ftynse what do you think we should do?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I copied this from other files and speculated that this is done to (1) make lit not complain and (2) maybe make a test fail if the syntax of this specific file gets out of date rather than relying a test of another file?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @qcolombet! The suggestions about the comments makes sense. Who does them?
LLVM_DEBUG(llvm::dbgs() << "cannot compare types\n"); | ||
/// Return whether `func1` can be merged into `func2`. | ||
bool canMergeInto(FunctionOpInterface func1, FunctionOpInterface func2) { | ||
return func1.isExternal() && (func2.isPublic() || func2.isExternal()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am open to using a different verb or name. From what I understand from the dictionary, unite, combine, merge, and coalesce all roughly mean "take several things and make a single one of them."
I guess that that is undoubtably what we do with the outer ops (and the result is indeed larger usually). I'd argue that it is close to what we do with the symbols as well but maybe not exactly the same, at least not what we do with functions currently. Another angle to look at what we do with symbols is that we "deduplicate" them. Maybe that can yield a name?
} | ||
/// Merge `func1` into `func2`. The two ops must be inside the same parent op | ||
/// and mergable according to `canMergeInto`. The function erases `func1` such | ||
/// that only `func2` exists when the function returns. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You will or you would? I.e., who makes that commit?
/// the result. Public symbols may not collide, with the exception of | ||
/// instances of `SymbolOpInterface`, where collisions are allowed if at least | ||
/// one of the two is external, in which case the other op preserved (or any one | ||
/// of the two if both are external). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, sounds good. Let's have the person do it that adds the comment you suggested above.
@@ -0,0 +1,14 @@ | |||
// RUN: mlir-opt %s |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I copied this from other files and speculated that this is done to (1) make lit not complain and (2) maybe make a test fail if the syntax of this specific file gets out of date rather than relying a test of another file?
This is a new attempt at fixing transitive includes in the transform dialect interpreter (next to #67241) and a preparation for being able to load multiple transform library files (#67120).
Until now, the interpreter would only load those symbols from the provided library files that were declared in the main transform module. However, sequences in the library may include other sequences on their own. Until now, if such sequences were not also declared in the main transform module, the interpreter would fail to resolve them. Forward declaring all of them is undesirable as it defeats the purpose of encapsulation into library modules.
This PR implements a kind of linker for transform scripts to solve this problem. The linker merges all symbols of the library module into the main module before interpreting the latter. Symbols whose names collide are handled as follows: (1) if they are both functions (in the sense of
FunctionOpInterface
) with compatible signatures, one is external, and the other one is public, then they are merged; (2) of one of them is private, that one is renamed; and (3) an error is raised otherwise.TODO:
XXX
.