@@ -120,6 +120,91 @@ class DeallocOpConversion
120120 return success ();
121121 }
122122
123+ // / A special case lowering for the deallocation operation with exactly one
124+ // / memref, but arbitrary number of retained values. This avoids the helper
125+ // / function that the general case needs and thus also avoids storing indices
126+ // / to specifically allocated memrefs. The size of the code produced by this
127+ // / lowering is linear to the number of retained values.
128+ // /
129+ // / Example:
130+ // / ```mlir
131+ // / %0:2 = bufferization.dealloc (%m : memref<2xf32>) if (%cond)
132+ // retain (%r0, %r1 : memref<1xf32>, memref<2xf32>)
133+ // / return %0#0, %0#1 : i1, i1
134+ // / ```
135+ // / ```mlir
136+ // / %m_base_pointer = memref.extract_aligned_pointer_as_index %m
137+ // / %r0_base_pointer = memref.extract_aligned_pointer_as_index %r0
138+ // / %r0_does_not_alias = arith.cmpi ne, %m_base_pointer, %r0_base_pointer
139+ // / %r1_base_pointer = memref.extract_aligned_pointer_as_index %r1
140+ // / %r1_does_not_alias = arith.cmpi ne, %m_base_pointer, %r1_base_pointer
141+ // / %not_retained = arith.andi %r0_does_not_alias, %r1_does_not_alias : i1
142+ // / %should_dealloc = arith.andi %not_retained, %cond : i1
143+ // / scf.if %should_dealloc {
144+ // / memref.dealloc %m : memref<2xf32>
145+ // / }
146+ // / %true = arith.constant true
147+ // / %r0_does_alias = arith.xori %r0_does_not_alias, %true : i1
148+ // / %r0_ownership = arith.andi %r0_does_alias, %cond : i1
149+ // / %r1_does_alias = arith.xori %r1_does_not_alias, %true : i1
150+ // / %r1_ownership = arith.andi %r1_does_alias, %cond : i1
151+ // / return %r0_ownership, %r1_ownership : i1, i1
152+ // / ```
153+ LogicalResult rewriteOneMemrefMultipleRetainCase (
154+ bufferization::DeallocOp op, OpAdaptor adaptor,
155+ ConversionPatternRewriter &rewriter) const {
156+ assert (adaptor.getMemrefs ().size () == 1 && " expected only one memref" );
157+
158+ // Compute the base pointer indices, compare all retained indices to the
159+ // memref index to check if they alias.
160+ SmallVector<Value> doesNotAliasList;
161+ Value memrefAsIdx = rewriter.create <memref::ExtractAlignedPointerAsIndexOp>(
162+ op->getLoc (), adaptor.getMemrefs ()[0 ]);
163+ for (Value retained : adaptor.getRetained ()) {
164+ Value retainedAsIdx =
165+ rewriter.create <memref::ExtractAlignedPointerAsIndexOp>(op->getLoc (),
166+ retained);
167+ Value doesNotAlias = rewriter.create <arith::CmpIOp>(
168+ op->getLoc (), arith::CmpIPredicate::ne, memrefAsIdx, retainedAsIdx);
169+ doesNotAliasList.push_back (doesNotAlias);
170+ }
171+
172+ // AND-reduce the list of booleans from above.
173+ Value prev = doesNotAliasList.front ();
174+ for (Value doesNotAlias : ArrayRef (doesNotAliasList).drop_front ())
175+ prev = rewriter.create <arith::AndIOp>(op->getLoc (), prev, doesNotAlias);
176+
177+ // Also consider the condition given by the dealloc operation and perform a
178+ // conditional deallocation guarded by that value.
179+ Value shouldDealloc = rewriter.create <arith::AndIOp>(
180+ op->getLoc (), prev, adaptor.getConditions ()[0 ]);
181+
182+ rewriter.create <scf::IfOp>(
183+ op.getLoc (), shouldDealloc, [&](OpBuilder &builder, Location loc) {
184+ builder.create <memref::DeallocOp>(loc, adaptor.getMemrefs ()[0 ]);
185+ builder.create <scf::YieldOp>(loc);
186+ });
187+
188+ // Compute the replacement values for the dealloc operation results. This
189+ // inserts an already canonicalized form of
190+ // `select(does_alias_with_memref(r), memref_cond, false)` for each retained
191+ // value r.
192+ SmallVector<Value> replacements;
193+ Value trueVal = rewriter.create <arith::ConstantOp>(
194+ op->getLoc (), rewriter.getBoolAttr (true ));
195+ for (Value doesNotAlias : doesNotAliasList) {
196+ Value aliases =
197+ rewriter.create <arith::XOrIOp>(op->getLoc (), doesNotAlias, trueVal);
198+ Value result = rewriter.create <arith::AndIOp>(op->getLoc (), aliases,
199+ adaptor.getConditions ()[0 ]);
200+ replacements.push_back (result);
201+ }
202+
203+ rewriter.replaceOp (op, replacements);
204+
205+ return success ();
206+ }
207+
123208 // / Lowering that supports all features the dealloc operation has to offer. It
124209 // / computes the base pointer of each memref (as an index), stores it in a
125210 // / new memref helper structure and passes it to the helper function generated
@@ -310,12 +395,20 @@ class DeallocOpConversion
310395 matchAndRewrite (bufferization::DeallocOp op, OpAdaptor adaptor,
311396 ConversionPatternRewriter &rewriter) const override {
312397 // Lower the trivial case.
313- if (adaptor.getMemrefs ().empty ())
314- return rewriter.eraseOp (op), success ();
398+ if (adaptor.getMemrefs ().empty ()) {
399+ Value falseVal = rewriter.create <arith::ConstantOp>(
400+ op.getLoc (), rewriter.getBoolAttr (false ));
401+ rewriter.replaceOp (
402+ op, SmallVector<Value>(adaptor.getRetained ().size (), falseVal));
403+ return success ();
404+ }
315405
316406 if (adaptor.getMemrefs ().size () == 1 && adaptor.getRetained ().empty ())
317407 return rewriteOneMemrefNoRetainCase (op, adaptor, rewriter);
318408
409+ if (adaptor.getMemrefs ().size () == 1 )
410+ return rewriteOneMemrefMultipleRetainCase (op, adaptor, rewriter);
411+
319412 return rewriteGeneralCase (op, adaptor, rewriter);
320413 }
321414
@@ -535,8 +628,7 @@ struct BufferizationToMemRefPass
535628 // Build dealloc helper function if there are deallocs.
536629 func::FuncOp helperFuncOp;
537630 getOperation ()->walk ([&](bufferization::DeallocOp deallocOp) {
538- if (deallocOp.getMemrefs ().size () > 1 ||
539- !deallocOp.getRetained ().empty ()) {
631+ if (deallocOp.getMemrefs ().size () > 1 ) {
540632 helperFuncOp = DeallocOpConversion::buildDeallocationHelperFunction (
541633 builder, getOperation ()->getLoc (), symbolTable);
542634 return WalkResult::interrupt ();
0 commit comments