Skip to content

Commit

Permalink
[mlir][linalg][bufferize] Support std.select bufferization
Browse files Browse the repository at this point in the history
This op is an example for how to deal with ops who's OpResult may aliasing with one of multiple OpOperands.

Differential Revision: https://reviews.llvm.org/D116868
  • Loading branch information
matthias-springer committed Jan 12, 2022
1 parent 5642ce5 commit 6c654b5
Show file tree
Hide file tree
Showing 17 changed files with 333 additions and 60 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
//===- LinalgInterfaceImpl.h - Linalg Impl. of BufferizableOpInterface ----===//
//===- AffineInterfaceImpl.h - Affine Impl. of BufferizableOpInterface ----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -346,18 +346,18 @@ class BufferizationState {
/// In the above example, Values with a star satisfy the condition. When
/// starting the traversal from Value 1, the resulting SetVector is:
/// { 2, 7, 8, 5 }
llvm::SetVector<Value> findValueInReverseUseDefChain(
SetVector<Value> findValueInReverseUseDefChain(
Value value, llvm::function_ref<bool(Value)> condition) const;

/// Find the Value of the last preceding write of a given Value.
/// Find the Values of the last preceding write of a given Value.
///
/// Note: Unknown ops are handled conservatively and assumed to be writes.
/// Furthermore, BlockArguments are also assumed to be writes. There is no
/// analysis across block boundaries.
///
/// Note: When reaching an end of the reverse SSA use-def chain, that value
/// is returned regardless of whether it is a memory write or not.
Value findLastPrecedingWrite(Value value) const;
SetVector<Value> findLastPrecedingWrite(Value value) const;

/// Creates a memref allocation.
FailureOr<Value> createAlloc(OpBuilder &b, Location loc, MemRefType type,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ runComprehensiveBufferize(ModuleOp moduleOp,

namespace std_ext {

void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry);
void registerModuleBufferizationExternalModels(DialectRegistry &registry);

} // namespace std_ext
} // namespace comprehensive_bufferize
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
//===- StdInterfaceImpl.h - Standard Impl. of BufferizableOpInterface- ----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_STD_INTERFACE_IMPL_H
#define MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_STD_INTERFACE_IMPL_H

namespace mlir {

class DialectRegistry;

namespace linalg {
namespace comprehensive_bufferize {
namespace std_ext {

void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry);

} // namespace std_ext
} // namespace comprehensive_bufferize
} // namespace linalg
} // namespace mlir

#endif // MLIR_DIALECT_LINALG_COMPREHENSIVEBUFFERIZE_STD_INTERFACE_IMPL_H
Original file line number Diff line number Diff line change
Expand Up @@ -305,26 +305,18 @@ llvm::SetVector<Value> mlir::linalg::comprehensive_bufferize::
return result;
}

// Find the Value of the last preceding write of a given Value.
Value mlir::linalg::comprehensive_bufferize::BufferizationState::
findLastPrecedingWrite(Value value) const {
SetVector<Value> result =
findValueInReverseUseDefChain(value, [&](Value value) {
Operation *op = value.getDefiningOp();
if (!op)
return true;
auto bufferizableOp = options.dynCastBufferizableOp(op);
if (!bufferizableOp)
return true;
return bufferizableOp.isMemoryWrite(value.cast<OpResult>(), *this);
});

// To simplify the analysis, `scf.if` ops are considered memory writes. There
// are currently no other ops where one OpResult may alias with multiple
// OpOperands. Therefore, this function should return exactly one result at
// the moment.
assert(result.size() == 1 && "expected exactly one result");
return result.front();
// Find the Values of the last preceding write of a given Value.
llvm::SetVector<Value> mlir::linalg::comprehensive_bufferize::
BufferizationState::findLastPrecedingWrite(Value value) const {
return findValueInReverseUseDefChain(value, [&](Value value) {
Operation *op = value.getDefiningOp();
if (!op)
return true;
auto bufferizableOp = options.dynCastBufferizableOp(op);
if (!bufferizableOp)
return true;
return bufferizableOp.isMemoryWrite(value.cast<OpResult>(), *this);
});
}

mlir::linalg::comprehensive_bufferize::BufferizationState::BufferizationState(
Expand Down Expand Up @@ -404,15 +396,19 @@ mlir::linalg::comprehensive_bufferize::BufferizationState::getBuffer(
createAlloc(rewriter, loc, operandBuffer, options.createDeallocs);
if (failed(resultBuffer))
return failure();
// Do not copy if the last preceding write of `operand` is an op that does
// Do not copy if the last preceding writes of `operand` are ops that do
// not write (skipping ops that merely create aliases). E.g., InitTensorOp.
// Note: If `findLastPrecedingWrite` reaches the end of the reverse SSA
// use-def chain, it returns that value, regardless of whether it is a
// memory write or not.
Value lastWrite = findLastPrecedingWrite(operand);
if (auto bufferizableOp = options.dynCastBufferizableOp(lastWrite))
if (!bufferizableOp.isMemoryWrite(lastWrite.cast<OpResult>(), *this))
return resultBuffer;
SetVector<Value> lastWrites = findLastPrecedingWrite(operand);
if (llvm::none_of(lastWrites, [&](Value lastWrite) {
if (auto bufferizableOp = options.dynCastBufferizableOp(lastWrite))
return bufferizableOp.isMemoryWrite(lastWrite.cast<OpResult>(),
*this);
return true;
}))
return resultBuffer;
// Do not copy if the copied data is never read.
OpResult aliasingOpResult = getAliasingOpResult(opOperand);
if (aliasingOpResult && !bufferizesToMemoryRead(opOperand) &&
Expand Down
9 changes: 9 additions & 0 deletions mlir/lib/Dialect/Linalg/ComprehensiveBufferize/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ set(LLVM_OPTIONAL_SOURCES
LinalgInterfaceImpl.cpp
ModuleBufferization.cpp
SCFInterfaceImpl.cpp
StdInterfaceImpl.cpp
TensorInterfaceImpl.cpp
VectorInterfaceImpl.cpp
)
Expand Down Expand Up @@ -61,6 +62,14 @@ add_mlir_dialect_library(MLIRSCFBufferizableOpInterfaceImpl
MLIRSCF
)

add_mlir_dialect_library(MLIRStdBufferizableOpInterfaceImpl
StdInterfaceImpl.cpp

LINK_LIBS PUBLIC
MLIRBufferizableOpInterface
MLIRStandard
)

add_mlir_dialect_library(MLIRTensorBufferizableOpInterfaceImpl
TensorInterfaceImpl.cpp

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,8 @@ static bool hasReadAfterWriteInterference(
for (OpOperand *uRead : usesRead) {
Operation *readingOp = uRead->getOwner();

// Find most recent write of uRead by following the SSA use-def chain. E.g.:
// Find most recent writes of uRead by following the SSA use-def chain.
// E.g.:
//
// %0 = "writing_op"(%t) : tensor<?x32> -> tensor<?xf32>
// %1 = "aliasing_op"(%0) : tensor<?x32> -> tensor<?xf32>
Expand All @@ -228,7 +229,7 @@ static bool hasReadAfterWriteInterference(
// In the above example, if uRead is the OpOperand of reading_op, lastWrite
// is %0. Note that operations that create an alias but do not write (such
// as ExtractSliceOp) are skipped.
Value lastWrite = state.findLastPrecedingWrite(uRead->get());
SetVector<Value> lastWrites = state.findLastPrecedingWrite(uRead->get());

// Look for conflicting memory writes. Potential conflicts are writes to an
// alias that have been decided to bufferize inplace.
Expand Down Expand Up @@ -265,35 +266,38 @@ static bool hasReadAfterWriteInterference(
if (insideMutuallyExclusiveRegions(readingOp, conflictingWritingOp))
continue;

// No conflict if the conflicting write happens before the last
// write.
if (Operation *writingOp = lastWrite.getDefiningOp()) {
if (happensBefore(conflictingWritingOp, writingOp, domInfo))
// conflictingWritingOp happens before writingOp. No conflict.
continue;
// No conflict if conflictingWritingOp is contained in writingOp.
if (writingOp->isProperAncestor(conflictingWritingOp))
continue;
} else {
auto bbArg = lastWrite.cast<BlockArgument>();
Block *block = bbArg.getOwner();
if (!block->findAncestorOpInBlock(*conflictingWritingOp))
// conflictingWritingOp happens outside of the block. No
// conflict.
continue;
}
// Check all possible last writes.
for (Value lastWrite : lastWrites) {
// No conflict if the conflicting write happens before the last
// write.
if (Operation *writingOp = lastWrite.getDefiningOp()) {
if (happensBefore(conflictingWritingOp, writingOp, domInfo))
// conflictingWritingOp happens before writingOp. No conflict.
continue;
// No conflict if conflictingWritingOp is contained in writingOp.
if (writingOp->isProperAncestor(conflictingWritingOp))
continue;
} else {
auto bbArg = lastWrite.cast<BlockArgument>();
Block *block = bbArg.getOwner();
if (!block->findAncestorOpInBlock(*conflictingWritingOp))
// conflictingWritingOp happens outside of the block. No
// conflict.
continue;
}

// No conflict if the conflicting write and the last write are the same
// use.
if (state.getAliasingOpResult(*uConflictingWrite) == lastWrite)
continue;
// No conflict if the conflicting write and the last write are the same
// use.
if (state.getAliasingOpResult(*uConflictingWrite) == lastWrite)
continue;

// All requirements are met. Conflict found!
// All requirements are met. Conflict found!

if (options.printConflicts)
annotateConflict(uRead, uConflictingWrite, lastWrite);
if (options.printConflicts)
annotateConflict(uRead, uConflictingWrite, lastWrite);

return true;
return true;
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -938,7 +938,7 @@ struct FuncOpInterface
} // namespace mlir

void mlir::linalg::comprehensive_bufferize::std_ext::
registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry) {
registerModuleBufferizationExternalModels(DialectRegistry &registry) {
registry.addOpInterface<CallOp, std_ext::CallOpInterface>();
registry.addOpInterface<ReturnOp, std_ext::ReturnOpInterface>();
registry.addOpInterface<FuncOp, std_ext::FuncOpInterface>();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
//===- StdInterfaceImpl.cpp - Standard Impl. of BufferizableOpInterface ---===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.h"

#include "mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Operation.h"

namespace mlir {
namespace linalg {
namespace comprehensive_bufferize {
namespace std_ext {

/// Bufferization of std.select. Just replace the operands.
struct SelectOpInterface
: public BufferizableOpInterface::ExternalModel<SelectOpInterface,
SelectOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
const BufferizationState &state) const {
return false;
}

bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
const BufferizationState &state) const {
return false;
}

OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
const BufferizationState &state) const {
return op->getOpResult(0) /*result*/;
}

SmallVector<OpOperand *>
getAliasingOpOperand(Operation *op, OpResult opResult,
const BufferizationState &state) const {
return {&op->getOpOperand(1) /*true_value*/,
&op->getOpOperand(2) /*false_value*/};
}

LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationState &state) const {
auto selectOp = cast<SelectOp>(op);
// `getBuffer` introduces copies if an OpOperand bufferizes out-of-place.
// TODO: It would be more efficient to copy the result of the `select` op
// instead of its OpOperands. In the worst case, 2 copies are inserted at
// the moment (one for each tensor). When copying the op result, only one
// copy would be needed.
Value trueBuffer =
*state.getBuffer(rewriter, selectOp->getOpOperand(1) /*true_value*/);
Value falseBuffer =
*state.getBuffer(rewriter, selectOp->getOpOperand(2) /*false_value*/);
replaceOpWithNewBufferizedOp<SelectOp>(
rewriter, op, selectOp.getCondition(), trueBuffer, falseBuffer);
return success();
}

BufferRelation bufferRelation(Operation *op, OpResult opResult,
const BufferizationAliasInfo &aliasInfo,
const BufferizationState &state) const {
return BufferRelation::None;
}
};

} // namespace std_ext
} // namespace comprehensive_bufferize
} // namespace linalg
} // namespace mlir

void mlir::linalg::comprehensive_bufferize::std_ext::
registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry) {
registry.addOpInterface<SelectOp, std_ext::SelectOpInterface>();
}
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
MLIRSCF
MLIRSCFBufferizableOpInterfaceImpl
MLIRSCFTransforms
MLIRStdBufferizableOpInterfaceImpl
MLIRPass
MLIRStandard
MLIRStandardOpsTransforms
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/StdInterfaceImpl.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/TensorInterfaceImpl.h"
#include "mlir/Dialect/Linalg/ComprehensiveBufferize/VectorInterfaceImpl.h"
#include "mlir/Dialect/Linalg/Passes.h"
Expand Down Expand Up @@ -51,6 +52,7 @@ struct LinalgComprehensiveModuleBufferize
bufferization_ext::registerBufferizableOpInterfaceExternalModels(registry);
linalg_ext::registerBufferizableOpInterfaceExternalModels(registry);
scf_ext::registerBufferizableOpInterfaceExternalModels(registry);
std_ext::registerModuleBufferizationExternalModels(registry);
std_ext::registerBufferizableOpInterfaceExternalModels(registry);
tensor_ext::registerBufferizableOpInterfaceExternalModels(registry);
vector_ext::registerBufferizableOpInterfaceExternalModels(registry);
Expand Down
Loading

0 comments on commit 6c654b5

Please sign in to comment.