Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][IR] Change MutableArrayRange to enumerate OpOperand & #66622

Merged

Conversation

matthias-springer
Copy link
Member

In line with #66515, change MutableArrayRange::begin/end to enumerate OpOperand & instead of Value. Also remove ForOp::getIterOpOperands/setIterArg, which are now redundant.

Note: MutableOperandRange cannot be made a derived class of indexed_accessor_range_base (like OperandRange), because MutableOperandRange::assign can change the number of operands in the range.

In line with llvm#66515, change `MutableArrayRange::begin`/`end` to enumerate `OpOperand &` instead of `Value`. Also a remove `ForOp::getIterOpOperands`/`setIterArg`, which are now redundant.

Note: `MutableOperandRange` cannot be made a derived class of `indexed_accessor_range_base` (like `OperandRange`), because `MutableOperandRange::assign` can change the number of operands in the range.
@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir mlir:bufferization Bufferization infrastructure mlir:scf mlir:cf labels Sep 18, 2023
@llvmbot
Copy link
Member

llvmbot commented Sep 18, 2023

@llvm/pr-subscribers-mlir-core
@llvm/pr-subscribers-mlir-bufferization
@llvm/pr-subscribers-mlir-cf
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-scf

Changes

In line with #66515, change MutableArrayRange::begin/end to enumerate OpOperand & instead of Value. Also remove ForOp::getIterOpOperands/setIterArg, which are now redundant.

Note: MutableOperandRange cannot be made a derived class of indexed_accessor_range_base (like OperandRange), because MutableOperandRange::assign can change the number of operands in the range.

Full diff: https://github.com/llvm/llvm-project/pull/66622.diff

9 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SCF/IR/SCFOps.td (-7)
  • (modified) mlir/include/mlir/IR/ValueRange.h (+3-7)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp (+5-1)
  • (modified) mlir/lib/Dialect/ControlFlow/Transforms/BufferDeallocationOpInterfaceImpl.cpp (+5-5)
  • (modified) mlir/lib/Dialect/SCF/IR/SCF.cpp (+2-2)
  • (modified) mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp (+3-3)
  • (modified) mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp (+7-9)
  • (modified) mlir/lib/IR/OperationSupport.cpp (+8)
  • (modified) mlir/lib/Transforms/Utils/CFGToSCF.cpp (+20-8)
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 6d8aaf64e3263b9..8a9ce949a750d43 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -250,17 +250,10 @@ def ForOp : SCF_Op<"for",
         "expected an index less than the number of region iter args");
       return getBody()->getArguments().drop_front(getNumInductionVars())[index];
     }
-    MutableArrayRef<OpOperand> getIterOpOperands() {
-      return
-        getOperation()->getOpOperands().drop_front(getNumControlOperands());
-    }
 
     void setLowerBound(Value bound) { getOperation()->setOperand(0, bound); }
     void setUpperBound(Value bound) { getOperation()->setOperand(1, bound); }
     void setStep(Value step) { getOperation()->setOperand(2, step); }
-    void setIterArg(unsigned iterArgNum, Value iterArgValue) {
-      getOperation()->setOperand(iterArgNum + getNumControlOperands(), iterArgValue);
-    }
 
     /// Number of induction variables, always 1 for scf::ForOp.
     unsigned getNumInductionVars() { return 1; }
diff --git a/mlir/include/mlir/IR/ValueRange.h b/mlir/include/mlir/IR/ValueRange.h
index f1a1f1841f179e7..9c11178f9cd9cae 100644
--- a/mlir/include/mlir/IR/ValueRange.h
+++ b/mlir/include/mlir/IR/ValueRange.h
@@ -165,13 +165,9 @@ class MutableOperandRange {
   /// Returns the OpOperand at the given index.
   OpOperand &operator[](unsigned index) const;
 
-  OperandRange::iterator begin() const {
-    return static_cast<OperandRange>(*this).begin();
-  }
-
-  OperandRange::iterator end() const {
-    return static_cast<OperandRange>(*this).end();
-  }
+  /// Iterators enumerate OpOperands.
+  MutableArrayRef<OpOperand>::iterator begin() const;
+  MutableArrayRef<OpOperand>::iterator end() const;
 
 private:
   /// Update the length of this range to the one provided.
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
index 43ba11cf132cb92..09d30835828084d 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
@@ -47,6 +47,10 @@ static Value buildBoolValue(OpBuilder &builder, Location loc, bool value) {
 
 static bool isMemref(Value v) { return v.getType().isa<BaseMemRefType>(); }
 
+static bool isMemrefOperand(OpOperand &operand) {
+  return isMemref(operand.get());
+}
+
 //===----------------------------------------------------------------------===//
 // Backedges analysis
 //===----------------------------------------------------------------------===//
@@ -937,7 +941,7 @@ BufferDeallocation::handleInterface(RegionBranchTerminatorOpInterface op) {
 
   // Add an additional operand for every MemRef for the ownership indicator.
   if (!funcWithoutDynamicOwnership) {
-    unsigned numMemRefs = llvm::count_if(operands, isMemref);
+    unsigned numMemRefs = llvm::count_if(operands, isMemrefOperand);
     SmallVector<Value> newOperands{OperandRange(operands)};
     auto ownershipValues =
         deallocOp.getUpdatedConditions().take_front(numMemRefs);
diff --git a/mlir/lib/Dialect/ControlFlow/Transforms/BufferDeallocationOpInterfaceImpl.cpp b/mlir/lib/Dialect/ControlFlow/Transforms/BufferDeallocationOpInterfaceImpl.cpp
index e847e946eef1b5d..9423af2542690d9 100644
--- a/mlir/lib/Dialect/ControlFlow/Transforms/BufferDeallocationOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/ControlFlow/Transforms/BufferDeallocationOpInterfaceImpl.cpp
@@ -96,12 +96,12 @@ struct CondBranchOpInterface
         mapping[retained] = ownership;
       }
       SmallVector<Value> replacements, ownerships;
-      for (Value operand : destOperands) {
-        replacements.push_back(operand);
-        if (isMemref(operand)) {
-          assert(mapping.contains(operand) &&
+      for (OpOperand &operand : destOperands) {
+        replacements.push_back(operand.get());
+        if (isMemref(operand.get())) {
+          assert(mapping.contains(operand.get()) &&
                  "Should be contained at this point");
-          ownerships.push_back(mapping[operand]);
+          ownerships.push_back(mapping[operand.get()]);
         }
       }
       replacements.append(ownerships);
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 5565aefbad18db5..2a760c76d2f6867 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -932,7 +932,7 @@ replaceTensorCastForOpIterArg(PatternRewriter &rewriter, OpOperand &operand,
   assert(operand.get().getType() != replacement.getType() &&
          "Expected a different type");
   SmallVector<Value> newIterOperands;
-  for (OpOperand &opOperand : forOp.getIterOpOperands()) {
+  for (OpOperand &opOperand : forOp.getInitArgsMutable()) {
     if (opOperand.getOperandNumber() == operand.getOperandNumber()) {
       newIterOperands.push_back(replacement);
       continue;
@@ -1015,7 +1015,7 @@ struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
 
   LogicalResult matchAndRewrite(ForOp op,
                                 PatternRewriter &rewriter) const override {
-    for (auto it : llvm::zip(op.getIterOpOperands(), op.getResults())) {
+    for (auto it : llvm::zip(op.getInitArgsMutable(), op.getResults())) {
       OpOperand &iterOpOperand = std::get<0>(it);
       auto incomingCast = iterOpOperand.get().getDefiningOp<tensor::CastOp>();
       if (!incomingCast ||
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index 11cfefed890c669..8c04a8887013c8f 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -332,7 +332,7 @@ DenseSet<int64_t> getEquivalentBuffers(Block::BlockArgListType bbArgs,
 /// Helper function for loop bufferization. Return the bufferized values of the
 /// given OpOperands. If an operand is not a tensor, return the original value.
 static FailureOr<SmallVector<Value>>
-getBuffers(RewriterBase &rewriter, MutableArrayRef<OpOperand> operands,
+getBuffers(RewriterBase &rewriter, MutableOperandRange operands,
            const BufferizationOptions &options) {
   SmallVector<Value> result;
   for (OpOperand &opOperand : operands) {
@@ -606,7 +606,7 @@ struct ForOpInterface
 
     // The new memref init_args of the loop.
     FailureOr<SmallVector<Value>> maybeInitArgs =
-        getBuffers(rewriter, forOp.getIterOpOperands(), options);
+        getBuffers(rewriter, forOp.getInitArgsMutable(), options);
     if (failed(maybeInitArgs))
       return failure();
     SmallVector<Value> initArgs = *maybeInitArgs;
@@ -825,7 +825,7 @@ struct WhileOpInterface
 
     // The new memref init_args of the loop.
     FailureOr<SmallVector<Value>> maybeInitArgs =
-        getBuffers(rewriter, whileOp->getOpOperands(), options);
+        getBuffers(rewriter, whileOp.getInitsMutable(), options);
     if (failed(maybeInitArgs))
       return failure();
     SmallVector<Value> initArgs = *maybeInitArgs;
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 1ce25565edcaf61..ceec0756e421ffd 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -508,7 +508,7 @@ mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
                                       MutableArrayRef<scf::ForOp> loops) {
   // 1. Get the producer of the source (potentially walking through
   // `iter_args` of nested `scf.for`)
-  auto [fusableProducer, destinationIterArg] =
+  auto [fusableProducer, destinationInitArg] =
       getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable()[0],
                                         loops);
   if (!fusableProducer)
@@ -575,17 +575,15 @@ mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
   // TODO: This can be modeled better if the `DestinationStyleOpInterface`.
   // Update to use that when it does become available.
   scf::ForOp outerMostLoop = loops.front();
-  std::optional<unsigned> iterArgNumber;
-  if (destinationIterArg) {
-    iterArgNumber =
-        outerMostLoop.getIterArgNumberForOpOperand(*destinationIterArg.value());
-  }
-  if (iterArgNumber) {
+  if (destinationInitArg &&
+      (*destinationInitArg)->getOwner() == outerMostLoop) {
+    std::optional<unsigned> iterArgNumber =
+        outerMostLoop.getIterArgNumberForOpOperand(**destinationInitArg);
     int64_t resultNumber = fusableProducer.getResultNumber();
     if (auto dstOp =
             dyn_cast<DestinationStyleOpInterface>(fusableProducer.getOwner())) {
-      outerMostLoop.setIterArg(iterArgNumber.value(),
-                               dstOp.getTiedOpOperand(fusableProducer)->get());
+      (*destinationInitArg)
+          ->set(dstOp.getTiedOpOperand(fusableProducer)->get());
     }
     for (auto tileAndFusedOp : tileAndFuseResult->tiledOps) {
       auto dstOp = dyn_cast<DestinationStyleOpInterface>(tileAndFusedOp);
diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp
index 7b17e231ce1065f..b0c50f3d6e29855 100644
--- a/mlir/lib/IR/OperationSupport.cpp
+++ b/mlir/lib/IR/OperationSupport.cpp
@@ -522,6 +522,14 @@ OpOperand &MutableOperandRange::operator[](unsigned index) const {
   return owner->getOpOperand(start + index);
 }
 
+MutableArrayRef<OpOperand>::iterator MutableOperandRange::begin() const {
+  return owner->getOpOperands().slice(start, length).begin();
+}
+
+MutableArrayRef<OpOperand>::iterator MutableOperandRange::end() const {
+  return owner->getOpOperands().slice(start, length).end();
+}
+
 //===----------------------------------------------------------------------===//
 // MutableOperandRangeRange
 
diff --git a/mlir/lib/Transforms/Utils/CFGToSCF.cpp b/mlir/lib/Transforms/Utils/CFGToSCF.cpp
index 9aab89ed7553600..e7bf6628ccbd7e0 100644
--- a/mlir/lib/Transforms/Utils/CFGToSCF.cpp
+++ b/mlir/lib/Transforms/Utils/CFGToSCF.cpp
@@ -137,6 +137,13 @@ getMutableSuccessorOperands(Block *block, unsigned successorIndex) {
   return succOps.getMutableForwardedOperands();
 }
 
+/// Return the operand range used to transfer operands from `block` to its
+/// successor with the given index.
+static OperandRange getSuccessorOperands(Block *block,
+                                         unsigned successorIndex) {
+  return getMutableSuccessorOperands(block, successorIndex);
+}
+
 /// Appends all the block arguments from `other` to the block arguments of
 /// `block`, copying their types and locations.
 static void addBlockArgumentsFromOther(Block *block, Block *other) {
@@ -175,8 +182,14 @@ class Edge {
 
   /// Returns the arguments of this edge that are passed to the block arguments
   /// of the successor.
-  MutableOperandRange getSuccessorOperands() const {
-    return getMutableSuccessorOperands(fromBlock, successorIndex);
+  MutableOperandRange getMutableSuccessorOperands() const {
+    return ::getMutableSuccessorOperands(fromBlock, successorIndex);
+  }
+
+  /// Returns the arguments of this edge that are passed to the block arguments
+  /// of the successor.
+  OperandRange getSuccessorOperands() const {
+    return ::getSuccessorOperands(fromBlock, successorIndex);
   }
 };
 
@@ -262,7 +275,7 @@ class EdgeMultiplexer {
     assert(result != blockArgMapping.end() &&
            "Edge was not originally passed to `create` method.");
 
-    MutableOperandRange successorOperands = edge.getSuccessorOperands();
+    MutableOperandRange successorOperands = edge.getMutableSuccessorOperands();
 
     // Extra arguments are always appended at the end of the block arguments.
     unsigned extraArgsBeginIndex =
@@ -666,7 +679,7 @@ transformToReduceLoop(Block *loopHeader, Block *exitBlock,
   // invalidated when mutating the operands through a different
   // `MutableOperandRange` of the same operation.
   SmallVector<Value> loopHeaderSuccessorOperands =
-      llvm::to_vector(getMutableSuccessorOperands(latch, loopHeaderIndex));
+      llvm::to_vector(getSuccessorOperands(latch, loopHeaderIndex));
 
   // Add all values used in the next iteration to the exit block. Replace
   // any uses that are outside the loop with the newly created exit block.
@@ -742,7 +755,7 @@ transformToReduceLoop(Block *loopHeader, Block *exitBlock,
 
           loopHeaderSuccessorOperands.push_back(argument);
           for (Edge edge : successorEdges(latch))
-            edge.getSuccessorOperands().append(argument);
+            edge.getMutableSuccessorOperands().append(argument);
         }
 
         use.set(blockArgument);
@@ -939,9 +952,8 @@ static FailureOr<SmallVector<Block *>> transformToStructuredCFBranches(
   if (regionEntry->getNumSuccessors() == 1) {
     // Single successor we can just splice together.
     Block *successor = regionEntry->getSuccessor(0);
-    for (auto &&[oldValue, newValue] :
-         llvm::zip(successor->getArguments(),
-                   getMutableSuccessorOperands(regionEntry, 0)))
+    for (auto &&[oldValue, newValue] : llvm::zip(
+             successor->getArguments(), getSuccessorOperands(regionEntry, 0)))
       oldValue.replaceAllUsesWith(newValue);
     regionEntry->getTerminator()->erase();
 

@matthias-springer matthias-springer merged commit 6923a31 into llvm:main Sep 19, 2023
ZijunZhaoCCK pushed a commit to ZijunZhaoCCK/llvm-project that referenced this pull request Sep 19, 2023
…m#66622)

In line with llvm#66515, change `MutableArrayRange::begin`/`end` to
enumerate `OpOperand &` instead of `Value`. Also remove
`ForOp::getIterOpOperands`/`setIterArg`, which are now redundant.

Note: `MutableOperandRange` cannot be made a derived class of
`indexed_accessor_range_base` (like `OperandRange`), because
`MutableOperandRange::assign` can change the number of operands in the
range.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:bufferization Bufferization infrastructure mlir:cf mlir:core MLIR Core Infrastructure mlir:scf mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants