diff --git a/mlir/include/mlir/IR/ValueRange.h b/mlir/include/mlir/IR/ValueRange.h index 187185b47b666..f1a1f1841f179 100644 --- a/mlir/include/mlir/IR/ValueRange.h +++ b/mlir/include/mlir/IR/ValueRange.h @@ -162,10 +162,8 @@ class MutableOperandRange { /// elements attribute, which contains the sizes of the sub ranges. MutableOperandRangeRange split(NamedAttribute segmentSizes) const; - /// Returns the value at the given index. - Value operator[](unsigned index) const { - return operator OperandRange()[index]; - } + /// Returns the OpOperand at the given index. + OpOperand &operator[](unsigned index) const; OperandRange::iterator begin() const { return static_cast(*this).begin(); diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h index 006aedced839f..7f6967f11444f 100644 --- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h +++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h @@ -76,7 +76,7 @@ class SuccessorOperands { Value operator[](unsigned index) const { if (isOperandProduced(index)) return Value(); - return forwardedOperands[index - producedOperandCount]; + return forwardedOperands[index - producedOperandCount].get(); } /// Get the range of operands that are simply forwarded to the successor. diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp index e5016c9568046..3a30f1a1405ec 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp @@ -549,22 +549,18 @@ LogicalResult DeallocTensorOp::bufferize(RewriterBase &rewriter, bool MaterializeInDestinationOp::bufferizesToMemoryRead( OpOperand &opOperand, const AnalysisState &state) { - if (&opOperand == &getOperation()->getOpOperand(0) /*source*/) - return true; - return false; + return &opOperand == &getSourceMutable()[0]; } bool MaterializeInDestinationOp::bufferizesToMemoryWrite( OpOperand &opOperand, const AnalysisState &state) { - if (&opOperand == &getOperation()->getOpOperand(1) /*dest*/) - return true; - return false; + return &opOperand == &getDestMutable()[0]; } AliasingValueList MaterializeInDestinationOp::getAliasingValues(OpOperand &opOperand, const AnalysisState &state) { - if (&opOperand == &getOperation()->getOpOperand(1) /*dest*/) + if (&opOperand == &getDestMutable()[0]) return {{getOperation()->getResult(0), BufferRelation::Equivalent}}; return {}; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index 581e7b0a8ea86..6a01c24f02699 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -949,7 +949,7 @@ struct FoldReshapeWithGenericOpByExpansion reshapeOp, "failed preconditions of fusion with producer generic op"); } - if (!controlFoldingReshapes(&reshapeOp->getOpOperand(0))) { + if (!controlFoldingReshapes(&reshapeOp.getSrcMutable()[0])) { return rewriter.notifyMatchFailure(reshapeOp, "fusion blocked by control function"); } diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index 597676a017bf4..1ce25565edcaf 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -509,7 +509,7 @@ mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter, // 1. Get the producer of the source (potentially walking through // `iter_args` of nested `scf.for`) auto [fusableProducer, destinationIterArg] = - getUntiledProducerFromSliceSource(&candidateSliceOp->getOpOperand(0), + getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable()[0], loops); if (!fusableProducer) return std::nullopt; diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp index ecca4dd3394e0..ef4352cf0c659 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -644,11 +644,11 @@ struct InsertSliceOpInterface RankedTensorType destType = insertSliceOp.getDestType(); // The source is always read. - if (&opOperand == &op->getOpOperand(0) /*src*/) + if (&opOperand == &insertSliceOp.getSourceMutable()[0]) return true; // For the destination, it depends... - assert(&opOperand == &insertSliceOp->getOpOperand(1) && "expected dest"); + assert(&opOperand == &insertSliceOp.getDestMutable()[0] && "expected dest"); // Dest is not read if it is entirely overwritten. E.g.: // tensor.insert_slice %a into %t[0][10][1] : ... into tensor<10xf32> @@ -851,9 +851,8 @@ struct ReshapeOpInterface tensor::ReshapeOp> { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { - if (&opOperand == &op->getOpOperand(1) /* shape */) - return true; - return false; + auto reshapeOp = cast(op); + return &opOperand == &reshapeOp.getShapeMutable()[0]; } bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, @@ -915,7 +914,8 @@ struct ParallelInsertSliceOpInterface bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { - return &opOperand == &op->getOpOperand(1) /*dest*/; + auto parallelInsertSliceOp = cast(op); + return &opOperand == ¶llelInsertSliceOp.getDestMutable()[0]; } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp index 65b123e10ddbc..7b17e231ce106 100644 --- a/mlir/lib/IR/OperationSupport.cpp +++ b/mlir/lib/IR/OperationSupport.cpp @@ -517,6 +517,11 @@ void MutableOperandRange::updateLength(unsigned newLength) { } } +OpOperand &MutableOperandRange::operator[](unsigned index) const { + assert(index < length && "index is out of bounds"); + return owner->getOpOperand(start + index); +} + //===----------------------------------------------------------------------===// // MutableOperandRangeRange diff --git a/mlir/lib/Transforms/Utils/CFGToSCF.cpp b/mlir/lib/Transforms/Utils/CFGToSCF.cpp index 84f23584e9f30..9aab89ed75536 100644 --- a/mlir/lib/Transforms/Utils/CFGToSCF.cpp +++ b/mlir/lib/Transforms/Utils/CFGToSCF.cpp @@ -277,7 +277,8 @@ class EdgeMultiplexer { if (index >= result->second && index < result->second + edge.getSuccessor()->getNumArguments()) { // Original block arguments to the entry block. - newSuccOperands[index] = successorOperands[index - result->second]; + newSuccOperands[index] = + successorOperands[index - result->second].get(); continue; }