Skip to content

Commit

Permalink
[mlir][linalg][bufferize][NFC] Clean up comments and minor code refac…
Browse files Browse the repository at this point in the history
…torings

Differential Revision: https://reviews.llvm.org/D116451
  • Loading branch information
matthias-springer committed Jan 6, 2022
1 parent 635f8f3 commit 75d6529
Show file tree
Hide file tree
Showing 12 changed files with 295 additions and 225 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,14 @@ struct AllocationCallbacks {
std::unique_ptr<AllocationCallbacks> defaultAllocationCallbacks();

/// PostAnalysisSteps can be registered with `BufferizationOptions` and are
/// executed after the analysis, but before bufferization. They can be used
/// executed after the analysis, but before bufferization. They can be used to
/// implement custom dialect-specific optimizations.
struct PostAnalysisStep {
virtual ~PostAnalysisStep() {}

/// Run the post analysis step. This function may modify the IR, but must keep
/// `aliasInfo` (inside `state`) consistent. Newly created operations and
/// operations that should be re-analyzed must be stored in `newOps`.
/// `aliasInfo` consistent. Newly created operations and operations that
/// should be re-analyzed must be added to `newOps`.
virtual LogicalResult run(Operation *op, BufferizationState &state,
BufferizationAliasInfo &aliasInfo,
SmallVector<Operation *> &newOps) = 0;
Expand Down Expand Up @@ -102,7 +102,8 @@ struct BufferizationOptions {
}

/// Allow-list the given dialects in the dialect filter. Only ops from
/// allow-listed dialects will be bufferized.
/// allow-listed dialects will be bufferized. If no dialect is added, ops from
/// any dialect will be bufferized.
template <typename... DialectTs>
void addToDialectFilter() {
// The following expands a call to addToDialectFilterImpl for each dialect
Expand Down Expand Up @@ -288,17 +289,7 @@ struct DialectBufferizationState {
};

/// BufferizationState provides a variety of helper functions for dealing with
/// tensor values and memref buffers. In particular,
/// `BufferizableOpInterface::bufferize` implementation should utilize the
/// following helper functions.
///
/// * `createAlloc` / `createDealloc` / `createAllocDeallocPair` creates ops
/// that allocate and/or deallocate memref buffers.
/// * `lookupBuffer` returns the memref buffer of a given tensor value.
/// * `getResultBuffer` returns the memref buffer for a given tensor OpResult.
/// Based on inplace bufferization decisions of the analysis, it may either
/// directly return a mapped buffer or allocate a new brand new buffer.
/// * `replaceOp` replaces an op with new values.
/// tensor values and memref buffers.
class BufferizationState {
public:
BufferizationState(Operation *op, const BufferizationOptions &options);
Expand Down Expand Up @@ -396,7 +387,8 @@ class BufferizationState {
/// Return the result buffer (memref) for a given OpResult (tensor). Allocate
/// a new buffer and copy over data from the existing buffer if out-of-place
/// bufferization is necessary.
Value getResultBuffer(RewriterBase &rewriter, OpResult result) const;
FailureOr<Value> getResultBuffer(RewriterBase &rewriter,
OpResult result) const;

/// Return dialect-specific bufferization state.
template <typename StateT>
Expand Down Expand Up @@ -455,12 +447,9 @@ MemRefType getContiguousMemRefType(ShapedType shapedType,
MemRefLayoutAttrInterface layout = {},
Attribute memorySpace = {});

/// Return a contiguous MemRefType (i.e. with canonical/empty layout map)
/// with the same shape as `shapedType` and specified `layout` and
/// `addressSpace` or an UnrankedMemRefType otherwise.
Type getContiguousOrUnrankedMemRefType(Type type,
MemRefLayoutAttrInterface layout = {},
Attribute memorySpace = {});
/// Return an UnrankedMemRefType with the given element type and memory space.
UnrankedMemRefType getUnrankedMemRefType(Type elementType,
Attribute memorySpace = {});

/// Return a MemRefType to which the `tensorType` can be bufferized in a
/// composable fashion. The layout must be the most dynamic possible and
Expand Down Expand Up @@ -493,7 +482,7 @@ struct AllocationHoistingBarrierOnly

bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
const BufferizationState &state) const {
return false;
return true;
}

SmallVector<OpOperand *>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ class BufferizationAliasInfo;
namespace linalg_ext {

struct InitTensorEliminationStep : public PostAnalysisStep {
/// A function that matches anchor OpOperands for InitTensorOp elimination.
using AnchorMatchFn = std::function<bool(OpOperand &)>;

/// A function that rewrites matched anchors.
using RewriteFn = std::function<Value(OpBuilder &, Location, OpOperand &)>;

/// Try to eliminate InitTensorOps inside `op`.
///
/// * `rewriteFunc` generates the replacement for the InitTensorOp.
Expand All @@ -33,12 +39,11 @@ struct InitTensorEliminationStep : public PostAnalysisStep {
/// InitTensorOp.
/// * The result of `rewriteFunc` must usually be analyzed for inplacability.
/// This analysis can be skipped with `skipAnalysis`.
LogicalResult eliminateInitTensors(
Operation *op, BufferizationState &state,
BufferizationAliasInfo &aliasInfo,
std::function<bool(OpOperand &)> anchorMatchFunc,
std::function<Value(OpBuilder &, Location, OpOperand &)> rewriteFunc,
SmallVector<Operation *> &newOps);
LogicalResult eliminateInitTensors(Operation *op, BufferizationState &state,
BufferizationAliasInfo &aliasInfo,
AnchorMatchFn anchorMatchFunc,
RewriteFn rewriteFunc,
SmallVector<Operation *> &newOps);
};

/// Try to eliminate InitTensorOps inside `op` that are anchored on an
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@

void mlir::linalg::comprehensive_bufferize::affine_ext::
registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry) {
// AffineParallelOp bufferization not implemented yet. However, never hoist
// memref allocations across AffineParallelOp boundaries.
registry.addOpInterface<AffineParallelOp,
AllocationHoistingBarrierOnly<AffineParallelOp>>();
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,30 @@ namespace linalg {
namespace comprehensive_bufferize {
namespace arith_ext {

/// Bufferization of arith.constant. Replace with memref.get_global.
struct ConstantOpInterface
: public BufferizableOpInterface::ExternalModel<ConstantOpInterface,
arith::ConstantOp> {
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationState &state) const {
auto constantOp = cast<arith::ConstantOp>(op);
assert(constantOp.getType().dyn_cast<RankedTensorType>() &&
"not a constant ranked tensor");

// Only ranked tensors are supported.
if (!constantOp.getType().isa<RankedTensorType>())
return failure();

// Only constants inside a module are supported.
auto moduleOp = constantOp->getParentOfType<ModuleOp>();
if (!moduleOp)
return constantOp.emitError(
"cannot bufferize constants not within builtin.module op");
return failure();

// Create global memory segment and replace tensor with memref pointing to
// that memory segment.
GlobalCreator globalCreator(moduleOp);
auto globalMemref = globalCreator.getGlobalFor(constantOp);
replaceOpWithNewBufferizedOp<memref::GetGlobalOp>(
rewriter, op, globalMemref.type(), globalMemref.getName());

return success();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,21 @@ mlir::linalg::comprehensive_bufferize::defaultAllocationCallbacks() {
BufferizationOptions::BufferizationOptions()
: allocationFns(defaultAllocationCallbacks()) {}

BufferizableOpInterface mlir::linalg::comprehensive_bufferize::
BufferizationOptions::dynCastBufferizableOp(Operation *op) const {
if (isOpAllowed(op))
return dyn_cast<BufferizableOpInterface>(op);
return nullptr;
}

BufferizableOpInterface mlir::linalg::comprehensive_bufferize::
BufferizationOptions::dynCastBufferizableOp(Value value) const {
if (auto bufferizableOp = value.getDefiningOp<BufferizableOpInterface>())
if (isOpAllowed(bufferizableOp.getOperation()))
return bufferizableOp;
return nullptr;
}

//===----------------------------------------------------------------------===//
// BufferizationAliasInfo
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -180,21 +195,6 @@ static void setInsertionPointAfter(OpBuilder &b, Value value) {
}
}

BufferizableOpInterface mlir::linalg::comprehensive_bufferize::
BufferizationOptions::dynCastBufferizableOp(Operation *op) const {
if (isOpAllowed(op))
return dyn_cast<BufferizableOpInterface>(op);
return nullptr;
}

BufferizableOpInterface mlir::linalg::comprehensive_bufferize::
BufferizationOptions::dynCastBufferizableOp(Value value) const {
if (auto bufferizableOp = value.getDefiningOp<BufferizableOpInterface>())
if (isOpAllowed(bufferizableOp.getOperation()))
return bufferizableOp;
return nullptr;
}

/// Determine which OpOperand* will alias with `result` if the op is bufferized
/// in place. Return an empty vector if the op is not bufferizable.
SmallVector<OpOperand *>
Expand Down Expand Up @@ -358,8 +358,9 @@ mlir::linalg::comprehensive_bufferize::BufferizationState::BufferizationState(
/// Return the result buffer (memref) for a given OpResult (tensor). Allocate
/// a new buffer and copy over data from the existing buffer if out-of-place
/// bufferization is necessary.
Value mlir::linalg::comprehensive_bufferize::BufferizationState::
getResultBuffer(RewriterBase &rewriter, OpResult result) const {
FailureOr<Value>
mlir::linalg::comprehensive_bufferize::BufferizationState::getResultBuffer(
RewriterBase &rewriter, OpResult result) const {
OpBuilder::InsertionGuard guard(rewriter);
Operation *op = result.getOwner();
SmallVector<OpOperand *> aliasingOperands = getAliasingOpOperand(result);
Expand All @@ -375,10 +376,8 @@ Value mlir::linalg::comprehensive_bufferize::BufferizationState::
if (aliasingOperands.size() > 1 &&
!llvm::all_of(aliasingOperands, [&](OpOperand *o) {
return lookupBuffer(rewriter, o->get()) == operandBuffer;
})) {
op->emitError("result buffer is ambiguous");
return Value();
}
}))
return FailureOr<Value>(op->emitError("result buffer is ambiguous"));

// If bufferizing out-of-place, allocate a new buffer.
if (!aliasInfo.isInPlace(result)) {
Expand Down Expand Up @@ -610,10 +609,13 @@ Value mlir::linalg::comprehensive_bufferize::BufferizationState::lookupBuffer(
// Insert to_memref op.
OpBuilder::InsertionGuard g(rewriter);
setInsertionPointAfter(rewriter, tensor);
Type memrefType =
tensor.getType().isa<RankedTensorType>()
? getDynamicMemRefType(tensor.getType().cast<RankedTensorType>())
: getContiguousOrUnrankedMemRefType(tensor.getType());
Type memrefType;
if (auto rankedTensorType = tensor.getType().dyn_cast<RankedTensorType>()) {
memrefType = getDynamicMemRefType(rankedTensorType);
} else {
memrefType = getUnrankedMemRefType(
tensor.getType().cast<TensorType>().getElementType());
}
return rewriter.create<bufferization::ToMemrefOp>(tensor.getLoc(), memrefType,
tensor);
}
Expand All @@ -630,13 +632,9 @@ MemRefType mlir::linalg::comprehensive_bufferize::getContiguousMemRefType(
layout, memorySpace);
}

Type mlir::linalg::comprehensive_bufferize::getContiguousOrUnrankedMemRefType(
Type type, MemRefLayoutAttrInterface layout, Attribute memorySpace) {
if (type.isa<RankedTensorType, MemRefType>())
return getContiguousMemRefType(type.cast<ShapedType>(), layout,
memorySpace);
assert(!layout && "expected empty layout with UnrankedMemRefType");
return UnrankedMemRefType::get(getElementTypeOrSelf(type), memorySpace);
UnrankedMemRefType mlir::linalg::comprehensive_bufferize::getUnrankedMemRefType(
Type elementType, Attribute memorySpace) {
return UnrankedMemRefType::get(elementType, memorySpace);
}

MemRefType mlir::linalg::comprehensive_bufferize::getDynamicMemRefType(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ namespace bufferization_ext {
// TODO: These ops should implement BufferizableOpInterface directly when moved
// to the Bufferization dialect.

/// Bufferization of bufferization.to_memref. to_memref(to_tensor(x)) is folded
/// to x. Other to_memref ops are ignored during bufferization.
///
/// ToMemrefOp casts a tensor into a memref. The resulting memref is the memory
/// location of the incoming tensor once it will be bufferized. In the anlysis,
/// the incoming tensor is assumed to bufferize to a memory read and to an
Expand All @@ -41,7 +44,7 @@ struct ToMemrefOpInterface
bufferization::ToMemrefOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
const BufferizationState &state) const {
// It is unknown whether the resulting MemRef will be read or not.
// It is unknown whether the resulting memref will be read or not.
return true;
}

Expand All @@ -58,27 +61,33 @@ struct ToMemrefOpInterface
if (auto toTensorOp =
toMemrefOp.tensor().getDefiningOp<bufferization::ToTensorOp>()) {
Value buffer = toTensorOp.memref();

// Insert cast in case to_memref(to_tensor(x))'s type is different from
// x's type.
if (toTensorOp.memref().getType() != toMemrefOp.getType())
buffer = rewriter.create<memref::CastOp>(toMemrefOp.getLoc(), buffer,
toMemrefOp.getType());
rewriter.replaceOp(toMemrefOp, buffer);
replaceOpWithBufferizedValues(rewriter, toMemrefOp, buffer);
return success();
}

return failure();
}
};

/// ToTensorOp conceptually loads a tensor from a memory location. Such ops do
/// not lower any further, and they should have disappeared by the time the
/// input is fully bufferized.
/// Bufferization of bufferization.to_tensor. Such ops cannot be bufferized.
/// However, other ops that are using to_tensor's result will eventually be
/// bufferized. At that point, they will start using to_tensor's memref operand.
/// Once all users of to_tensor are bufferized, the op will not have any users
/// anymore and DCE away.
///
/// The analysis has no information about the memref that is loaded from by the
/// ToTensorOp. We have to assume that the loaded tensor may after bufferization
/// potentially alias with any other bufferized tensor. Since ToTensorOp and
/// ToMemrefOp have no aliasing OpOperand/OpResult pairs, this cannot be encoded
/// directly in the analysis. However, declaring ToTensorOp results as not
/// writable also enforces a buffer copy and has the same effect.
/// ToTensorOp conceptually loads a tensor from a memory location. The analysis
/// has no information about the memref that is loaded from by ToTensorOp. We
/// have to assume that the loaded tensor may after bufferization potentially
/// alias with any other bufferized tensor. Since ToTensorOp and ToMemrefOp have
/// no aliasing OpOperand/OpResult pairs, this cannot be encoded directly in the
/// analysis. However, declaring ToTensorOp results as not writable enforces a
/// buffer copy and has the same effect.
struct ToTensorOpInterface
: public BufferizableOpInterface::ExternalModel<ToTensorOpInterface,
bufferization::ToTensorOp> {
Expand All @@ -89,7 +98,7 @@ struct ToTensorOpInterface

bool isWritable(Operation *op, Value value,
const BufferizationState &state) const {
// It is unknown whether the MemRef operand is writable or not.
// It is unknown whether the memref operand is writable or not.
return false;
}
};
Expand Down
Loading

0 comments on commit 75d6529

Please sign in to comment.