Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][bufferization] MaterializeInDestinationOp: Support memref destinations #68074

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 65 additions & 22 deletions mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -216,33 +216,58 @@ def Bufferization_CloneOp : Bufferization_Op<"clone", [

def Bufferization_MaterializeInDestinationOp
: Bufferization_Op<"materialize_in_destination",
[BufferizableOpInterface, SameOperandsAndResultType,
DestinationStyleOpInterface,
[AllShapesMatch<["source", "dest"]>,
AllElementTypesMatch<["source", "dest"]>,
BufferizableOpInterface, DestinationStyleOpInterface,
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
DeclareOpInterfaceMethods<SubsetInsertionOpInterface,
["getSourceOperand", "getValuesNeededToBuildSubsetExtraction",
"buildSubsetExtraction", "isEquivalentSubset"]>]> {
"buildSubsetExtraction", "isEquivalentSubset"]>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface, ["getEffects"]>]> {
let summary = "copy a tensor";

let description = [{
This op indicates that the data of the `source` tensor should materialize
in the future buffer of the `dest` tensors. Both tensors must have the same
shape and element type at runtime.
in `dest`, which can be a tensor or a memref. In case of a tensor, `source`
matthias-springer marked this conversation as resolved.
Show resolved Hide resolved
should materialize in the future buffer of `dest` and a the updated
destination tensor is returned. In case of a memref, `source` should
materialize in `dest`, which is already a buffer. The op has no results in
that case.

`source`, `dest` and `result` (if present) must have the same shape and
element type. If the op has a result, the types of `result` and `dest` must
match exactly (e.g., including any tensor encodings).

By default, this op bufferizes to a memcpy from the future buffer of the
`source` tensor to the future buffer of the `dest` tensor. However,
transformations such as "empty tensor elimination" may rewrite IR such that
a computation is performed directly in the future buffer of the `dest`
tensor and no memcpy is needed.

Note: "tensor.insert_slice" could be used for the same purpose, but since
tensor dialect ops only indicate *what* should be computed but not *where*,
it could fold away, causing the computation to materialize in a different
buffer.
`source` tensor to the future buffer of the `dest` tensor or to the `dest`
buffer. However, transformations such as "empty tensor elimination" may
rewrite IR such that a computation is performed directly in `dest` and no
memcpy is needed.

If `dest` is a buffer, the `restrict` and `writable` attributes must be
specified. These attributes have the same meaning as the respective
attributes of `bufferization.to_tensor`. `writable` indicates that the
`dest` buffer is considered writable. It does not make sense to materialize
a computation in a read-only buffer, so `writable` is required. `restrict`
indicates that this op is the only way for the tensor IR to access `dest`
(or an alias thereof). E.g., there must be no other `to_tensor` ops with
`dest` or with an alias of `dest`. Such IR is not supported by
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are deeper safety concerns here right ?
We should make it more explicit, here and below, that using this op improperly will result in undefined behavior.

I think the current wording here and above is potentially error prone in that the user may expect the analysis to be conservative or produce warnings and/or errors.

In reality this is a very strict directive that is easy to misuse.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added this sentence here and at ToTensorOp:
Ops that have incorrect usage of restrict may bufferize incorrectly.

One-Shot Bufferize.

Note: `restrict` and `writable` could be removed from this op because they
must always be set for memref destinations. This op has these attributes to
make clear the requirements on the `dest` operand in the op assembly format.
Moreover, these requirements may be relaxed at some point in the future.

Note: If `dest` is a tensor, `tensor.insert_slice` could be used for the
same purpose, but since tensor dialect ops only indicate *what* should be
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this true ?

If one uses tensor.insert_slice, the analysis will insert copies if required, whereas materialize_in_destination is prescriptive IIUC.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right, that is the desired behavior. It was actually not like that for tensor destinations.

computed but not *where*, it could fold away, causing the computation to
materialize in a different buffer.
Copy link
Contributor

@nicolasvasilache nicolasvasilache Oct 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

another note maybe: is there any correctness guarantee when e.g. using this op as the last op in a function ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As long as restrict is not used incorrectly, the IR is guaranteed to bufferize correctly. If the computation cannot materialize in the specified tensor due to a RaW conflict or a read-only tensor, the IR fails to bufferize. (Added test cases.)

}];

let arguments = (ins AnyTensor:$source, AnyTensor:$dest);
let results = (outs AnyTensor:$result);
let arguments = (ins AnyTensor:$source, AnyShaped:$dest,
UnitAttr:$restrict, UnitAttr:$writable);
let results = (outs Optional<AnyTensor>:$result);

let extraClassDeclaration = [{
LogicalResult bufferize(RewriterBase &rewriter,
Expand All @@ -264,10 +289,23 @@ def Bufferization_MaterializeInDestinationOp
return ::llvm::cast<RankedTensorType>(getResult().getType());
}

MutableOperandRange getDpsInitsMutable() { return getDestMutable(); }
MutableOperandRange getDpsInitsMutable();

bool isWritable(Value value, const AnalysisState &state);
}];

let assemblyFormat = "$source `in` $dest attr-dict `:` type($source)";
let builders = [
// Builder that materializes a source tensor in a tensor destination.
// Asserts that `dest` has tensor type. Infers the result type of this op
// from the destination tensor.
OpBuilder<(ins "Value":$source, "Value":$dest)>
];

let assemblyFormat = [{
$source `in` (`restrict` $restrict^)? (`writable` $writable^)? $dest
attr-dict `:` functional-type(operands, results)
}];
let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -361,10 +399,15 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
thereof) will bufferize out-of-place to prevent emitting any writes to
`memref` during bufferization.

If the given memref does not alias with any other memref passed to another
`to_tensor` op, the `restrict` unit attribute can be set. Only such
operations are supported by One-Shot Bufferize. (Otherwise, potential memref
aliasing relationships would have to be captured in One-Shot Bufferize.)
The `restrict` unit attribute (similar to the C `restrict` keyword)
indicates that the produced tensor result is the only way for the tensor
IR to gain access to the `memref` operand (or an alias thereof). E.g.,
there must be no other `to_tensor` op with the same or with an aliasing
`memref` operand.

Note: Only `to_tensor` ops with the `restrict` unit attribute are supported
by One-Shot Bufferize. Other IR is rejected. (To support `to_tensor`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here, can we be more explicit about user's responsibility / UB vs what the analysis will catch ?

Copy link
Member Author

@matthias-springer matthias-springer Oct 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added this sentence:
Ops that have incorrect usage of restrict may bufferize incorrectly.

without `restrict`, One-Shot Bufferize would have to analyze memref IR.)

Example:

Expand Down
96 changes: 85 additions & 11 deletions mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -542,25 +542,40 @@ bool MaterializeInDestinationOp::bufferizesToMemoryRead(

bool MaterializeInDestinationOp::bufferizesToMemoryWrite(
OpOperand &opOperand, const AnalysisState &state) {
return &opOperand == &getDestMutable();
if (&opOperand == &getDestMutable()) {
assert(isa<TensorType>(getDest().getType()) && "expected tensor type");
return true;
}
return false;
}

AliasingValueList
MaterializeInDestinationOp::getAliasingValues(OpOperand &opOperand,
const AnalysisState &state) {
if (&opOperand == &getDestMutable())
if (&opOperand == &getDestMutable()) {
assert(isa<TensorType>(getDest().getType()) && "expected tensor type");
return {{getOperation()->getResult(0), BufferRelation::Equivalent}};
}
return {};
}

LogicalResult
MaterializeInDestinationOp::bufferize(RewriterBase &rewriter,
const BufferizationOptions &options) {
FailureOr<Value> buffer = getBuffer(rewriter, getDest(), options);
if (failed(buffer))
return failure();
rewriter.create<memref::TensorStoreOp>(getLoc(), getSource(), *buffer);
replaceOpWithBufferizedValues(rewriter, getOperation(), *buffer);
bool tensorDest = isa<TensorType>(getDest().getType());
Value buffer;
if (tensorDest) {
FailureOr<Value> maybeBuffer = getBuffer(rewriter, getDest(), options);
if (failed(maybeBuffer))
return failure();
buffer = *maybeBuffer;
} else {
assert(isa<BaseMemRefType>(getDest().getType()) && "expected memref type");
buffer = getDest();
}
rewriter.create<memref::TensorStoreOp>(getLoc(), getSource(), buffer);
replaceOpWithBufferizedValues(rewriter, getOperation(),
tensorDest ? ValueRange(buffer) : ValueRange());
return success();
}

Expand All @@ -573,15 +588,29 @@ bool MaterializeInDestinationOp::bufferizesToElementwiseAccess(

LogicalResult MaterializeInDestinationOp::reifyResultShapes(
OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
reifiedReturnShapes[0] = tensor::getMixedSizes(builder, getLoc(), getDest());
if (getOperation()->getNumResults() == 1) {
assert(isa<TensorType>(getDest().getType()) && "expected tensor type");
reifiedReturnShapes.resize(1,
SmallVector<OpFoldResult>(getType().getRank()));
reifiedReturnShapes[0] =
tensor::getMixedSizes(builder, getLoc(), getDest());
}
return success();
}

Value MaterializeInDestinationOp::buildSubsetExtraction(OpBuilder &builder,
Location loc) {
// The subset is the entire destination tensor.
return getDest();
if (isa<TensorType>(getDest().getType())) {
// The subset is the entire destination tensor.
return getDest();
}

// Build a bufferization.to_tensor op.
assert(isa<BaseMemRefType>(getDest().getType()) && "expected memref type");
assert(getRestrict() &&
"expected that ops with memrefs dest have 'restrict'");
return builder.create<ToTensorOp>(loc, getDest(), getRestrict(),
getWritable());
}

bool MaterializeInDestinationOp::isEquivalentSubset(
Expand All @@ -598,6 +627,51 @@ OpOperand &MaterializeInDestinationOp::getSourceOperand() {
return getOperation()->getOpOperand(0) /*source*/;
}

LogicalResult MaterializeInDestinationOp::verify() {
if (!isa<TensorType, BaseMemRefType>(getDest().getType()))
return emitOpError("'dest' must be a tensor or a memref");
matthias-springer marked this conversation as resolved.
Show resolved Hide resolved
if (auto destType = dyn_cast<TensorType>(getDest().getType())) {
if (getOperation()->getNumResults() != 1)
return emitOpError("tensor 'dest' implies exactly one tensor result");
if (destType != getResult().getType())
return emitOpError("result and 'dest' types must match");
}
if (isa<BaseMemRefType>(getDest().getType()) &&
getOperation()->getNumResults() != 0)
return emitOpError("memref 'dest' implies zero results");
if (getRestrict() != isa<BaseMemRefType>(getDest().getType()))
return emitOpError("'restrict' must be specified if and only if the "
"destination is of memref type");
if (getWritable() != isa<BaseMemRefType>(getDest().getType()))
return emitOpError("'writable' must be specified if and only if the "
"destination is of memref type");
return success();
}

void MaterializeInDestinationOp::build(OpBuilder &builder,
OperationState &state, Value source,
Value dest) {
assert(isa<TensorType>(dest.getType()) && "expected tensor type");
build(builder, state, /*result=*/dest.getType(), source, dest);
}

bool MaterializeInDestinationOp::isWritable(Value value,
const AnalysisState &state) {
return isa<TensorType>(getDest().getType()) ? true : getWritable();
}

MutableOperandRange MaterializeInDestinationOp::getDpsInitsMutable() {
return getDestMutable();
}

void MaterializeInDestinationOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
if (isa<BaseMemRefType>(getDest().getType()))
effects.emplace_back(MemoryEffects::Write::get(), getDest(),
SideEffects::DefaultResource::get());
}

//===----------------------------------------------------------------------===//
// ToTensorOp
//===----------------------------------------------------------------------===//
Expand Down
6 changes: 4 additions & 2 deletions mlir/lib/Dialect/Linalg/Transforms/Padding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -248,8 +248,10 @@ linalg::rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad,
LinalgPaddingOptions::CopyBackOp::
BufferizationMaterializeInDestination) {
replacements.push_back(
rewriter.create<bufferization::MaterializeInDestinationOp>(
loc, std::get<0>(it), std::get<1>(it).get()));
rewriter
.create<bufferization::MaterializeInDestinationOp>(
loc, std::get<0>(it), std::get<1>(it).get())
->getResult(0));
} else {
llvm_unreachable("unsupported copy back op");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ func.func @materialize_in_destination_aliasing(%t: tensor<?xf32>, %p1: index, %p
%dest = tensor.extract_slice %t[%p2][5][1] : tensor<?xf32> to tensor<5xf32>
// CHECK: bufferization.materialize_in_destination
// CHECK-SAME: {__inplace_operands_attr__ = ["true", "true"]}
%r = bufferization.materialize_in_destination %src in %dest : tensor<5xf32>
%r = bufferization.materialize_in_destination %src in %dest : (tensor<5xf32>, tensor<5xf32>) -> tensor<5xf32>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure this actually reads better.
Using only the type of the tensor of the memref is fully unambiguous, no?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is because the result is optional. In case of a memref, there is no result (like linalg.generic). TableGen helpers such as AllTypesMatch etc. do not work with that. I could write the printer/parser in C++ though, then we could stay with the original syntax.

return %r : tensor<5xf32>
}

Expand All @@ -183,6 +183,6 @@ func.func @materialize_in_destination(%t: tensor<?xf32>, %sz: index) -> tensor<?
%buffer = tensor.empty(%sz) : tensor<?xf32>
// CHECK: bufferization.materialize_in_destination
// CHECK-SAME: {__inplace_operands_attr__ = ["true", "true"]}
%r = bufferization.materialize_in_destination %buffer in %buffer : tensor<?xf32>
%r = bufferization.materialize_in_destination %buffer in %buffer : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
return %r : tensor<?xf32>
}
Original file line number Diff line number Diff line change
Expand Up @@ -301,12 +301,25 @@ func.func @regression_multiple_insertion_points(%t1: tensor<?x?xf32>) -> tensor<
func.func @materialize_in_destination(%t: tensor<5xf32>, %f: f32) -> tensor<5xf32> {
%0 = tensor.empty() : tensor<5xf32>
%filled = linalg.fill ins(%f : f32) outs(%0 : tensor<5xf32>) -> tensor<5xf32>
%1 = bufferization.materialize_in_destination %filled in %t : tensor<5xf32>
%1 = bufferization.materialize_in_destination %filled in %t : (tensor<5xf32>, tensor<5xf32>) -> tensor<5xf32>
return %1 : tensor<5xf32>
}

// -----

// CHECK-LABEL: func @materialize_in_destination_buffer(
// CHECK-SAME: %[[m:.*]]: memref<5xf32>,
// CHECK-NEXT: linalg.fill {{.*}} outs(%[[m]]
// CHECK-NEXT: return
func.func @materialize_in_destination_buffer(%m: memref<5xf32>, %f: f32) {
%0 = tensor.empty() : tensor<5xf32>
%filled = linalg.fill ins(%f : f32) outs(%0 : tensor<5xf32>) -> tensor<5xf32>
bufferization.materialize_in_destination %filled in restrict writable %m : (tensor<5xf32>, memref<5xf32>) -> ()
return
}

// -----

// CHECK-LABEL: func @linalg_copy(
// CHECK-SAME: %[[m:.*]]: memref<5xf32, strided<[?], offset: ?>>,
// CHECK: linalg.fill {{.*}} outs(%[[m]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,20 @@ func.func @tensor_copy(%arg0: tensor<5xf32>) -> tensor<5xf32> {
// CHECK: %[[r:.*]] = bufferization.to_tensor %[[alloc]]
// CHECK: return %[[r]]
%dest = bufferization.alloc_tensor() : tensor<5xf32>
%0 = bufferization.materialize_in_destination %arg0 in %dest : tensor<5xf32>
%0 = bufferization.materialize_in_destination %arg0 in %dest
: (tensor<5xf32>, tensor<5xf32>) -> tensor<5xf32>
return %0 : tensor<5xf32>
}

// -----

// CHECK-LABEL: func @materialize_in_destination_buffer(
// CHECK-SAME: %[[t:.*]]: tensor<5xf32>, %[[m:.*]]: memref<5xf32>)
// CHECK: %[[b:.*]] = bufferization.to_memref %[[t]] : memref<5xf32, strided<[?], offset: ?>>
// CHECK: memref.copy %[[b]], %[[m]]
func.func @materialize_in_destination_buffer(%t: tensor<5xf32>, %m: memref<5xf32>) {
bufferization.materialize_in_destination %t in restrict writable %m
: (tensor<5xf32>, memref<5xf32>) -> ()
return
}

54 changes: 51 additions & 3 deletions mlir/test/Dialect/Bufferization/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,58 @@ func.func @invalid_writable_on_op() {

// -----

// expected-note @below{{prior use here}}
func.func @invalid_materialize_in_destination(%arg0: tensor<?xf32>, %arg1: tensor<5xf32>) {
// expected-error @below{{expects different type than prior uses: 'tensor<?xf32>' vs 'tensor<5xf32>'}}
bufferization.materialize_in_destination %arg0 in %arg1 : tensor<?xf32>
// expected-error @below{{failed to verify that all of {source, dest} have same shape}}
bufferization.materialize_in_destination %arg0 in %arg1 : (tensor<?xf32>, tensor<5xf32>) -> tensor<5xf32>
}

// -----

func.func @invalid_materialize_in_destination_dest_type(%arg0: tensor<5xf32>, %arg1: vector<5xf32>) {
// expected-error @below{{'dest' must be a tensor or a memref}}
bufferization.materialize_in_destination %arg0 in %arg1 : (tensor<5xf32>, vector<5xf32>) -> ()
}

// -----

func.func @invalid_materialize_in_destination_restrict_missing(%arg0: tensor<?xf32>, %arg1: memref<?xf32>) {
// expected-error @below{{'restrict' must be specified if and only if the destination is of memref type}}
bufferization.materialize_in_destination %arg0 in %arg1 : (tensor<?xf32>, memref<?xf32>) -> ()
}

// -----

func.func @invalid_materialize_in_destination_result(%arg0: tensor<?xf32>, %arg1: memref<?xf32>) {
// expected-error @below{{memref 'dest' implies zero results}}
bufferization.materialize_in_destination %arg0 in restrict %arg1 : (tensor<?xf32>, memref<?xf32>) -> (tensor<?xf32>)
}

// -----

func.func @invalid_materialize_in_destination_result_missing(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) {
// expected-error @below{{tensor 'dest' implies exactly one tensor result}}
bufferization.materialize_in_destination %arg0 in %arg1 : (tensor<?xf32>, tensor<?xf32>) -> ()
}

// -----

func.func @invalid_materialize_in_destination_restrict(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) {
// expected-error @below{{'restrict' must be specified if and only if the destination is of memref type}}
bufferization.materialize_in_destination %arg0 in restrict %arg1 : (tensor<?xf32>, tensor<?xf32>) -> (tensor<?xf32>)
}

// -----

func.func @invalid_materialize_in_destination_restrict(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) {
// expected-error @below{{'writable' must be specified if and only if the destination is of memref type}}
bufferization.materialize_in_destination %arg0 in writable %arg1 : (tensor<?xf32>, tensor<?xf32>) -> (tensor<?xf32>)
}

// -----

func.func @invalid_materialize_in_destination_result_shape(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) {
// expected-error @below{{result and 'dest' types must match}}
bufferization.materialize_in_destination %arg0 in %arg1 : (tensor<?xf32>, tensor<?xf32>) -> (tensor<6xf32>)
}

// -----
Expand Down
Loading