Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
amend! [mlir][vector] Prevent incorrect vector.transfer_{read|write} …
…hoisting [mlir][vector] Prevent incorrect vector.transfer_{read|write} hoisting Refines how opportunities for hoisting vector.transfer_{read|write} pairs are identified. More specifically, rather than looking for specific MemRef ops that could lead to aliasing, this patch updates the hoisting logic to check whether the underlying Op implements `ViewLikeOpInterface`. Additional condition is added to prevent hoisting when one of the source operands implements `ViewLikeOpInterface`. This was motivated by the following example [1]: ```mlir func.func @no_hoisting_write_to_memref(%rhs: i32, %arg1: vector<1xi32>) { %c0_i32 = arith.constant 0 : i32 %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c4 = arith.constant 4 : index %c20 = arith.constant 20 : index %alloca = memref.alloca() {alignment = 64 : i64} : memref<1x1x2xi32> %cast = memref.cast %alloca : memref<1x1x2xi32> to memref<1x1x2xi32> %collapsed_1 = memref.collapse_shape %alloca [[0, 1, 2]] : memref<1x1x2xi32> into memref<2xi32> scf.for %_ = %c0 to %c20 step %c4 { %collapsed_2 = memref.collapse_shape %alloca [[0, 1, 2]] : memref<1x1x2xi32> into memref<2xi32> %lhs = vector.transfer_read %collapsed_1[%c0], %c0_i32 {in_bounds = [true]} : memref<2xi32>, vector<1xi32> %acc = vector.transfer_read %collapsed_2[%c0], %c0_i32 {in_bounds = [true]} : memref<2xi32>, vector<1xi32> %op = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<1xi32>, i32 vector.transfer_write %op, %collapsed_1[%c0] {in_bounds = [true]} : vector<1xi32>, memref<2xi32> } return } ``` Originally, it would be rewritten as follows: ```mlir func.func @no_hoisting_write_to_memref(%arg0: i32, %arg1: vector<1xi32>) { %c0_i32 = arith.constant 0 : i32 %c0 = arith.constant 0 : index %c4 = arith.constant 4 : index %c20 = arith.constant 20 : index %alloca = memref.alloca() {alignment = 64 : i64} : memref<1x1x2xi32> %collapse_shape = memref.collapse_shape %alloca [[0, 1, 2]] : memref<1x1x2xi32> into memref<2xi32> %collapse_shape_0 = memref.collapse_shape %alloca [[0, 1, 2]] : memref<1x1x2xi32> into memref<2xi32> %0 = vector.transfer_read %collapse_shape[%c0], %c0_i32 {in_bounds = [true]} : memref<2xi32>, vector<1xi32> %1 = vector.transfer_read %collapse_shape_0[%c0], %c0_i32 {in_bounds = [true]} : memref<2xi32>, vector<1xi32> %2 = scf.for %arg2 = %c0 to %c20 step %c4 iter_args(%arg3 = %0) -> (vector<1xi32>) { %3 = vector.outerproduct %arg3, %arg0, %1 {kind = #vector.kind<add>} : vector<1xi32>, i32 scf.yield %3 : vector<1xi32> } vector.transfer_write %2, %collapse_shape[%c0] {in_bounds = [true]} : vector<1xi32>, memref<2xi32> return } ``` This was not safe. While one argument for `vector.outerproduct` was correctly being forwarded via `iter_args` (`%rhs` from the original loop), the other one wasn't (`%acc` from the original loop). [1] Based on iree-org/iree#14994.
- Loading branch information