diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp index 7c6639304d97c5..31cb9c010e00a9 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp @@ -152,6 +152,14 @@ void mlir::linalg::hoistRedundantVectorTransfers(func::FuncOp func) { transferRead.getPermutationMap() != transferWrite.getPermutationMap()) return WalkResult::advance(); + // When the source of transfer_read aliases, the following dominance + // analysis might not be sufficient. + // TODO: There might be other, similar cases missing here (i.e. other + // Memref Ops). + auto source = transferRead.getSource(); + if (source.getDefiningOp()) + return WalkResult::advance(); + // TODO: may want to memoize this information for performance but it // likely gets invalidated often. DominanceInfo dom(loop); diff --git a/mlir/test/Dialect/Linalg/hoisting.mlir b/mlir/test/Dialect/Linalg/hoisting.mlir index e25914620726b9..7d0c3648c344b1 100644 --- a/mlir/test/Dialect/Linalg/hoisting.mlir +++ b/mlir/test/Dialect/Linalg/hoisting.mlir @@ -765,10 +765,10 @@ transform.sequence failures(propagate) { // CHECK-LABEL: func.func @no_hoisting_collapse_shape // CHECK: scf.for {{.*}} { -// CHECK: vector.transfer_write -// CHECK: vector.transfer_read -// CHECK: vector.transfer_write -// CHECK: } +// CHECK: vector.transfer_write {{.*}} : vector<4xi32>, memref<4xi32> +// CHECK-NEXT: vector.transfer_read {{.*}} : memref<1x4x1xi32>, vector<1x4x1xi32> +// CHECK-NEXT: vector.transfer_write {{.*}} : vector<1x4x1xi32>, memref<1x4x1xi32, strided<[20, 1, 1], offset: ?>> +// CHECK-NEXT: } func.func @no_hoisting_collapse_shape(%in_0: memref<1x20x1xi32>, %1: memref<9x1xi32>, %vec: vector<4xi32>) { %c0_i32 = arith.constant 0 : i32 @@ -827,3 +827,48 @@ transform.sequence failures(propagate) { transform.structured.hoist_redundant_vector_transfers %0 : (!transform.any_op) -> !transform.any_op } + +// ----- + +// Regression test - hoisting the following `vector.transfer_{read|write}` pair +// would not be safe: +// %lhs = vector.transfer_read %collapsed_1[%c0] +// vector.transfer_write %op, %collapsed_1[%c0] +// That's because the following `vector.transfer_read` reads from the same +// memory (i.e. `%collapsed_1` and `%collapsed_2` alias): +// %acc = vector.transfer_read %collapsed_2[%c0] + +// CHECK-LABEL: func.func @no_hoisting_write_to_memref +// CHECK: scf.for {{.*}} { +// CHECK: vector.transfer_read {{.*}} : memref<2xi32>, vector<1xi32> +// CHECK-NEXT: vector.transfer_read {{.*}} : memref<2xi32>, vector<1xi32> +// CHECK-NEXT: vector.outerproduct {{.*}} : vector<1xi32>, i32 +// CHECK-NEXT: vector.transfer_write {{.*}} : vector<1xi32>, memref<2xi32> +// CHECK-NEXT: } + +func.func @no_hoisting_write_to_memref(%rhs: i32, %arg1: vector<1xi32>) { + %c0_i32 = arith.constant 0 : i32 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %c20 = arith.constant 20 : index + %alloca = memref.alloca() {alignment = 64 : i64} : memref<1x1x2xi32> + %cast = memref.cast %alloca : memref<1x1x2xi32> to memref<1x1x2xi32> + %collapsed_1 = memref.collapse_shape %alloca [[0, 1, 2]] : memref<1x1x2xi32> into memref<2xi32> + scf.for %_ = %c0 to %c20 step %c4 { + %collapsed_2 = memref.collapse_shape %alloca [[0, 1, 2]] : memref<1x1x2xi32> into memref<2xi32> + %lhs = vector.transfer_read %collapsed_1[%c0], %c0_i32 {in_bounds = [true]} : memref<2xi32>, vector<1xi32> + %acc = vector.transfer_read %collapsed_2[%c0], %c0_i32 {in_bounds = [true]} : memref<2xi32>, vector<1xi32> + %op = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind} : vector<1xi32>, i32 + vector.transfer_write %op, %collapsed_1[%c0] {in_bounds = [true]} : vector<1xi32>, memref<2xi32> + } + return +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["func.func"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + transform.structured.hoist_redundant_vector_transfers %0 + : (!transform.any_op) -> !transform.any_op +}