-
Notifications
You must be signed in to change notification settings - Fork 12.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][transform] Fix handling of transitive includes in interpreter. #67241
[mlir][transform] Fix handling of transitive includes in interpreter. #67241
Conversation
@llvm/pr-subscribers-mlir ChangesUntil now, the interpreter would only load those symbols from the provided library files that were declared in the main transform module. However, sequences in the library may include other sequences on their own. Until now, if such sequences were not also declared in the main transform module, the interpreter would fail to resolve them. Forward declaring all of them is undesirable as it defeats the purpose of encapsulation into library modules. This PR extends the loading missing as follows: in Full diff: https://github.com/llvm/llvm-project/pull/67241.diff 3 Files Affected:
diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
index d5c65b23e3a2134..aa2b1157c254b47 100644
--- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
+++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
@@ -311,6 +311,9 @@ static LogicalResult defineDeclaredSymbols(Block &block, ModuleOp definitions) {
auto readOnlyName =
StringAttr::get(&ctx, transform::TransformDialect::kArgReadOnlyAttrName);
+ // Collect symbols missing in the block.
+ SmallVector<SymbolOpInterface> missingSymbols;
+ LLVM_DEBUG(DBGS() << "searching block for missing symbols:\n");
for (Operation &op : llvm::make_early_inc_range(block)) {
LLVM_DEBUG(DBGS() << op << "\n");
auto symbol = dyn_cast<SymbolOpInterface>(op);
@@ -318,25 +321,33 @@ static LogicalResult defineDeclaredSymbols(Block &block, ModuleOp definitions) {
continue;
if (symbol->getNumRegions() == 1 && !symbol->getRegion(0).empty())
continue;
+ LLVM_DEBUG(DBGS() << " -> symbol missing\n");
+ missingSymbols.push_back(symbol);
+ }
- LLVM_DEBUG(DBGS() << "looking for definition of symbol "
- << symbol.getNameAttr() << ":");
- SymbolTable symbolTable(definitions);
- Operation *externalSymbol = symbolTable.lookup(symbol.getNameAttr());
+ // Resolve missing symbols until they are all resolved.
+ while (!missingSymbols.empty()) {
+ SymbolOpInterface symbol = missingSymbols.pop_back_val();
+ LLVM_DEBUG(DBGS() << "looking for definition of symbol @"
+ << symbol.getNameAttr().getValue() << ": ");
+ SymbolTable definitionsSymbolTable(definitions);
+ Operation *externalSymbol =
+ definitionsSymbolTable.lookup(symbol.getNameAttr());
if (!externalSymbol || externalSymbol->getNumRegions() != 1 ||
externalSymbol->getRegion(0).empty()) {
LLVM_DEBUG(llvm::dbgs() << "not found\n");
continue;
}
- auto symbolFunc = dyn_cast<FunctionOpInterface>(op);
+ auto symbolFunc = dyn_cast<FunctionOpInterface>(symbol.getOperation());
auto externalSymbolFunc = dyn_cast<FunctionOpInterface>(externalSymbol);
if (!symbolFunc || !externalSymbolFunc) {
LLVM_DEBUG(llvm::dbgs() << "cannot compare types\n");
continue;
}
- LLVM_DEBUG(llvm::dbgs() << "found @" << externalSymbol << "\n");
+ LLVM_DEBUG(llvm::dbgs() << "found " << externalSymbol << " from "
+ << externalSymbol->getLoc() << "\n");
if (symbolFunc.getFunctionType() != externalSymbolFunc.getFunctionType()) {
return symbolFunc.emitError()
<< "external definition has a mismatching signature ("
@@ -367,10 +378,52 @@ static LogicalResult defineDeclaredSymbols(Block &block, ModuleOp definitions) {
}
}
- OpBuilder builder(&op);
- builder.setInsertionPoint(&op);
+ OpBuilder builder(symbol);
+ builder.setInsertionPoint(symbol);
builder.clone(*externalSymbol);
symbol->erase();
+
+ LLVM_DEBUG(DBGS() << "scanning definition of @"
+ << externalSymbolFunc.getNameAttr().getValue()
+ << " for symbol usages\n");
+ externalSymbolFunc.walk([&](CallOpInterface callOp) {
+ LLVM_DEBUG(DBGS() << " found symbol usage in:\n" << callOp << "\n");
+ CallInterfaceCallable callable = callOp.getCallableForCallee();
+ if (!isa<SymbolRefAttr>(callable)) {
+ LLVM_DEBUG(DBGS() << " not a 'SymbolRefAttr'\n");
+ return WalkResult::advance();
+ }
+
+ StringRef callableSymbol =
+ cast<SymbolRefAttr>(callable).getLeafReference();
+ LLVM_DEBUG(DBGS() << " looking for @" << callableSymbol
+ << " in definitions: ");
+
+ Operation *callableOp = definitionsSymbolTable.lookup(callableSymbol);
+ if (!isa<SymbolRefAttr>(callable)) {
+ LLVM_DEBUG(llvm::dbgs() << "not found\n");
+ return WalkResult::advance();
+ }
+ LLVM_DEBUG(llvm::dbgs() << "found " << callableOp << " from "
+ << callableOp->getLoc() << "\n");
+
+ if (!block.getParent() || !block.getParent()->getParentOp()) {
+ LLVM_DEBUG(DBGS() << "could not get parent of provided block");
+ return WalkResult::advance();
+ }
+
+ SymbolTable targetSymbolTable(block.getParent()->getParentOp());
+ if (targetSymbolTable.lookup(callableSymbol)) {
+ LLVM_DEBUG(DBGS() << " symbol @" << callableSymbol
+ << " already present in target\n");
+ return WalkResult::advance();
+ }
+
+ LLVM_DEBUG(DBGS() << " cloning op into target\n");
+ builder.clone(*callableOp);
+
+ return WalkResult::advance();
+ });
}
return success();
diff --git a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-transitive.mlir b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-transitive.mlir
new file mode 100644
index 000000000000000..3a122ce2f77c3a8
--- /dev/null
+++ b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-transitive.mlir
@@ -0,0 +1,25 @@
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir})" \
+// RUN: --verify-diagnostics --split-input-file | FileCheck %s
+
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir}, test-transform-dialect-interpreter)" \
+// RUN: --verify-diagnostics --split-input-file | FileCheck %s
+
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir}, test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir})" \
+// RUN: --verify-diagnostics --split-input-file | FileCheck %s
+
+// The definition of the @foo named sequence is provided in another file. It
+// will be included because of the pass option. Repeated application of the
+// same pass, with or without the library option, should not be a problem.
+// Note that the same diagnostic produced twice at the same location only
+// needs to be matched once.
+
+// expected-remark @below {{message}}
+// expected-remark @below {{unannotated}}
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence private @bar(!transform.any_op {transform.readonly})
+
+ transform.sequence failures(propagate) {
+ ^bb0(%arg0: !transform.any_op):
+ include @bar failures(propagate) (%arg0) : (!transform.any_op) -> ()
+ }
+}
diff --git a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-def.mlir b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-def.mlir
index 1149bda98ab8527..9aa2d46d5abb995 100644
--- a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-def.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-def.mlir
@@ -1,6 +1,11 @@
// RUN: mlir-opt %s
module attributes {transform.with_named_sequence} {
+ transform.named_sequence @bar(%arg0: !transform.any_op) {
+ transform.include @foo failures(propagate) (%arg0) : (!transform.any_op) -> ()
+ transform.yield
+ }
+
transform.named_sequence @foo(%arg0: !transform.any_op {transform.readonly}) {
transform.test_print_remark_at_operand %arg0, "message" : !transform.any_op
transform.yield
|
if (!block.getParent() || !block.getParent()->getParentOp()) { | ||
LLVM_DEBUG(DBGS() << "could not get parent of provided block"); | ||
return WalkResult::advance(); | ||
} | ||
|
||
SymbolTable targetSymbolTable(block.getParent()->getParentOp()); | ||
if (targetSymbolTable.lookup(callableSymbol)) { | ||
LLVM_DEBUG(DBGS() << " symbol @" << callableSymbol | ||
<< " already present in target\n"); | ||
return WalkResult::advance(); | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not perfectly happy with this part yet. I am wondering whether defineDeclaredSymbols
shouldn't accept the parent op right away rather than a block. If that isn't acceptable for some reason, I wonder whether this shouldn't fail rather just outputting a silent debug message.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to be careful here, as this basically a proto-linker. One low-tech solution could be to just clone all symbols from the library file(s) into the main one, not just the declared/referenced ones. The difficult issue in any case is name clashes between files. AFAIR, named_sequence
doesn't have the visibility attribute and it probably should for this purpose. It would explicitly tell us where renaming a symbol to avoid the name clash is acceptable (private symbol) and when it isn't. When it isn't, the process needs to run in three stages: (1) collect all definitions (i.e., with bodies) of named sequences that need to be cloned without actually cloning them + check that there is exactly one definition of each if renaming is not allowed; (2) if needed and allowed, rename private symbols and update their uses in each module individually; (3) actually clone (also let's consider moving them instead).
Limiting the process to only the symbols that are actually needed is an optimization. We can do it, but can also punt. In the general case, we need something like connected components for all named sequences in the main file as it may also have unused declarations. There's CallGraph
for that and LLVM has algorithms to compute connected components so no need to write this manually.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, this is kind of a linker. I am actually wondering if it shouldn't live in a more central place?
But, boy, it's the third time that I realize that this is more involved than I originally thought :P But I think I understand the issue and should be able to fix it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Eventually, it could become a more generic utility. But I wouldn't go out of my way right now as I'm not aware of any other use case, just keeping it relatively separable by being interface-based (symbol tables, call graph) should be sufficient.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NamedSequenceOp
has the FunctionOpInterface
, which inherits from SymbolOpInterface
, which is where getVisibility
is declared/defined, so named sequences already have the visibility attribute as it seems.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, I think that you propose makes sense. Linking in all symbols instead of just the referenced once is something that I need for #67120.
One question, though: SymbolOpInterface
alone does not know the concept of "external" (aka the distinction between definition vs declaration); only FunctionOpInterface
does. However, the previous implementation of defineDeclaredSymbols
tested for SymbolOpInterface
and symbol->getNumRegions() == 1 && !symbol->getRegion(0).empty()
, which corresponds to the definition of isExternal
in FunctionOpInterface
. I think the more general case to deal with that is to implement the case that allows external functions (aka declarations) for the special case FunctionOpInterface
and to not allow clashes of public symbols otherwise. Do you agree?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good as a first approximation. I suspect that one will have to lift the declaration/definition distinction out of FunctionOpInterface
into either SymbolOpInterface
or a new interface as it makes sense for other kinds of symbols such as memref globals. We don't care about these for our use case, so let's stick with functions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Limiting the process to only the symbols that are actually needed is an optimization. We can do it, but can also punt. In the general case, we need something like connected components for all named sequences in the main file as it may also have unused declarations. There's
CallGraph
for that and LLVM has algorithms to compute connected components so no need to write this manually.
I briefly started to think about this. The call graph analysis works on CallOpInterface
and CallableOpInterface
, which I think should be enough for transform.include
and transform.sequence
at the moment. However, if some op used a symbol defined in the library module that does not implement the CallableOpInterface
, we wouldn't capture that relationship with the call graph analysis. I don't know of any (transform) op that does that currently, so it may not be an issue. For full generality, wouldn't we have to build a "used by" graph of symbols (or even a mixture of that and the call graph)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the general case, yes, we would need to care about things like addressof
that are not a call. I don't have an immediate plan for having something like that in the transform dialect.
06bd212
to
a7d0ea4
Compare
a7d0ea4
to
ab3e70b
Compare
Until now, the interpreter would only load those symbols from the provided library files that were declared in the main transform module. However, sequences in the library may include other sequences on their own. Until now, if such sequences were not *also* declared in the main transform module, the interpreter would fail to resolve them. Forward declaring all of them is undesirable as it defeats the purpose of encapsulation into library modules. This PR extends the loading missing as follows: in `defineDeclaredSymbols`, not only are the definitions inserted that are forward-declared in the main module, but any such inserted definition is scanned for further dependencies, and those are processed in the same way as the forward-declarations from the main module.
ab3e70b
to
2b7d6b4
Compare
Closing in favor or #67560, which has a better design and is more mature and complete. |
Until now, the interpreter would only load those symbols from the provided library files that were declared in the main transform module. However, sequences in the library may include other sequences on their own. Until now, if such sequences were not also declared in the main transform module, the interpreter would fail to resolve them. Forward declaring all of them is undesirable as it defeats the purpose of encapsulation into library modules.
This PR extends the loading missing as follows: in
defineDeclaredSymbols
, not only are the definitions inserted that are forward-declared in the main module, but any such inserted definition is scanned for further dependencies, and those are processed in the same way as the forward-declarations from the main module.