diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h index c1529a36465ac..6245f88db3d19 100644 --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -712,7 +712,7 @@ AliasingOpOperandList defaultGetAliasingOpOperands(Value value, /// This is the default implementation of /// BufferizableOpInterface::getBufferType. Should not be called from other /// places. -FailureOr +FailureOr defaultGetBufferType(Value value, const BufferizationOptions &options, const BufferizationState &state, SmallVector &invocationStack); diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td index cafe05fe5f189..246ae77f327cf 100644 --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td @@ -525,7 +525,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> { Note: This interface method should never be called directly from user code. Always use `bufferization::getBufferType`. }], - /*retType=*/"::mlir::FailureOr<::mlir::BaseMemRefType>", + /*retType=*/"::mlir::FailureOr<::mlir::bufferization::BufferLikeType>", /*methodName=*/"getBufferType", /*args=*/(ins "::mlir::Value":$value, "const ::mlir::bufferization::BufferizationOptions &":$options, diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td index 32c53ea9c494a..f175b15c8770f 100644 --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td @@ -111,7 +111,7 @@ def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor", AliasingValueList getAliasingValues( OpOperand &opOperand, const AnalysisState &state); - FailureOr getBufferType( + FailureOr getBufferType( Value value, const BufferizationOptions &options, const BufferizationState &state, SmallVector &invocationStack); @@ -478,10 +478,10 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [ bool isWritable(Value value, const AnalysisState &state); - FailureOr getBufferType( + FailureOr getBufferType( Value value, const BufferizationOptions &options, const BufferizationState &state, SmallVector &invocationStack) { - return ::llvm::cast(getBuffer().getType()); + return getBuffer().getType(); } }]; diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h index cbb6054fcf886..da7fee4b4a220 100644 --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h @@ -13,6 +13,7 @@ // Bufferization Type Interfaces //===----------------------------------------------------------------------===// +#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Types.h" diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h index f56c10555f02c..e8a81c74bd77a 100644 --- a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h @@ -32,7 +32,7 @@ template struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel : public BufferizableOpInterface::ExternalModel { - FailureOr + FailureOr getBufferType(Operation *op, Value value, const BufferizationOptions &options, const BufferizationState &state, SmallVector &invocationStack) const { @@ -110,7 +110,7 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel if (!bufferType) return op->emitOpError("could not infer buffer type of block argument"); - return bufferType; + return cast(bufferType); } protected: diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp index 85d1b5ac73bf4..afee162053bea 100644 --- a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp @@ -181,7 +181,7 @@ struct SelectOpInterface return success(); } - FailureOr + FailureOr getBufferType(Operation *op, Value value, const BufferizationOptions &options, const BufferizationState &state, SmallVector &invocationStack) const { @@ -196,17 +196,17 @@ struct SelectOpInterface if (failed(trueType) || failed(falseType)) return failure(); if (*trueType == *falseType) - return *trueType; + return cast(*trueType); if (trueType->getMemorySpace() != falseType->getMemorySpace()) return op->emitError("inconsistent memory space on true/false operands"); // If the buffers have different types, they differ only in their layout // map. auto memrefType = llvm::cast(*trueType); - return getMemRefTypeWithFullyDynamicLayout( + return cast(getMemRefTypeWithFullyDynamicLayout( RankedTensorType::get(memrefType.getShape(), memrefType.getElementType()), - memrefType.getMemorySpace()); + memrefType.getMemorySpace())); } }; diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp index 2ab182c9b7b2e..55784ac20d353 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -945,7 +945,7 @@ AliasingOpOperandList bufferization::detail::defaultGetAliasingOpOperands( return AliasingOpOperandList(std::move(result)); } -FailureOr bufferization::detail::defaultGetBufferType( +FailureOr bufferization::detail::defaultGetBufferType( Value value, const BufferizationOptions &options, const BufferizationState &bufferizationState, SmallVector &invocationStack) { @@ -953,8 +953,10 @@ FailureOr bufferization::detail::defaultGetBufferType( auto tensorType = cast(value.getType()); // No further analysis is possible for a block argument. - if (llvm::isa(value)) - return bufferization::getMemRefType(tensorType, options); + if (llvm::isa(value)) { + return cast( + bufferization::getMemRefType(tensorType, options)); + } // Value is an OpResult. Operation *op = getOwnerOfValue(value); @@ -966,8 +968,8 @@ FailureOr bufferization::detail::defaultGetBufferType( // If the OpResult has an equivalent OpOperand, both OpResult and // OpOperand bufferize to the exact same buffer type. Value equivalentOperand = aliases.getAliases().front().opOperand->get(); - return asMemRefType(getBufferType(equivalentOperand, options, - bufferizationState, invocationStack)); + return getBufferType(equivalentOperand, options, bufferizationState, + invocationStack); } // If we do not know the memory space and there is no default memory space, @@ -977,7 +979,8 @@ FailureOr bufferization::detail::defaultGetBufferType( if (!memSpace.has_value()) return op->emitError("could not infer memory space"); - return getMemRefType(tensorType, options, /*layout=*/{}, *memSpace); + return cast( + getMemRefType(tensorType, options, /*layout=*/{}, *memSpace)); } bool bufferization::detail::defaultIsRepetitiveRegion( diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp index 9bd87d66c7d36..66949c96798de 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp @@ -222,7 +222,7 @@ AliasingValueList AllocTensorOp::getAliasingValues(OpOperand &opOperand, return {}; } -FailureOr +FailureOr AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options, const BufferizationState &state, SmallVector &invocationStack) { @@ -245,7 +245,8 @@ AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options, return getOperation()->emitError("could not infer memory space"); } - return getMemRefTypeWithStaticIdentityLayout(getType(), memorySpace); + return cast( + getMemRefTypeWithStaticIdentityLayout(getType(), memorySpace)); } LogicalResult AllocTensorOp::verify() { diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp index 453ed43bcadd2..bd2aebca68079 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp @@ -211,7 +211,7 @@ struct CallOpInterface return result; } - FailureOr + FailureOr getBufferType(Operation *op, Value value, const BufferizationOptions &options, const BufferizationState &state, SmallVector &invocationStack) const { @@ -229,12 +229,13 @@ struct CallOpInterface Type resultType = funcType.getResult(cast(value).getResultNumber()); if (auto bufferizedType = dyn_cast(resultType)) - return bufferizedType; + return cast(bufferizedType); // Otherwise, call the type converter to compute the bufferized type. auto tensorType = cast(resultType); - return options.functionArgTypeConverterFn( - tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp, options); + return cast(options.functionArgTypeConverterFn( + tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp, + options)); } /// All function arguments are writable. It is the responsibility of the @@ -396,7 +397,7 @@ struct FuncOpInterface return getAliasingBranchOpOperands(op, cast(value), state); } - FailureOr + FailureOr getBufferType(Operation *op, Value value, const BufferizationOptions &options, const BufferizationState &state, SmallVector &invocationStack) const { @@ -405,8 +406,8 @@ struct FuncOpInterface // Function arguments are special. if (bbArg.getOwner() == &funcOp.getBody().front()) - return getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(), - options); + return cast( + getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(), options)); return OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel:: getBufferType(op, value, options, state, invocationStack); diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp index 58562536be61f..d36d91249ed36 100644 --- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp @@ -274,7 +274,7 @@ struct IfOpInterface return success(); } - FailureOr + FailureOr getBufferType(Operation *op, Value value, const BufferizationOptions &options, const BufferizationState &state, SmallVector &invocationStack) const { @@ -313,15 +313,15 @@ struct IfOpInterface // Best case: Both branches have the exact same buffer type. if (thenBufferType == elseBufferType) - return thenBufferType; + return cast(thenBufferType); // Memory space mismatch. if (thenBufferType.getMemorySpace() != elseBufferType.getMemorySpace()) return op->emitError("inconsistent memory space on then/else branches"); // Layout maps are different: Promote to fully dynamic layout map. - return getMemRefTypeWithFullyDynamicLayout( - cast(opResult.getType()), thenBufferType.getMemorySpace()); + return cast(getMemRefTypeWithFullyDynamicLayout( + cast(opResult.getType()), thenBufferType.getMemorySpace())); } }; @@ -392,7 +392,7 @@ struct IndexSwitchOpInterface return success(); } - FailureOr + FailureOr getBufferType(Operation *op, Value value, const BufferizationOptions &options, const BufferizationState &state, SmallVector &invocationStack) const { @@ -436,7 +436,7 @@ struct IndexSwitchOpInterface cast(value.getType()), bufferType.getMemorySpace()); } - return bufferType; + return cast(bufferType); } }; @@ -522,13 +522,13 @@ getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs, /// If both buffer types are equal, no casts are needed the computed buffer type /// can be used directly. Otherwise, the buffer types can only differ in their /// layout map and a cast must be inserted. -static FailureOr computeLoopRegionIterArgBufferType( +static FailureOr computeLoopRegionIterArgBufferType( Operation *loopOp, BlockArgument iterArg, Value initArg, Value yieldedValue, const BufferizationOptions &options, const BufferizationState &state, SmallVector &invocationStack) { // Determine the buffer type of the init_arg. - auto initArgBufferType = bufferization::detail::asMemRefType( - bufferization::getBufferType(initArg, options, state, invocationStack)); + auto initArgBufferType = + bufferization::getBufferType(initArg, options, state, invocationStack); if (failed(initArgBufferType)) return failure(); @@ -547,16 +547,15 @@ static FailureOr computeLoopRegionIterArgBufferType( } // Compute the buffer type of the yielded value. - BaseMemRefType yieldedValueBufferType; + BufferLikeType yieldedValueBufferType; if (isa(yieldedValue.getType())) { // scf.yield was already bufferized. - yieldedValueBufferType = cast(yieldedValue.getType()); + yieldedValueBufferType = cast(yieldedValue.getType()); } else { // Note: This typically triggers a recursive call for the buffer type of // the iter_arg. - auto maybeBufferType = - bufferization::detail::asMemRefType(bufferization::getBufferType( - yieldedValue, options, state, invocationStack)); + auto maybeBufferType = bufferization::getBufferType(yieldedValue, options, + state, invocationStack); if (failed(maybeBufferType)) return failure(); yieldedValueBufferType = *maybeBufferType; @@ -584,8 +583,8 @@ static FailureOr computeLoopRegionIterArgBufferType( "expected same shape"); } #endif // NDEBUG - return getMemRefTypeWithFullyDynamicLayout( - iterTensorType, yieldedBufferType.getMemorySpace()); + return cast(getMemRefTypeWithFullyDynamicLayout( + iterTensorType, yieldedBufferType.getMemorySpace())); } /// Return `true` if the given loop may have 0 iterations. @@ -708,7 +707,7 @@ struct ForOpInterface return success(); } - FailureOr + FailureOr getBufferType(Operation *op, Value value, const BufferizationOptions &options, const BufferizationState &state, SmallVector &invocationStack) const { @@ -719,12 +718,8 @@ struct ForOpInterface if (auto opResult = dyn_cast(value)) { // The type of an OpResult must match the corresponding iter_arg type. BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult); - auto bufferType = - bufferization::getBufferType(bbArg, options, state, invocationStack); - if (failed(bufferType)) - return failure(); - assert(isa(*bufferType) && "expected memref type"); - return cast(*bufferType); + return bufferization::getBufferType(bbArg, options, state, + invocationStack); } // Compute result/argument number. @@ -1047,7 +1042,7 @@ struct WhileOpInterface return success(); } - FailureOr + FailureOr getBufferType(Operation *op, Value value, const BufferizationOptions &options, const BufferizationState &state, SmallVector &invocationStack) const { @@ -1081,10 +1076,10 @@ struct WhileOpInterface Value conditionYieldedVal = whileOp.getConditionOp().getArgs()[resultNum]; if (!isa(conditionYieldedVal.getType())) { // scf.condition was already bufferized. - return cast(conditionYieldedVal.getType()); + return cast(conditionYieldedVal.getType()); } - return bufferization::detail::asMemRefType(bufferization::getBufferType( - conditionYieldedVal, options, state, invocationStack)); + return bufferization::getBufferType(conditionYieldedVal, options, state, + invocationStack); } /// Assert that yielded values of an scf.while op are equivalent to their @@ -1303,7 +1298,7 @@ struct ForallOpInterface return success(); } - FailureOr + FailureOr getBufferType(Operation *op, Value value, const BufferizationOptions &options, const BufferizationState &state, SmallVector &invocationStack) const { @@ -1312,15 +1307,15 @@ struct ForallOpInterface if (auto bbArg = dyn_cast(value)) // A tensor block argument has the same bufferized type as the // corresponding output operand. - return bufferization::detail::asMemRefType( - bufferization::getBufferType(forallOp.getTiedOpOperand(bbArg)->get(), - options, state, invocationStack)); + return bufferization::getBufferType( + forallOp.getTiedOpOperand(bbArg)->get(), options, state, + invocationStack); // The bufferized result type is the same as the bufferized type of the // corresponding output operand. - return bufferization::detail::asMemRefType(bufferization::getBufferType( + return bufferization::getBufferType( forallOp.getOutputs()[cast(value).getResultNumber()], options, - state, invocationStack)); + state, invocationStack); } bool isRepetitiveRegion(Operation *op, unsigned index) const { diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp index 729c048db4560..829b2ab92ac24 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -49,7 +49,7 @@ struct CastOpInterface return {{op->getResult(0), BufferRelation::Equivalent}}; } - FailureOr + FailureOr getBufferType(Operation *op, Value value, const BufferizationOptions &options, const BufferizationState &state, SmallVector &invocationStack) const { @@ -68,20 +68,22 @@ struct CastOpInterface if (isa(castOp.getSource().getType())) { // When casting to a ranked tensor, we cannot infer any static offset or // strides from the source. Assume fully dynamic. - return getMemRefTypeWithFullyDynamicLayout(castOp.getType(), memorySpace); + return cast( + getMemRefTypeWithFullyDynamicLayout(castOp.getType(), memorySpace)); } // Case 2: Casting to an unranked tensor type if (isa(castOp.getType())) { - return getMemRefTypeWithFullyDynamicLayout(castOp.getType(), memorySpace); + return cast( + getMemRefTypeWithFullyDynamicLayout(castOp.getType(), memorySpace)); } // Case 3: Ranked tensor -> ranked tensor. The offsets and strides do not // change. auto rankedResultType = cast(castOp.getType()); - return MemRefType::get( + return cast(MemRefType::get( rankedResultType.getShape(), rankedResultType.getElementType(), - llvm::cast(*maybeSrcBufferType).getLayout(), memorySpace); + llvm::cast(*maybeSrcBufferType).getLayout(), memorySpace)); } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, @@ -141,7 +143,7 @@ struct CollapseShapeOpInterface return {{op->getOpResult(0), BufferRelation::Equivalent}}; } - FailureOr + FailureOr getBufferType(Operation *op, Value value, const BufferizationOptions &options, const BufferizationState &state, SmallVector &invocationStack) const { @@ -157,12 +159,13 @@ struct CollapseShapeOpInterface if (!canBeCollapsed) { // If dims cannot be collapsed, this op bufferizes to a new allocation. RankedTensorType tensorResultType = collapseShapeOp.getResultType(); - return bufferization::getMemRefTypeWithStaticIdentityLayout( - tensorResultType, srcBufferType.getMemorySpace()); + return cast( + bufferization::getMemRefTypeWithStaticIdentityLayout( + tensorResultType, srcBufferType.getMemorySpace())); } - return memref::CollapseShapeOp::computeCollapsedType( - srcBufferType, collapseShapeOp.getReassociationIndices()); + return cast(memref::CollapseShapeOp::computeCollapsedType( + srcBufferType, collapseShapeOp.getReassociationIndices())); } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, @@ -319,7 +322,7 @@ struct ExpandShapeOpInterface return {{op->getOpResult(0), BufferRelation::Equivalent}}; } - FailureOr + FailureOr getBufferType(Operation *op, Value value, const BufferizationOptions &options, const BufferizationState &state, SmallVector &invocationStack) const { @@ -334,7 +337,7 @@ struct ExpandShapeOpInterface expandShapeOp.getReassociationIndices()); if (failed(maybeResultType)) return failure(); - return *maybeResultType; + return cast(*maybeResultType); } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, @@ -404,7 +407,7 @@ struct ExtractSliceOpInterface return success(); } - FailureOr + FailureOr getBufferType(Operation *op, Value value, const BufferizationOptions &options, const BufferizationState &state, SmallVector &invocationStack) const { @@ -417,10 +420,10 @@ struct ExtractSliceOpInterface SmallVector mixedOffsets = extractSliceOp.getMixedOffsets(); SmallVector mixedSizes = extractSliceOp.getMixedSizes(); SmallVector mixedStrides = extractSliceOp.getMixedStrides(); - return memref::SubViewOp::inferRankReducedResultType( + return cast(memref::SubViewOp::inferRankReducedResultType( extractSliceOp.getType().getShape(), llvm::cast(*srcMemrefType), mixedOffsets, mixedSizes, - mixedStrides); + mixedStrides)); } }; @@ -501,8 +504,8 @@ struct FromElementsOpInterface /*copy=*/false); if (failed(tensorAlloc)) return failure(); - FailureOr memrefType = bufferization::detail::asMemRefType( - bufferization::getBufferType(*tensorAlloc, options, state)); + FailureOr memrefType = + bufferization::getBufferType(*tensorAlloc, options, state); if (failed(memrefType)) return failure(); Value buffer = rewriter.create( @@ -753,7 +756,7 @@ struct PadOpInterface return {}; } - FailureOr + FailureOr getBufferType(Operation *op, Value value, const BufferizationOptions &options, const BufferizationState &state, SmallVector &invocationStack) const { @@ -765,9 +768,10 @@ struct PadOpInterface if (failed(maybeSrcBufferType)) return failure(); MemRefLayoutAttrInterface layout; - return MemRefType::get(padOp.getResultType().getShape(), - padOp.getResultType().getElementType(), layout, - maybeSrcBufferType->getMemorySpace()); + return cast( + MemRefType::get(padOp.getResultType().getShape(), + padOp.getResultType().getElementType(), layout, + maybeSrcBufferType->getMemorySpace())); } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, @@ -927,7 +931,7 @@ struct ReshapeOpInterface return success(); } - FailureOr + FailureOr getBufferType(Operation *op, Value value, const BufferizationOptions &options, const BufferizationState &state, SmallVector &invocationStack) const { @@ -937,9 +941,9 @@ struct ReshapeOpInterface reshapeOp.getSource(), options, state, invocationStack); if (failed(maybeSourceBufferType)) return failure(); - return getMemRefTypeWithStaticIdentityLayout( + return cast(getMemRefTypeWithStaticIdentityLayout( reshapeOp.getResult().getType(), - cast(maybeSourceBufferType.value()).getMemorySpace()); + cast(maybeSourceBufferType.value()).getMemorySpace())); } }; diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir index da3c26ce36ba5..8031732011839 100644 --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir @@ -272,10 +272,10 @@ func.func @materialize_in_dest_raw(%f: f32, %f2: f32, %idx: index) -> (tensor<5x // ----- -// CHECK-LABEL: func.func @test_dialect_op( +// CHECK: func.func @custom_op( // CHECK-SAME: %[[ARG:.*]]: !test.test_tensor<[32, 64], f64> // CHECK-SAME: ) -> !test.test_tensor<[32, 128], f64> { -func.func @test_dialect_op(%arg: !test.test_tensor<[32, 64], f64>) +func.func @custom_op(%arg: !test.test_tensor<[32, 64], f64>) -> !test.test_tensor<[32, 128], f64> { // CHECK: %[[MEMREF:.*]] = bufferization.to_buffer %[[ARG]] // CHECK: %[[DUMMY:.*]] = "test.dummy_memref_op"(%[[MEMREF]]) @@ -288,3 +288,22 @@ func.func @test_dialect_op(%arg: !test.test_tensor<[32, 64], f64>) // CHECK: return %[[OUT]] return %out : !test.test_tensor<[32, 128], f64> } + +// ----- + +// CHECK: func.func @custom_origin_op() +// CHECK-SAME: -> !test.test_tensor<[42], f64> { +func.func @custom_origin_op() -> !test.test_tensor<[42], f64> { + // CHECK: %[[MEMREF:.*]] = "test.create_memref_op"() : () + // CHECK-SAME: -> !test.test_memref<[21], f64> + // CHECK: %[[DUMMY:.*]] = "test.dummy_memref_op"(%[[MEMREF]]) + // CHECK-SAME: : (!test.test_memref<[21], f64>) + // CHECK-SAME: -> !test.test_memref<[42], f64> + %in = "test.create_tensor_op"() : () -> !test.test_tensor<[21], f64> + %out = "test.dummy_tensor_op"(%in) : (!test.test_tensor<[21], f64>) + -> !test.test_tensor<[42], f64> + + // CHECK: %[[OUT:.*]] = bufferization.to_tensor %[[DUMMY]] + // CHECK: return %[[OUT]] + return %out : !test.test_tensor<[42], f64> +} diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp index 78e44c6ec7a9b..b64d3b7230b36 100644 --- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp +++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp @@ -1410,3 +1410,37 @@ ::mlir::LogicalResult test::TestDummyTensorOp::bufferize( return mlir::success(); } + +::mlir::LogicalResult test::TestCreateTensorOp::bufferize( + ::mlir::RewriterBase &rewriter, + const ::mlir::bufferization::BufferizationOptions &options, + ::mlir::bufferization::BufferizationState &state) { + // Note: mlir::bufferization::getBufferType() would internally call + // TestCreateTensorOp::getBufferType() + const auto bufferizedOutType = + mlir::bufferization::getBufferType(getOutput(), options, state); + if (mlir::failed(bufferizedOutType)) + return failure(); + + // replace op with memref analogy + auto createMemrefOp = + rewriter.create(getLoc(), *bufferizedOutType); + + mlir::bufferization::replaceOpWithBufferizedValues( + rewriter, getOperation(), createMemrefOp.getResult()); + + return mlir::success(); +} + +mlir::FailureOr +test::TestCreateTensorOp::getBufferType( + mlir::Value value, const mlir::bufferization::BufferizationOptions &, + const mlir::bufferization::BufferizationState &, + llvm::SmallVector<::mlir::Value> &) { + const auto type = dyn_cast(value.getType()); + if (type == nullptr) + return failure(); + + return cast(test::TestMemrefType::get( + getContext(), type.getShape(), type.getElementType(), nullptr)); +} diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index 79bcd9c2e0a9a..2a4de535b0841 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -3606,4 +3606,57 @@ def TestDummyMemrefOp : TEST_Op<"dummy_memref_op", []> { ); } +def TestCreateTensorOp : TEST_Op<"create_tensor_op", [BufferizableOpInterface]> { + let arguments = (ins); + let results = (outs Arg:$output); + let extraClassDeclaration = [{ + // BufferizableOpInterface + bool bufferizesToMemoryRead(mlir::OpOperand&, + const mlir::bufferization::AnalysisState&); + + bool bufferizesToMemoryWrite(mlir::OpOperand&, + const mlir::bufferization::AnalysisState&); + + bool bufferizesToAllocation(mlir::Value value); + + mlir::bufferization::AliasingValueList getAliasingValues(mlir::OpOperand&, + const mlir::bufferization::AnalysisState&); + + mlir::LogicalResult bufferize( + mlir::RewriterBase& rewriter, + const mlir::bufferization::BufferizationOptions& options, + mlir::bufferization::BufferizationState &state); + + mlir::FailureOr getBufferType( + mlir::Value value, const mlir::bufferization::BufferizationOptions &, + const mlir::bufferization::BufferizationState &, + llvm::SmallVector<::mlir::Value> &); + }]; + + let extraClassDefinition = [{ + bool test::TestCreateTensorOp::bufferizesToMemoryRead(::mlir::OpOperand&, + const ::mlir::bufferization::AnalysisState&) { + return true; + } + bool test::TestCreateTensorOp::bufferizesToMemoryWrite(::mlir::OpOperand&, + const ::mlir::bufferization::AnalysisState&) { + return true; + } + bool test::TestCreateTensorOp::bufferizesToAllocation(mlir::Value value) { + return false; + } + + ::mlir::bufferization::AliasingValueList + test::TestCreateTensorOp::getAliasingValues(::mlir::OpOperand&, + const ::mlir::bufferization::AnalysisState&) { + return {}; + } + }]; +} + +def TestCreateMemrefOp : TEST_Op<"create_memref_op"> { + let arguments = (ins); + let results = (outs Arg:$output); +} + #endif // TEST_OPS