From afea3533c06c7bf70676b8a8fb8d07031275c87c Mon Sep 17 00:00:00 2001 From: Ian Wood <75152913+IanWood1@users.noreply.github.com> Date: Fri, 26 Jul 2024 10:59:02 -0700 Subject: [PATCH] [Flow] Fix dominance error in `FormScalarDispatches` (#17785) When preforming horizontal fusion, ops that were clonable (but not used by the fusion group) were ignored. If these ops were dependent on values produced by 'root ops', then the root op would get moved into the region. Closes https://github.com/iree-org/iree/issues/17759 Signed-off-by: Ian Wood --- .../Flow/Transforms/FormScalarDispatches.cpp | 29 +++++++- .../test/form_scalar_dispatches.mlir | 67 +++++++++++++++++++ 2 files changed, 94 insertions(+), 2 deletions(-) diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormScalarDispatches.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormScalarDispatches.cpp index aa50084f71cf..fcc9d56f89fa 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormScalarDispatches.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormScalarDispatches.cpp @@ -6,6 +6,7 @@ #include "iree/compiler/Dialect/Flow/Transforms/Passes.h" #include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h" +#include "llvm/ADT/SetVector.h" #include "llvm/Support/Debug.h" #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Analysis/TopologicalSortUtils.h" @@ -14,6 +15,7 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Iterators.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/RegionUtils.h" #define DEBUG_TYPE "iree-flow-form-scalar-dispatches" @@ -189,18 +191,41 @@ void FormScalarDispatchesPass::runOnOperation() { Block *currBlock = op->getBlock(); Operation *prevOp = op; bool didHorizontalFusion = false; + llvm::SetVector ineligibleRoots; while (prevOp != &currBlock->front()) { prevOp = prevOp->getPrevNode(); + // If this operation is used by a operation we previously visited, but we + // couldn't fuse it, stop. + if (ineligibleRoots.contains(prevOp)) { + break; + } + if (opToRootMap.count(prevOp)) { continue; } if (!isSliceRoot(scalarWorkloadLimit, prevOp)) { - if (isClonableIntoDispatchOp(prevOp)) { + if (fusedOpsSet.contains(prevOp)) { continue; } - break; + // If this op is not being fused, any operations that defines values + // used by this op cannot be horizontally fused + // Insert all operations into the set that define op's operands or + // define values used inside of op's regions + mlir::visitUsedValuesDefinedAbove( + prevOp->getRegions(), [&](OpOperand *operand) { + if (auto definingOp = operand->get().getDefiningOp()) { + ineligibleRoots.insert(definingOp); + } + }); + + for (Value val : prevOp->getOperands()) { + if (auto definingOp = val.getDefiningOp()) { + ineligibleRoots.insert(definingOp); + } + } + continue; } didHorizontalFusion = true; diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/form_scalar_dispatches.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/form_scalar_dispatches.mlir index c54d8f7b9a8f..f91d7c5a127b 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/form_scalar_dispatches.mlir +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/form_scalar_dispatches.mlir @@ -157,3 +157,70 @@ util.func public @interleaving( // CHECK-SAME: outs(%[[EMPTY1]] : // CHECK: flow.return %[[GENERIC3]], %[[GENERIC2]] // CHECK: util.return %[[DISPATCH1]]#0, %[[DISPATCH1]]#1 + +// ----- + +#map = affine_map<() -> ()> +util.func public @clonable_op_in_chain(%arg0: i32, %arg1: i1, %arg2: tensor, %arg7: i32) -> (tensor, tensor) { + %0 = linalg.generic {indexing_maps = [#map], iterator_types = []} outs(%arg2 : tensor) { + ^bb0(%out: i32): + %2 = arith.select %arg1, %arg7, %arg0 : i32 + linalg.yield %2 : i32 + } -> tensor + %extracted = tensor.extract %0[] : tensor + %1 = linalg.generic {indexing_maps = [#map], iterator_types = []} outs(%arg2 : tensor) { + ^bb0(%out: i32): + %2 = arith.select %arg1, %arg7, %extracted : i32 + linalg.yield %2 : i32 + } -> tensor + util.return %0, %1 : tensor, tensor +} + +// CHECK-LABEL: util.func public @clonable_op_in_chain( +// CHECK: %[[DISPATCH0:.+]] = flow.dispatch.region +// CHECK: %[[DISPATCH1:.+]] = flow.dispatch.region + +// ----- + +#map = affine_map<() -> ()> +util.func public @clonable_op_used_after(%arg0: i32, %arg1: i1, %arg2: tensor, %arg7: i32) -> (tensor, tensor, i32) { + %0 = linalg.generic {indexing_maps = [#map], iterator_types = []} outs(%arg2 : tensor) { + ^bb0(%out: i32): + %2 = arith.select %arg1, %arg7, %arg0 : i32 + linalg.yield %2 : i32 + } -> tensor + %extracted = tensor.extract %0[] : tensor + %1 = linalg.generic {indexing_maps = [#map], iterator_types = []} outs(%arg2 : tensor) { + ^bb0(%out: i32): + %2 = arith.select %arg1, %arg7, %extracted : i32 + linalg.yield %2 : i32 + } -> tensor + util.return %0, %1, %extracted : tensor, tensor, i32 +} + +// CHECK-LABEL: util.func public @clonable_op_used_after +// CHECK: %[[DISPATCH0:.+]] = flow.dispatch.region +// CHECK: %[[DISPATCH1:.+]] = flow.dispatch.region + +// ----- + +#map = affine_map<() -> ()> +util.func public @clonable_op_only_used_after(%arg0: i32, %arg1: i1, %arg2: tensor, %arg7: i32) -> (tensor, tensor, i32) { + %0 = linalg.generic {indexing_maps = [#map], iterator_types = []} outs(%arg2 : tensor) { + ^bb0(%out: i32): + %2 = arith.select %arg1, %arg7, %arg0 : i32 + linalg.yield %2 : i32 + } -> tensor + %extracted = tensor.extract %0[] : tensor + %1 = linalg.generic {indexing_maps = [#map], iterator_types = []} outs(%arg2 : tensor) { + ^bb0(%out: i32): + %extracted2 = tensor.extract %arg2[] : tensor + %2 = arith.select %arg1, %arg7, %extracted2 : i32 + linalg.yield %2 : i32 + } -> tensor + util.return %0, %1, %extracted : tensor, tensor, i32 +} + +// CHECK-LABEL: util.func public @clonable_op_only_used_after +// CHECK: %[[DISPATCH0:.+]] = flow.dispatch.region +// CHECK: %[[DISPATCH1:.+]] = flow.dispatch.region