diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 122f735628521..abd996bdbaf85 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -12,6 +12,7 @@ #include #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/Utils/Utils.h" @@ -28,6 +29,7 @@ namespace mlir { namespace bufferization { +class AllocTensorOp; class OneShotAnalysisState; } // namespace bufferization @@ -110,6 +112,18 @@ Value bufferizeToAllocation(RewriterBase &rewriter, vector::MaskOp maskOp, Attribute memorySpace = {}, Operation *insertionPoint = nullptr); +/// Materialize a buffer allocation for the given bufferization.alloc_tensor op +/// and lower the op to memref.alloc + memref.tensor_store. +/// +/// In addition to rewriting the IR, this function returns the newly allocated +/// buffer. The `insertionPoint` parameter can be used to specify a custom +/// insertion point for the buffer allocation. +Value bufferizeToAllocation(RewriterBase &rewriter, + const BufferizeToAllocationOptions &options, + bufferization::AllocTensorOp allocTensorOp, + Attribute memorySpace = {}, + Operation *insertionPoint = nullptr); + /// Bufferize the given op with tensor semantics and materialize the result in /// a newly allocated buffer. /// diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp index f7340844f7e19..311540fde512b 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp @@ -317,6 +317,27 @@ Value linalg::bufferizeToAllocation( return alloc; } +Value linalg::bufferizeToAllocation( + RewriterBase &rewriter, const linalg::BufferizeToAllocationOptions &options, + bufferization::AllocTensorOp allocTensorOp, Attribute memorySpace, + Operation *insertionPoint) { + Location loc = allocTensorOp.getLoc(); + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(insertionPoint ? insertionPoint : allocTensorOp); + bufferization::BufferizationOptions bufferizationOptions; + + // Create buffer allocation. + Value alloc = createAllocationForTensor( + rewriter, loc, allocTensorOp.getResult(), options, memorySpace); + + // Create bufferization.to_tensor with "restrict" and "writable". The returned + // tensor is a new buffer allocation, so it does not alias with any buffer. + Value toTensorOp = rewriter.create( + loc, alloc, /*restrict=*/true, /*writable=*/true); + rewriter.replaceOp(allocTensorOp, toTensorOp); + return alloc; +} + /// Lower tensor.from_elements to a sequence of chained tensor.insert. FailureOr mlir::linalg::rewriteInDestinationPassingStyle( RewriterBase &rewriter, tensor::FromElementsOp fromElementsOp) { @@ -454,6 +475,8 @@ Value linalg::bufferizeToAllocation( return bufferizeToAllocation(rewriter, options, padOp, memorySpace); if (auto maskOp = dyn_cast(op)) return bufferizeToAllocation(rewriter, options, maskOp, memorySpace); + if (auto allocTensorOp = dyn_cast(op)) + return bufferizeToAllocation(rewriter, options, allocTensorOp, memorySpace); // Only bufferizable ops are supported. auto bufferizableOp = dyn_cast(op); diff --git a/mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir b/mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir index 8d52d9900a793..3c50a9e72d9d9 100644 --- a/mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir @@ -215,3 +215,26 @@ func.func @buffer_loop_hoisting(%lb: index, %ub: index, %step: index, %f: f32, % } return } + +// ----- + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %alloc_tensor = transform.structured.match ops{["bufferization.alloc_tensor"]} in %arg1 + : (!transform.any_op) -> !transform.op<"bufferization.alloc_tensor"> + %2, %new = transform.structured.bufferize_to_allocation %alloc_tensor + {alloc_op = "memref.alloca"} + : !transform.op<"bufferization.alloc_tensor"> + transform.yield + } +} + +// Expect `bufferization.bufferize_to_allocation` to create an alloc. +// CHECK-LABEL: func.func @empty_to_tensor_alloc() +func.func @empty_to_tensor_alloc() -> tensor<2x2xf32> { + // CHECK-NEXT: %[[alloca:.*]] = memref.alloca() : memref<2x2xf32> + // CHECK-NEXT: %[[tensor:.*]] = bufferization.to_tensor %[[alloca]] restrict writable : memref<2x2xf32> + // CHECK-NEXT: return %[[tensor]] : tensor<2x2xf32> + %0 = bufferization.alloc_tensor() : tensor<2x2xf32> + return %0 : tensor<2x2xf32> +}