-
Notifications
You must be signed in to change notification settings - Fork 12k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][Tensor] Move concat operation decomposition as a method of the concat operation. #116004
[mlir][Tensor] Move concat operation decomposition as a method of the concat operation. #116004
Conversation
@llvm/pr-subscribers-mlir-tensor @llvm/pr-subscribers-mlir Author: None (MaheshRavishankar) ChangesCurrently the implementation is within a pattern that cannot be used without a pattern rewriter. Move the decomposition as a method of the operation to make it usable outside of pattern rewrites. Full diff: https://github.com/llvm/llvm-project/pull/116004.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 3170115883e2be..b73da8bb6af59c 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -178,6 +178,9 @@ def Tensor_ConcatOp : Tensor_Op<"concat",
int64_t getRank() {
return ::llvm::cast<RankedTensorType>(getResult().getType()).getRank();
}
+
+ // Method to decompose the operation into a sequence of insert_slices.
+ FailureOr<SmallVector<Value>> decomposeOperation(OpBuilder &builder);
}];
let hasCanonicalizer = 1;
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 8e0d0104397468..dd6c7ebf1d0919 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -615,6 +615,51 @@ LogicalResult ConcatOp::verify() {
return success();
}
+FailureOr<SmallVector<Value>> ConcatOp::decomposeOperation(OpBuilder &builder) {
+ size_t numInputs = getInputs().size();
+ uint64_t concatDim = getDim();
+
+ SmallVector<SmallVector<OpFoldResult>> inputShapes;
+ inputShapes.reserve(numInputs);
+ SmallVector<OpFoldResult> concatOffsets;
+ concatOffsets.reserve(numInputs);
+ SmallVector<OpFoldResult> outputShape;
+
+ AffineExpr addExpr =
+ builder.getAffineSymbolExpr(0) + builder.getAffineSymbolExpr(1);
+ OpFoldResult zero = builder.getIndexAttr(0);
+ Location loc = getLoc();
+ for (auto [index, input] : llvm::enumerate(getInputs())) {
+ SmallVector<OpFoldResult> inputShape =
+ tensor::getMixedSizes(builder, input.getLoc(), input);
+ if (index == 0) {
+ outputShape = inputShape;
+ concatOffsets.push_back(zero);
+ } else {
+ concatOffsets.push_back(outputShape[concatDim]);
+ outputShape[concatDim] = affine::makeComposedFoldedAffineApply(
+ builder, loc, addExpr,
+ {outputShape[concatDim], inputShape[concatDim]});
+ }
+ inputShapes.emplace_back(std::move(inputShape));
+ }
+
+ Value replacement = builder.create<tensor::EmptyOp>(
+ loc, outputShape, getType().getElementType());
+
+ int64_t rank = getType().getRank();
+ OpFoldResult one = builder.getIndexAttr(1);
+ SmallVector<OpFoldResult> strides(rank, one);
+ SmallVector<OpFoldResult> offsets(rank, zero);
+ for (auto [index, input] : llvm::enumerate(getInputs())) {
+ offsets[concatDim] = concatOffsets[index];
+ auto insertSlice = builder.create<tensor::InsertSliceOp>(
+ loc, input, replacement, offsets, inputShapes[index], strides);
+ replacement = insertSlice.getResult();
+ }
+ return SmallVector<Value>{replacement};
+}
+
LogicalResult
ConcatOp::reifyResultShapes(OpBuilder &builder,
ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
diff --git a/mlir/lib/Dialect/Tensor/Transforms/ConcatOpPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ConcatOpPatterns.cpp
index 7c8403c9609d84..a2a860fcb38abb 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/ConcatOpPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/ConcatOpPatterns.cpp
@@ -33,54 +33,13 @@ struct DecomposeTensorConcatOp : public OpRewritePattern<ConcatOp> {
LogicalResult matchAndRewrite(ConcatOp concatOp,
PatternRewriter &rewriter) const override {
- Location loc = concatOp.getLoc();
- FailureOr<Value> dest =
- tensor::getOrCreateDestination(rewriter, loc, concatOp->getResult(0));
- if (failed(dest))
- return failure();
-
- auto empty = dest->getDefiningOp<tensor::EmptyOp>();
- if (!empty)
- return failure();
-
- int64_t dim = concatOp.getDim();
- Value dimValue =
- rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(dim));
-
- int64_t rank = concatOp.getResultType().getRank();
- SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
- SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
-
- // Compute the partial sums for the slice offsets.
- AffineExpr sum = rewriter.getAffineDimExpr(0);
- SmallVector<AffineExpr> partialSums = {sum};
- SmallVector<OpFoldResult> offsetStrides = {rewriter.getIndexAttr(0)};
- for (auto [idx, input] :
- llvm::enumerate(concatOp.getInputs().drop_back())) {
- sum = sum + rewriter.getAffineDimExpr(idx + 1);
- partialSums.push_back(sum);
- offsetStrides.push_back(
- rewriter.createOrFold<tensor::DimOp>(loc, input, dimValue));
+ FailureOr<SmallVector<Value>> decomposed =
+ concatOp.decomposeOperation(rewriter);
+ if (failed(decomposed)) {
+ return rewriter.notifyMatchFailure(
+ concatOp, "failed to get the decomposed insert slices");
}
- auto partialSumMap = AffineMap::get(concatOp.getInputs().size(), 0,
- partialSums, rewriter.getContext());
- SmallVector<OpFoldResult> dimOffsets =
- affine::makeComposedFoldedMultiResultAffineApply(
- rewriter, loc, partialSumMap, offsetStrides);
-
- // Construct the chain of insert_slice ops into the destination.
- Value result = *dest;
- for (auto [input, offset] :
- llvm::zip_equal(concatOp.getInputs(), dimOffsets)) {
- SmallVector<OpFoldResult> sizes =
- tensor::getMixedSizes(rewriter, loc, input);
- offsets[dim] = offset;
- result = rewriter.createOrFold<tensor::InsertSliceOp>(
- loc, input, result, offsets, sizes, strides);
- }
-
- rewriter.replaceOpWithNewOp<tensor::CastOp>(
- concatOp, concatOp.getResultType(), result);
+ rewriter.replaceOp(concatOp, decomposed.value()[0]);
return success();
}
};
|
6fb6c27
to
6ceb8f8
Compare
… concat operation. Currently the implementation is within a pattern that cannot be used without a pattern rewriter. Move the decomposition as a method of the operation to make it usable outside of pattern rewrites. Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
6ceb8f8
to
e33c018
Compare
Dropped the "NFC" cause the lit-test changed, and I'd rather keep it this way. |
if (replacement.getType() != getType()) { | ||
replacement = builder.create<tensor::CastOp>(loc, getType(), replacement); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the main difference is that the getOrCreateDestination
"infers" the static shape when possible. We can get rid of the tensor.cast ops
if we use the method. Why do you like the current way better?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cause the earlier implementation was relying on that creating a tensor.empty
and that is a weird dependence on an implementation detail of that method.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, good point. Have you considered using ConcatOp::reifyResultShapes
to create the outputShape for tensor.empty op? Though this way we might create more operations and pay the cost. (I don't have preference, just wanna make sure that the idea is evaluated.)
The current implementation looks okay to me because the root issue is that the op does no infer static shapes when possible. We'll end up with these tensor.cast ops even if the shape inference is implemented.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(sorry my phone was misbehaving and I was afk). That's exactly right. reifyResultShapes
would create a lot of operations. We need to have proper cast propagation to resolve the static information that is outside of this decomposition
Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
Currently the implementation is within a pattern that cannot be used without a pattern rewriter. Move the decomposition as a method of the operation to make it usable outside of pattern rewrites.