@@ -300,8 +300,9 @@ class DeallocOpConversion
300
300
MemRefType::get ({ShapedType::kDynamic }, rewriter.getI1Type ()),
301
301
retainCondsMemref);
302
302
303
+ Operation *symtableOp = op->getParentWithTrait <OpTrait::SymbolTable>();
303
304
rewriter.create <func::CallOp>(
304
- op.getLoc (), deallocHelperFunc ,
305
+ op.getLoc (), deallocHelperFuncMap. lookup (symtableOp) ,
305
306
SmallVector<Value>{castedDeallocMemref, castedRetainMemref,
306
307
castedCondsMemref, castedDeallocCondsMemref,
307
308
castedRetainCondsMemref});
@@ -338,9 +339,11 @@ class DeallocOpConversion
338
339
}
339
340
340
341
public:
341
- DeallocOpConversion (MLIRContext *context, func::FuncOp deallocHelperFunc)
342
+ DeallocOpConversion (
343
+ MLIRContext *context,
344
+ const bufferization::DeallocHelperMap &deallocHelperFuncMap)
342
345
: OpConversionPattern<bufferization::DeallocOp>(context),
343
- deallocHelperFunc (deallocHelperFunc ) {}
346
+ deallocHelperFuncMap (deallocHelperFuncMap ) {}
344
347
345
348
LogicalResult
346
349
matchAndRewrite (bufferization::DeallocOp op, OpAdaptor adaptor,
@@ -360,7 +363,8 @@ class DeallocOpConversion
360
363
if (adaptor.getMemrefs ().size () == 1 )
361
364
return rewriteOneMemrefMultipleRetainCase (op, adaptor, rewriter);
362
365
363
- if (!deallocHelperFunc)
366
+ Operation *symtableOp = op->getParentWithTrait <OpTrait::SymbolTable>();
367
+ if (!deallocHelperFuncMap.contains (symtableOp))
364
368
return op->emitError (
365
369
" library function required for generic lowering, but cannot be "
366
370
" automatically inserted when operating on functions" );
@@ -369,7 +373,7 @@ class DeallocOpConversion
369
373
}
370
374
371
375
private:
372
- func::FuncOp deallocHelperFunc ;
376
+ const bufferization::DeallocHelperMap &deallocHelperFuncMap ;
373
377
};
374
378
} // namespace
375
379
@@ -385,26 +389,29 @@ struct LowerDeallocationsPass
385
389
return ;
386
390
}
387
391
388
- func::FuncOp helperFuncOp ;
392
+ bufferization::DeallocHelperMap deallocHelperFuncMap ;
389
393
if (auto module = dyn_cast<ModuleOp>(getOperation ())) {
390
394
OpBuilder builder =
391
395
OpBuilder::atBlockBegin (&module.getBodyRegion ().front ());
392
- SymbolTable symbolTable (module);
393
396
394
397
// Build dealloc helper function if there are deallocs.
395
398
getOperation ()->walk ([&](bufferization::DeallocOp deallocOp) {
396
- if (deallocOp.getMemrefs ().size () > 1 ) {
397
- helperFuncOp = bufferization::buildDeallocationLibraryFunction (
398
- builder, getOperation ()->getLoc (), symbolTable);
399
- return WalkResult::interrupt ();
399
+ Operation *symtableOp =
400
+ deallocOp->getParentWithTrait <OpTrait::SymbolTable>();
401
+ if (deallocOp.getMemrefs ().size () > 1 &&
402
+ !deallocHelperFuncMap.contains (symtableOp)) {
403
+ SymbolTable symbolTable (symtableOp);
404
+ func::FuncOp helperFuncOp =
405
+ bufferization::buildDeallocationLibraryFunction (
406
+ builder, getOperation ()->getLoc (), symbolTable);
407
+ deallocHelperFuncMap[symtableOp] = helperFuncOp;
400
408
}
401
- return WalkResult::advance ();
402
409
});
403
410
}
404
411
405
412
RewritePatternSet patterns (&getContext ());
406
- bufferization::populateBufferizationDeallocLoweringPattern (patterns,
407
- helperFuncOp );
413
+ bufferization::populateBufferizationDeallocLoweringPattern (
414
+ patterns, deallocHelperFuncMap );
408
415
409
416
ConversionTarget target (getContext ());
410
417
target.addLegalDialect <memref::MemRefDialect, arith::ArithDialect,
@@ -535,8 +542,10 @@ func::FuncOp mlir::bufferization::buildDeallocationLibraryFunction(
535
542
}
536
543
537
544
void mlir::bufferization::populateBufferizationDeallocLoweringPattern (
538
- RewritePatternSet &patterns, func::FuncOp deallocLibraryFunc) {
539
- patterns.add <DeallocOpConversion>(patterns.getContext (), deallocLibraryFunc);
545
+ RewritePatternSet &patterns,
546
+ const bufferization::DeallocHelperMap &deallocHelperFuncMap) {
547
+ patterns.add <DeallocOpConversion>(patterns.getContext (),
548
+ deallocHelperFuncMap);
540
549
}
541
550
542
551
std::unique_ptr<Pass> mlir::bufferization::createLowerDeallocationsPass () {
0 commit comments