Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][transform] Allow passing various library files to interpreter. #67120

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ namespace detail {
/// Template-free implementation of TransformInterpreterPassBase::initialize.
LogicalResult interpreterBaseInitializeImpl(
MLIRContext *context, StringRef transformFileName,
StringRef transformLibraryFileName,
ArrayRef<std::string> transformLibraryPaths,
std::shared_ptr<OwningOpRef<ModuleOp>> &module,
std::shared_ptr<OwningOpRef<ModuleOp>> &libraryModule,
function_ref<std::optional<LogicalResult>(OpBuilder &, Location)>
Expand All @@ -48,7 +48,7 @@ LogicalResult interpreterBaseRunOnOperationImpl(
const RaggedArray<MappedValue> &extraMappings,
const TransformOptions &options,
const Pass::Option<std::string> &transformFileName,
const Pass::Option<std::string> &transformLibraryFileName,
const Pass::ListOption<std::string> &transformLibraryPaths,
const Pass::Option<std::string> &debugPayloadRootTag,
const Pass::Option<std::string> &debugTransformRootTag,
StringRef binaryName);
Expand All @@ -62,11 +62,12 @@ LogicalResult interpreterBaseRunOnOperationImpl(
/// transform script. If empty, `debugTransformRootTag` is considered or the
/// pass root operation must contain a single top-level transform op that
/// will be interpreted.
/// - transformLibraryFileName: if non-empty, the module in this file will be
/// - transformLibraryPaths: if non-empty, the modules in these files will be
/// merged into the main transform script run by the interpreter before
/// execution. This allows to provide definitions for external functions
/// used in the main script. Other public symbols in the library module may
/// lead to collisions with public symbols in the main script.
/// used in the main script. Other public symbols in the library modules may
/// lead to collisions with public symbols in the main script and among each
/// other.
/// - debugPayloadRootTag: if non-empty, the value of the attribute named
/// `kTransformDialectTagAttrName` indicating the single op that is
/// considered the payload root of the transform interpreter; otherwise, the
Expand Down Expand Up @@ -118,16 +119,26 @@ class TransformInterpreterPassBase : public GeneratedBase<Concrete> {
REQUIRE_PASS_OPTION(transformFileName);
REQUIRE_PASS_OPTION(debugPayloadRootTag);
REQUIRE_PASS_OPTION(debugTransformRootTag);
REQUIRE_PASS_OPTION(transformLibraryFileName);

#undef REQUIRE_PASS_OPTION

#define REQUIRE_PASS_LIST_OPTION(NAME) \
static_assert( \
std::is_same_v< \
std::remove_reference_t<decltype(std::declval<Concrete &>().NAME)>, \
Pass::ListOption<std::string>>, \
"required " #NAME " string pass option is missing")

REQUIRE_PASS_LIST_OPTION(transformLibraryPaths);

#undef REQUIRE_PASS_LIST_OPTION

StringRef transformFileName =
static_cast<Concrete *>(this)->transformFileName;
StringRef transformLibraryFileName =
static_cast<Concrete *>(this)->transformLibraryFileName;
ArrayRef<std::string> transformLibraryPaths =
static_cast<Concrete *>(this)->transformLibraryPaths;
return detail::interpreterBaseInitializeImpl(
context, transformFileName, transformLibraryFileName,
context, transformFileName, transformLibraryPaths,
sharedTransformModule, transformLibraryModule,
[this](OpBuilder &builder, Location loc) {
return static_cast<Concrete *>(this)->constructTransformModule(
Expand Down Expand Up @@ -162,7 +173,7 @@ class TransformInterpreterPassBase : public GeneratedBase<Concrete> {
op, pass->getArgument(), sharedTransformModule,
transformLibraryModule,
/*extraMappings=*/{}, options, pass->transformFileName,
pass->transformLibraryFileName, pass->debugPayloadRootTag,
pass->transformLibraryPaths, pass->debugPayloadRootTag,
pass->debugTransformRootTag, binaryName)) ||
failed(pass->runAfterInterpreter(op))) {
return pass->signalPassFailure();
Expand Down
157 changes: 126 additions & 31 deletions mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/Dialect/Transform/IR/TransformOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Verifier.h"
#include "mlir/IR/Visitors.h"
Expand Down Expand Up @@ -194,7 +195,7 @@ saveReproToTempFile(llvm::raw_ostream &os, Operation *target,
Operation *transform, StringRef passName,
const Pass::Option<std::string> &debugPayloadRootTag,
const Pass::Option<std::string> &debugTransformRootTag,
const Pass::Option<std::string> &transformLibraryFileName,
const Pass::ListOption<std::string> &transformLibraryPaths,
StringRef binaryName) {
using llvm::sys::fs::TempFile;
Operation *root = getRootOperation(target);
Expand Down Expand Up @@ -231,7 +232,7 @@ static void performOptionalDebugActions(
Operation *target, Operation *transform, StringRef passName,
const Pass::Option<std::string> &debugPayloadRootTag,
const Pass::Option<std::string> &debugTransformRootTag,
const Pass::Option<std::string> &transformLibraryFileName,
const Pass::ListOption<std::string> &transformLibraryPaths,
StringRef binaryName) {
MLIRContext *context = target->getContext();

Expand Down Expand Up @@ -284,7 +285,7 @@ static void performOptionalDebugActions(
DEBUG_WITH_TYPE(DEBUG_TYPE_DUMP_FILE, {
saveReproToTempFile(llvm::dbgs(), target, transform, passName,
debugPayloadRootTag, debugTransformRootTag,
transformLibraryFileName, binaryName);
transformLibraryPaths, binaryName);
});

// Remove temporary attributes if they were set.
Expand Down Expand Up @@ -534,7 +535,7 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
const RaggedArray<MappedValue> &extraMappings,
const TransformOptions &options,
const Pass::Option<std::string> &transformFileName,
const Pass::Option<std::string> &transformLibraryFileName,
const Pass::ListOption<std::string> &transformLibraryPaths,
const Pass::Option<std::string> &debugPayloadRootTag,
const Pass::Option<std::string> &debugTransformRootTag,
StringRef binaryName) {
Expand Down Expand Up @@ -597,7 +598,8 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
if (failed(
mergeSymbolsInto(SymbolTable::getNearestSymbolTable(transformRoot),
transformLibraryModule->get()->clone())))
return failure();
return emitError(transformRoot->getLoc(),
"failed to merge library symbols into transform root");
}

// Step 4
Expand All @@ -606,7 +608,7 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
// repro to stderr and/or a file.
performOptionalDebugActions(target, transformRoot, passName,
debugPayloadRootTag, debugTransformRootTag,
transformLibraryFileName, binaryName);
transformLibraryPaths, binaryName);

// Step 5
// ------
Expand All @@ -615,55 +617,148 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
extraMappings, options);
}

/// Expands the given list of `paths` to a list of `.mlir` files.
///
/// Each entry in `paths` may either be a regular file, in which case it ends up
/// in the result list, or a directory, in which case all (regular) `.mlir`
/// files in that directory are added. Any other file types lead to a failure.
static LogicalResult
expandPathsToMLIRFiles(ArrayRef<std::string> &paths, MLIRContext *const context,
SmallVectorImpl<std::string> &fileNames) {
for (const std::string &path : paths) {
auto loc = FileLineColLoc::get(context, path, 0, 0);

if (llvm::sys::fs::is_regular_file(path)) {
LLVM_DEBUG(DBGS() << "Adding '" << path << "' to list of files\n");
fileNames.push_back(path);
continue;
}

if (!llvm::sys::fs::is_directory(path)) {
return emitError(loc)
<< "'" << path << "' is neither a file nor a directory";
}

LLVM_DEBUG(DBGS() << "Looking for files in '" << path << "':\n");

std::error_code ec;
for (llvm::sys::fs::directory_iterator it(path, ec), itEnd;
it != itEnd && !ec; it.increment(ec)) {
const std::string &fileName = it->path();

if (it->type() != llvm::sys::fs::file_type::regular_file) {
LLVM_DEBUG(DBGS() << " Skipping non-regular file '" << fileName
<< "'\n");
continue;
}

if (!StringRef(fileName).endswith(".mlir")) {
LLVM_DEBUG(DBGS() << " Skipping '" << fileName
<< "' because it does not end with '.mlir'\n");
continue;
}

LLVM_DEBUG(DBGS() << " Adding '" << fileName << "' to list of files\n");
fileNames.push_back(fileName);
}

if (ec)
return emitError(loc) << "error while opening files in '" << path
<< "': " << ec.message();
}

return success();
}

LogicalResult transform::detail::interpreterBaseInitializeImpl(
MLIRContext *context, StringRef transformFileName,
StringRef transformLibraryFileName,
ArrayRef<std::string> transformLibraryPaths,
std::shared_ptr<OwningOpRef<ModuleOp>> &sharedTransformModule,
std::shared_ptr<OwningOpRef<ModuleOp>> &transformLibraryModule,
function_ref<std::optional<LogicalResult>(OpBuilder &, Location)>
moduleBuilder) {
OwningOpRef<ModuleOp> parsedTransformModule;
if (failed(parseTransformModuleFromFile(context, transformFileName,
parsedTransformModule)))
return failure();
if (parsedTransformModule && failed(mlir::verify(*parsedTransformModule)))
return failure();
auto unknownLoc = UnknownLoc::get(context);

OwningOpRef<ModuleOp> parsedLibraryModule;
if (failed(parseTransformModuleFromFile(context, transformLibraryFileName,
parsedLibraryModule)))
return failure();
if (parsedLibraryModule && failed(mlir::verify(*parsedLibraryModule)))
// Parse module from file.
OwningOpRef<ModuleOp> moduleFromFile;
{
auto loc = FileLineColLoc::get(context, transformFileName, 0, 0);
if (failed(parseTransformModuleFromFile(context, transformFileName,
moduleFromFile)))
return emitError(loc) << "failed to parse transform module";
if (moduleFromFile && failed(mlir::verify(*moduleFromFile)))
return emitError(loc) << "failed to verify transform module";
}

// Assemble list of library files.
SmallVector<std::string> libraryFileNames;
if (failed(expandPathsToMLIRFiles(transformLibraryPaths, context,
libraryFileNames)))
return failure();

if (parsedTransformModule) {
sharedTransformModule = std::make_shared<OwningOpRef<ModuleOp>>(
std::move(parsedTransformModule));
// Parse modules from library files.
SmallVector<OwningOpRef<ModuleOp>> parsedLibraries;
for (const std::string &libraryFileName : libraryFileNames) {
OwningOpRef<ModuleOp> parsedLibrary;
auto loc = FileLineColLoc::get(context, libraryFileName, 0, 0);
if (failed(parseTransformModuleFromFile(context, libraryFileName,
parsedLibrary)))
return emitError(loc) << "failed to parse transform library module";
if (parsedLibrary && failed(mlir::verify(*parsedLibrary)))
return emitError(loc) << "failed to verify transform library module";
parsedLibraries.push_back(std::move(parsedLibrary));
}

// Build shared transform module.
if (moduleFromFile) {
sharedTransformModule =
std::make_shared<OwningOpRef<ModuleOp>>(std::move(moduleFromFile));
} else if (moduleBuilder) {
// TODO: better location story.
auto location = UnknownLoc::get(context);
auto loc = FileLineColLoc::get(context, "<shared-transform-module>", 0, 0);
auto localModule = std::make_shared<OwningOpRef<ModuleOp>>(
ModuleOp::create(location, "__transform"));
ModuleOp::create(unknownLoc, "__transform"));

OpBuilder b(context);
b.setInsertionPointToEnd(localModule->get().getBody());
if (std::optional<LogicalResult> result = moduleBuilder(b, location)) {
if (std::optional<LogicalResult> result = moduleBuilder(b, loc)) {
if (failed(*result))
return failure();
return (*localModule)->emitError()
<< "failed to create shared transform module";
sharedTransformModule = std::move(localModule);
}
}

if (!parsedLibraryModule || !*parsedLibraryModule)
if (parsedLibraries.empty())
return success();

// Merge parsed libraries into one module.
auto loc = FileLineColLoc::get(context, "<shared-library-module>", 0, 0);
OwningOpRef<ModuleOp> mergedParsedLibraries =
ModuleOp::create(loc, "__transform");
{
mergedParsedLibraries.get()->setAttr("transform.with_named_sequence",
UnitAttr::get(context));
IRRewriter rewriter(context);
// TODO: extend `mergeSymbolsInto` to support multiple `other` modules.
for (OwningOpRef<ModuleOp> &parsedLibrary : parsedLibraries) {
if (failed(mergeSymbolsInto(mergedParsedLibraries.get(),
std::move(parsedLibrary))))
return mergedParsedLibraries->emitError()
<< "failed to verify merged transform module";
}
}

// Use parsed libaries to resolve symbols in shared transform module or return
// as separate library module.
if (sharedTransformModule && *sharedTransformModule) {
if (failed(mergeSymbolsInto(sharedTransformModule->get(),
std::move(parsedLibraryModule))))
return failure();
std::move(mergedParsedLibraries))))
return (*sharedTransformModule)->emitError()
<< "failed to merge symbols from library files "
"into shared transform module";
} else {
transformLibraryModule =
std::make_shared<OwningOpRef<ModuleOp>>(std::move(parsedLibraryModule));
transformLibraryModule = std::make_shared<OwningOpRef<ModuleOp>>(
std::move(mergedParsedLibraries));
}
return success();
}
2 changes: 1 addition & 1 deletion mlir/test/Dialect/LLVM/lower-to-llvm-e2e.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

// RUN: mlir-opt %s -test-lower-to-llvm -cse | FileCheck %s

// RUN: mlir-opt %s -test-transform-dialect-interpreter="transform-library-file-name=%p/lower-to-llvm-transform-symbol-def.mlir debug-payload-root-tag=payload" \
// RUN: mlir-opt %s -test-transform-dialect-interpreter="transform-library-paths=%p/lower-to-llvm-transform-symbol-def.mlir debug-payload-root-tag=payload" \
// RUN: -test-transform-dialect-erase-schedule -cse \
// RUN: | FileCheck %s

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-file-name=%p/test-interpreter-external-symbol-decl.mlir transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir})" \
// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-file-name=%p/test-interpreter-external-symbol-decl.mlir transform-library-paths=%p/test-interpreter-library/definitions-self-contained.mlir})" \
// RUN: --verify-diagnostics

// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-file-name=%p/test-interpreter-external-symbol-decl.mlir transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir}, test-transform-dialect-interpreter{transform-file-name=%p/test-interpreter-external-symbol-decl.mlir transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir})" \
// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-file-name=%p/test-interpreter-external-symbol-decl.mlir transform-library-paths=%p/test-interpreter-library/definitions-self-contained.mlir}, test-transform-dialect-interpreter{transform-file-name=%p/test-interpreter-external-symbol-decl.mlir transform-library-paths=%p/test-interpreter-library/definitions-self-contained.mlir})" \
// RUN: --verify-diagnostics

// The external transform script has a declaration to the named sequence @foo,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-paths=%p%{fs-sep}test-interpreter-library})" \
// RUN: --verify-diagnostics --split-input-file | FileCheck %s

// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-paths=%p%{fs-sep}test-interpreter-library/definitions-self-contained.mlir,%p%{fs-sep}test-interpreter-library/definitions-with-unresolved.mlir})" \
// RUN: --verify-diagnostics --split-input-file | FileCheck %s

// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-paths=%p%{fs-sep}test-interpreter-library}, test-transform-dialect-interpreter)" \
// RUN: --verify-diagnostics --split-input-file | FileCheck %s

// The definition of the @foo named sequence is provided in another file. It
// will be included because of the pass option. Repeated application of the
// same pass, with or without the library option, should not be a problem.
// Note that the same diagnostic produced twice at the same location only
// needs to be matched once.

// expected-remark @below {{message}}
module attributes {transform.with_named_sequence} {
// CHECK: transform.named_sequence @print_message
transform.named_sequence @print_message(%arg0: !transform.any_op {transform.readonly})

transform.named_sequence @reference_other_module(!transform.any_op {transform.readonly})

transform.sequence failures(propagate) {
^bb0(%arg0: !transform.any_op):
include @print_message failures(propagate) (%arg0) : (!transform.any_op) -> ()
include @reference_other_module failures(propagate) (%arg0) : (!transform.any_op) -> ()
}
}
Loading