Skip to content

Commit

Permalink
[Flow] Add patterns to convert from tensor.concat to `flow.tensor.u…
Browse files Browse the repository at this point in the history
…pdate`. (iree-org#19126)

These are in preparation to delay to decomposition of `tensor.concat`
into `tensor.insert_slice`s. This patch just adds the patterns to lower
a `tensor.concat` along the outer dimension to `flow.tensor.update`.
Future changes will delay the decomposition of `tensor.concat` to allow
for non-outer dimension concatenation to be conveted into
`tensor.insert_slice`s before dispatch formation with the
`tensor.insert_slice` fused into its producers.

Towards iree-org#19092

---------

Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
  • Loading branch information
MaheshRavishankar authored and Groverkss committed Nov 29, 2024
1 parent c39b4e2 commit f51e1da
Show file tree
Hide file tree
Showing 6 changed files with 117 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ iree_compiler_cc_library(
deps = [
"//compiler/src/iree/compiler/Dialect/Flow/IR",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AffineDialect",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:ArithUtils",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ iree_cc_library(
"Utils.cpp"
DEPS
LLVMSupport
MLIRAffineDialect
MLIRAnalysis
MLIRArithDialect
MLIRArithUtils
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "iree/compiler/Dialect/Flow/Conversion/TensorToFlow/Utils.h"
#include "iree/compiler/Dialect/Flow/IR/FlowDialect.h"
#include "iree/compiler/Dialect/Flow/IR/FlowOps.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
Expand Down Expand Up @@ -174,6 +175,74 @@ struct ConvertTensorCastPattern : public OpRewritePattern<tensor::CastOp> {
}
};

struct ConvertTensorConcatPattern : public OpRewritePattern<tensor::ConcatOp> {
using OpRewritePattern<tensor::ConcatOp>::OpRewritePattern;

LogicalResult matchAndRewrite(tensor::ConcatOp concatOp,
PatternRewriter &rewriter) const override {
if (concatOp->getParentOfType<IREE::Flow::DispatchRegionOp>() ||
concatOp->getParentOfType<IREE::Flow::DispatchWorkgroupsOp>()) {
return failure();
}
if (concatOp.getDim() != 0) {
return rewriter.notifyMatchFailure(
concatOp, "only outer-dim concat lowering supported");
}
assert(cast<RankedTensorType>(concatOp.getInputs().front().getType())
.getRank() != 0 &&
"concat cannot be of zero-rank tensors");

Location loc = concatOp.getLoc();
SmallVector<SmallVector<OpFoldResult>> inputShapes;
inputShapes.reserve(concatOp.getInputs().size());
// Note the output shape is computed directly without using
// `reifyResultShapes` since we need the `inputShapes` anyway and using the
// method would create duplicate `tensor.dim` operations.
SmallVector<OpFoldResult> outputShape;
AffineExpr addExpr =
rewriter.getAffineSymbolExpr(0) + rewriter.getAffineSymbolExpr(1);
SmallVector<OpFoldResult> concatOffsets;
concatOffsets.reserve(concatOp.getInputs().size());
for (auto [index, input] : llvm::enumerate(concatOp.getInputs())) {
SmallVector<OpFoldResult> inputShape =
tensor::getMixedSizes(rewriter, input.getLoc(), input);
if (index == 0) {
outputShape = inputShape;
concatOffsets.push_back(rewriter.getIndexAttr(0));
} else {
concatOffsets.push_back(outputShape[0]);
outputShape[0] = affine::makeComposedFoldedAffineApply(
rewriter, loc, addExpr, {outputShape[0], inputShape[0]});
}
inputShapes.emplace_back(std::move(inputShape));
}

Value replacement = rewriter.create<tensor::EmptyOp>(
loc, outputShape, concatOp.getType().getElementType());

SmallVector<int64_t> resultStaticDims;
SmallVector<Value> resultDynamicDims;
dispatchIndexOpFoldResults(outputShape, resultDynamicDims,
resultStaticDims);
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
// Generate the `flow.tensor.update` operations for the concat.
for (auto [index, input] : llvm::enumerate(concatOp.getInputs())) {
SmallVector<int64_t> inputStaticShape;
SmallVector<Value> inputDynamicShape;
dispatchIndexOpFoldResults(inputShapes[index], inputDynamicShape,
inputStaticShape);
SmallVector<Value> offsets(inputStaticShape.size(), zero);
offsets[0] =
getValueOrCreateConstantIndexOp(rewriter, loc, concatOffsets[index]);
replacement = rewriter.create<IREE::Flow::TensorUpdateOp>(
loc, replacement.getType(), replacement, resultDynamicDims, offsets,
input, inputDynamicShape);
}
rewriter.replaceOp(concatOp, replacement);
return success();
}
};

struct ConvertTensorFromElementsPattern
: public OpRewritePattern<tensor::FromElementsOp> {
using OpRewritePattern<tensor::FromElementsOp>::OpRewritePattern;
Expand Down Expand Up @@ -316,14 +385,14 @@ struct ConvertTensorReshapePattern : public OpRewritePattern<TensorReshapeOp> {

void populateTensorToFlowConversionPatterns(MLIRContext *context,
RewritePatternSet &patterns) {
patterns
.insert<ConvertLinalgFillPattern, ConvertTensorBitcastPattern,
ConvertTensorCastPattern, ConvertTensorExtractPattern,
ConvertTensorExtractSlicePattern, ConvertTensorInsertSlicePattern,
ConvertTensorInsertPattern, ConvertTensorFromElementsPattern,
ConvertTensorDialectReshapeOpPattern,
ConvertTensorReshapePattern<tensor::CollapseShapeOp>,
ConvertTensorReshapePattern<tensor::ExpandShapeOp>>(context);
patterns.insert<ConvertLinalgFillPattern, ConvertTensorBitcastPattern,
ConvertTensorCastPattern, ConvertTensorConcatPattern,
ConvertTensorExtractPattern, ConvertTensorExtractSlicePattern,
ConvertTensorInsertSlicePattern, ConvertTensorInsertPattern,
ConvertTensorFromElementsPattern,
ConvertTensorDialectReshapeOpPattern,
ConvertTensorReshapePattern<tensor::CollapseShapeOp>,
ConvertTensorReshapePattern<tensor::ExpandShapeOp>>(context);
}

} // namespace mlir::iree_compiler::IREE::Flow
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ iree_lit_test_suite(
[
"bitcast.mlir",
"cast.mlir",
"concat.mlir",
"extract.mlir",
"extract_slice.mlir",
"fill.mlir",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ iree_lit_test_suite(
SRCS
"bitcast.mlir"
"cast.mlir"
"concat.mlir"
"extract.mlir"
"extract_slice.mlir"
"fill.mlir"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// RUN: iree-opt --iree-flow-convert-to-flow --split-input-file --mlir-print-local-scope %s | FileCheck %s

func.func @mixed_concat(%arg0: tensor<2x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : tensor<4x?xf32>) -> tensor<?x?xf32> {
%0 = tensor.concat dim(0) %arg0, %arg1, %arg2 : (tensor<2x?xf32>, tensor<?x?xf32>, tensor<4x?xf32>) -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
// CHECK-LABEL: func @mixed_concat
// CHECK-SAME: %[[ARG0:.+]]: tensor<2x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[ARG2:.+]]: tensor<4x?xf32>
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
// CHECK-DAG: %[[ARG0_D1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
// CHECK-DAG: %[[ARG1_D0:.+]] = tensor.dim %[[ARG1]], %[[C0]]
// CHECK-DAG: %[[ARG1_D1:.+]] = tensor.dim %[[ARG1]], %[[C1]]
// CHECK: %[[OFFSET0:.+]] = affine.apply affine_map<()[s0] -> (s0 + 2)>()[%[[ARG1_D0]]]
// CHECK: %[[ARG2_D1:.+]] = tensor.dim %[[ARG2]], %[[C1]]
// CHECK: %[[RESULT_D0:.+]] = affine.apply affine_map<()[s0] -> (s0 + 6)>()[%[[ARG1_D0]]]
// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[RESULT_D0]], %[[ARG0_D1]])
// CHECK: %[[UPDATE0:.+]] = flow.tensor.update %[[ARG0]], %[[EMPTY]][%[[C0]], %[[C0]]]
// CHECK-SAME: : tensor<2x?xf32>{%[[ARG0_D1]]} -> %[[EMPTY]] as tensor<?x?xf32>{%[[RESULT_D0]], %[[ARG0_D1]]}
// CHECK: %[[UPDATE1:.+]] = flow.tensor.update %[[ARG1]], %[[UPDATE0]][%[[C2]], %[[C0]]]
// CHECK-SAME: : tensor<?x?xf32>{%[[ARG1_D0]], %[[ARG1_D1]]} -> %[[UPDATE0]] as tensor<?x?xf32>{%[[RESULT_D0]], %[[ARG0_D1]]}
// CHECK: %[[UPDATE2:.+]] = flow.tensor.update %[[ARG2]], %[[UPDATE1]][%[[OFFSET0]], %[[C0]]]
// CHECK-SAME: : tensor<4x?xf32>{%[[ARG2_D1]]} -> %[[UPDATE1]] as tensor<?x?xf32>{%[[RESULT_D0]], %[[ARG0_D1]]}

// -----

func.func @dont_lower_non_outer_dim_concat(%arg0: tensor<4x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : tensor<4x?xf32>) -> tensor<?x?xf32> {
%0 = tensor.concat dim(1) %arg0, %arg1, %arg2 : (tensor<4x?xf32>, tensor<?x?xf32>, tensor<4x?xf32>) -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
// CHECK-LABEL: func @dont_lower_non_outer_dim_concat
// CHECK: %[[CONCAT:.+]] = tensor.concat
// CHECK: return %[[CONCAT]]

0 comments on commit f51e1da

Please sign in to comment.