Skip to content

Commit 662c6fc

Browse files
authored
[mlir] [bufferize] fix bufferize deallocation error in nest symbol table (#98476)
In nested symbols, the dealloc_helper function generated by lower deallocations pass was incorrectly positioned, causing calls fail. This patch fixes this issue.
1 parent 3698453 commit 662c6fc

File tree

4 files changed

+82
-26
lines changed

4 files changed

+82
-26
lines changed

mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ class FuncOp;
1818
namespace bufferization {
1919
struct OneShotBufferizationOptions;
2020

21+
/// Maps from symbol table to its corresponding dealloc helper function.
22+
using DeallocHelperMap = llvm::DenseMap<Operation *, func::FuncOp>;
23+
2124
//===----------------------------------------------------------------------===//
2225
// Passes
2326
//===----------------------------------------------------------------------===//
@@ -46,7 +49,7 @@ std::unique_ptr<Pass> createLowerDeallocationsPass();
4649
/// Adds the conversion pattern of the `bufferization.dealloc` operation to the
4750
/// given pattern set for use in other transformation passes.
4851
void populateBufferizationDeallocLoweringPattern(
49-
RewritePatternSet &patterns, func::FuncOp deallocLibraryFunc);
52+
RewritePatternSet &patterns, const DeallocHelperMap &deallocHelperFuncMap);
5053

5154
/// Construct the library function needed for the fully generic
5255
/// `bufferization.dealloc` lowering implemented in the LowerDeallocations pass.

mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -132,27 +132,30 @@ struct BufferizationToMemRefPass
132132
return;
133133
}
134134

135-
func::FuncOp helperFuncOp;
135+
bufferization::DeallocHelperMap deallocHelperFuncMap;
136136
if (auto module = dyn_cast<ModuleOp>(getOperation())) {
137137
OpBuilder builder =
138138
OpBuilder::atBlockBegin(&module.getBodyRegion().front());
139-
SymbolTable symbolTable(module);
140139

141140
// Build dealloc helper function if there are deallocs.
142141
getOperation()->walk([&](bufferization::DeallocOp deallocOp) {
143-
if (deallocOp.getMemrefs().size() > 1) {
144-
helperFuncOp = bufferization::buildDeallocationLibraryFunction(
145-
builder, getOperation()->getLoc(), symbolTable);
146-
return WalkResult::interrupt();
142+
Operation *symtableOp =
143+
deallocOp->getParentWithTrait<OpTrait::SymbolTable>();
144+
if (deallocOp.getMemrefs().size() > 1 &&
145+
!deallocHelperFuncMap.contains(symtableOp)) {
146+
SymbolTable symbolTable(symtableOp);
147+
func::FuncOp helperFuncOp =
148+
bufferization::buildDeallocationLibraryFunction(
149+
builder, getOperation()->getLoc(), symbolTable);
150+
deallocHelperFuncMap[symtableOp] = helperFuncOp;
147151
}
148-
return WalkResult::advance();
149152
});
150153
}
151154

152155
RewritePatternSet patterns(&getContext());
153156
patterns.add<CloneOpConversion>(patterns.getContext());
154-
bufferization::populateBufferizationDeallocLoweringPattern(patterns,
155-
helperFuncOp);
157+
bufferization::populateBufferizationDeallocLoweringPattern(
158+
patterns, deallocHelperFuncMap);
156159

157160
ConversionTarget target(getContext());
158161
target.addLegalDialect<memref::MemRefDialect, arith::ArithDialect,

mlir/lib/Dialect/Bufferization/Transforms/LowerDeallocations.cpp

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -300,8 +300,9 @@ class DeallocOpConversion
300300
MemRefType::get({ShapedType::kDynamic}, rewriter.getI1Type()),
301301
retainCondsMemref);
302302

303+
Operation *symtableOp = op->getParentWithTrait<OpTrait::SymbolTable>();
303304
rewriter.create<func::CallOp>(
304-
op.getLoc(), deallocHelperFunc,
305+
op.getLoc(), deallocHelperFuncMap.lookup(symtableOp),
305306
SmallVector<Value>{castedDeallocMemref, castedRetainMemref,
306307
castedCondsMemref, castedDeallocCondsMemref,
307308
castedRetainCondsMemref});
@@ -338,9 +339,11 @@ class DeallocOpConversion
338339
}
339340

340341
public:
341-
DeallocOpConversion(MLIRContext *context, func::FuncOp deallocHelperFunc)
342+
DeallocOpConversion(
343+
MLIRContext *context,
344+
const bufferization::DeallocHelperMap &deallocHelperFuncMap)
342345
: OpConversionPattern<bufferization::DeallocOp>(context),
343-
deallocHelperFunc(deallocHelperFunc) {}
346+
deallocHelperFuncMap(deallocHelperFuncMap) {}
344347

345348
LogicalResult
346349
matchAndRewrite(bufferization::DeallocOp op, OpAdaptor adaptor,
@@ -360,7 +363,8 @@ class DeallocOpConversion
360363
if (adaptor.getMemrefs().size() == 1)
361364
return rewriteOneMemrefMultipleRetainCase(op, adaptor, rewriter);
362365

363-
if (!deallocHelperFunc)
366+
Operation *symtableOp = op->getParentWithTrait<OpTrait::SymbolTable>();
367+
if (!deallocHelperFuncMap.contains(symtableOp))
364368
return op->emitError(
365369
"library function required for generic lowering, but cannot be "
366370
"automatically inserted when operating on functions");
@@ -369,7 +373,7 @@ class DeallocOpConversion
369373
}
370374

371375
private:
372-
func::FuncOp deallocHelperFunc;
376+
const bufferization::DeallocHelperMap &deallocHelperFuncMap;
373377
};
374378
} // namespace
375379

@@ -385,26 +389,29 @@ struct LowerDeallocationsPass
385389
return;
386390
}
387391

388-
func::FuncOp helperFuncOp;
392+
bufferization::DeallocHelperMap deallocHelperFuncMap;
389393
if (auto module = dyn_cast<ModuleOp>(getOperation())) {
390394
OpBuilder builder =
391395
OpBuilder::atBlockBegin(&module.getBodyRegion().front());
392-
SymbolTable symbolTable(module);
393396

394397
// Build dealloc helper function if there are deallocs.
395398
getOperation()->walk([&](bufferization::DeallocOp deallocOp) {
396-
if (deallocOp.getMemrefs().size() > 1) {
397-
helperFuncOp = bufferization::buildDeallocationLibraryFunction(
398-
builder, getOperation()->getLoc(), symbolTable);
399-
return WalkResult::interrupt();
399+
Operation *symtableOp =
400+
deallocOp->getParentWithTrait<OpTrait::SymbolTable>();
401+
if (deallocOp.getMemrefs().size() > 1 &&
402+
!deallocHelperFuncMap.contains(symtableOp)) {
403+
SymbolTable symbolTable(symtableOp);
404+
func::FuncOp helperFuncOp =
405+
bufferization::buildDeallocationLibraryFunction(
406+
builder, getOperation()->getLoc(), symbolTable);
407+
deallocHelperFuncMap[symtableOp] = helperFuncOp;
400408
}
401-
return WalkResult::advance();
402409
});
403410
}
404411

405412
RewritePatternSet patterns(&getContext());
406-
bufferization::populateBufferizationDeallocLoweringPattern(patterns,
407-
helperFuncOp);
413+
bufferization::populateBufferizationDeallocLoweringPattern(
414+
patterns, deallocHelperFuncMap);
408415

409416
ConversionTarget target(getContext());
410417
target.addLegalDialect<memref::MemRefDialect, arith::ArithDialect,
@@ -535,8 +542,10 @@ func::FuncOp mlir::bufferization::buildDeallocationLibraryFunction(
535542
}
536543

537544
void mlir::bufferization::populateBufferizationDeallocLoweringPattern(
538-
RewritePatternSet &patterns, func::FuncOp deallocLibraryFunc) {
539-
patterns.add<DeallocOpConversion>(patterns.getContext(), deallocLibraryFunc);
545+
RewritePatternSet &patterns,
546+
const bufferization::DeallocHelperMap &deallocHelperFuncMap) {
547+
patterns.add<DeallocOpConversion>(patterns.getContext(),
548+
deallocHelperFuncMap);
540549
}
541550

542551
std::unique_ptr<Pass> mlir::bufferization::createLowerDeallocationsPass() {

mlir/test/Dialect/Bufferization/Transforms/lower-deallocations.mlir

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,3 +154,44 @@ func.func @conversion_dealloc_multiple_memrefs_and_retained(%arg0: memref<2xf32>
154154
// CHECK-NEXT: memref.store [[DEALLOC_COND]], [[DEALLOC_CONDS_OUT]][[[OUTER_ITER]]]
155155
// CHECK-NEXT: }
156156
// CHECK-NEXT: return
157+
158+
// -----
159+
160+
// This test check dealloc_helper function is generated on each nested symbol
161+
// table operation when needed and only generated once.
162+
module @conversion_nest_module_dealloc_helper {
163+
func.func @top_level_func(%arg0: memref<2xf32>, %arg1: memref<5xf32>, %arg2: memref<1xf32>, %arg3: i1, %arg4: i1, %arg5: memref<2xf32>) -> (i1, i1) {
164+
%0:2 = bufferization.dealloc (%arg0, %arg1 : memref<2xf32>, memref<5xf32>) if (%arg3, %arg4) retain (%arg2, %arg5 : memref<1xf32>, memref<2xf32>)
165+
func.return %0#0, %0#1 : i1, i1
166+
}
167+
module @nested_module_not_need_dealloc_helper {
168+
func.func @nested_module_not_need_dealloc_helper_func(%arg0: memref<2xf32>, %arg1: memref<1xf32>, %arg2: i1, %arg3: memref<2xf32>) -> (i1, i1) {
169+
%0:2 = bufferization.dealloc (%arg0 : memref<2xf32>) if (%arg2) retain (%arg1, %arg3 : memref<1xf32>, memref<2xf32>)
170+
return %0#0, %0#1 : i1, i1
171+
}
172+
}
173+
module @nested_module_need_dealloc_helper {
174+
func.func @nested_module_need_dealloc_helper_func0(%arg0: memref<2xf32>, %arg1: memref<5xf32>, %arg2: memref<1xf32>, %arg3: i1, %arg4: i1, %arg5: memref<2xf32>) -> (i1, i1) {
175+
%0:2 = bufferization.dealloc (%arg0, %arg1 : memref<2xf32>, memref<5xf32>) if (%arg3, %arg4) retain (%arg2, %arg5 : memref<1xf32>, memref<2xf32>)
176+
func.return %0#0, %0#1 : i1, i1
177+
}
178+
func.func @nested_module_need_dealloc_helper_func1(%arg0: memref<2xf32>, %arg1: memref<5xf32>, %arg2: memref<1xf32>, %arg3: i1, %arg4: i1, %arg5: memref<2xf32>) -> (i1, i1) {
179+
%0:2 = bufferization.dealloc (%arg0, %arg1 : memref<2xf32>, memref<5xf32>) if (%arg3, %arg4) retain (%arg2, %arg5 : memref<1xf32>, memref<2xf32>)
180+
func.return %0#0, %0#1 : i1, i1
181+
}
182+
}
183+
}
184+
185+
// CHECK: module @conversion_nest_module_dealloc_helper {
186+
// CHECK: func.func @top_level_func
187+
// CHECK: call @dealloc_helper
188+
// CHECK: module @nested_module_not_need_dealloc_helper {
189+
// CHECK: func.func @nested_module_not_need_dealloc_helper_func
190+
// CHECK-NOT: @dealloc_helper
191+
// CHECK: module @nested_module_need_dealloc_helper {
192+
// CHECK: func.func @nested_module_need_dealloc_helper_func0
193+
// CHECK: call @dealloc_helper
194+
// CHECK: func.func @nested_module_need_dealloc_helper_func1
195+
// CHECK: call @dealloc_helper
196+
// CHECK: func.func private @dealloc_helper
197+
// CHECK: func.func private @dealloc_helper

0 commit comments

Comments
 (0)