Skip to content

Commit 61d5fdf

Browse files
authored
[MLIR] Add bufferization state class to OneShotBufferization pass (#141019)
Follow-up on #138143, which was reverted due to a missing update a method signature (more specifically, the bufferization interface for `tensor::ConcatOp`) that was not catched before merging. The old PR description is reported in the next lines. This PR is a follow-up on #138125, and adds a bufferization state class providing information about the IR. The information currently consists of a cached list of symbol tables, which aims to solve the quadratic scaling of the bufferization task with respect to the number of symbols. The PR breaks API compatibility: the bufferize method of the BufferizableOpInterface has been enriched with a reference to a BufferizationState object. The bufferization state must be kept in a valid state by the interface implementations. For example, if an operation with the Symbol trait is inserted or replaced, its parent SymbolTable must be updated accordingly (see, for example, the bufferization of arith::ConstantOp, where the symbol table of the module gets the new global symbol inserted). Similarly, the invalidation of a symbol table must be performed if an operation with the SymbolTable trait is removed (this can be performed using the invalidateSymbolTable method, introduced in #138014).
1 parent 3d02834 commit 61d5fdf

27 files changed

+215
-87
lines changed

mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -578,6 +578,20 @@ class AnalysisState {
578578
insideMutuallyExclusiveRegionsCache;
579579
};
580580

581+
/// BufferizationState provides information about the state of the IR during the
582+
/// bufferization process.
583+
class BufferizationState {
584+
public:
585+
/// Get a reference to the collection of cached symbol tables.
586+
SymbolTableCollection &getSymbolTables();
587+
588+
private:
589+
/// The cached symbol tables.
590+
/// The user is expected to update / invalidate the cached symbol tables if
591+
/// the bufferized operation has the Symbol or SymbolTable traits.
592+
SymbolTableCollection symbolTables;
593+
};
594+
581595
/// Create an AllocTensorOp for the given shaped value (memref or tensor).
582596
/// If `copy` is set, the shaped value is copied. Otherwise, a tensor with
583597
/// undefined contents is allocated.

mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,8 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
426426
/*retType=*/"::llvm::LogicalResult",
427427
/*methodName=*/"bufferize",
428428
/*args=*/(ins "::mlir::RewriterBase &":$rewriter,
429-
"const ::mlir::bufferization::BufferizationOptions &":$options),
429+
"const ::mlir::bufferization::BufferizationOptions &":$options,
430+
"::mlir::bufferization::BufferizationState &":$state),
430431
/*methodBody=*/"",
431432
/*defaultImplementation=*/[{
432433
llvm_unreachable("bufferize not implemented");

mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,8 @@ def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor",
9393

9494
let extraClassDeclaration = [{
9595
LogicalResult bufferize(RewriterBase &rewriter,
96-
const BufferizationOptions &options);
96+
const BufferizationOptions &options,
97+
BufferizationState &state);
9798

9899
bool resultBufferizesToMemoryWrite(OpResult opResult,
99100
const AnalysisState &state);
@@ -282,7 +283,8 @@ def Bufferization_MaterializeInDestinationOp
282283

283284
let extraClassDeclaration = [{
284285
LogicalResult bufferize(RewriterBase &rewriter,
285-
const BufferizationOptions &options);
286+
const BufferizationOptions &options,
287+
BufferizationState &state);
286288

287289
bool bufferizesToMemoryRead(OpOperand &opOperand,
288290
const AnalysisState &state);
@@ -375,7 +377,8 @@ def Bufferization_DeallocTensorOp : Bufferization_Op<"dealloc_tensor",
375377
}
376378

377379
LogicalResult bufferize(RewriterBase &rewriter,
378-
const BufferizationOptions &options);
380+
const BufferizationOptions &options,
381+
BufferizationState &state);
379382
}];
380383
}
381384

@@ -458,7 +461,8 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
458461
//===------------------------------------------------------------------===//
459462

460463
LogicalResult bufferize(RewriterBase &rewriter,
461-
const BufferizationOptions &options) const {
464+
const BufferizationOptions &options,
465+
BufferizationState &state) const {
462466
// to_tensor/to_buffer pairs fold away after bufferization.
463467
return success();
464468
}
@@ -550,7 +554,8 @@ def Bufferization_ToBufferOp : Bufferization_Op<"to_buffer", [
550554
}
551555

552556
LogicalResult bufferize(RewriterBase &rewriter,
553-
const BufferizationOptions &options);
557+
const BufferizationOptions &options,
558+
BufferizationState &state);
554559
}];
555560

556561
let assemblyFormat = [{

mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ class GlobalOp;
2929
} // namespace memref
3030

3131
namespace bufferization {
32+
class BufferizationState;
3233

3334
/// A simple analysis that detects allocation operations.
3435
class BufferPlacementAllocs {
@@ -122,9 +123,14 @@ class BufferPlacementTransformationBase {
122123
// Globals are created lazily at the top of the enclosing ModuleOp with pretty
123124
// names. Duplicates are avoided.
124125
FailureOr<memref::GlobalOp> getGlobalFor(arith::ConstantOp constantOp,
126+
SymbolTableCollection &symbolTables,
125127
uint64_t alignment,
126128
Attribute memorySpace = {});
127129

130+
void removeSymbol(Operation *op, BufferizationState &state);
131+
132+
void insertSymbol(Operation *op, BufferizationState &state);
133+
128134
} // namespace bufferization
129135
} // namespace mlir
130136

mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ struct BufferizationStatistics {
4545
/// additional buffer copies or set "options.copyBeforeWrite = true". The
4646
/// general bufferization entry point is `runOneShotBufferize`.
4747
LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options,
48+
BufferizationState &bufferizationState,
4849
BufferizationStatistics *statistics = nullptr);
4950

5051
/// Bufferize the signature of `block` and its callers (i.e., ops that have the

mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,7 @@ LogicalResult analyzeOp(Operation *op, OneShotAnalysisState &state,
270270
/// Run One-Shot Bufferize on the given op: Analysis + Bufferization
271271
LogicalResult
272272
runOneShotBufferize(Operation *op, const OneShotBufferizationOptions &options,
273+
BufferizationState &state,
273274
BufferizationStatistics *statistics = nullptr);
274275

275276
} // namespace bufferization

mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ namespace bufferization {
2020
struct BufferizationStatistics;
2121
class OneShotAnalysisState;
2222
struct OneShotBufferizationOptions;
23+
class BufferizationState;
2324

2425
/// Analyze `moduleOp` and its nested ops. Bufferization decisions are stored in
2526
/// `state`.
@@ -38,6 +39,7 @@ analyzeModuleOp(ModuleOp moduleOp, OneShotAnalysisState &state,
3839
/// will be inserted only to these FuncOps.
3940
llvm::LogicalResult
4041
bufferizeModuleOp(ModuleOp moduleOp, const OneShotBufferizationOptions &options,
42+
BufferizationState &state,
4143
BufferizationStatistics *statistics = nullptr);
4244

4345
/// Remove bufferization attributes on every FuncOp arguments in the ModuleOp.
@@ -50,7 +52,7 @@ void removeBufferizationAttributesInModule(ModuleOp moduleOp);
5052
llvm::LogicalResult runOneShotModuleBufferize(
5153
ModuleOp moduleOp,
5254
const bufferization::OneShotBufferizationOptions &options,
53-
BufferizationStatistics *statistics = nullptr);
55+
BufferizationState &state, BufferizationStatistics *statistics = nullptr);
5456

5557
} // namespace bufferization
5658
} // namespace mlir

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ namespace mlir {
3030
namespace bufferization {
3131
class AllocTensorOp;
3232
class OneShotAnalysisState;
33+
class BufferizationState;
3334
} // namespace bufferization
3435

3536
namespace linalg {

mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ struct ConstantOpInterface
2424
: public BufferizableOpInterface::ExternalModel<ConstantOpInterface,
2525
arith::ConstantOp> {
2626
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
27-
const BufferizationOptions &options) const {
27+
const BufferizationOptions &options,
28+
BufferizationState &state) const {
2829
auto constantOp = cast<arith::ConstantOp>(op);
2930
auto type = dyn_cast<RankedTensorType>(constantOp.getType());
3031

@@ -46,7 +47,8 @@ struct ConstantOpInterface
4647
// Create global memory segment and replace tensor with memref pointing to
4748
// that memory segment.
4849
FailureOr<memref::GlobalOp> globalOp =
49-
getGlobalFor(constantOp, options.bufferAlignment, memorySpace);
50+
getGlobalFor(constantOp, state.getSymbolTables(),
51+
options.bufferAlignment, memorySpace);
5052
if (failed(globalOp))
5153
return failure();
5254
memref::GlobalOp globalMemref = *globalOp;
@@ -83,7 +85,8 @@ struct IndexCastOpInterface
8385
}
8486

8587
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
86-
const BufferizationOptions &options) const {
88+
const BufferizationOptions &options,
89+
BufferizationState &state) const {
8790
auto castOp = cast<arith::IndexCastOp>(op);
8891
auto resultTensorType = cast<TensorType>(castOp.getType());
8992

@@ -131,7 +134,8 @@ struct SelectOpInterface
131134
}
132135

133136
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
134-
const BufferizationOptions &options) const {
137+
const BufferizationOptions &options,
138+
BufferizationState &state) const {
135139
auto selectOp = cast<arith::SelectOp>(op);
136140
Location loc = selectOp.getLoc();
137141

mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,10 @@ void AnalysisState::resetCache() {
125125
insideMutuallyExclusiveRegionsCache.clear();
126126
}
127127

128+
SymbolTableCollection &BufferizationState::getSymbolTables() {
129+
return symbolTables;
130+
}
131+
128132
Region *bufferization::getNextEnclosingRepetitiveRegion(
129133
Region *region, const BufferizationOptions &options) {
130134
assert(isRepetitiveRegion(region, options) && "expected repetitive region");

mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,8 @@ void mlir::bufferization::populateDynamicDimSizes(
149149
//===----------------------------------------------------------------------===//
150150

151151
LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter,
152-
const BufferizationOptions &options) {
152+
const BufferizationOptions &options,
153+
BufferizationState &state) {
153154
OpBuilder::InsertionGuard g(rewriter);
154155
Location loc = getLoc();
155156

@@ -529,7 +530,8 @@ void CloneOp::getCanonicalizationPatterns(RewritePatternSet &results,
529530
//===----------------------------------------------------------------------===//
530531

531532
LogicalResult DeallocTensorOp::bufferize(RewriterBase &rewriter,
532-
const BufferizationOptions &options) {
533+
const BufferizationOptions &options,
534+
BufferizationState &state) {
533535
FailureOr<Value> buffer = getBuffer(rewriter, getTensor(), options);
534536
if (failed(buffer))
535537
return failure();
@@ -576,7 +578,8 @@ MaterializeInDestinationOp::getAliasingValues(OpOperand &opOperand,
576578

577579
LogicalResult
578580
MaterializeInDestinationOp::bufferize(RewriterBase &rewriter,
579-
const BufferizationOptions &options) {
581+
const BufferizationOptions &options,
582+
BufferizationState &state) {
580583
bool tensorDest = isa<TensorType>(getDest().getType());
581584
Value buffer;
582585
if (tensorDest) {
@@ -861,7 +864,8 @@ void ToBufferOp::getCanonicalizationPatterns(RewritePatternSet &results,
861864
}
862865

863866
LogicalResult ToBufferOp::bufferize(RewriterBase &rewriter,
864-
const BufferizationOptions &options) {
867+
const BufferizationOptions &options,
868+
BufferizationState &state) {
865869
// Fold to_buffer(to_tensor(x)) to x. Insert a cast if necessary.
866870
(void)foldToBufferToTensorPair(rewriter, *this, options);
867871
// Note: The return value of `bufferize` indicates whether there was an error

mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,17 +83,21 @@ transform::OneShotBufferizeOp::apply(transform::TransformRewriter &rewriter,
8383
}
8484

8585
auto payloadOps = state.getPayloadOps(getTarget());
86+
BufferizationState bufferizationState;
87+
8688
for (Operation *target : payloadOps) {
8789
if (!isa<ModuleOp, FunctionOpInterface>(target))
8890
return emitSilenceableError() << "expected module or function target";
8991
auto moduleOp = dyn_cast<ModuleOp>(target);
9092
if (options.bufferizeFunctionBoundaries) {
9193
if (!moduleOp)
9294
return emitSilenceableError() << "expected module target";
93-
if (failed(bufferization::runOneShotModuleBufferize(moduleOp, options)))
95+
if (failed(bufferization::runOneShotModuleBufferize(moduleOp, options,
96+
bufferizationState)))
9497
return emitSilenceableError() << "bufferization failed";
9598
} else {
96-
if (failed(bufferization::runOneShotBufferize(target, options)))
99+
if (failed(bufferization::runOneShotBufferize(target, options,
100+
bufferizationState)))
97101
return emitSilenceableError() << "bufferization failed";
98102
}
99103
}
@@ -162,6 +166,7 @@ class BufferizationTransformDialectExtension
162166
registerTransformOps<
163167
#define GET_OP_LIST
164168
#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp.inc"
169+
165170
>();
166171
}
167172
};

mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,9 @@ BufferPlacementTransformationBase::BufferPlacementTransformationBase(
103103
//===----------------------------------------------------------------------===//
104104

105105
FailureOr<memref::GlobalOp>
106-
bufferization::getGlobalFor(arith::ConstantOp constantOp, uint64_t alignment,
107-
Attribute memorySpace) {
106+
bufferization::getGlobalFor(arith::ConstantOp constantOp,
107+
SymbolTableCollection &symbolTables,
108+
uint64_t alignment, Attribute memorySpace) {
108109
auto type = cast<RankedTensorType>(constantOp.getType());
109110
auto moduleOp = constantOp->getParentOfType<ModuleOp>();
110111
if (!moduleOp)
@@ -127,7 +128,7 @@ bufferization::getGlobalFor(arith::ConstantOp constantOp, uint64_t alignment,
127128
// Create a builder without an insertion point. We will insert using the
128129
// symbol table to guarantee unique names.
129130
OpBuilder globalBuilder(moduleOp.getContext());
130-
SymbolTable symbolTable(moduleOp);
131+
SymbolTable &symbolTable = symbolTables.getSymbolTable(moduleOp);
131132

132133
// Create a pretty name.
133134
SmallString<64> buf;
@@ -158,3 +159,19 @@ bufferization::getGlobalFor(arith::ConstantOp constantOp, uint64_t alignment,
158159
global->moveBefore(&moduleOp.front());
159160
return global;
160161
}
162+
163+
namespace mlir::bufferization {
164+
void removeSymbol(Operation *op, BufferizationState &state) {
165+
SymbolTable &symbolTable = state.getSymbolTables().getSymbolTable(
166+
op->getParentWithTrait<OpTrait::SymbolTable>());
167+
168+
symbolTable.remove(op);
169+
}
170+
171+
void insertSymbol(Operation *op, BufferizationState &state) {
172+
SymbolTable &symbolTable = state.getSymbolTables().getSymbolTable(
173+
op->getParentWithTrait<OpTrait::SymbolTable>());
174+
175+
symbolTable.insert(op);
176+
}
177+
} // namespace mlir::bufferization

mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -161,10 +161,12 @@ struct OneShotBufferizePass
161161
return signalPassFailure();
162162
}
163163

164+
BufferizationState state;
164165
BufferizationStatistics statistics;
165166
ModuleOp moduleOp = getOperation();
166167
if (opt.bufferizeFunctionBoundaries) {
167-
if (failed(runOneShotModuleBufferize(moduleOp, opt, &statistics))) {
168+
if (failed(
169+
runOneShotModuleBufferize(moduleOp, opt, state, &statistics))) {
168170
signalPassFailure();
169171
return;
170172
}
@@ -175,7 +177,7 @@ struct OneShotBufferizePass
175177
"'bufferize-function-boundaries'");
176178
return signalPassFailure();
177179
}
178-
if (failed(runOneShotBufferize(moduleOp, opt, &statistics))) {
180+
if (failed(runOneShotBufferize(moduleOp, opt, state, &statistics))) {
179181
signalPassFailure();
180182
return;
181183
}
@@ -275,6 +277,7 @@ class BufferizationRewriter : public IRRewriter, public RewriterBase::Listener {
275277

276278
LogicalResult bufferization::bufferizeOp(Operation *op,
277279
const BufferizationOptions &options,
280+
BufferizationState &bufferizationState,
278281
BufferizationStatistics *statistics) {
279282
if (options.copyBeforeWrite) {
280283
AnalysisState state(options);
@@ -331,7 +334,8 @@ LogicalResult bufferization::bufferizeOp(Operation *op,
331334
<< "//===-------------------------------------------===//\n"
332335
<< "IR after bufferizing: " << nextOp->getName() << "\n");
333336
rewriter.setInsertionPoint(nextOp);
334-
if (failed(bufferizableOp.bufferize(rewriter, options))) {
337+
if (failed(
338+
bufferizableOp.bufferize(rewriter, options, bufferizationState))) {
335339
LLVM_DEBUG(llvm::dbgs()
336340
<< "failed to bufferize\n"
337341
<< "//===-------------------------------------------===//\n");

mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,8 @@ struct CallOpInterface
239239
/// All function arguments are writable. It is the responsibility of the
240240
/// CallOp to insert buffer copies where necessary.
241241
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
242-
const BufferizationOptions &options) const {
242+
const BufferizationOptions &options,
243+
BufferizationState &state) const {
243244
func::CallOp callOp = cast<func::CallOp>(op);
244245

245246
// 1. Compute the result types of the new CallOp.
@@ -349,7 +350,8 @@ struct ReturnOpInterface
349350
}
350351

351352
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
352-
const BufferizationOptions &options) const {
353+
const BufferizationOptions &options,
354+
BufferizationState &state) const {
353355
#ifndef NDEBUG
354356
auto returnOp = cast<func::ReturnOp>(op);
355357
assert(isa<FuncOp>(returnOp->getParentOp()) &&
@@ -418,7 +420,8 @@ struct FuncOpInterface
418420
/// All function bbArgs are writable unless they are explicitly marked as
419421
/// read-only. Callers must insert copies when needed.
420422
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
421-
const BufferizationOptions &options) const {
423+
const BufferizationOptions &options,
424+
BufferizationState &state) const {
422425
auto funcOp = cast<FuncOp>(op);
423426
FunctionType funcType = funcOp.getFunctionType();
424427

0 commit comments

Comments
 (0)