forked from iree-org/iree
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Flow] Add patterns to convert from
tensor.concat
to `flow.tensor.u…
…pdate`. (iree-org#19126) These are in preparation to delay to decomposition of `tensor.concat` into `tensor.insert_slice`s. This patch just adds the patterns to lower a `tensor.concat` along the outer dimension to `flow.tensor.update`. Future changes will delay the decomposition of `tensor.concat` to allow for non-outer dimension concatenation to be conveted into `tensor.insert_slice`s before dispatch formation with the `tensor.insert_slice` fused into its producers. Towards iree-org#19092 --------- Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
- Loading branch information
1 parent
c39b4e2
commit f51e1da
Showing
6 changed files
with
117 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
36 changes: 36 additions & 0 deletions
36
compiler/src/iree/compiler/Dialect/Flow/Conversion/TensorToFlow/test/concat.mlir
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
// RUN: iree-opt --iree-flow-convert-to-flow --split-input-file --mlir-print-local-scope %s | FileCheck %s | ||
|
||
func.func @mixed_concat(%arg0: tensor<2x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : tensor<4x?xf32>) -> tensor<?x?xf32> { | ||
%0 = tensor.concat dim(0) %arg0, %arg1, %arg2 : (tensor<2x?xf32>, tensor<?x?xf32>, tensor<4x?xf32>) -> tensor<?x?xf32> | ||
return %0 : tensor<?x?xf32> | ||
} | ||
// CHECK-LABEL: func @mixed_concat | ||
// CHECK-SAME: %[[ARG0:.+]]: tensor<2x?xf32> | ||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32> | ||
// CHECK-SAME: %[[ARG2:.+]]: tensor<4x?xf32> | ||
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index | ||
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index | ||
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index | ||
// CHECK-DAG: %[[ARG0_D1:.+]] = tensor.dim %[[ARG0]], %[[C1]] | ||
// CHECK-DAG: %[[ARG1_D0:.+]] = tensor.dim %[[ARG1]], %[[C0]] | ||
// CHECK-DAG: %[[ARG1_D1:.+]] = tensor.dim %[[ARG1]], %[[C1]] | ||
// CHECK: %[[OFFSET0:.+]] = affine.apply affine_map<()[s0] -> (s0 + 2)>()[%[[ARG1_D0]]] | ||
// CHECK: %[[ARG2_D1:.+]] = tensor.dim %[[ARG2]], %[[C1]] | ||
// CHECK: %[[RESULT_D0:.+]] = affine.apply affine_map<()[s0] -> (s0 + 6)>()[%[[ARG1_D0]]] | ||
// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[RESULT_D0]], %[[ARG0_D1]]) | ||
// CHECK: %[[UPDATE0:.+]] = flow.tensor.update %[[ARG0]], %[[EMPTY]][%[[C0]], %[[C0]]] | ||
// CHECK-SAME: : tensor<2x?xf32>{%[[ARG0_D1]]} -> %[[EMPTY]] as tensor<?x?xf32>{%[[RESULT_D0]], %[[ARG0_D1]]} | ||
// CHECK: %[[UPDATE1:.+]] = flow.tensor.update %[[ARG1]], %[[UPDATE0]][%[[C2]], %[[C0]]] | ||
// CHECK-SAME: : tensor<?x?xf32>{%[[ARG1_D0]], %[[ARG1_D1]]} -> %[[UPDATE0]] as tensor<?x?xf32>{%[[RESULT_D0]], %[[ARG0_D1]]} | ||
// CHECK: %[[UPDATE2:.+]] = flow.tensor.update %[[ARG2]], %[[UPDATE1]][%[[OFFSET0]], %[[C0]]] | ||
// CHECK-SAME: : tensor<4x?xf32>{%[[ARG2_D1]]} -> %[[UPDATE1]] as tensor<?x?xf32>{%[[RESULT_D0]], %[[ARG0_D1]]} | ||
|
||
// ----- | ||
|
||
func.func @dont_lower_non_outer_dim_concat(%arg0: tensor<4x?xf32>, %arg1 : tensor<?x?xf32>, %arg2 : tensor<4x?xf32>) -> tensor<?x?xf32> { | ||
%0 = tensor.concat dim(1) %arg0, %arg1, %arg2 : (tensor<4x?xf32>, tensor<?x?xf32>, tensor<4x?xf32>) -> tensor<?x?xf32> | ||
return %0 : tensor<?x?xf32> | ||
} | ||
// CHECK-LABEL: func @dont_lower_non_outer_dim_concat | ||
// CHECK: %[[CONCAT:.+]] = tensor.concat | ||
// CHECK: return %[[CONCAT]] |