diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td index 70a76ab9670f9..f28205a255070 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td @@ -43,6 +43,15 @@ def Transform_Dialect : Dialect { constexpr const static ::llvm::StringLiteral kArgReadOnlyAttrName = "transform.readonly"; + /// Names of the attributes indicating whether an argument of an external + /// transform dialect symbol is consumed or only read. + StringAttr getConsumedAttrName() const { + return StringAttr::get(getContext(), kArgConsumedAttrName); + } + StringAttr getReadOnlyAttrName() const { + return StringAttr::get(getContext(), kArgReadOnlyAttrName); + } + template const DataTy &getExtraData() const { return *static_cast( diff --git a/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h b/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h index 6102417ceda1a..a6f0dddebd7ea 100644 --- a/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h +++ b/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h @@ -62,9 +62,11 @@ LogicalResult interpreterBaseRunOnOperationImpl( /// transform script. If empty, `debugTransformRootTag` is considered or the /// pass root operation must contain a single top-level transform op that /// will be interpreted. -/// - transformLibraryFileName: if non-empty, the name of the file containing -/// definitions of external symbols referenced in the transform script. -/// These definitions will be used to replace declarations. +/// - transformLibraryFileName: if non-empty, the module in this file will be +/// merged into the main transform script run by the interpreter before +/// execution. This allows to provide definitions for external functions +/// used in the main script. Other public symbols in the library module may +/// lead to collisions with public symbols in the main script. /// - debugPayloadRootTag: if non-empty, the value of the attribute named /// `kTransformDialectTagAttrName` indicating the single op that is /// considered the payload root of the transform interpreter; otherwise, the @@ -85,7 +87,7 @@ LogicalResult interpreterBaseRunOnOperationImpl( /// as template arguments. They are *not* expected to to implement `initialize` /// or `runOnOperation`. They *are* expected to call the copy constructor of /// this class in their copy constructors, short of which the file-based -/// transform dialect script injection facility will become nonoperational. +/// transform dialect script injection facility will become non-operational. /// /// Concrete passes may implement the `runBeforeInterpreter` and /// `runAfterInterpreter` to customize the behavior of the pass. diff --git a/mlir/include/mlir/IR/SymbolTable.h b/mlir/include/mlir/IR/SymbolTable.h index 33427788a075e..7f21f22eba951 100644 --- a/mlir/include/mlir/IR/SymbolTable.h +++ b/mlir/include/mlir/IR/SymbolTable.h @@ -55,6 +55,23 @@ class SymbolTable { /// after insertion as attribute. StringAttr insert(Operation *symbol, Block::iterator insertPt = {}); + /// Renames the given op or the op refered to by the given name to the given + /// new name and updates the symbol table and all usages of the symbol + /// accordingly. Fails if the updating of the usages fails. + LogicalResult rename(StringAttr from, StringAttr to); + LogicalResult rename(Operation *op, StringAttr to); + LogicalResult rename(StringAttr from, StringRef to); + LogicalResult rename(Operation *op, StringRef to); + + /// Renames the given op or the op refered to by the given name to the a name + /// that is unique within this and the provided other symbol tables and + /// updates the symbol table and all usages of the symbol accordingly. Returns + /// the new name or failure if the renaming fails. + FailureOr renameToUnique(StringAttr from, + ArrayRef others); + FailureOr renameToUnique(Operation *op, + ArrayRef others); + /// Return the name of the attribute used for symbol names. static StringRef getSymbolAttrName() { return "sym_name"; } diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp index 23640c92457a8..68a735e7ef8e0 100644 --- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp +++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp @@ -161,17 +161,9 @@ static llvm::raw_ostream & printReproCall(llvm::raw_ostream &os, StringRef rootOpName, StringRef passName, const Pass::Option &debugPayloadRootTag, const Pass::Option &debugTransformRootTag, - const Pass::Option &transformLibraryFileName, StringRef binaryName) { - std::string transformLibraryOption = ""; - if (!transformLibraryFileName.empty()) { - transformLibraryOption = - llvm::formatv(" {0}={1}", transformLibraryFileName.getArgStr(), - transformLibraryFileName.getValue()) - .str(); - } os << llvm::formatv( - "{7} --pass-pipeline=\"{0}({1}{{{2}={3} {4}={5}{6}})\"", rootOpName, + "{6} --pass-pipeline=\"{0}({1}{{{2}={3} {4}={5}})\"", rootOpName, passName, debugPayloadRootTag.getArgStr(), debugPayloadRootTag.empty() ? StringRef(kTransformDialectTagPayloadRootValue) @@ -180,14 +172,15 @@ printReproCall(llvm::raw_ostream &os, StringRef rootOpName, StringRef passName, debugTransformRootTag.empty() ? StringRef(kTransformDialectTagTransformContainerValue) : debugTransformRootTag, - transformLibraryOption, binaryName); + binaryName); return os; } /// Prints the module rooted at `root` to `os` and appends /// `transformContainer` if it is not nested in `root`. -llvm::raw_ostream &printModuleForRepro(llvm::raw_ostream &os, Operation *root, - Operation *transform) { +static llvm::raw_ostream &printModuleForRepro(llvm::raw_ostream &os, + Operation *root, + Operation *transform) { root->print(os); if (!root->isAncestor(transform)) transform->print(os); @@ -196,12 +189,13 @@ llvm::raw_ostream &printModuleForRepro(llvm::raw_ostream &os, Operation *root, /// Saves the payload and the transform IR into a temporary file and reports /// the file name to `os`. -void saveReproToTempFile( - llvm::raw_ostream &os, Operation *target, Operation *transform, - StringRef passName, const Pass::Option &debugPayloadRootTag, - const Pass::Option &debugTransformRootTag, - const Pass::Option &transformLibraryFileName, - StringRef binaryName) { +static void +saveReproToTempFile(llvm::raw_ostream &os, Operation *target, + Operation *transform, StringRef passName, + const Pass::Option &debugPayloadRootTag, + const Pass::Option &debugTransformRootTag, + const Pass::Option &transformLibraryFileName, + StringRef binaryName) { using llvm::sys::fs::TempFile; Operation *root = getRootOperation(target); @@ -226,8 +220,7 @@ void saveReproToTempFile( os << "=== Transform Interpreter Repro ===\n"; printReproCall(os, root->getName().getStringRef(), passName, - debugPayloadRootTag, debugTransformRootTag, - transformLibraryFileName, binaryName) + debugPayloadRootTag, debugTransformRootTag, binaryName) << " " << filename << "\n"; os << "===================================\n"; } @@ -281,8 +274,7 @@ static void performOptionalDebugActions( llvm::dbgs() << "=== Transform Interpreter Repro ===\n"; printReproCall(llvm::dbgs() << "cat <getName().getStringRef(), passName, - debugPayloadRootTag, debugTransformRootTag, - transformLibraryFileName, binaryName) + debugPayloadRootTag, debugTransformRootTag, binaryName) << "\n"; printModuleForRepro(llvm::dbgs(), root, transform); llvm::dbgs() << "\nEOF\n"; @@ -302,77 +294,236 @@ static void performOptionalDebugActions( transform->removeAttr(kTransformDialectTagAttrName); } -/// Replaces external symbols in `block` with their (non-external) definitions -/// from the given module. -static LogicalResult defineDeclaredSymbols(Block &block, ModuleOp definitions) { - MLIRContext &ctx = *definitions->getContext(); - auto consumedName = - StringAttr::get(&ctx, transform::TransformDialect::kArgConsumedAttrName); - auto readOnlyName = - StringAttr::get(&ctx, transform::TransformDialect::kArgReadOnlyAttrName); - - for (Operation &op : llvm::make_early_inc_range(block)) { - LLVM_DEBUG(DBGS() << op << "\n"); - auto symbol = dyn_cast(op); - if (!symbol) - continue; - if (symbol->getNumRegions() == 1 && !symbol->getRegion(0).empty()) - continue; +/// Return whether `func1` can be merged into `func2`. For that to work `func1` +/// has to be a declaration (aka has to be external) and `func2` either has to +/// be a declaration as well, or it has to be public (otherwise, it wouldn't +/// be visible by `func1`). +static bool canMergeInto(FunctionOpInterface func1, FunctionOpInterface func2) { + return func1.isExternal() && (func2.isPublic() || func2.isExternal()); +} - LLVM_DEBUG(DBGS() << "looking for definition of symbol " - << symbol.getNameAttr() << ":"); - SymbolTable symbolTable(definitions); - Operation *externalSymbol = symbolTable.lookup(symbol.getNameAttr()); - if (!externalSymbol || externalSymbol->getNumRegions() != 1 || - externalSymbol->getRegion(0).empty()) { - LLVM_DEBUG(llvm::dbgs() << "not found\n"); - continue; - } +/// Merge `func1` into `func2`. The two ops must be inside the same parent op +/// and mergable according to `canMergeInto`. The function erases `func1` such +/// that only `func2` exists when the function returns. +static LogicalResult mergeInto(FunctionOpInterface func1, + FunctionOpInterface func2) { + assert(canMergeInto(func1, func2)); + assert(func1->getParentOp() == func2->getParentOp() && + "expected func1 and func2 to be in the same parent op"); + + // Check that function signatures match. + if (func1.getFunctionType() != func2.getFunctionType()) { + return func1.emitError() + << "external definition has a mismatching signature (" + << func2.getFunctionType() << ")"; + } - auto symbolFunc = dyn_cast(op); - auto externalSymbolFunc = dyn_cast(externalSymbol); - if (!symbolFunc || !externalSymbolFunc) { - LLVM_DEBUG(llvm::dbgs() << "cannot compare types\n"); + // Check and merge argument attributes. + MLIRContext *context = func1->getContext(); + auto *td = context->getLoadedDialect(); + StringAttr consumedName = td->getConsumedAttrName(); + StringAttr readOnlyName = td->getReadOnlyAttrName(); + for (unsigned i = 0, e = func1.getNumArguments(); i < e; ++i) { + bool isExternalConsumed = func2.getArgAttr(i, consumedName) != nullptr; + bool isExternalReadonly = func2.getArgAttr(i, readOnlyName) != nullptr; + bool isConsumed = func1.getArgAttr(i, consumedName) != nullptr; + bool isReadonly = func1.getArgAttr(i, readOnlyName) != nullptr; + if (!isExternalConsumed && !isExternalReadonly) { + if (isConsumed) + func2.setArgAttr(i, consumedName, UnitAttr::get(context)); + else if (isReadonly) + func2.setArgAttr(i, readOnlyName, UnitAttr::get(context)); continue; } - LLVM_DEBUG(llvm::dbgs() << "found @" << externalSymbol << "\n"); - if (symbolFunc.getFunctionType() != externalSymbolFunc.getFunctionType()) { - return symbolFunc.emitError() - << "external definition has a mismatching signature (" - << externalSymbolFunc.getFunctionType() << ")"; + if ((isExternalConsumed && !isConsumed) || + (isExternalReadonly && !isReadonly)) { + return func1.emitError() + << "external definition has mismatching consumption " + "annotations for argument #" + << i; } + } + + // `func1` is the external one, so we can remove it. + assert(func1.isExternal()); + func1->erase(); - for (unsigned i = 0, e = symbolFunc.getNumArguments(); i < e; ++i) { - bool isExternalConsumed = - externalSymbolFunc.getArgAttr(i, consumedName) != nullptr; - bool isExternalReadonly = - externalSymbolFunc.getArgAttr(i, readOnlyName) != nullptr; - bool isConsumed = symbolFunc.getArgAttr(i, consumedName) != nullptr; - bool isReadonly = symbolFunc.getArgAttr(i, readOnlyName) != nullptr; - if (!isExternalConsumed && !isExternalReadonly) { - if (isConsumed) - externalSymbolFunc.setArgAttr(i, consumedName, UnitAttr::get(&ctx)); - else if (isReadonly) - externalSymbolFunc.setArgAttr(i, readOnlyName, UnitAttr::get(&ctx)); + return success(); +} + +/// Merge all symbols from `other` into `target`. Both ops need to implement the +/// `SymbolTable` trait. Operations are moved from `other`, i.e., `other` may be +/// modified by this function and might not verify after the function returns. +/// Upon merging, private symbols may be renamed in order to avoid collisions in +/// the result. Public symbols may not collide, with the exception of +/// instances of `SymbolOpInterface`, where collisions are allowed if at least +/// one of the two is external, in which case the other op preserved (or any one +/// of the two if both are external). +// TODO: Reconsider cloning individual ops rather than forcing users of the +// function to clone (or move) `other` in order to improve efficiency. +// This might primarily make sense if we can also prune the symbols that +// are merged to a subset (such as those that are actually used). +static LogicalResult mergeSymbolsInto(Operation *target, + OwningOpRef other) { + assert(target->hasTrait() && + "requires target to implement the 'SymbolTable' trait"); + assert(other->hasTrait() && + "requires target to implement the 'SymbolTable' trait"); + + SymbolTable targetSymbolTable(target); + SymbolTable otherSymbolTable(*other); + + // Step 1: + // + // Rename private symbols in both ops in order to resolve conflicts that can + // be resolved that way. + LLVM_DEBUG(DBGS() << "renaming private symbols to resolve conflicts:\n"); + // TODO: Do we *actually* need to test in both directions? + for (auto &&[symbolTable, otherSymbolTable] : llvm::zip( + SmallVector{&targetSymbolTable, &otherSymbolTable}, + SmallVector{&otherSymbolTable, + &targetSymbolTable})) { + Operation *symbolTableOp = symbolTable->getOp(); + for (Operation &op : symbolTableOp->getRegion(0).front()) { + auto symbolOp = dyn_cast(op); + if (!symbolOp) continue; + StringAttr name = symbolOp.getNameAttr(); + LLVM_DEBUG(DBGS() << " found @" << name.getValue() << "\n"); + + // Check if there is a colliding op in the other module. + auto collidingOp = + cast_or_null(otherSymbolTable->lookup(name)); + if (!collidingOp) + continue; + + LLVM_DEBUG(DBGS() << " collision found for @" << name.getValue()); + + // Collisions are fine if both opt are functions and can be merged. + if (auto funcOp = dyn_cast(op), + collidingFuncOp = + dyn_cast(collidingOp.getOperation()); + funcOp && collidingFuncOp) { + if (canMergeInto(funcOp, collidingFuncOp) || + canMergeInto(collidingFuncOp, funcOp)) { + LLVM_DEBUG(llvm::dbgs() << " but both ops are functions and " + "will be merged\n"); + continue; + } + + // If they can't be merged, proceed like any other collision. + LLVM_DEBUG(llvm::dbgs() << " and both ops are function definitions"); } - if ((isExternalConsumed && !isConsumed) || - (isExternalReadonly && !isReadonly)) { - return symbolFunc.emitError() - << "external definition has mismatching consumption annotations " - "for argument #" - << i; + // Collision can be resolved by renaming if one of the ops is private. + auto renameToUnique = + [&](SymbolOpInterface op, SymbolOpInterface otherOp, + SymbolTable &symbolTable, + SymbolTable &otherSymbolTable) -> LogicalResult { + LLVM_DEBUG(llvm::dbgs() << ", renaming\n"); + FailureOr maybeNewName = + symbolTable.renameToUnique(op, {&otherSymbolTable}); + if (failed(maybeNewName)) { + InFlightDiagnostic diag = op->emitError("failed to rename symbol"); + diag.attachNote(otherOp->getLoc()) + << "attempted renaming due to collision with this op"; + return diag; + } + LLVM_DEBUG(DBGS() << " renamed to @" << maybeNewName->getValue() + << "\n"); + return success(); + }; + + if (symbolOp.isPrivate()) { + if (failed(renameToUnique(symbolOp, collidingOp, *symbolTable, + *otherSymbolTable))) + return failure(); + continue; } + if (collidingOp.isPrivate()) { + if (failed(renameToUnique(collidingOp, symbolOp, *otherSymbolTable, + *symbolTable))) + return failure(); + continue; + } + + LLVM_DEBUG(llvm::dbgs() << ", emitting error\n"); + InFlightDiagnostic diag = symbolOp.emitError() + << "doubly defined symbol @" << name.getValue(); + diag.attachNote(collidingOp->getLoc()) << "previously defined here"; + return diag; } + } + + // TODO: This duplicates pass infrastructure. We should split this pass into + // several and let the pass infrastructure do the verification. + for (auto *op : SmallVector{target, *other}) { + if (failed(mlir::verify(op))) + return op->emitError() << "failed to verify input op after renaming"; + } + + // Step 2: + // + // Move all ops from `other` into target and merge public symbols. + LLVM_DEBUG(DBGS() << "moving all symbols into target\n"); + { + SmallVector opsToMove; + for (Operation &op : other->getRegion(0).front()) { + if (auto symbol = dyn_cast(op)) + opsToMove.push_back(symbol); + } + + for (SymbolOpInterface op : opsToMove) { + // Remember potentially colliding op in the target module. + auto collidingOp = cast_or_null( + targetSymbolTable.lookup(op.getNameAttr())); + + // Move op even if we get a collision. + LLVM_DEBUG(DBGS() << " moving @" << op.getName()); + op->moveBefore(&target->getRegion(0).front(), + target->getRegion(0).front().end()); + + // If there is no collision, we are done. + if (!collidingOp) { + LLVM_DEBUG(llvm::dbgs() << " without collision\n"); + continue; + } + + // The two colliding ops must both be functions because we have already + // emitted errors otherwise earlier. + auto funcOp = cast(op.getOperation()); + auto collidingFuncOp = + cast(collidingOp.getOperation()); + + // Both ops are in the target module now and can be treated symmetrically, + // so w.l.o.g. we can reduce to merging `funcOp` into `collidingFuncOp`. + if (!canMergeInto(funcOp, collidingFuncOp)) { + std::swap(funcOp, collidingFuncOp); + } + assert(canMergeInto(funcOp, collidingFuncOp)); + + LLVM_DEBUG(llvm::dbgs() << " with collision, trying to keep op at " + << collidingFuncOp.getLoc() << ":\n" + << collidingFuncOp << "\n"); - OpBuilder builder(&op); - builder.setInsertionPoint(&op); - builder.clone(*externalSymbol); - symbol->erase(); + // Update symbol table. This works with or without the previous `swap`. + targetSymbolTable.remove(funcOp); + targetSymbolTable.insert(collidingFuncOp); + assert(targetSymbolTable.lookup(funcOp.getName()) == collidingFuncOp); + + // Do the actual merging. + if (failed(mergeInto(funcOp, collidingFuncOp))) { + return failure(); + } + } } + if (failed(mlir::verify(target))) + return target->emitError() + << "failed to verify target op after merging symbols"; + + LLVM_DEBUG(DBGS() << "done merging ops\n"); return success(); } @@ -443,8 +594,9 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl( diag.attachNote(target->getLoc()) << "pass anchor op"; return diag; } - if (failed(defineDeclaredSymbols(*transformRoot->getBlock(), - transformLibraryModule->get()))) + if (failed( + mergeSymbolsInto(SymbolTable::getNearestSymbolTable(transformRoot), + transformLibraryModule->get()->clone()))) return failure(); } @@ -506,8 +658,8 @@ LogicalResult transform::detail::interpreterBaseInitializeImpl( return success(); if (sharedTransformModule && *sharedTransformModule) { - if (failed(defineDeclaredSymbols(*sharedTransformModule->get().getBody(), - parsedLibraryModule.get()))) + if (failed(mergeSymbolsInto(sharedTransformModule->get(), + std::move(parsedLibraryModule)))) return failure(); } else { transformLibraryModule = diff --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp index 2494cb7086f0d..8ff1859e1383f 100644 --- a/mlir/lib/IR/SymbolTable.cpp +++ b/mlir/lib/IR/SymbolTable.cpp @@ -218,6 +218,79 @@ StringAttr SymbolTable::insert(Operation *symbol, Block::iterator insertPt) { return getSymbolName(symbol); } +LogicalResult SymbolTable::rename(StringAttr from, StringAttr to) { + Operation *op = lookup(from); + return rename(op, to); +} + +LogicalResult SymbolTable::rename(Operation *op, StringAttr to) { + StringAttr from = getNameIfSymbol(op); + + assert(from && "expected valid 'name' attribute"); + assert(op->getParentOp() == symbolTableOp && + "expected this operation to be inside of the operation with this " + "SymbolTable"); + assert(lookup(from) == op && "current name does not resolve to op"); + assert(lookup(to) == nullptr && "new name already exists"); + + if (failed(SymbolTable::replaceAllSymbolUses(op, to, getOp()))) + return failure(); + + // Remove op with old name, change name, add with new name. The order is + // important here due to how `remove` and `insert` rely on the op name. + remove(op); + setSymbolName(op, to); + insert(op); + + assert(lookup(to) == op && "new name does not resolve to renamed op"); + assert(lookup(from) == nullptr && "old name still exists"); + + return success(); +} + +LogicalResult SymbolTable::rename(StringAttr from, StringRef to) { + auto toAttr = StringAttr::get(getOp()->getContext(), to); + return rename(from, toAttr); +} + +LogicalResult SymbolTable::rename(Operation *op, StringRef to) { + auto toAttr = StringAttr::get(getOp()->getContext(), to); + return rename(op, toAttr); +} + +FailureOr +SymbolTable::renameToUnique(StringAttr oldName, + ArrayRef others) { + + // Determine new name that is unique in all symbol tables. + StringAttr newName; + { + MLIRContext *context = oldName.getContext(); + SmallString<64> prefix = oldName.getValue(); + int uniqueId = 0; + prefix.push_back('_'); + while (true) { + newName = StringAttr::get(context, prefix + Twine(uniqueId++)); + auto lookupNewName = [&](SymbolTable *st) { return st->lookup(newName); }; + if (!lookupNewName(this) && llvm::none_of(others, lookupNewName)) { + break; + } + } + } + + // Apply renaming. + if (failed(rename(oldName, newName))) + return failure(); + return newName; +} + +FailureOr +SymbolTable::renameToUnique(Operation *op, ArrayRef others) { + StringAttr from = getNameIfSymbol(op); + assert(from && "expected valid 'name' attribute"); + return renameToUnique(from, others); +} + /// Returns the name of the given symbol operation. StringAttr SymbolTable::getSymbolName(Operation *symbol) { StringAttr name = getNameIfSymbol(symbol); diff --git a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-and-schedule.mlir b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-and-schedule.mlir index 3d4cb07769829..dd8d141e994da 100644 --- a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-and-schedule.mlir +++ b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-and-schedule.mlir @@ -11,4 +11,10 @@ // expected-remark @below {{message}} // expected-remark @below {{unannotated}} +// expected-remark @below {{internal colliding (without suffix)}} +// expected-remark @below {{internal colliding_0}} +// expected-remark @below {{internal colliding_1}} +// expected-remark @below {{internal colliding_3}} +// expected-remark @below {{internal colliding_4}} +// expected-remark @below {{internal colliding_5}} module {} diff --git a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-invalid.mlir b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-invalid.mlir index b21abbbdfd6d0..7452deb39b6c1 100644 --- a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-invalid.mlir +++ b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-invalid.mlir @@ -1,16 +1,16 @@ -// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir}, test-transform-dialect-interpreter)" \ +// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def-invalid.mlir}, test-transform-dialect-interpreter)" \ // RUN: --verify-diagnostics --split-input-file -// The definition of the @foo named sequence is provided in another file. It +// The definition of the @print_message named sequence is provided in another file. It // will be included because of the pass option. module attributes {transform.with_named_sequence} { // expected-error @below {{external definition has a mismatching signature}} - transform.named_sequence private @foo(!transform.op<"builtin.module"> {transform.readonly}) + transform.named_sequence private @print_message(!transform.op<"builtin.module"> {transform.readonly}) transform.sequence failures(propagate) { ^bb0(%arg0: !transform.op<"builtin.module">): - include @foo failures(propagate) (%arg0) : (!transform.op<"builtin.module">) -> () + include @print_message failures(propagate) (%arg0) : (!transform.op<"builtin.module">) -> () } } @@ -37,3 +37,18 @@ module attributes {transform.with_named_sequence} { include @consuming failures(suppress) (%arg0) : (!transform.any_op) -> () } } + +// ----- + +module attributes {transform.with_named_sequence} { + // expected-error @below {{doubly defined symbol @print_message}} + transform.named_sequence @print_message(%arg0: !transform.any_op {transform.readonly}) { + transform.test_print_remark_at_operand %arg0, "message" : !transform.any_op + transform.yield + } + + transform.sequence failures(suppress) { + ^bb0(%arg0: !transform.any_op): + include @print_message failures(propagate) (%arg0) : (!transform.any_op) -> () + } +} diff --git a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir index 04b6c5a02e0ad..7d0837abebde3 100644 --- a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir +++ b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir @@ -4,29 +4,68 @@ // RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir}, test-transform-dialect-interpreter)" \ // RUN: --verify-diagnostics --split-input-file | FileCheck %s -// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir}, test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir})" \ -// RUN: --verify-diagnostics --split-input-file | FileCheck %s - -// The definition of the @foo named sequence is provided in another file. It -// will be included because of the pass option. Repeated application of the -// same pass, with or without the library option, should not be a problem. +// The definition of the @print_message named sequence is provided in another +// file. It will be included because of the pass option. Subsequent application +// of the same pass works but only without the library file (since the first +// application loads external symbols and loading them again woul make them +// clash). // Note that the same diagnostic produced twice at the same location only // needs to be matched once. // expected-remark @below {{message}} // expected-remark @below {{unannotated}} +// expected-remark @below {{internal colliding (without suffix)}} +// expected-remark @below {{internal colliding_0}} +// expected-remark @below {{internal colliding_1}} +// expected-remark @below {{internal colliding_3}} +// expected-remark @below {{internal colliding_4}} +// expected-remark @below {{internal colliding_5}} module attributes {transform.with_named_sequence} { - // CHECK: transform.named_sequence @foo - // CHECK: test_print_remark_at_operand %{{.*}}, "message" - transform.named_sequence private @foo(!transform.any_op {transform.readonly}) + // CHECK-DAG: transform.named_sequence @print_message( + // CHECK-DAG: transform.include @private_helper + transform.named_sequence private @print_message(!transform.any_op {transform.readonly}) + + // These ops collide with ops from the other module before or after renaming. + transform.named_sequence private @colliding(%arg0: !transform.any_op {transform.readonly}) { + transform.test_print_remark_at_operand %arg0, "internal colliding (without suffix)" : !transform.any_op + transform.yield + } + transform.named_sequence private @colliding_0(%arg0: !transform.any_op {transform.readonly}) { + transform.test_print_remark_at_operand %arg0, "internal colliding_0" : !transform.any_op + transform.yield + } + transform.named_sequence private @colliding_1(%arg0: !transform.any_op {transform.readonly}) { + transform.test_print_remark_at_operand %arg0, "internal colliding_1" : !transform.any_op + transform.yield + } + transform.named_sequence private @colliding_3(%arg0: !transform.any_op {transform.readonly}) { + transform.test_print_remark_at_operand %arg0, "internal colliding_3" : !transform.any_op + transform.yield + } + // This symbol is public and thus can't be renamed. + // CHECK-DAG: transform.named_sequence @colliding_4( + transform.named_sequence @colliding_4(%arg0: !transform.any_op {transform.readonly}) { + transform.test_print_remark_at_operand %arg0, "internal colliding_4" : !transform.any_op + transform.yield + } + transform.named_sequence private @colliding_5(%arg0: !transform.any_op {transform.readonly}) { + transform.test_print_remark_at_operand %arg0, "internal colliding_5" : !transform.any_op + transform.yield + } - // CHECK: transform.named_sequence @unannotated - // CHECK: test_print_remark_at_operand %{{.*}}, "unannotated" - transform.named_sequence private @unannotated(!transform.any_op {transform.readonly}) + // CHECK-DAG: transform.named_sequence @unannotated( + // CHECK-DAG: test_print_remark_at_operand %{{.*}}, "unannotated" + transform.named_sequence @unannotated(!transform.any_op {transform.readonly}) transform.sequence failures(propagate) { ^bb0(%arg0: !transform.any_op): - include @foo failures(propagate) (%arg0) : (!transform.any_op) -> () + include @print_message failures(propagate) (%arg0) : (!transform.any_op) -> () include @unannotated failures(propagate) (%arg0) : (!transform.any_op) -> () + include @colliding failures(propagate) (%arg0) : (!transform.any_op) -> () + include @colliding_0 failures(propagate) (%arg0) : (!transform.any_op) -> () + include @colliding_1 failures(propagate) (%arg0) : (!transform.any_op) -> () + include @colliding_3 failures(propagate) (%arg0) : (!transform.any_op) -> () + include @colliding_4 failures(propagate) (%arg0) : (!transform.any_op) -> () + include @colliding_5 failures(propagate) (%arg0) : (!transform.any_op) -> () } } diff --git a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-def-invalid.mlir b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-def-invalid.mlir new file mode 100644 index 0000000000000..1d9ef1dbead63 --- /dev/null +++ b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-def-invalid.mlir @@ -0,0 +1,14 @@ +// RUN: mlir-opt %s + +module attributes {transform.with_named_sequence} { + // expected-note @below {{previously defined here}} + transform.named_sequence @print_message(%arg0: !transform.any_op {transform.readonly}) { + transform.test_print_remark_at_operand %arg0, "message" : !transform.any_op + transform.yield + } + + transform.named_sequence @consuming(%arg0: !transform.any_op {transform.consumed}) { + transform.test_consume_operand %arg0 : !transform.any_op + transform.yield + } +} \ No newline at end of file diff --git a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-def.mlir b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-def.mlir index 1149bda98ab85..66f0f1f62683b 100644 --- a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-def.mlir +++ b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-def.mlir @@ -1,11 +1,42 @@ // RUN: mlir-opt %s module attributes {transform.with_named_sequence} { - transform.named_sequence @foo(%arg0: !transform.any_op {transform.readonly}) { + transform.named_sequence private @private_helper(%arg0: !transform.any_op {transform.readonly}) { transform.test_print_remark_at_operand %arg0, "message" : !transform.any_op transform.yield } + // These ops collide with ops from the other module before or after renaming. + transform.named_sequence private @colliding(%arg0: !transform.any_op {transform.readonly}) { + transform.test_print_remark_at_operand %arg0, "external colliding (without suffix)" : !transform.any_op + transform.yield + } + transform.named_sequence private @colliding_0(%arg0: !transform.any_op {transform.readonly}) { + transform.test_print_remark_at_operand %arg0, "external colliding_0" : !transform.any_op + transform.yield + } + transform.named_sequence private @colliding_2(%arg0: !transform.any_op {transform.readonly}) { + transform.test_print_remark_at_operand %arg0, "external colliding_2" : !transform.any_op + transform.yield + } + transform.named_sequence private @colliding_3(%arg0: !transform.any_op {transform.readonly}) { + transform.test_print_remark_at_operand %arg0, "external colliding_3" : !transform.any_op + transform.yield + } + transform.named_sequence private @colliding_4(%arg0: !transform.any_op {transform.readonly}) { + transform.test_print_remark_at_operand %arg0, "external colliding_4" : !transform.any_op + transform.yield + } + transform.named_sequence @colliding_5(%arg0: !transform.any_op {transform.readonly}) { + transform.test_print_remark_at_operand %arg0, "external colliding_5" : !transform.any_op + transform.yield + } + + transform.named_sequence @print_message(%arg0: !transform.any_op {transform.readonly}) { + transform.include @private_helper failures(propagate) (%arg0) : (!transform.any_op) -> () + transform.yield + } + transform.named_sequence @consuming(%arg0: !transform.any_op {transform.consumed}) { transform.test_consume_operand %arg0 : !transform.any_op transform.yield @@ -15,4 +46,14 @@ module attributes {transform.with_named_sequence} { transform.test_print_remark_at_operand %arg0, "unannotated" : !transform.any_op transform.yield } + + transform.named_sequence @symbol_user(%arg0: !transform.any_op {transform.readonly}) { + transform.include @colliding failures(propagate) (%arg0) : (!transform.any_op) -> () + transform.include @colliding_0 failures(propagate) (%arg0) : (!transform.any_op) -> () + transform.include @colliding_2 failures(propagate) (%arg0) : (!transform.any_op) -> () + transform.include @colliding_3 failures(propagate) (%arg0) : (!transform.any_op) -> () + transform.include @colliding_4 failures(propagate) (%arg0) : (!transform.any_op) -> () + transform.include @colliding_5 failures(propagate) (%arg0) : (!transform.any_op) -> () + transform.yield + } } diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp index f73deef9d5fd4..578d9abe4a56e 100644 --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp @@ -218,9 +218,9 @@ class TestTransformDialectInterpreterPass "select the container of the top-level transform op.")}; Option transformLibraryFileName{ *this, "transform-library-file-name", llvm::cl::init(""), - llvm::cl::desc( - "Optional name of the file containing transform dialect symbol " - "definitions to be injected into the transform module.")}; + llvm::cl::desc("Optional name of a file with a module that should be " + "merged into the transform module to provide the " + "definitions of external named sequences.")}; Option testModuleGeneration{ *this, "test-module-generation", llvm::cl::init(false),