diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormScalarDispatches.cpp b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormScalarDispatches.cpp index aa50084f71cf..fcc9d56f89fa 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormScalarDispatches.cpp +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/FormScalarDispatches.cpp @@ -6,6 +6,7 @@ #include "iree/compiler/Dialect/Flow/Transforms/Passes.h" #include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h" +#include "llvm/ADT/SetVector.h" #include "llvm/Support/Debug.h" #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Analysis/TopologicalSortUtils.h" @@ -14,6 +15,7 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Iterators.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/RegionUtils.h" #define DEBUG_TYPE "iree-flow-form-scalar-dispatches" @@ -189,18 +191,41 @@ void FormScalarDispatchesPass::runOnOperation() { Block *currBlock = op->getBlock(); Operation *prevOp = op; bool didHorizontalFusion = false; + llvm::SetVector ineligibleRoots; while (prevOp != &currBlock->front()) { prevOp = prevOp->getPrevNode(); + // If this operation is used by a operation we previously visited, but we + // couldn't fuse it, stop. + if (ineligibleRoots.contains(prevOp)) { + break; + } + if (opToRootMap.count(prevOp)) { continue; } if (!isSliceRoot(scalarWorkloadLimit, prevOp)) { - if (isClonableIntoDispatchOp(prevOp)) { + if (fusedOpsSet.contains(prevOp)) { continue; } - break; + // If this op is not being fused, any operations that defines values + // used by this op cannot be horizontally fused + // Insert all operations into the set that define op's operands or + // define values used inside of op's regions + mlir::visitUsedValuesDefinedAbove( + prevOp->getRegions(), [&](OpOperand *operand) { + if (auto definingOp = operand->get().getDefiningOp()) { + ineligibleRoots.insert(definingOp); + } + }); + + for (Value val : prevOp->getOperands()) { + if (auto definingOp = val.getDefiningOp()) { + ineligibleRoots.insert(definingOp); + } + } + continue; } didHorizontalFusion = true; diff --git a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/form_scalar_dispatches.mlir b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/form_scalar_dispatches.mlir index c54d8f7b9a8f..f91d7c5a127b 100644 --- a/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/form_scalar_dispatches.mlir +++ b/compiler/src/iree/compiler/Dialect/Flow/Transforms/test/form_scalar_dispatches.mlir @@ -157,3 +157,70 @@ util.func public @interleaving( // CHECK-SAME: outs(%[[EMPTY1]] : // CHECK: flow.return %[[GENERIC3]], %[[GENERIC2]] // CHECK: util.return %[[DISPATCH1]]#0, %[[DISPATCH1]]#1 + +// ----- + +#map = affine_map<() -> ()> +util.func public @clonable_op_in_chain(%arg0: i32, %arg1: i1, %arg2: tensor, %arg7: i32) -> (tensor, tensor) { + %0 = linalg.generic {indexing_maps = [#map], iterator_types = []} outs(%arg2 : tensor) { + ^bb0(%out: i32): + %2 = arith.select %arg1, %arg7, %arg0 : i32 + linalg.yield %2 : i32 + } -> tensor + %extracted = tensor.extract %0[] : tensor + %1 = linalg.generic {indexing_maps = [#map], iterator_types = []} outs(%arg2 : tensor) { + ^bb0(%out: i32): + %2 = arith.select %arg1, %arg7, %extracted : i32 + linalg.yield %2 : i32 + } -> tensor + util.return %0, %1 : tensor, tensor +} + +// CHECK-LABEL: util.func public @clonable_op_in_chain( +// CHECK: %[[DISPATCH0:.+]] = flow.dispatch.region +// CHECK: %[[DISPATCH1:.+]] = flow.dispatch.region + +// ----- + +#map = affine_map<() -> ()> +util.func public @clonable_op_used_after(%arg0: i32, %arg1: i1, %arg2: tensor, %arg7: i32) -> (tensor, tensor, i32) { + %0 = linalg.generic {indexing_maps = [#map], iterator_types = []} outs(%arg2 : tensor) { + ^bb0(%out: i32): + %2 = arith.select %arg1, %arg7, %arg0 : i32 + linalg.yield %2 : i32 + } -> tensor + %extracted = tensor.extract %0[] : tensor + %1 = linalg.generic {indexing_maps = [#map], iterator_types = []} outs(%arg2 : tensor) { + ^bb0(%out: i32): + %2 = arith.select %arg1, %arg7, %extracted : i32 + linalg.yield %2 : i32 + } -> tensor + util.return %0, %1, %extracted : tensor, tensor, i32 +} + +// CHECK-LABEL: util.func public @clonable_op_used_after +// CHECK: %[[DISPATCH0:.+]] = flow.dispatch.region +// CHECK: %[[DISPATCH1:.+]] = flow.dispatch.region + +// ----- + +#map = affine_map<() -> ()> +util.func public @clonable_op_only_used_after(%arg0: i32, %arg1: i1, %arg2: tensor, %arg7: i32) -> (tensor, tensor, i32) { + %0 = linalg.generic {indexing_maps = [#map], iterator_types = []} outs(%arg2 : tensor) { + ^bb0(%out: i32): + %2 = arith.select %arg1, %arg7, %arg0 : i32 + linalg.yield %2 : i32 + } -> tensor + %extracted = tensor.extract %0[] : tensor + %1 = linalg.generic {indexing_maps = [#map], iterator_types = []} outs(%arg2 : tensor) { + ^bb0(%out: i32): + %extracted2 = tensor.extract %arg2[] : tensor + %2 = arith.select %arg1, %arg7, %extracted2 : i32 + linalg.yield %2 : i32 + } -> tensor + util.return %0, %1, %extracted : tensor, tensor, i32 +} + +// CHECK-LABEL: util.func public @clonable_op_only_used_after +// CHECK: %[[DISPATCH0:.+]] = flow.dispatch.region +// CHECK: %[[DISPATCH1:.+]] = flow.dispatch.region