Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 20 additions & 92 deletions mlir/lib/Transforms/Utils/DialectConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1105,10 +1105,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// A set of operations that were modified by the current pattern.
SetVector<Operation *> patternModifiedOps;

/// A set of blocks that were inserted (newly-created blocks or moved blocks)
/// by the current pattern.
SetVector<Block *> patternInsertedBlocks;

/// A list of unresolved materializations that were created by the current
/// pattern.
DenseSet<UnrealizedConversionCastOp> patternMaterializations;
Expand Down Expand Up @@ -2046,8 +2042,6 @@ void ConversionPatternRewriterImpl::notifyBlockInserted(
if (!config.allowPatternRollback && config.listener)
config.listener->notifyBlockInserted(block, previous, previousIt);

patternInsertedBlocks.insert(block);

if (wasDetached) {
// If the block was detached, it is most likely a newly created block.
if (config.allowPatternRollback) {
Expand Down Expand Up @@ -2399,17 +2393,12 @@ class OperationLegalizer {
bool canApplyPattern(Operation *op, const Pattern &pattern);

/// Legalize the resultant IR after successfully applying the given pattern.
LogicalResult legalizePatternResult(Operation *op, const Pattern &pattern,
const RewriterState &curState,
const SetVector<Operation *> &newOps,
const SetVector<Operation *> &modifiedOps,
const SetVector<Block *> &insertedBlocks);

/// Legalizes the actions registered during the execution of a pattern.
LogicalResult
legalizePatternBlockRewrites(Operation *op,
const SetVector<Block *> &insertedBlocks,
const SetVector<Operation *> &newOps);
legalizePatternResult(Operation *op, const Pattern &pattern,
const RewriterState &curState,
const SetVector<Operation *> &newOps,
const SetVector<Operation *> &modifiedOps);

LogicalResult
legalizePatternCreatedOperations(const SetVector<Operation *> &newOps);
LogicalResult
Expand Down Expand Up @@ -2608,7 +2597,6 @@ LogicalResult OperationLegalizer::legalizeWithFold(Operation *op) {
auto cleanup = llvm::make_scope_exit([&]() {
rewriterImpl.patternNewOps.clear();
rewriterImpl.patternModifiedOps.clear();
rewriterImpl.patternInsertedBlocks.clear();
});

// Upon failure, undo all changes made by the folder.
Expand Down Expand Up @@ -2662,24 +2650,16 @@ LogicalResult OperationLegalizer::legalizeWithFold(Operation *op) {
static void
reportNewIrLegalizationFatalError(const Pattern &pattern,
const SetVector<Operation *> &newOps,
const SetVector<Operation *> &modifiedOps,
const SetVector<Block *> &insertedBlocks) {
const SetVector<Operation *> &modifiedOps) {
auto newOpNames = llvm::map_range(
newOps, [](Operation *op) { return op->getName().getStringRef(); });
auto modifiedOpNames = llvm::map_range(
modifiedOps, [](Operation *op) { return op->getName().getStringRef(); });
StringRef detachedBlockStr = "(detached block)";
auto insertedBlockNames = llvm::map_range(insertedBlocks, [&](Block *block) {
if (block->getParentOp())
return block->getParentOp()->getName().getStringRef();
return detachedBlockStr;
});
llvm::report_fatal_error(
"pattern '" + pattern.getDebugName() +
"' produced IR that could not be legalized. " + "new ops: {" +
llvm::join(newOpNames, ", ") + "}, " + "modified ops: {" +
llvm::join(modifiedOpNames, ", ") + "}, " + "inserted block into ops: {" +
llvm::join(insertedBlockNames, ", ") + "}");
llvm::report_fatal_error("pattern '" + pattern.getDebugName() +
"' produced IR that could not be legalized. " +
"new ops: {" + llvm::join(newOpNames, ", ") + "}, " +
"modified ops: {" +
llvm::join(modifiedOpNames, ", ") + "}");
}

LogicalResult OperationLegalizer::legalizeWithPattern(Operation *op) {
Expand Down Expand Up @@ -2743,7 +2723,6 @@ LogicalResult OperationLegalizer::legalizeWithPattern(Operation *op) {
}
rewriterImpl.patternNewOps.clear();
rewriterImpl.patternModifiedOps.clear();
rewriterImpl.patternInsertedBlocks.clear();
LLVM_DEBUG({
logFailure(rewriterImpl.logger, "pattern failed to match");
if (rewriterImpl.config.notifyCallback) {
Expand Down Expand Up @@ -2777,15 +2756,12 @@ LogicalResult OperationLegalizer::legalizeWithPattern(Operation *op) {
SetVector<Operation *> newOps = moveAndReset(rewriterImpl.patternNewOps);
SetVector<Operation *> modifiedOps =
moveAndReset(rewriterImpl.patternModifiedOps);
SetVector<Block *> insertedBlocks =
moveAndReset(rewriterImpl.patternInsertedBlocks);
auto result = legalizePatternResult(op, pattern, curState, newOps,
modifiedOps, insertedBlocks);
auto result =
legalizePatternResult(op, pattern, curState, newOps, modifiedOps);
appliedPatterns.erase(&pattern);
if (failed(result)) {
if (!rewriterImpl.config.allowPatternRollback)
reportNewIrLegalizationFatalError(pattern, newOps, modifiedOps,
insertedBlocks);
reportNewIrLegalizationFatalError(pattern, newOps, modifiedOps);
rewriterImpl.resetState(curState, pattern.getDebugName());
}
if (config.listener)
Expand Down Expand Up @@ -2823,8 +2799,7 @@ bool OperationLegalizer::canApplyPattern(Operation *op,
LogicalResult OperationLegalizer::legalizePatternResult(
Operation *op, const Pattern &pattern, const RewriterState &curState,
const SetVector<Operation *> &newOps,
const SetVector<Operation *> &modifiedOps,
const SetVector<Block *> &insertedBlocks) {
const SetVector<Operation *> &modifiedOps) {
[[maybe_unused]] auto &impl = rewriter.getImpl();
assert(impl.pendingRootUpdates.empty() && "dangling root updates");

Expand All @@ -2843,8 +2818,7 @@ LogicalResult OperationLegalizer::legalizePatternResult(
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS

// Legalize each of the actions registered during application.
if (failed(legalizePatternBlockRewrites(op, insertedBlocks, newOps)) ||
failed(legalizePatternRootUpdates(modifiedOps)) ||
if (failed(legalizePatternRootUpdates(modifiedOps)) ||
failed(legalizePatternCreatedOperations(newOps))) {
return failure();
}
Expand All @@ -2853,53 +2827,6 @@ LogicalResult OperationLegalizer::legalizePatternResult(
return success();
}

LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
Operation *op, const SetVector<Block *> &insertedBlocks,
const SetVector<Operation *> &newOps) {
ConversionPatternRewriterImpl &impl = rewriter.getImpl();
SmallPtrSet<Operation *, 16> alreadyLegalized;

// If the pattern moved or created any blocks, make sure the types of block
// arguments get legalized.
for (Block *block : insertedBlocks) {
if (impl.erasedBlocks.contains(block))
continue;

// Only check blocks outside of the current operation.
Operation *parentOp = block->getParentOp();
if (!parentOp || parentOp == op || block->getNumArguments() == 0)
continue;

// If the region of the block has a type converter, try to convert the block
// directly.
if (auto *converter = impl.regionToConverter.lookup(block->getParent())) {
std::optional<TypeConverter::SignatureConversion> conversion =
converter->convertBlockSignature(block);
if (!conversion) {
LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved "
"block"));
return failure();
}
impl.applySignatureConversion(block, converter, *conversion);
continue;
}

// Otherwise, try to legalize the parent operation if it was not generated
// by this pattern. This is because we will attempt to legalize the parent
// operation, and blocks in regions created by this pattern will already be
// legalized later on.
if (!newOps.count(parentOp) && alreadyLegalized.insert(parentOp).second) {
if (failed(legalize(parentOp))) {
LLVM_DEBUG(logFailure(
impl.logger, "operation '{0}'({1}) became illegal after rewrite",
parentOp->getName(), parentOp));
return failure();
}
}
}
return success();
}

LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
const SetVector<Operation *> &newOps) {
for (Operation *op : newOps) {
Expand Down Expand Up @@ -3800,10 +3727,11 @@ static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp,
TypeConverter::SignatureConversion result(type.getNumInputs());
SmallVector<Type, 1> newResults;
if (failed(typeConverter.convertSignatureArgs(type.getInputs(), result)) ||
failed(typeConverter.convertTypes(type.getResults(), newResults)) ||
failed(rewriter.convertRegionTypes(&funcOp.getFunctionBody(),
typeConverter, &result)))
failed(typeConverter.convertTypes(type.getResults(), newResults)))
return failure();
if (!funcOp.getFunctionBody().empty())
rewriter.applySignatureConversion(&funcOp.getFunctionBody().front(), result,
&typeConverter);

// Update the function signature in-place.
auto newType = FunctionType::get(rewriter.getContext(),
Expand Down
67 changes: 67 additions & 0 deletions mlir/test/Transforms/test-legalizer-no-materializations.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-legalize-patterns="allow-pattern-rollback=0 build-materializations=0 attach-debug-materialization-kind=1" -verify-diagnostics %s | FileCheck %s --check-prefix=CHECK-KIND
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Splitting up the legalization test looks unrelated, do we have to do it as a part of this change?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The last two test cases that I moved produce different IR and error messages when running with build-materializations=0. That's because the block signatures are no longer getting converted. (And the blocks are dead, so they are not converted by the CF structural type conversion pattern.)


// CHECK-LABEL: func @dropped_input_in_use
// CHECK-KIND-LABEL: func @dropped_input_in_use
func.func @dropped_input_in_use(%arg: i16, %arg2: i64) {
// CHECK-NEXT: %[[cast:.*]] = "test.cast"() : () -> i16
// CHECK-NEXT: "work"(%[[cast]]) : (i16)
// CHECK-KIND-NEXT: %[[cast:.*]] = builtin.unrealized_conversion_cast to i16 {__kind__ = "source"}
// CHECK-KIND-NEXT: "work"(%[[cast]]) : (i16)
// expected-remark@+1 {{op 'work' is not legalizable}}
"work"(%arg) : (i16) -> ()
}

// -----

// CHECK-KIND-LABEL: func @test_lookup_without_converter
// CHECK-KIND: %[[producer:.*]] = "test.valid_producer"() : () -> i16
// CHECK-KIND: %[[cast:.*]] = builtin.unrealized_conversion_cast %[[producer]] : i16 to f64 {__kind__ = "target"}
// CHECK-KIND: "test.valid_consumer"(%[[cast]]) : (f64) -> ()
// CHECK-KIND: "test.valid_consumer"(%[[producer]]) : (i16) -> ()
func.func @test_lookup_without_converter() {
%0 = "test.replace_with_valid_producer"() {type = i16} : () -> (i64)
"test.replace_with_valid_consumer"(%0) {with_converter} : (i64) -> ()
// Make sure that the second "replace_with_valid_consumer" lowering does not
// lookup the materialization that was created for the above op.
"test.replace_with_valid_consumer"(%0) : (i64) -> ()
// expected-remark@+1 {{op 'func.return' is not legalizable}}
return
}

// -----

// CHECK-LABEL: func @remap_moved_region_args
func.func @remap_moved_region_args() {
// CHECK-NEXT: return
// CHECK-NEXT: ^bb1(%[[arg0:.*]]: i64, %[[arg1:.*]]: i16, %[[arg2:.*]]: i64, %[[arg3:.*]]: f32):
// CHECK-NEXT: %[[cast1:.*]]:2 = builtin.unrealized_conversion_cast %[[arg3]] : f32 to f16, f16
// CHECK-NEXT: %[[cast2:.*]] = builtin.unrealized_conversion_cast %[[arg2]] : i64 to f64
// CHECK-NEXT: %[[cast3:.*]] = builtin.unrealized_conversion_cast %[[arg0]] : i64 to f64
// CHECK-NEXT: %[[cast4:.*]] = "test.cast"(%[[cast1]]#0, %[[cast1]]#1) : (f16, f16) -> f32
// CHECK-NEXT: "test.valid"(%[[cast3]], %[[cast2]], %[[cast4]]) : (f64, f64, f32)
"test.region"() ({
^bb1(%i0: i64, %unused: i16, %i1: i64, %2: f32):
"test.invalid"(%i0, %i1, %2) : (i64, i64, f32) -> ()
}) : () -> ()
// expected-remark@+1 {{op 'func.return' is not legalizable}}
return
}

// -----

// CHECK-LABEL: func @remap_cloned_region_args
func.func @remap_cloned_region_args() {
// CHECK-NEXT: return
// CHECK-NEXT: ^bb1(%[[arg0:.*]]: i64, %[[arg1:.*]]: i16, %[[arg2:.*]]: i64, %[[arg3:.*]]: f32):
// CHECK-NEXT: %[[cast1:.*]]:2 = builtin.unrealized_conversion_cast %[[arg3]] : f32 to f16, f16
// CHECK-NEXT: %[[cast2:.*]] = builtin.unrealized_conversion_cast %[[arg2]] : i64 to f64
// CHECK-NEXT: %[[cast3:.*]] = builtin.unrealized_conversion_cast %[[arg0]] : i64 to f64
// CHECK-NEXT: %[[cast4:.*]] = "test.cast"(%[[cast1]]#0, %[[cast1]]#1) : (f16, f16) -> f32
// CHECK-NEXT: "test.valid"(%[[cast3]], %[[cast2]], %[[cast4]]) : (f64, f64, f32)
"test.region"() ({
^bb1(%i0: i64, %unused: i16, %i1: i64, %2: f32):
"test.invalid"(%i0, %i1, %2) : (i64, i64, f32) -> ()
}) {legalizer.should_clone} : () -> ()
// expected-remark@+1 {{op 'func.return' is not legalizable}}
return
}
39 changes: 0 additions & 39 deletions mlir/test/Transforms/test-legalizer.mlir
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-legalize-patterns="allow-pattern-rollback=1" -verify-diagnostics %s | FileCheck %s
// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-legalize-patterns="allow-pattern-rollback=1" -verify-diagnostics -profile-actions-to=- %s | FileCheck %s --check-prefix=CHECK-PROFILER
// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-legalize-patterns="allow-pattern-rollback=0" -verify-diagnostics %s | FileCheck %s
// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-legalize-patterns="allow-pattern-rollback=0 build-materializations=0 attach-debug-materialization-kind=1" -verify-diagnostics %s | FileCheck %s --check-prefix=CHECK-KIND

// CHECK-PROFILER: "name": "pass-execution", "cat": "PERF", "ph": "B"
// CHECK-PROFILER: "name": "apply-conversion", "cat": "PERF", "ph": "B"
Expand Down Expand Up @@ -146,36 +145,6 @@ func.func @no_remap_nested() {

// -----

// CHECK-LABEL: func @remap_moved_region_args
func.func @remap_moved_region_args() {
// CHECK-NEXT: return
// CHECK-NEXT: ^bb1(%{{.*}}: f64, %{{.*}}: f64, %{{.*}}: f16, %{{.*}}: f16):
// CHECK-NEXT: "test.cast"{{.*}} : (f16, f16) -> f32
// CHECK-NEXT: "test.valid"{{.*}} : (f64, f64, f32)
"test.region"() ({
^bb1(%i0: i64, %unused: i16, %i1: i64, %2: f32):
"test.invalid"(%i0, %i1, %2) : (i64, i64, f32) -> ()
}) : () -> ()
// expected-remark@+1 {{op 'func.return' is not legalizable}}
return
}

// -----

// CHECK-LABEL: func @remap_cloned_region_args
func.func @remap_cloned_region_args() {
// CHECK-NEXT: return
// CHECK-NEXT: ^bb1(%{{.*}}: f64, %{{.*}}: f64, %{{.*}}: f16, %{{.*}}: f16):
// CHECK-NEXT: "test.cast"{{.*}} : (f16, f16) -> f32
// CHECK-NEXT: "test.valid"{{.*}} : (f64, f64, f32)
"test.region"() ({
^bb1(%i0: i64, %unused: i16, %i1: i64, %2: f32):
"test.invalid"(%i0, %i1, %2) : (i64, i64, f32) -> ()
}) {legalizer.should_clone} : () -> ()
// expected-remark@+1 {{op 'func.return' is not legalizable}}
return
}

// CHECK-LABEL: func @remap_drop_region
func.func @remap_drop_region() {
// CHECK-NEXT: return
Expand All @@ -191,12 +160,9 @@ func.func @remap_drop_region() {
// -----

// CHECK-LABEL: func @dropped_input_in_use
// CHECK-KIND-LABEL: func @dropped_input_in_use
func.func @dropped_input_in_use(%arg: i16, %arg2: i64) {
// CHECK-NEXT: %[[cast:.*]] = "test.cast"() : () -> i16
// CHECK-NEXT: "work"(%[[cast]]) : (i16)
// CHECK-KIND-NEXT: %[[cast:.*]] = builtin.unrealized_conversion_cast to i16 {__kind__ = "source"}
// CHECK-KIND-NEXT: "work"(%[[cast]]) : (i16)
// expected-remark@+1 {{op 'work' is not legalizable}}
"work"(%arg) : (i16) -> ()
}
Expand Down Expand Up @@ -452,11 +418,6 @@ func.func @test_multiple_1_to_n_replacement() {
// CHECK: %[[cast:.*]] = "test.cast"(%[[producer]]) : (i16) -> f64
// CHECK: "test.valid_consumer"(%[[cast]]) : (f64) -> ()
// CHECK: "test.valid_consumer"(%[[producer]]) : (i16) -> ()
// CHECK-KIND-LABEL: func @test_lookup_without_converter
// CHECK-KIND: %[[producer:.*]] = "test.valid_producer"() : () -> i16
// CHECK-KIND: %[[cast:.*]] = builtin.unrealized_conversion_cast %[[producer]] : i16 to f64 {__kind__ = "target"}
// CHECK-KIND: "test.valid_consumer"(%[[cast]]) : (f64) -> ()
// CHECK-KIND: "test.valid_consumer"(%[[producer]]) : (i16) -> ()
func.func @test_lookup_without_converter() {
%0 = "test.replace_with_valid_producer"() {type = i16} : () -> (i64)
"test.replace_with_valid_consumer"(%0) {with_converter} : (i64) -> ()
Expand Down
6 changes: 2 additions & 4 deletions mlir/test/lib/Dialect/Test/TestPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1553,8 +1553,7 @@ struct TestLegalizePatternDriver
[](Type type) { return type.isF32(); });
});
target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
return converter.isSignatureLegal(op.getFunctionType()) &&
converter.isLegal(&op.getBody());
return converter.isSignatureLegal(op.getFunctionType());
});
target.addDynamicallyLegalOp<func::CallOp>(
[&](func::CallOp op) { return converter.isLegal(op); });
Expand Down Expand Up @@ -2156,8 +2155,7 @@ struct TestTypeConversionDriver
recursiveType.getName() == "outer_converted_type");
});
target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
return converter.isSignatureLegal(op.getFunctionType()) &&
converter.isLegal(&op.getBody());
return converter.isSignatureLegal(op.getFunctionType());
});
target.addDynamicallyLegalOp<TestCastOp>([&](TestCastOp op) {
// Allow casts from F64 to F32.
Expand Down
Loading