Skip to content

Commit 7bee741

Browse files
[mlir][Transforms] Dialect Conversion: Convert entry block only
1 parent 8785595 commit 7bee741

File tree

4 files changed

+89
-135
lines changed

4 files changed

+89
-135
lines changed

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 20 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -1105,10 +1105,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
11051105
/// A set of operations that were modified by the current pattern.
11061106
SetVector<Operation *> patternModifiedOps;
11071107

1108-
/// A set of blocks that were inserted (newly-created blocks or moved blocks)
1109-
/// by the current pattern.
1110-
SetVector<Block *> patternInsertedBlocks;
1111-
11121108
/// A list of unresolved materializations that were created by the current
11131109
/// pattern.
11141110
DenseSet<UnrealizedConversionCastOp> patternMaterializations;
@@ -2046,8 +2042,6 @@ void ConversionPatternRewriterImpl::notifyBlockInserted(
20462042
if (!config.allowPatternRollback && config.listener)
20472043
config.listener->notifyBlockInserted(block, previous, previousIt);
20482044

2049-
patternInsertedBlocks.insert(block);
2050-
20512045
if (wasDetached) {
20522046
// If the block was detached, it is most likely a newly created block.
20532047
if (config.allowPatternRollback) {
@@ -2399,17 +2393,12 @@ class OperationLegalizer {
23992393
bool canApplyPattern(Operation *op, const Pattern &pattern);
24002394

24012395
/// Legalize the resultant IR after successfully applying the given pattern.
2402-
LogicalResult legalizePatternResult(Operation *op, const Pattern &pattern,
2403-
const RewriterState &curState,
2404-
const SetVector<Operation *> &newOps,
2405-
const SetVector<Operation *> &modifiedOps,
2406-
const SetVector<Block *> &insertedBlocks);
2407-
2408-
/// Legalizes the actions registered during the execution of a pattern.
24092396
LogicalResult
2410-
legalizePatternBlockRewrites(Operation *op,
2411-
const SetVector<Block *> &insertedBlocks,
2412-
const SetVector<Operation *> &newOps);
2397+
legalizePatternResult(Operation *op, const Pattern &pattern,
2398+
const RewriterState &curState,
2399+
const SetVector<Operation *> &newOps,
2400+
const SetVector<Operation *> &modifiedOps);
2401+
24132402
LogicalResult
24142403
legalizePatternCreatedOperations(const SetVector<Operation *> &newOps);
24152404
LogicalResult
@@ -2608,7 +2597,6 @@ LogicalResult OperationLegalizer::legalizeWithFold(Operation *op) {
26082597
auto cleanup = llvm::make_scope_exit([&]() {
26092598
rewriterImpl.patternNewOps.clear();
26102599
rewriterImpl.patternModifiedOps.clear();
2611-
rewriterImpl.patternInsertedBlocks.clear();
26122600
});
26132601

26142602
// Upon failure, undo all changes made by the folder.
@@ -2662,24 +2650,16 @@ LogicalResult OperationLegalizer::legalizeWithFold(Operation *op) {
26622650
static void
26632651
reportNewIrLegalizationFatalError(const Pattern &pattern,
26642652
const SetVector<Operation *> &newOps,
2665-
const SetVector<Operation *> &modifiedOps,
2666-
const SetVector<Block *> &insertedBlocks) {
2653+
const SetVector<Operation *> &modifiedOps) {
26672654
auto newOpNames = llvm::map_range(
26682655
newOps, [](Operation *op) { return op->getName().getStringRef(); });
26692656
auto modifiedOpNames = llvm::map_range(
26702657
modifiedOps, [](Operation *op) { return op->getName().getStringRef(); });
2671-
StringRef detachedBlockStr = "(detached block)";
2672-
auto insertedBlockNames = llvm::map_range(insertedBlocks, [&](Block *block) {
2673-
if (block->getParentOp())
2674-
return block->getParentOp()->getName().getStringRef();
2675-
return detachedBlockStr;
2676-
});
2677-
llvm::report_fatal_error(
2678-
"pattern '" + pattern.getDebugName() +
2679-
"' produced IR that could not be legalized. " + "new ops: {" +
2680-
llvm::join(newOpNames, ", ") + "}, " + "modified ops: {" +
2681-
llvm::join(modifiedOpNames, ", ") + "}, " + "inserted block into ops: {" +
2682-
llvm::join(insertedBlockNames, ", ") + "}");
2658+
llvm::report_fatal_error("pattern '" + pattern.getDebugName() +
2659+
"' produced IR that could not be legalized. " +
2660+
"new ops: {" + llvm::join(newOpNames, ", ") + "}, " +
2661+
"modified ops: {" +
2662+
llvm::join(modifiedOpNames, ", ") + "}");
26832663
}
26842664

26852665
LogicalResult OperationLegalizer::legalizeWithPattern(Operation *op) {
@@ -2743,7 +2723,6 @@ LogicalResult OperationLegalizer::legalizeWithPattern(Operation *op) {
27432723
}
27442724
rewriterImpl.patternNewOps.clear();
27452725
rewriterImpl.patternModifiedOps.clear();
2746-
rewriterImpl.patternInsertedBlocks.clear();
27472726
LLVM_DEBUG({
27482727
logFailure(rewriterImpl.logger, "pattern failed to match");
27492728
if (rewriterImpl.config.notifyCallback) {
@@ -2777,15 +2756,12 @@ LogicalResult OperationLegalizer::legalizeWithPattern(Operation *op) {
27772756
SetVector<Operation *> newOps = moveAndReset(rewriterImpl.patternNewOps);
27782757
SetVector<Operation *> modifiedOps =
27792758
moveAndReset(rewriterImpl.patternModifiedOps);
2780-
SetVector<Block *> insertedBlocks =
2781-
moveAndReset(rewriterImpl.patternInsertedBlocks);
2782-
auto result = legalizePatternResult(op, pattern, curState, newOps,
2783-
modifiedOps, insertedBlocks);
2759+
auto result =
2760+
legalizePatternResult(op, pattern, curState, newOps, modifiedOps);
27842761
appliedPatterns.erase(&pattern);
27852762
if (failed(result)) {
27862763
if (!rewriterImpl.config.allowPatternRollback)
2787-
reportNewIrLegalizationFatalError(pattern, newOps, modifiedOps,
2788-
insertedBlocks);
2764+
reportNewIrLegalizationFatalError(pattern, newOps, modifiedOps);
27892765
rewriterImpl.resetState(curState, pattern.getDebugName());
27902766
}
27912767
if (config.listener)
@@ -2823,8 +2799,7 @@ bool OperationLegalizer::canApplyPattern(Operation *op,
28232799
LogicalResult OperationLegalizer::legalizePatternResult(
28242800
Operation *op, const Pattern &pattern, const RewriterState &curState,
28252801
const SetVector<Operation *> &newOps,
2826-
const SetVector<Operation *> &modifiedOps,
2827-
const SetVector<Block *> &insertedBlocks) {
2802+
const SetVector<Operation *> &modifiedOps) {
28282803
[[maybe_unused]] auto &impl = rewriter.getImpl();
28292804
assert(impl.pendingRootUpdates.empty() && "dangling root updates");
28302805

@@ -2843,8 +2818,7 @@ LogicalResult OperationLegalizer::legalizePatternResult(
28432818
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
28442819

28452820
// Legalize each of the actions registered during application.
2846-
if (failed(legalizePatternBlockRewrites(op, insertedBlocks, newOps)) ||
2847-
failed(legalizePatternRootUpdates(modifiedOps)) ||
2821+
if (failed(legalizePatternRootUpdates(modifiedOps)) ||
28482822
failed(legalizePatternCreatedOperations(newOps))) {
28492823
return failure();
28502824
}
@@ -2853,53 +2827,6 @@ LogicalResult OperationLegalizer::legalizePatternResult(
28532827
return success();
28542828
}
28552829

2856-
LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
2857-
Operation *op, const SetVector<Block *> &insertedBlocks,
2858-
const SetVector<Operation *> &newOps) {
2859-
ConversionPatternRewriterImpl &impl = rewriter.getImpl();
2860-
SmallPtrSet<Operation *, 16> alreadyLegalized;
2861-
2862-
// If the pattern moved or created any blocks, make sure the types of block
2863-
// arguments get legalized.
2864-
for (Block *block : insertedBlocks) {
2865-
if (impl.erasedBlocks.contains(block))
2866-
continue;
2867-
2868-
// Only check blocks outside of the current operation.
2869-
Operation *parentOp = block->getParentOp();
2870-
if (!parentOp || parentOp == op || block->getNumArguments() == 0)
2871-
continue;
2872-
2873-
// If the region of the block has a type converter, try to convert the block
2874-
// directly.
2875-
if (auto *converter = impl.regionToConverter.lookup(block->getParent())) {
2876-
std::optional<TypeConverter::SignatureConversion> conversion =
2877-
converter->convertBlockSignature(block);
2878-
if (!conversion) {
2879-
LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved "
2880-
"block"));
2881-
return failure();
2882-
}
2883-
impl.applySignatureConversion(block, converter, *conversion);
2884-
continue;
2885-
}
2886-
2887-
// Otherwise, try to legalize the parent operation if it was not generated
2888-
// by this pattern. This is because we will attempt to legalize the parent
2889-
// operation, and blocks in regions created by this pattern will already be
2890-
// legalized later on.
2891-
if (!newOps.count(parentOp) && alreadyLegalized.insert(parentOp).second) {
2892-
if (failed(legalize(parentOp))) {
2893-
LLVM_DEBUG(logFailure(
2894-
impl.logger, "operation '{0}'({1}) became illegal after rewrite",
2895-
parentOp->getName(), parentOp));
2896-
return failure();
2897-
}
2898-
}
2899-
}
2900-
return success();
2901-
}
2902-
29032830
LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
29042831
const SetVector<Operation *> &newOps) {
29052832
for (Operation *op : newOps) {
@@ -3800,10 +3727,11 @@ static LogicalResult convertFuncOpTypes(FunctionOpInterface funcOp,
38003727
TypeConverter::SignatureConversion result(type.getNumInputs());
38013728
SmallVector<Type, 1> newResults;
38023729
if (failed(typeConverter.convertSignatureArgs(type.getInputs(), result)) ||
3803-
failed(typeConverter.convertTypes(type.getResults(), newResults)) ||
3804-
failed(rewriter.convertRegionTypes(&funcOp.getFunctionBody(),
3805-
typeConverter, &result)))
3730+
failed(typeConverter.convertTypes(type.getResults(), newResults)))
38063731
return failure();
3732+
if (!funcOp.getFunctionBody().empty())
3733+
rewriter.applySignatureConversion(&funcOp.getFunctionBody().front(), result,
3734+
&typeConverter);
38073735

38083736
// Update the function signature in-place.
38093737
auto newType = FunctionType::get(rewriter.getContext(),
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-legalize-patterns="allow-pattern-rollback=0 build-materializations=0 attach-debug-materialization-kind=1" -verify-diagnostics %s | FileCheck %s --check-prefix=CHECK-KIND
2+
3+
// CHECK-LABEL: func @dropped_input_in_use
4+
// CHECK-KIND-LABEL: func @dropped_input_in_use
5+
func.func @dropped_input_in_use(%arg: i16, %arg2: i64) {
6+
// CHECK-NEXT: %[[cast:.*]] = "test.cast"() : () -> i16
7+
// CHECK-NEXT: "work"(%[[cast]]) : (i16)
8+
// CHECK-KIND-NEXT: %[[cast:.*]] = builtin.unrealized_conversion_cast to i16 {__kind__ = "source"}
9+
// CHECK-KIND-NEXT: "work"(%[[cast]]) : (i16)
10+
// expected-remark@+1 {{op 'work' is not legalizable}}
11+
"work"(%arg) : (i16) -> ()
12+
}
13+
14+
// -----
15+
16+
// CHECK-KIND-LABEL: func @test_lookup_without_converter
17+
// CHECK-KIND: %[[producer:.*]] = "test.valid_producer"() : () -> i16
18+
// CHECK-KIND: %[[cast:.*]] = builtin.unrealized_conversion_cast %[[producer]] : i16 to f64 {__kind__ = "target"}
19+
// CHECK-KIND: "test.valid_consumer"(%[[cast]]) : (f64) -> ()
20+
// CHECK-KIND: "test.valid_consumer"(%[[producer]]) : (i16) -> ()
21+
func.func @test_lookup_without_converter() {
22+
%0 = "test.replace_with_valid_producer"() {type = i16} : () -> (i64)
23+
"test.replace_with_valid_consumer"(%0) {with_converter} : (i64) -> ()
24+
// Make sure that the second "replace_with_valid_consumer" lowering does not
25+
// lookup the materialization that was created for the above op.
26+
"test.replace_with_valid_consumer"(%0) : (i64) -> ()
27+
// expected-remark@+1 {{op 'func.return' is not legalizable}}
28+
return
29+
}
30+
31+
// -----
32+
33+
// CHECK-LABEL: func @remap_moved_region_args
34+
func.func @remap_moved_region_args() {
35+
// CHECK-NEXT: return
36+
// CHECK-NEXT: ^bb1(%[[arg0:.*]]: i64, %[[arg1:.*]]: i16, %[[arg2:.*]]: i64, %[[arg3:.*]]: f32):
37+
// CHECK-NEXT: %[[cast1:.*]]:2 = builtin.unrealized_conversion_cast %[[arg3]] : f32 to f16, f16
38+
// CHECK-NEXT: %[[cast2:.*]] = builtin.unrealized_conversion_cast %[[arg2]] : i64 to f64
39+
// CHECK-NEXT: %[[cast3:.*]] = builtin.unrealized_conversion_cast %[[arg0]] : i64 to f64
40+
// CHECK-NEXT: %[[cast4:.*]] = "test.cast"(%[[cast1]]#0, %[[cast1]]#1) : (f16, f16) -> f32
41+
// CHECK-NEXT: "test.valid"(%[[cast3]], %[[cast2]], %[[cast4]]) : (f64, f64, f32)
42+
"test.region"() ({
43+
^bb1(%i0: i64, %unused: i16, %i1: i64, %2: f32):
44+
"test.invalid"(%i0, %i1, %2) : (i64, i64, f32) -> ()
45+
}) : () -> ()
46+
// expected-remark@+1 {{op 'func.return' is not legalizable}}
47+
return
48+
}
49+
50+
// -----
51+
52+
// CHECK-LABEL: func @remap_cloned_region_args
53+
func.func @remap_cloned_region_args() {
54+
// CHECK-NEXT: return
55+
// CHECK-NEXT: ^bb1(%[[arg0:.*]]: i64, %[[arg1:.*]]: i16, %[[arg2:.*]]: i64, %[[arg3:.*]]: f32):
56+
// CHECK-NEXT: %[[cast1:.*]]:2 = builtin.unrealized_conversion_cast %[[arg3]] : f32 to f16, f16
57+
// CHECK-NEXT: %[[cast2:.*]] = builtin.unrealized_conversion_cast %[[arg2]] : i64 to f64
58+
// CHECK-NEXT: %[[cast3:.*]] = builtin.unrealized_conversion_cast %[[arg0]] : i64 to f64
59+
// CHECK-NEXT: %[[cast4:.*]] = "test.cast"(%[[cast1]]#0, %[[cast1]]#1) : (f16, f16) -> f32
60+
// CHECK-NEXT: "test.valid"(%[[cast3]], %[[cast2]], %[[cast4]]) : (f64, f64, f32)
61+
"test.region"() ({
62+
^bb1(%i0: i64, %unused: i16, %i1: i64, %2: f32):
63+
"test.invalid"(%i0, %i1, %2) : (i64, i64, f32) -> ()
64+
}) {legalizer.should_clone} : () -> ()
65+
// expected-remark@+1 {{op 'func.return' is not legalizable}}
66+
return
67+
}

mlir/test/Transforms/test-legalizer.mlir

Lines changed: 0 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-legalize-patterns="allow-pattern-rollback=1" -verify-diagnostics %s | FileCheck %s
22
// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-legalize-patterns="allow-pattern-rollback=1" -verify-diagnostics -profile-actions-to=- %s | FileCheck %s --check-prefix=CHECK-PROFILER
33
// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-legalize-patterns="allow-pattern-rollback=0" -verify-diagnostics %s | FileCheck %s
4-
// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-legalize-patterns="allow-pattern-rollback=0 build-materializations=0 attach-debug-materialization-kind=1" -verify-diagnostics %s | FileCheck %s --check-prefix=CHECK-KIND
54

65
// CHECK-PROFILER: "name": "pass-execution", "cat": "PERF", "ph": "B"
76
// CHECK-PROFILER: "name": "apply-conversion", "cat": "PERF", "ph": "B"
@@ -146,36 +145,6 @@ func.func @no_remap_nested() {
146145

147146
// -----
148147

149-
// CHECK-LABEL: func @remap_moved_region_args
150-
func.func @remap_moved_region_args() {
151-
// CHECK-NEXT: return
152-
// CHECK-NEXT: ^bb1(%{{.*}}: f64, %{{.*}}: f64, %{{.*}}: f16, %{{.*}}: f16):
153-
// CHECK-NEXT: "test.cast"{{.*}} : (f16, f16) -> f32
154-
// CHECK-NEXT: "test.valid"{{.*}} : (f64, f64, f32)
155-
"test.region"() ({
156-
^bb1(%i0: i64, %unused: i16, %i1: i64, %2: f32):
157-
"test.invalid"(%i0, %i1, %2) : (i64, i64, f32) -> ()
158-
}) : () -> ()
159-
// expected-remark@+1 {{op 'func.return' is not legalizable}}
160-
return
161-
}
162-
163-
// -----
164-
165-
// CHECK-LABEL: func @remap_cloned_region_args
166-
func.func @remap_cloned_region_args() {
167-
// CHECK-NEXT: return
168-
// CHECK-NEXT: ^bb1(%{{.*}}: f64, %{{.*}}: f64, %{{.*}}: f16, %{{.*}}: f16):
169-
// CHECK-NEXT: "test.cast"{{.*}} : (f16, f16) -> f32
170-
// CHECK-NEXT: "test.valid"{{.*}} : (f64, f64, f32)
171-
"test.region"() ({
172-
^bb1(%i0: i64, %unused: i16, %i1: i64, %2: f32):
173-
"test.invalid"(%i0, %i1, %2) : (i64, i64, f32) -> ()
174-
}) {legalizer.should_clone} : () -> ()
175-
// expected-remark@+1 {{op 'func.return' is not legalizable}}
176-
return
177-
}
178-
179148
// CHECK-LABEL: func @remap_drop_region
180149
func.func @remap_drop_region() {
181150
// CHECK-NEXT: return
@@ -191,12 +160,9 @@ func.func @remap_drop_region() {
191160
// -----
192161

193162
// CHECK-LABEL: func @dropped_input_in_use
194-
// CHECK-KIND-LABEL: func @dropped_input_in_use
195163
func.func @dropped_input_in_use(%arg: i16, %arg2: i64) {
196164
// CHECK-NEXT: %[[cast:.*]] = "test.cast"() : () -> i16
197165
// CHECK-NEXT: "work"(%[[cast]]) : (i16)
198-
// CHECK-KIND-NEXT: %[[cast:.*]] = builtin.unrealized_conversion_cast to i16 {__kind__ = "source"}
199-
// CHECK-KIND-NEXT: "work"(%[[cast]]) : (i16)
200166
// expected-remark@+1 {{op 'work' is not legalizable}}
201167
"work"(%arg) : (i16) -> ()
202168
}
@@ -452,11 +418,6 @@ func.func @test_multiple_1_to_n_replacement() {
452418
// CHECK: %[[cast:.*]] = "test.cast"(%[[producer]]) : (i16) -> f64
453419
// CHECK: "test.valid_consumer"(%[[cast]]) : (f64) -> ()
454420
// CHECK: "test.valid_consumer"(%[[producer]]) : (i16) -> ()
455-
// CHECK-KIND-LABEL: func @test_lookup_without_converter
456-
// CHECK-KIND: %[[producer:.*]] = "test.valid_producer"() : () -> i16
457-
// CHECK-KIND: %[[cast:.*]] = builtin.unrealized_conversion_cast %[[producer]] : i16 to f64 {__kind__ = "target"}
458-
// CHECK-KIND: "test.valid_consumer"(%[[cast]]) : (f64) -> ()
459-
// CHECK-KIND: "test.valid_consumer"(%[[producer]]) : (i16) -> ()
460421
func.func @test_lookup_without_converter() {
461422
%0 = "test.replace_with_valid_producer"() {type = i16} : () -> (i64)
462423
"test.replace_with_valid_consumer"(%0) {with_converter} : (i64) -> ()

mlir/test/lib/Dialect/Test/TestPatterns.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1553,8 +1553,7 @@ struct TestLegalizePatternDriver
15531553
[](Type type) { return type.isF32(); });
15541554
});
15551555
target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
1556-
return converter.isSignatureLegal(op.getFunctionType()) &&
1557-
converter.isLegal(&op.getBody());
1556+
return converter.isSignatureLegal(op.getFunctionType());
15581557
});
15591558
target.addDynamicallyLegalOp<func::CallOp>(
15601559
[&](func::CallOp op) { return converter.isLegal(op); });
@@ -2156,8 +2155,7 @@ struct TestTypeConversionDriver
21562155
recursiveType.getName() == "outer_converted_type");
21572156
});
21582157
target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
2159-
return converter.isSignatureLegal(op.getFunctionType()) &&
2160-
converter.isLegal(&op.getBody());
2158+
return converter.isSignatureLegal(op.getFunctionType());
21612159
});
21622160
target.addDynamicallyLegalOp<TestCastOp>([&](TestCastOp op) {
21632161
// Allow casts from F64 to F32.

0 commit comments

Comments
 (0)