Skip to content

Commit 75d6529

Browse files
[mlir][linalg][bufferize][NFC] Clean up comments and minor code refactorings
Differential Revision: https://reviews.llvm.org/D116451
1 parent 635f8f3 commit 75d6529

12 files changed

+295
-225
lines changed

mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h

Lines changed: 12 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -64,14 +64,14 @@ struct AllocationCallbacks {
6464
std::unique_ptr<AllocationCallbacks> defaultAllocationCallbacks();
6565

6666
/// PostAnalysisSteps can be registered with `BufferizationOptions` and are
67-
/// executed after the analysis, but before bufferization. They can be used
67+
/// executed after the analysis, but before bufferization. They can be used to
6868
/// implement custom dialect-specific optimizations.
6969
struct PostAnalysisStep {
7070
virtual ~PostAnalysisStep() {}
7171

7272
/// Run the post analysis step. This function may modify the IR, but must keep
73-
/// `aliasInfo` (inside `state`) consistent. Newly created operations and
74-
/// operations that should be re-analyzed must be stored in `newOps`.
73+
/// `aliasInfo` consistent. Newly created operations and operations that
74+
/// should be re-analyzed must be added to `newOps`.
7575
virtual LogicalResult run(Operation *op, BufferizationState &state,
7676
BufferizationAliasInfo &aliasInfo,
7777
SmallVector<Operation *> &newOps) = 0;
@@ -102,7 +102,8 @@ struct BufferizationOptions {
102102
}
103103

104104
/// Allow-list the given dialects in the dialect filter. Only ops from
105-
/// allow-listed dialects will be bufferized.
105+
/// allow-listed dialects will be bufferized. If no dialect is added, ops from
106+
/// any dialect will be bufferized.
106107
template <typename... DialectTs>
107108
void addToDialectFilter() {
108109
// The following expands a call to addToDialectFilterImpl for each dialect
@@ -288,17 +289,7 @@ struct DialectBufferizationState {
288289
};
289290

290291
/// BufferizationState provides a variety of helper functions for dealing with
291-
/// tensor values and memref buffers. In particular,
292-
/// `BufferizableOpInterface::bufferize` implementation should utilize the
293-
/// following helper functions.
294-
///
295-
/// * `createAlloc` / `createDealloc` / `createAllocDeallocPair` creates ops
296-
/// that allocate and/or deallocate memref buffers.
297-
/// * `lookupBuffer` returns the memref buffer of a given tensor value.
298-
/// * `getResultBuffer` returns the memref buffer for a given tensor OpResult.
299-
/// Based on inplace bufferization decisions of the analysis, it may either
300-
/// directly return a mapped buffer or allocate a new brand new buffer.
301-
/// * `replaceOp` replaces an op with new values.
292+
/// tensor values and memref buffers.
302293
class BufferizationState {
303294
public:
304295
BufferizationState(Operation *op, const BufferizationOptions &options);
@@ -396,7 +387,8 @@ class BufferizationState {
396387
/// Return the result buffer (memref) for a given OpResult (tensor). Allocate
397388
/// a new buffer and copy over data from the existing buffer if out-of-place
398389
/// bufferization is necessary.
399-
Value getResultBuffer(RewriterBase &rewriter, OpResult result) const;
390+
FailureOr<Value> getResultBuffer(RewriterBase &rewriter,
391+
OpResult result) const;
400392

401393
/// Return dialect-specific bufferization state.
402394
template <typename StateT>
@@ -455,12 +447,9 @@ MemRefType getContiguousMemRefType(ShapedType shapedType,
455447
MemRefLayoutAttrInterface layout = {},
456448
Attribute memorySpace = {});
457449

458-
/// Return a contiguous MemRefType (i.e. with canonical/empty layout map)
459-
/// with the same shape as `shapedType` and specified `layout` and
460-
/// `addressSpace` or an UnrankedMemRefType otherwise.
461-
Type getContiguousOrUnrankedMemRefType(Type type,
462-
MemRefLayoutAttrInterface layout = {},
463-
Attribute memorySpace = {});
450+
/// Return an UnrankedMemRefType with the given element type and memory space.
451+
UnrankedMemRefType getUnrankedMemRefType(Type elementType,
452+
Attribute memorySpace = {});
464453

465454
/// Return a MemRefType to which the `tensorType` can be bufferized in a
466455
/// composable fashion. The layout must be the most dynamic possible and
@@ -493,7 +482,7 @@ struct AllocationHoistingBarrierOnly
493482

494483
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
495484
const BufferizationState &state) const {
496-
return false;
485+
return true;
497486
}
498487

499488
SmallVector<OpOperand *>

mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,12 @@ class BufferizationAliasInfo;
2323
namespace linalg_ext {
2424

2525
struct InitTensorEliminationStep : public PostAnalysisStep {
26+
/// A function that matches anchor OpOperands for InitTensorOp elimination.
27+
using AnchorMatchFn = std::function<bool(OpOperand &)>;
28+
29+
/// A function that rewrites matched anchors.
30+
using RewriteFn = std::function<Value(OpBuilder &, Location, OpOperand &)>;
31+
2632
/// Try to eliminate InitTensorOps inside `op`.
2733
///
2834
/// * `rewriteFunc` generates the replacement for the InitTensorOp.
@@ -33,12 +39,11 @@ struct InitTensorEliminationStep : public PostAnalysisStep {
3339
/// InitTensorOp.
3440
/// * The result of `rewriteFunc` must usually be analyzed for inplacability.
3541
/// This analysis can be skipped with `skipAnalysis`.
36-
LogicalResult eliminateInitTensors(
37-
Operation *op, BufferizationState &state,
38-
BufferizationAliasInfo &aliasInfo,
39-
std::function<bool(OpOperand &)> anchorMatchFunc,
40-
std::function<Value(OpBuilder &, Location, OpOperand &)> rewriteFunc,
41-
SmallVector<Operation *> &newOps);
42+
LogicalResult eliminateInitTensors(Operation *op, BufferizationState &state,
43+
BufferizationAliasInfo &aliasInfo,
44+
AnchorMatchFn anchorMatchFunc,
45+
RewriteFn rewriteFunc,
46+
SmallVector<Operation *> &newOps);
4247
};
4348

4449
/// Try to eliminate InitTensorOps inside `op` that are anchored on an

mlir/lib/Dialect/Linalg/ComprehensiveBufferize/AffineInterfaceImpl.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313

1414
void mlir::linalg::comprehensive_bufferize::affine_ext::
1515
registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry) {
16+
// AffineParallelOp bufferization not implemented yet. However, never hoist
17+
// memref allocations across AffineParallelOp boundaries.
1618
registry.addOpInterface<AffineParallelOp,
1719
AllocationHoistingBarrierOnly<AffineParallelOp>>();
1820
}

mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,23 +20,30 @@ namespace linalg {
2020
namespace comprehensive_bufferize {
2121
namespace arith_ext {
2222

23+
/// Bufferization of arith.constant. Replace with memref.get_global.
2324
struct ConstantOpInterface
2425
: public BufferizableOpInterface::ExternalModel<ConstantOpInterface,
2526
arith::ConstantOp> {
2627
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
2728
const BufferizationState &state) const {
2829
auto constantOp = cast<arith::ConstantOp>(op);
29-
assert(constantOp.getType().dyn_cast<RankedTensorType>() &&
30-
"not a constant ranked tensor");
30+
31+
// Only ranked tensors are supported.
32+
if (!constantOp.getType().isa<RankedTensorType>())
33+
return failure();
34+
35+
// Only constants inside a module are supported.
3136
auto moduleOp = constantOp->getParentOfType<ModuleOp>();
3237
if (!moduleOp)
33-
return constantOp.emitError(
34-
"cannot bufferize constants not within builtin.module op");
38+
return failure();
3539

40+
// Create global memory segment and replace tensor with memref pointing to
41+
// that memory segment.
3642
GlobalCreator globalCreator(moduleOp);
3743
auto globalMemref = globalCreator.getGlobalFor(constantOp);
3844
replaceOpWithNewBufferizedOp<memref::GetGlobalOp>(
3945
rewriter, op, globalMemref.type(), globalMemref.getName());
46+
4047
return success();
4148
}
4249

mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp

Lines changed: 30 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,21 @@ mlir::linalg::comprehensive_bufferize::defaultAllocationCallbacks() {
7474
BufferizationOptions::BufferizationOptions()
7575
: allocationFns(defaultAllocationCallbacks()) {}
7676

77+
BufferizableOpInterface mlir::linalg::comprehensive_bufferize::
78+
BufferizationOptions::dynCastBufferizableOp(Operation *op) const {
79+
if (isOpAllowed(op))
80+
return dyn_cast<BufferizableOpInterface>(op);
81+
return nullptr;
82+
}
83+
84+
BufferizableOpInterface mlir::linalg::comprehensive_bufferize::
85+
BufferizationOptions::dynCastBufferizableOp(Value value) const {
86+
if (auto bufferizableOp = value.getDefiningOp<BufferizableOpInterface>())
87+
if (isOpAllowed(bufferizableOp.getOperation()))
88+
return bufferizableOp;
89+
return nullptr;
90+
}
91+
7792
//===----------------------------------------------------------------------===//
7893
// BufferizationAliasInfo
7994
//===----------------------------------------------------------------------===//
@@ -180,21 +195,6 @@ static void setInsertionPointAfter(OpBuilder &b, Value value) {
180195
}
181196
}
182197

183-
BufferizableOpInterface mlir::linalg::comprehensive_bufferize::
184-
BufferizationOptions::dynCastBufferizableOp(Operation *op) const {
185-
if (isOpAllowed(op))
186-
return dyn_cast<BufferizableOpInterface>(op);
187-
return nullptr;
188-
}
189-
190-
BufferizableOpInterface mlir::linalg::comprehensive_bufferize::
191-
BufferizationOptions::dynCastBufferizableOp(Value value) const {
192-
if (auto bufferizableOp = value.getDefiningOp<BufferizableOpInterface>())
193-
if (isOpAllowed(bufferizableOp.getOperation()))
194-
return bufferizableOp;
195-
return nullptr;
196-
}
197-
198198
/// Determine which OpOperand* will alias with `result` if the op is bufferized
199199
/// in place. Return an empty vector if the op is not bufferizable.
200200
SmallVector<OpOperand *>
@@ -358,8 +358,9 @@ mlir::linalg::comprehensive_bufferize::BufferizationState::BufferizationState(
358358
/// Return the result buffer (memref) for a given OpResult (tensor). Allocate
359359
/// a new buffer and copy over data from the existing buffer if out-of-place
360360
/// bufferization is necessary.
361-
Value mlir::linalg::comprehensive_bufferize::BufferizationState::
362-
getResultBuffer(RewriterBase &rewriter, OpResult result) const {
361+
FailureOr<Value>
362+
mlir::linalg::comprehensive_bufferize::BufferizationState::getResultBuffer(
363+
RewriterBase &rewriter, OpResult result) const {
363364
OpBuilder::InsertionGuard guard(rewriter);
364365
Operation *op = result.getOwner();
365366
SmallVector<OpOperand *> aliasingOperands = getAliasingOpOperand(result);
@@ -375,10 +376,8 @@ Value mlir::linalg::comprehensive_bufferize::BufferizationState::
375376
if (aliasingOperands.size() > 1 &&
376377
!llvm::all_of(aliasingOperands, [&](OpOperand *o) {
377378
return lookupBuffer(rewriter, o->get()) == operandBuffer;
378-
})) {
379-
op->emitError("result buffer is ambiguous");
380-
return Value();
381-
}
379+
}))
380+
return FailureOr<Value>(op->emitError("result buffer is ambiguous"));
382381

383382
// If bufferizing out-of-place, allocate a new buffer.
384383
if (!aliasInfo.isInPlace(result)) {
@@ -610,10 +609,13 @@ Value mlir::linalg::comprehensive_bufferize::BufferizationState::lookupBuffer(
610609
// Insert to_memref op.
611610
OpBuilder::InsertionGuard g(rewriter);
612611
setInsertionPointAfter(rewriter, tensor);
613-
Type memrefType =
614-
tensor.getType().isa<RankedTensorType>()
615-
? getDynamicMemRefType(tensor.getType().cast<RankedTensorType>())
616-
: getContiguousOrUnrankedMemRefType(tensor.getType());
612+
Type memrefType;
613+
if (auto rankedTensorType = tensor.getType().dyn_cast<RankedTensorType>()) {
614+
memrefType = getDynamicMemRefType(rankedTensorType);
615+
} else {
616+
memrefType = getUnrankedMemRefType(
617+
tensor.getType().cast<TensorType>().getElementType());
618+
}
617619
return rewriter.create<bufferization::ToMemrefOp>(tensor.getLoc(), memrefType,
618620
tensor);
619621
}
@@ -630,13 +632,9 @@ MemRefType mlir::linalg::comprehensive_bufferize::getContiguousMemRefType(
630632
layout, memorySpace);
631633
}
632634

633-
Type mlir::linalg::comprehensive_bufferize::getContiguousOrUnrankedMemRefType(
634-
Type type, MemRefLayoutAttrInterface layout, Attribute memorySpace) {
635-
if (type.isa<RankedTensorType, MemRefType>())
636-
return getContiguousMemRefType(type.cast<ShapedType>(), layout,
637-
memorySpace);
638-
assert(!layout && "expected empty layout with UnrankedMemRefType");
639-
return UnrankedMemRefType::get(getElementTypeOrSelf(type), memorySpace);
635+
UnrankedMemRefType mlir::linalg::comprehensive_bufferize::getUnrankedMemRefType(
636+
Type elementType, Attribute memorySpace) {
637+
return UnrankedMemRefType::get(elementType, memorySpace);
640638
}
641639

642640
MemRefType mlir::linalg::comprehensive_bufferize::getDynamicMemRefType(

mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ namespace bufferization_ext {
2525
// TODO: These ops should implement BufferizableOpInterface directly when moved
2626
// to the Bufferization dialect.
2727

28+
/// Bufferization of bufferization.to_memref. to_memref(to_tensor(x)) is folded
29+
/// to x. Other to_memref ops are ignored during bufferization.
30+
///
2831
/// ToMemrefOp casts a tensor into a memref. The resulting memref is the memory
2932
/// location of the incoming tensor once it will be bufferized. In the anlysis,
3033
/// the incoming tensor is assumed to bufferize to a memory read and to an
@@ -41,7 +44,7 @@ struct ToMemrefOpInterface
4144
bufferization::ToMemrefOp> {
4245
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
4346
const BufferizationState &state) const {
44-
// It is unknown whether the resulting MemRef will be read or not.
47+
// It is unknown whether the resulting memref will be read or not.
4548
return true;
4649
}
4750

@@ -58,27 +61,33 @@ struct ToMemrefOpInterface
5861
if (auto toTensorOp =
5962
toMemrefOp.tensor().getDefiningOp<bufferization::ToTensorOp>()) {
6063
Value buffer = toTensorOp.memref();
64+
65+
// Insert cast in case to_memref(to_tensor(x))'s type is different from
66+
// x's type.
6167
if (toTensorOp.memref().getType() != toMemrefOp.getType())
6268
buffer = rewriter.create<memref::CastOp>(toMemrefOp.getLoc(), buffer,
6369
toMemrefOp.getType());
64-
rewriter.replaceOp(toMemrefOp, buffer);
70+
replaceOpWithBufferizedValues(rewriter, toMemrefOp, buffer);
6571
return success();
6672
}
6773

6874
return failure();
6975
}
7076
};
7177

72-
/// ToTensorOp conceptually loads a tensor from a memory location. Such ops do
73-
/// not lower any further, and they should have disappeared by the time the
74-
/// input is fully bufferized.
78+
/// Bufferization of bufferization.to_tensor. Such ops cannot be bufferized.
79+
/// However, other ops that are using to_tensor's result will eventually be
80+
/// bufferized. At that point, they will start using to_tensor's memref operand.
81+
/// Once all users of to_tensor are bufferized, the op will not have any users
82+
/// anymore and DCE away.
7583
///
76-
/// The analysis has no information about the memref that is loaded from by the
77-
/// ToTensorOp. We have to assume that the loaded tensor may after bufferization
78-
/// potentially alias with any other bufferized tensor. Since ToTensorOp and
79-
/// ToMemrefOp have no aliasing OpOperand/OpResult pairs, this cannot be encoded
80-
/// directly in the analysis. However, declaring ToTensorOp results as not
81-
/// writable also enforces a buffer copy and has the same effect.
84+
/// ToTensorOp conceptually loads a tensor from a memory location. The analysis
85+
/// has no information about the memref that is loaded from by ToTensorOp. We
86+
/// have to assume that the loaded tensor may after bufferization potentially
87+
/// alias with any other bufferized tensor. Since ToTensorOp and ToMemrefOp have
88+
/// no aliasing OpOperand/OpResult pairs, this cannot be encoded directly in the
89+
/// analysis. However, declaring ToTensorOp results as not writable enforces a
90+
/// buffer copy and has the same effect.
8291
struct ToTensorOpInterface
8392
: public BufferizableOpInterface::ExternalModel<ToTensorOpInterface,
8493
bufferization::ToTensorOp> {
@@ -89,7 +98,7 @@ struct ToTensorOpInterface
8998

9099
bool isWritable(Operation *op, Value value,
91100
const BufferizationState &state) const {
92-
// It is unknown whether the MemRef operand is writable or not.
101+
// It is unknown whether the memref operand is writable or not.
93102
return false;
94103
}
95104
};

0 commit comments

Comments
 (0)