Skip to content

[mlir][transform] Fix handling of transitive includes in interpreter. #67241

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -311,32 +311,43 @@ static LogicalResult defineDeclaredSymbols(Block &block, ModuleOp definitions) {
auto readOnlyName =
StringAttr::get(&ctx, transform::TransformDialect::kArgReadOnlyAttrName);

// Collect symbols missing in the block.
SmallVector<SymbolOpInterface> missingSymbols;
LLVM_DEBUG(DBGS() << "searching block for missing symbols:\n");
for (Operation &op : llvm::make_early_inc_range(block)) {
LLVM_DEBUG(DBGS() << op << "\n");
auto symbol = dyn_cast<SymbolOpInterface>(op);
if (!symbol)
continue;
if (symbol->getNumRegions() == 1 && !symbol->getRegion(0).empty())
continue;
LLVM_DEBUG(DBGS() << " -> symbol missing\n");
missingSymbols.push_back(symbol);
}

LLVM_DEBUG(DBGS() << "looking for definition of symbol "
<< symbol.getNameAttr() << ":");
SymbolTable symbolTable(definitions);
Operation *externalSymbol = symbolTable.lookup(symbol.getNameAttr());
// Resolve missing symbols until they are all resolved.
while (!missingSymbols.empty()) {
SymbolOpInterface symbol = missingSymbols.pop_back_val();
LLVM_DEBUG(DBGS() << "looking for definition of symbol @"
<< symbol.getNameAttr().getValue() << ": ");
SymbolTable definitionsSymbolTable(definitions);
Operation *externalSymbol =
definitionsSymbolTable.lookup(symbol.getNameAttr());
if (!externalSymbol || externalSymbol->getNumRegions() != 1 ||
externalSymbol->getRegion(0).empty()) {
LLVM_DEBUG(llvm::dbgs() << "not found\n");
continue;
}

auto symbolFunc = dyn_cast<FunctionOpInterface>(op);
auto symbolFunc = dyn_cast<FunctionOpInterface>(symbol.getOperation());
auto externalSymbolFunc = dyn_cast<FunctionOpInterface>(externalSymbol);
if (!symbolFunc || !externalSymbolFunc) {
LLVM_DEBUG(llvm::dbgs() << "cannot compare types\n");
continue;
}

LLVM_DEBUG(llvm::dbgs() << "found @" << externalSymbol << "\n");
LLVM_DEBUG(llvm::dbgs() << "found " << externalSymbol << " from "
<< externalSymbol->getLoc() << "\n");
if (symbolFunc.getFunctionType() != externalSymbolFunc.getFunctionType()) {
return symbolFunc.emitError()
<< "external definition has a mismatching signature ("
Expand Down Expand Up @@ -367,10 +378,53 @@ static LogicalResult defineDeclaredSymbols(Block &block, ModuleOp definitions) {
}
}

OpBuilder builder(&op);
builder.setInsertionPoint(&op);
builder.clone(*externalSymbol);
OpBuilder builder(symbol);
builder.setInsertionPoint(symbol);
Operation *newSymbol = builder.clone(*externalSymbol);
builder.setInsertionPoint(newSymbol);
symbol->erase();

LLVM_DEBUG(DBGS() << "scanning definition of @"
<< externalSymbolFunc.getNameAttr().getValue()
<< " for symbol usages\n");
externalSymbolFunc.walk([&](CallOpInterface callOp) {
LLVM_DEBUG(DBGS() << " call op in:\n" << callOp << "\n");
CallInterfaceCallable callable = callOp.getCallableForCallee();
if (!isa<SymbolRefAttr>(callable)) {
LLVM_DEBUG(DBGS() << " not a symbol usage\n");
return WalkResult::advance();
}

StringRef callableSymbolName =
cast<SymbolRefAttr>(callable).getLeafReference();
LLVM_DEBUG(DBGS() << " looking for @" << callableSymbolName
<< " in definitions: ");

Operation *callableOp = definitionsSymbolTable.lookup(callableSymbolName);
if (!isa<SymbolRefAttr>(callable)) {
LLVM_DEBUG(llvm::dbgs() << "not found\n");
return WalkResult::advance();
}
LLVM_DEBUG(llvm::dbgs() << "found " << callableOp << " from "
<< callableOp->getLoc() << "\n");

if (!block.getParent() || !block.getParent()->getParentOp()) {
LLVM_DEBUG(DBGS() << "could not get parent op of provided block");
return WalkResult::advance();
}

SymbolTable targetSymbolTable(block.getParent()->getParentOp());
if (targetSymbolTable.lookup(callableSymbolName)) {
LLVM_DEBUG(DBGS() << " symbol @" << callableSymbolName
<< " already present in target\n");
return WalkResult::advance();
}

LLVM_DEBUG(DBGS() << " cloning op into target\n");
builder.clone(*callableOp);

return WalkResult::advance();
});
}

return success();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir})" \
// RUN: --verify-diagnostics --split-input-file | FileCheck %s

// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir}, test-transform-dialect-interpreter)" \
// RUN: --verify-diagnostics --split-input-file | FileCheck %s

// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir}, test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir})" \
// RUN: --verify-diagnostics --split-input-file | FileCheck %s

// The definition of the @bar named sequence is provided in another file. It
// will be included because of the pass option. That sequence uses another named
// sequence @foo, which should be made available here. Repeated application of
// the same pass, with or without the library option, should not be a problem.
// Note that the same diagnostic produced twice at the same location only
// needs to be matched once.

// expected-remark @below {{message}}
module attributes {transform.with_named_sequence} {
// CHECK-DAG: transform.named_sequence @foo
// CHECK-DAG: transform.named_sequence @bar
transform.named_sequence private @bar(!transform.any_op {transform.readonly})

transform.sequence failures(propagate) {
^bb0(%arg0: !transform.any_op):
include @bar failures(propagate) (%arg0) : (!transform.any_op) -> ()
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
// RUN: mlir-opt %s

module attributes {transform.with_named_sequence} {
transform.named_sequence @bar(%arg0: !transform.any_op) {
transform.include @foo failures(propagate) (%arg0) : (!transform.any_op) -> ()
transform.yield
}

transform.named_sequence @foo(%arg0: !transform.any_op {transform.readonly}) {
transform.test_print_remark_at_operand %arg0, "message" : !transform.any_op
transform.yield
Expand Down