Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][Transform] Provide a minimal set of utils that allow implementing a simple transform dialect interpreter pass #68330

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 27 additions & 2 deletions mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,14 @@ def Transform_Dialect : Dialect {

let hasOperationAttrVerify = 1;
let extraClassDeclaration = [{
/// Symbol name for the default entry point "named sequence".
constexpr const static ::llvm::StringLiteral
kTransformEntryPointSymbolName = "__transform_main";

/// Name of the attribute attachable to the symbol table operation
/// containing named sequences. This is used to trigger verification.
constexpr const static ::llvm::StringLiteral kWithNamedSequenceAttrName =
"transform.with_named_sequence";
constexpr const static ::llvm::StringLiteral
kWithNamedSequenceAttrName = "transform.with_named_sequence";

/// Name of the attribute attachable to an operation so it can be
/// identified as root by the default interpreter pass.
Expand Down Expand Up @@ -74,6 +78,22 @@ def Transform_Dialect : Dialect {
using ExtensionTypePrintingHook =
std::function<void(::mlir::Type, ::mlir::AsmPrinter &)>;

/// Appends the given module as a transform symbol library available to
/// all dialect users.
void registerLibraryModule(::mlir::OwningOpRef<::mlir::ModuleOp> &&
library) {
libraryModules.push_back(std::move(library));
}

/// Returns a range of registered library modules.
auto getLibraryModules() const {
return ::llvm::map_range(
libraryModules,
[](const ::mlir::OwningOpRef<::mlir::ModuleOp> &library) {
return library.get();
});
}

private:
/// Registers operations specified as template parameters with this
/// dialect. Checks that they implement the required interfaces.
Expand Down Expand Up @@ -132,6 +152,11 @@ def Transform_Dialect : Dialect {
/// lookups when the type is fully constructed.
::llvm::DenseMap<::mlir::TypeID, ExtensionTypePrintingHook>
typePrintingHooks;

/// Modules containing symbols, e.g. named sequences, that will be
/// resolved by the interpreter when used.
::llvm::SmallVector<::mlir::OwningOpRef<::mlir::ModuleOp>, 2>
libraryModules;
}];
}

Expand Down
5 changes: 3 additions & 2 deletions mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ class TransformOptions {
LogicalResult
applyTransforms(Operation *payloadRoot, TransformOpInterface transform,
const RaggedArray<MappedValue> &extraMapping = {},
const TransformOptions &options = TransformOptions());
const TransformOptions &options = TransformOptions(),
bool enforceToplevelTransformOp = true);

/// The state maintained across applications of various ops implementing the
/// TransformOpInterface. The operations implementing this interface and the
Expand Down Expand Up @@ -193,7 +194,7 @@ class TransformState {

friend LogicalResult applyTransforms(Operation *, TransformOpInterface,
const RaggedArray<MappedValue> &,
const TransformOptions &);
const TransformOptions &, bool);

friend TransformState
detail::makeTransformStateForTesting(Region *region, Operation *payloadRoot);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
//===- TransformInterpreterUtils.h - Transform Utils ------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_TRANSFORM_TRANSFORMS_TRANSFORMINTERPRETERUTILS_H
#define MLIR_DIALECT_TRANSFORM_TRANSFORMS_TRANSFORMINTERPRETERUTILS_H

#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include <memory>

namespace mlir {
struct LogicalResult;
class MLIRContext;
class ModuleOp;
class Operation;
template <typename>
class OwningOpRef;
class Region;

namespace transform {
namespace detail {
/// Utility to parse and verify the content of a `transformFileName` MLIR file
/// containing a transform dialect specification.
LogicalResult
parseTransformModuleFromFile(MLIRContext *context,
llvm::StringRef transformFileName,
OwningOpRef<ModuleOp> &transformModule);

/// Utility to load a transform interpreter `module` from a module that has
/// already been preloaded in the context.
/// This mode is useful in cases where explicit parsing of a transform library
/// from file is expected to be prohibitively expensive.
/// In such cases, the transform module is expected to be found in the preloaded
/// library modules of the transform dialect.
/// Returns null if the module is not found.
ModuleOp getPreloadedTransformModule(MLIRContext *context);

/// Finds the first TransformOpInterface named `kTransformEntryPointSymbolName`
/// that is either:
/// 1. nested under `root` (takes precedence).
/// 2. nested under `module`, if not found in `root`.
/// Reports errors and returns null if no such operation found.
TransformOpInterface findTransformEntryPoint(
Operation *root, ModuleOp module,
StringRef entryPoint = TransformDialect::kTransformEntryPointSymbolName);

/// Merge all symbols from `other` into `target`. Both ops need to implement the
/// `SymbolTable` trait. Operations are moved from `other`, i.e., `other` may be
/// modified by this function and might not verify after the function returns.
/// Upon merging, private symbols may be renamed in order to avoid collisions in
/// the result. Public symbols may not collide, with the exception of
/// instances of `SymbolOpInterface`, where collisions are allowed if at least
/// one of the two is external, in which case the other op preserved (or any one
/// of the two if both are external).
// TODO: Reconsider cloning individual ops rather than forcing users of the
// function to clone (or move) `other` in order to improve efficiency.
// This might primarily make sense if we can also prune the symbols that
// are merged to a subset (such as those that are actually used).
LogicalResult mergeSymbolsInto(Operation *target,
OwningOpRef<Operation *> other);
} // namespace detail

/// Standalone util to apply the named sequence `entryPoint` to the payload.
/// This is done in 3 steps:
/// 1. lookup the `entryPoint` symbol in `{payload, sharedTransformModule}` by
/// calling detail::findTransformEntryPoint.
/// 2. if the entry point is found and not nested under
/// `sharedTransformModule`, call `detail::defineDeclaredSymbols` to "link" in
/// the `sharedTransformModule`. Note: this may modify the transform IR
/// embedded with the payload IR.
/// 3. apply the transform IR to the payload IR, relaxing the requirement that
/// the transform IR is a top-level transform op. We are applying a named
/// sequence anyway.
LogicalResult applyTransformNamedSequence(
Operation *payload, ModuleOp transformModule,
const TransformOptions &options,
StringRef entryPoint = TransformDialect::kTransformEntryPointSymbolName);

} // namespace transform
} // namespace mlir

#endif // MLIR_DIALECT_TRANSFORM_TRANSFORMS_TRANSFORMINTERPRETERUTILS_H
26 changes: 13 additions & 13 deletions mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2079,20 +2079,20 @@ LogicalResult transform::detail::verifyTransformOpInterface(Operation *op) {
// Entry point.
//===----------------------------------------------------------------------===//

LogicalResult
transform::applyTransforms(Operation *payloadRoot,
TransformOpInterface transform,
const RaggedArray<MappedValue> &extraMapping,
const TransformOptions &options) {
#ifndef NDEBUG
if (!transform->hasTrait<PossibleTopLevelTransformOpTrait>() ||
transform->getNumOperands() != 0) {
transform->emitError()
<< "expected transform to start at the top-level transform op";
llvm::report_fatal_error("could not run transforms",
/*gen_crash_diag=*/false);
LogicalResult transform::applyTransforms(
Operation *payloadRoot, TransformOpInterface transform,
const RaggedArray<MappedValue> &extraMapping,
const TransformOptions &options, bool enforceToplevelTransformOp) {
if (enforceToplevelTransformOp) {
nicolasvasilache marked this conversation as resolved.
Show resolved Hide resolved
if (!transform->hasTrait<PossibleTopLevelTransformOpTrait>() ||
transform->getNumOperands() != 0) {
return transform->emitError()
<< "expected transform to start at the top-level transform op";
}
} else if (failed(
detail::verifyPossibleTopLevelTransformOpTrait(transform))) {
return failure();
}
#endif // NDEBUG

TransformState state(transform->getParentRegion(), payloadRoot, extraMapping,
options);
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Transform/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ add_mlir_dialect_library(MLIRTransformDialectTransforms
CheckUses.cpp
InferEffects.cpp
TransformInterpreterPassBase.cpp
TransformInterpreterUtils.cpp

DEPENDS
MLIRTransformDialectTransformsIncGen
Expand Down
Loading