Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][Transforms] Dialect conversion: Align handling of dropped values #106760

Merged
merged 2 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 30 additions & 129 deletions mlir/lib/Transforms/Utils/DialectConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -624,10 +624,9 @@ class ModifyOperationRewrite : public OperationRewrite {
class ReplaceOperationRewrite : public OperationRewrite {
public:
ReplaceOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
Operation *op, const TypeConverter *converter,
bool changedResults)
Operation *op, const TypeConverter *converter)
: OperationRewrite(Kind::ReplaceOperation, rewriterImpl, op),
converter(converter), changedResults(changedResults) {}
converter(converter) {}

static bool classof(const IRRewrite *rewrite) {
return rewrite->getKind() == Kind::ReplaceOperation;
Expand All @@ -641,15 +640,10 @@ class ReplaceOperationRewrite : public OperationRewrite {

const TypeConverter *getConverter() const { return converter; }

bool hasChangedResults() const { return changedResults; }

private:
/// An optional type converter that can be used to materialize conversions
/// between the new and old values if necessary.
const TypeConverter *converter;

/// A boolean flag that indicates whether result types have changed or not.
bool changedResults;
};

class CreateOperationRewrite : public OperationRewrite {
Expand Down Expand Up @@ -941,6 +935,9 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// to modify/access them is invalid rewriter API usage.
SetVector<Operation *> replacedOps;

/// A set of all unresolved materializations.
DenseSet<Operation *> unresolvedMaterializations;

/// The current type converter, or nullptr if no type converter is currently
/// active.
const TypeConverter *currentTypeConverter = nullptr;
Expand Down Expand Up @@ -1066,6 +1063,7 @@ void UnresolvedMaterializationRewrite::rollback() {
for (Value input : op->getOperands())
rewriterImpl.mapping.erase(input);
}
rewriterImpl.unresolvedMaterializations.erase(op);
op->erase();
}

Expand Down Expand Up @@ -1347,6 +1345,7 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
auto convertOp =
builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs);
unresolvedMaterializations.insert(convertOp);
appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind);
return convertOp.getResult(0);
}
Expand Down Expand Up @@ -1379,22 +1378,28 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
assert(newValues.size() == op->getNumResults());
assert(!ignoredOps.contains(op) && "operation was already replaced");

// Track if any of the results changed, e.g. erased and replaced with null.
bool resultChanged = false;

// Create mappings for each of the new result values.
for (auto [newValue, result] : llvm::zip(newValues, op->getResults())) {
if (!newValue) {
resultChanged = true;
continue;
// This result was dropped and no replacement value was provided.
if (unresolvedMaterializations.contains(op)) {
// Do not create another materializations if we are erasing a
// materialization.
continue;
}

// Materialize a replacement value "out of thin air".
newValue = buildUnresolvedMaterialization(
MaterializationKind::Source, computeInsertPoint(result),
result.getLoc(), /*inputs=*/ValueRange(),
/*outputType=*/result.getType(), currentTypeConverter);
}

// Remap, and check for any result type changes.
mapping.map(result, newValue);
resultChanged |= (newValue.getType() != result.getType());
}

appendRewrite<ReplaceOperationRewrite>(op, currentTypeConverter,
resultChanged);
appendRewrite<ReplaceOperationRewrite>(op, currentTypeConverter);

// Mark this operation and all nested ops as replaced.
op->walk([&](Operation *op) { replacedOps.insert(op); });
Expand Down Expand Up @@ -2359,11 +2364,6 @@ struct OperationConverter {
ConversionPatternRewriterImpl &rewriterImpl,
DenseMap<Value, SmallVector<Value>> &inverseMapping);

/// Legalize an operation result that was marked as "erased".
LogicalResult
legalizeErasedResult(Operation *op, OpResult result,
ConversionPatternRewriterImpl &rewriterImpl);

/// Dialect conversion configuration.
ConversionConfig config;

Expand Down Expand Up @@ -2455,77 +2455,6 @@ legalizeUnresolvedMaterialization(RewriterBase &rewriter,
return failure();
}

/// Erase all dead unrealized_conversion_cast ops. An op is dead if its results
/// are not used (transitively) by any op that is not in the given list of
/// cast ops.
///
/// In particular, this function erases cyclic casts that may be inserted
/// during the dialect conversion process. E.g.:
/// %0 = unrealized_conversion_cast(%1)
/// %1 = unrealized_conversion_cast(%0)
// Note: This step will become unnecessary when
// https://github.com/llvm/llvm-project/pull/106760 has been merged.
static void eraseDeadUnrealizedCasts(
ArrayRef<UnrealizedConversionCastOp> castOps,
SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
// Ops that have already been visited or are currently being visited.
DenseSet<Operation *> visited;
// Set of all cast ops for faster lookups.
DenseSet<Operation *> castOpSet;
// Set of all cast ops that have been determined to be alive.
DenseSet<Operation *> live;

for (UnrealizedConversionCastOp op : castOps)
castOpSet.insert(op);

// Visit a cast operation. Return "true" if the operation is live.
std::function<bool(Operation *)> visit = [&](Operation *op) -> bool {
// No need to traverse any IR if the op was already marked as live.
if (live.contains(op))
return true;

// Do not visit ops multiple times. If we find a circle, no live user was
// found on the current path.
if (!visited.insert(op).second)
return false;

// Visit all users.
for (Operation *user : op->getUsers()) {
// If the user is not an unrealized_conversion_cast op, then the given op
// is live.
if (!castOpSet.contains(user)) {
live.insert(op);
return true;
}
// Otherwise, it is live if a live op can be reached from one of its
// users (which must all be unrealized_conversion_cast ops).
if (visit(user)) {
live.insert(op);
return true;
}
}

return false;
};

// Visit all cast ops.
for (UnrealizedConversionCastOp op : castOps) {
visit(op);
visited.clear();
}

// Erase all cast ops that are dead.
for (UnrealizedConversionCastOp op : castOps) {
if (live.contains(op)) {
if (remainingCastOps)
remainingCastOps->push_back(op);
continue;
}
op->dropAllUses();
op->erase();
}
}

LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
if (ops.empty())
return success();
Expand Down Expand Up @@ -2584,14 +2513,13 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
// Reconcile all UnrealizedConversionCastOps that were inserted by the
// dialect conversion frameworks. (Not the one that were inserted by
// patterns.)
SmallVector<UnrealizedConversionCastOp> remainingCastOps1, remainingCastOps2;
eraseDeadUnrealizedCasts(allCastOps, &remainingCastOps1);
reconcileUnrealizedCasts(remainingCastOps1, &remainingCastOps2);
SmallVector<UnrealizedConversionCastOp> remainingCastOps;
reconcileUnrealizedCasts(allCastOps, &remainingCastOps);

// Try to legalize all unresolved materializations.
if (config.buildMaterializations) {
IRRewriter rewriter(rewriterImpl.context, config.listener);
for (UnrealizedConversionCastOp castOp : remainingCastOps2) {
for (UnrealizedConversionCastOp castOp : remainingCastOps) {
auto it = rewriteMap.find(castOp.getOperation());
assert(it != rewriteMap.end() && "inconsistent state");
if (failed(legalizeUnresolvedMaterialization(rewriter, it->second)))
Expand Down Expand Up @@ -2646,30 +2574,22 @@ LogicalResult OperationConverter::legalizeConvertedOpResultTypes(
for (unsigned i = 0; i < rewriterImpl.rewrites.size(); ++i) {
auto *opReplacement =
dyn_cast<ReplaceOperationRewrite>(rewriterImpl.rewrites[i].get());
if (!opReplacement || !opReplacement->hasChangedResults())
if (!opReplacement)
continue;
Operation *op = opReplacement->getOperation();
for (OpResult result : op->getResults()) {
Value newValue = rewriterImpl.mapping.lookupOrNull(result);

// If the operation result was replaced with null, all of the uses of this
// value should be replaced.
if (!newValue) {
if (failed(legalizeErasedResult(op, result, rewriterImpl)))
return failure();
// If the type of this op result changed and the result is still live,
// we need to materialize a conversion.
if (rewriterImpl.mapping.lookupOrNull(result, result.getType()))
continue;
}

// Otherwise, check to see if the type of the result changed.
if (result.getType() == newValue.getType())
continue;

Operation *liveUser =
findLiveUserOfReplaced(result, rewriterImpl, inverseMapping);
if (!liveUser)
continue;

// Legalize this result.
Value newValue = rewriterImpl.mapping.lookupOrNull(result);
assert(newValue && "replacement value not found");
Value castValue = rewriterImpl.buildUnresolvedMaterialization(
MaterializationKind::Source, computeInsertPoint(result), op->getLoc(),
/*inputs=*/newValue, /*outputType=*/result.getType(),
Expand Down Expand Up @@ -2727,25 +2647,6 @@ LogicalResult OperationConverter::legalizeConvertedArgumentTypes(
return success();
}

LogicalResult OperationConverter::legalizeErasedResult(
Operation *op, OpResult result,
ConversionPatternRewriterImpl &rewriterImpl) {
// If the operation result was replaced with null, all of the uses of this
// value should be replaced.
auto liveUserIt = llvm::find_if_not(result.getUsers(), [&](Operation *user) {
return rewriterImpl.isOpIgnored(user);
});
if (liveUserIt != result.user_end()) {
InFlightDiagnostic diag = op->emitError("failed to legalize operation '")
<< op->getName() << "' marked as erased";
diag.attachNote(liveUserIt->getLoc())
<< "found live user of result #" << result.getResultNumber() << ": "
<< *liveUserIt;
return failure();
}
return success();
}

//===----------------------------------------------------------------------===//
// Reconcile Unrealized Casts
//===----------------------------------------------------------------------===//
Expand Down
4 changes: 2 additions & 2 deletions mlir/test/Transforms/test-legalize-erased-op-with-uses.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
// Test that an error is emitted when an operation is marked as "erased", but
// has users that live across the conversion.
func.func @remove_all_ops(%arg0: i32) -> i32 {
// expected-error@below {{failed to legalize operation 'test.illegal_op_a' marked as erased}}
// expected-error@below {{failed to legalize unresolved materialization from () to 'i32' that remained live after conversion}}
%0 = "test.illegal_op_a"() : () -> i32
// expected-note@below {{found live user of result #0: func.return %0 : i32}}
// expected-note@below {{see existing live user here}}
return %0 : i32
}
Loading