diff --git a/mlir/include/mlir/IR/Value.h b/mlir/include/mlir/IR/Value.h index 51d4e366e4970..4e550fe3e3a60 100644 --- a/mlir/include/mlir/IR/Value.h +++ b/mlir/include/mlir/IR/Value.h @@ -268,6 +268,9 @@ class OpOperand : public IROperand { /// Return which operand this is in the OpOperand list of the Operation. unsigned getOperandNumber(); + /// Set the current value being used by this operand. + void assign(Value value) { set(value); } + private: /// Keep the constructor private and accessible to the OperandStorage class /// only to avoid hard-to-debug typo/programming mistakes. diff --git a/mlir/include/mlir/IR/ValueRange.h b/mlir/include/mlir/IR/ValueRange.h index 9c11178f9cd9c..ed69e5824f70b 100644 --- a/mlir/include/mlir/IR/ValueRange.h +++ b/mlir/include/mlir/IR/ValueRange.h @@ -126,6 +126,9 @@ class MutableOperandRange { ArrayRef operandSegments = std::nullopt); MutableOperandRange(Operation *owner); + /// Construct a new mutable range for the given OpOperand. + MutableOperandRange(OpOperand &opOperand); + /// Slice this range into a sub range, with the additional operand segment. MutableOperandRange slice(unsigned subStart, unsigned subLen, diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp index d2823d17c99c8..01cbacc96fd42 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp @@ -537,18 +537,18 @@ LogicalResult DeallocTensorOp::bufferize(RewriterBase &rewriter, bool MaterializeInDestinationOp::bufferizesToMemoryRead( OpOperand &opOperand, const AnalysisState &state) { - return &opOperand == &getSourceMutable()[0]; + return &opOperand == &getSourceMutable(); } bool MaterializeInDestinationOp::bufferizesToMemoryWrite( OpOperand &opOperand, const AnalysisState &state) { - return &opOperand == &getDestMutable()[0]; + return &opOperand == &getDestMutable(); } AliasingValueList MaterializeInDestinationOp::getAliasingValues(OpOperand &opOperand, const AnalysisState &state) { - if (&opOperand == &getDestMutable()[0]) + if (&opOperand == &getDestMutable()) return {{getOperation()->getResult(0), BufferRelation::Equivalent}}; return {}; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index 17346607fa9cd..069c613cc246d 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -949,7 +949,7 @@ struct FoldReshapeWithGenericOpByExpansion reshapeOp, "failed preconditions of fusion with producer generic op"); } - if (!controlFoldingReshapes(&reshapeOp.getSrcMutable()[0])) { + if (!controlFoldingReshapes(&reshapeOp.getSrcMutable())) { return rewriter.notifyMatchFailure(reshapeOp, "fusion blocked by control function"); } diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index 1ecfe7906c571..96d6169111b38 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -526,7 +526,7 @@ mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter, // 1. Get the producer of the source (potentially walking through // `iter_args` of nested `scf.for`) auto [fusableProducer, destinationInitArg] = - getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable()[0], + getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable(), loops); if (!fusableProducer) return std::nullopt; diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp index b08283f007078..c0daebefb0ad5 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -635,11 +635,11 @@ struct InsertSliceOpInterface RankedTensorType destType = insertSliceOp.getDestType(); // The source is always read. - if (&opOperand == &insertSliceOp.getSourceMutable()[0]) + if (&opOperand == &insertSliceOp.getSourceMutable()) return true; // For the destination, it depends... - assert(&opOperand == &insertSliceOp.getDestMutable()[0] && "expected dest"); + assert(&opOperand == &insertSliceOp.getDestMutable() && "expected dest"); // Dest is not read if it is entirely overwritten. E.g.: // tensor.insert_slice %a into %t[0][10][1] : ... into tensor<10xf32> @@ -839,7 +839,7 @@ struct ReshapeOpInterface bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { auto reshapeOp = cast(op); - return &opOperand == &reshapeOp.getShapeMutable()[0]; + return &opOperand == &reshapeOp.getShapeMutable(); } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, @@ -916,7 +916,7 @@ struct ParallelInsertSliceOpInterface bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { auto parallelInsertSliceOp = cast(op); - return &opOperand == ¶llelInsertSliceOp.getDestMutable()[0]; + return &opOperand == ¶llelInsertSliceOp.getDestMutable(); } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, diff --git a/mlir/lib/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.cpp index dff9f64169d49..f4f46d54d78e5 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.cpp @@ -63,7 +63,7 @@ struct InsertSliceOpInterface : public SubsetInsertionOpInterface::ExternalModel { OpOperand &getSourceOperand(Operation *op) const { - return op->getOpOperand(0); + return cast(op).getSourceMutable(); } bool @@ -91,11 +91,11 @@ struct ParallelInsertSliceOpInterface : public SubsetInsertionOpInterface::ExternalModel< ParallelInsertSliceOpInterface, tensor::ParallelInsertSliceOp> { OpOperand &getSourceOperand(Operation *op) const { - return op->getOpOperand(0); + return cast(op).getSourceMutable(); } OpOperand &getDestinationOperand(Operation *op) const { - return op->getOpOperand(1); + return cast(op).getDestMutable(); } bool diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp index 5aac72fcb062c..6726b49dd3d31 100644 --- a/mlir/lib/IR/OperationSupport.cpp +++ b/mlir/lib/IR/OperationSupport.cpp @@ -437,6 +437,12 @@ MutableOperandRange::MutableOperandRange( MutableOperandRange::MutableOperandRange(Operation *owner) : MutableOperandRange(owner, /*start=*/0, owner->getNumOperands()) {} +/// Construct a new mutable range for the given OpOperand. +MutableOperandRange::MutableOperandRange(OpOperand &opOperand) + : MutableOperandRange(opOperand.getOwner(), + /*start=*/opOperand.getOperandNumber(), + /*length=*/1) {} + /// Slice this range into a sub range, with the additional operand segment. MutableOperandRange MutableOperandRange::slice(unsigned subStart, unsigned subLen, diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp index 3f1bb667a0aec..8aa7774d2bf00 100644 --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -998,7 +998,7 @@ void LoopBlockOp::getSuccessorRegions( OperandRange LoopBlockOp::getEntrySuccessorOperands(RegionBranchPoint point) { assert(point == getBody()); - return getInitMutable(); + return MutableOperandRange(getInitMutable()); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/mlir-tblgen/op-decl-and-defs.td b/mlir/test/mlir-tblgen/op-decl-and-defs.td index 2afde1abdb726..85d68cc42ccbd 100644 --- a/mlir/test/mlir-tblgen/op-decl-and-defs.td +++ b/mlir/test/mlir-tblgen/op-decl-and-defs.td @@ -97,7 +97,7 @@ def NS_AOp : NS_Op<"a_op", [IsolatedFromAbove, IsolatedFromAbove]> { // CHECK: ::mlir::Operation::operand_range getODSOperands(unsigned index); // CHECK: ::mlir::TypedValue<::mlir::IntegerType> getA(); // CHECK: ::mlir::Operation::operand_range getB(); -// CHECK: ::mlir::MutableOperandRange getAMutable(); +// CHECK: ::mlir::OpOperand &getAMutable(); // CHECK: ::mlir::MutableOperandRange getBMutable(); // CHECK: ::mlir::Operation::result_range getODSResults(unsigned index); // CHECK: ::mlir::TypedValue<::mlir::IntegerType> getR(); diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index 5f6a4e3bc52a8..7029c0eac15c3 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -2071,14 +2071,26 @@ void OpEmitter::genNamedOperandSetters() { continue; std::string name = op.getGetterName(operand.name); - auto *m = opClass.addMethod(operand.isVariadicOfVariadic() - ? "::mlir::MutableOperandRangeRange" - : "::mlir::MutableOperandRange", - name + "Mutable"); + StringRef returnType; + if (operand.isVariadicOfVariadic()) { + returnType = "::mlir::MutableOperandRangeRange"; + } else if (operand.isVariableLength()) { + returnType = "::mlir::MutableOperandRange"; + } else { + returnType = "::mlir::OpOperand &"; + } + auto *m = opClass.addMethod(returnType, name + "Mutable"); ERROR_IF_PRUNED(m, name, op); auto &body = m->body(); - body << " auto range = getODSOperandIndexAndLength(" << i << ");\n" - << " auto mutableRange = " + body << " auto range = getODSOperandIndexAndLength(" << i << ");\n"; + + if (!operand.isVariadicOfVariadic() && !operand.isVariableLength()) { + // In case of a single operand, return a single OpOperand. + body << " return getOperation()->getOpOperand(range.first);\n"; + continue; + } + + body << " auto mutableRange = " "::mlir::MutableOperandRange(getOperation(), " "range.first, range.second"; if (attrSizedOperands) {