Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][Tensor] Move concat operation decomposition as a method of the concat operation. #116004

Merged
merged 1 commit into from
Nov 13, 2024

Conversation

MaheshRavishankar
Copy link
Contributor

Currently the implementation is within a pattern that cannot be used without a pattern rewriter. Move the decomposition as a method of the operation to make it usable outside of pattern rewrites.

@llvmbot
Copy link

llvmbot commented Nov 13, 2024

@llvm/pr-subscribers-mlir-tensor

@llvm/pr-subscribers-mlir

Author: None (MaheshRavishankar)

Changes

Currently the implementation is within a pattern that cannot be used without a pattern rewriter. Move the decomposition as a method of the operation to make it usable outside of pattern rewrites.


Full diff: https://github.com/llvm/llvm-project/pull/116004.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td (+3)
  • (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+45)
  • (modified) mlir/lib/Dialect/Tensor/Transforms/ConcatOpPatterns.cpp (+6-47)
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 3170115883e2be..b73da8bb6af59c 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -178,6 +178,9 @@ def Tensor_ConcatOp : Tensor_Op<"concat",
     int64_t getRank() {
       return ::llvm::cast<RankedTensorType>(getResult().getType()).getRank();
     }
+
+    // Method to decompose the operation into a sequence of insert_slices.
+    FailureOr<SmallVector<Value>> decomposeOperation(OpBuilder &builder);
   }];
 
   let hasCanonicalizer = 1;
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 8e0d0104397468..dd6c7ebf1d0919 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -615,6 +615,51 @@ LogicalResult ConcatOp::verify() {
   return success();
 }
 
+FailureOr<SmallVector<Value>> ConcatOp::decomposeOperation(OpBuilder &builder) {
+  size_t numInputs = getInputs().size();
+  uint64_t concatDim = getDim();
+
+  SmallVector<SmallVector<OpFoldResult>> inputShapes;
+  inputShapes.reserve(numInputs);
+  SmallVector<OpFoldResult> concatOffsets;
+  concatOffsets.reserve(numInputs);
+  SmallVector<OpFoldResult> outputShape;
+
+  AffineExpr addExpr =
+      builder.getAffineSymbolExpr(0) + builder.getAffineSymbolExpr(1);
+  OpFoldResult zero = builder.getIndexAttr(0);
+  Location loc = getLoc();
+  for (auto [index, input] : llvm::enumerate(getInputs())) {
+    SmallVector<OpFoldResult> inputShape =
+        tensor::getMixedSizes(builder, input.getLoc(), input);
+    if (index == 0) {
+      outputShape = inputShape;
+      concatOffsets.push_back(zero);
+    } else {
+      concatOffsets.push_back(outputShape[concatDim]);
+      outputShape[concatDim] = affine::makeComposedFoldedAffineApply(
+          builder, loc, addExpr,
+          {outputShape[concatDim], inputShape[concatDim]});
+    }
+    inputShapes.emplace_back(std::move(inputShape));
+  }
+
+  Value replacement = builder.create<tensor::EmptyOp>(
+      loc, outputShape, getType().getElementType());
+
+  int64_t rank = getType().getRank();
+  OpFoldResult one = builder.getIndexAttr(1);
+  SmallVector<OpFoldResult> strides(rank, one);
+  SmallVector<OpFoldResult> offsets(rank, zero);
+  for (auto [index, input] : llvm::enumerate(getInputs())) {
+    offsets[concatDim] = concatOffsets[index];
+    auto insertSlice = builder.create<tensor::InsertSliceOp>(
+        loc, input, replacement, offsets, inputShapes[index], strides);
+    replacement = insertSlice.getResult();
+  }
+  return SmallVector<Value>{replacement};
+}
+
 LogicalResult
 ConcatOp::reifyResultShapes(OpBuilder &builder,
                             ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
diff --git a/mlir/lib/Dialect/Tensor/Transforms/ConcatOpPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ConcatOpPatterns.cpp
index 7c8403c9609d84..a2a860fcb38abb 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/ConcatOpPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/ConcatOpPatterns.cpp
@@ -33,54 +33,13 @@ struct DecomposeTensorConcatOp : public OpRewritePattern<ConcatOp> {
 
   LogicalResult matchAndRewrite(ConcatOp concatOp,
                                 PatternRewriter &rewriter) const override {
-    Location loc = concatOp.getLoc();
-    FailureOr<Value> dest =
-        tensor::getOrCreateDestination(rewriter, loc, concatOp->getResult(0));
-    if (failed(dest))
-      return failure();
-
-    auto empty = dest->getDefiningOp<tensor::EmptyOp>();
-    if (!empty)
-      return failure();
-
-    int64_t dim = concatOp.getDim();
-    Value dimValue =
-        rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(dim));
-
-    int64_t rank = concatOp.getResultType().getRank();
-    SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
-    SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
-
-    // Compute the partial sums for the slice offsets.
-    AffineExpr sum = rewriter.getAffineDimExpr(0);
-    SmallVector<AffineExpr> partialSums = {sum};
-    SmallVector<OpFoldResult> offsetStrides = {rewriter.getIndexAttr(0)};
-    for (auto [idx, input] :
-         llvm::enumerate(concatOp.getInputs().drop_back())) {
-      sum = sum + rewriter.getAffineDimExpr(idx + 1);
-      partialSums.push_back(sum);
-      offsetStrides.push_back(
-          rewriter.createOrFold<tensor::DimOp>(loc, input, dimValue));
+    FailureOr<SmallVector<Value>> decomposed =
+        concatOp.decomposeOperation(rewriter);
+    if (failed(decomposed)) {
+      return rewriter.notifyMatchFailure(
+          concatOp, "failed to get the decomposed insert slices");
     }
-    auto partialSumMap = AffineMap::get(concatOp.getInputs().size(), 0,
-                                        partialSums, rewriter.getContext());
-    SmallVector<OpFoldResult> dimOffsets =
-        affine::makeComposedFoldedMultiResultAffineApply(
-            rewriter, loc, partialSumMap, offsetStrides);
-
-    // Construct the chain of insert_slice ops into the destination.
-    Value result = *dest;
-    for (auto [input, offset] :
-         llvm::zip_equal(concatOp.getInputs(), dimOffsets)) {
-      SmallVector<OpFoldResult> sizes =
-          tensor::getMixedSizes(rewriter, loc, input);
-      offsets[dim] = offset;
-      result = rewriter.createOrFold<tensor::InsertSliceOp>(
-          loc, input, result, offsets, sizes, strides);
-    }
-
-    rewriter.replaceOpWithNewOp<tensor::CastOp>(
-        concatOp, concatOp.getResultType(), result);
+    rewriter.replaceOp(concatOp, decomposed.value()[0]);
     return success();
   }
 };

… concat operation.

Currently the implementation is within a pattern that cannot be used
without a pattern rewriter. Move the decomposition as a method of the
operation to make it usable outside of pattern rewrites.

Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
@MaheshRavishankar MaheshRavishankar changed the title [mlir][Tensor] NFC: Move concat operation decomposition as a method of the concat operation. [mlir][Tensor] Move concat operation decomposition as a method of the concat operation. Nov 13, 2024
@MaheshRavishankar
Copy link
Contributor Author

Dropped the "NFC" cause the lit-test changed, and I'd rather keep it this way.

Comment on lines +660 to +662
if (replacement.getType() != getType()) {
replacement = builder.create<tensor::CastOp>(loc, getType(), replacement);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the main difference is that the getOrCreateDestination "infers" the static shape when possible. We can get rid of the tensor.cast ops if we use the method. Why do you like the current way better?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cause the earlier implementation was relying on that creating a tensor.empty and that is a weird dependence on an implementation detail of that method.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, good point. Have you considered using ConcatOp::reifyResultShapes to create the outputShape for tensor.empty op? Though this way we might create more operations and pay the cost. (I don't have preference, just wanna make sure that the idea is evaluated.)

The current implementation looks okay to me because the root issue is that the op does no infer static shapes when possible. We'll end up with these tensor.cast ops even if the shape inference is implemented.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(sorry my phone was misbehaving and I was afk). That's exactly right. reifyResultShapes would create a lot of operations. We need to have proper cast propagation to resolve the static information that is outside of this decomposition

@MaheshRavishankar MaheshRavishankar merged commit de6d48d into llvm:main Nov 13, 2024
8 checks passed
@MaheshRavishankar MaheshRavishankar deleted the decompose_concat branch November 13, 2024 21:46
MaheshRavishankar added a commit to MaheshRavishankar/iree that referenced this pull request Nov 14, 2024
Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
MaheshRavishankar added a commit to MaheshRavishankar/iree that referenced this pull request Nov 14, 2024
Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
MaheshRavishankar added a commit to MaheshRavishankar/iree that referenced this pull request Nov 14, 2024
Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
MaheshRavishankar added a commit to MaheshRavishankar/iree that referenced this pull request Nov 15, 2024
Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants