Skip to content

Commit

Permalink
[Flow] Fix dominance error in FormScalarDispatches (#17785)
Browse files Browse the repository at this point in the history
When preforming horizontal fusion, ops that were clonable (but not used
by the fusion group) were ignored. If these ops were dependent on values
produced by 'root ops', then the root op would get moved into the
region.


Closes #17759

Signed-off-by: Ian Wood <ianwood2024@u.northwestern.edu>
  • Loading branch information
IanWood1 authored Jul 26, 2024
1 parent 1ee68bd commit afea353
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include "iree/compiler/Dialect/Flow/Transforms/Passes.h"
#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/Support/Debug.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Analysis/TopologicalSortUtils.h"
Expand All @@ -14,6 +15,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Iterators.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/RegionUtils.h"

#define DEBUG_TYPE "iree-flow-form-scalar-dispatches"

Expand Down Expand Up @@ -189,18 +191,41 @@ void FormScalarDispatchesPass::runOnOperation() {
Block *currBlock = op->getBlock();
Operation *prevOp = op;
bool didHorizontalFusion = false;
llvm::SetVector<Operation *> ineligibleRoots;
while (prevOp != &currBlock->front()) {
prevOp = prevOp->getPrevNode();

// If this operation is used by a operation we previously visited, but we
// couldn't fuse it, stop.
if (ineligibleRoots.contains(prevOp)) {
break;
}

if (opToRootMap.count(prevOp)) {
continue;
}

if (!isSliceRoot(scalarWorkloadLimit, prevOp)) {
if (isClonableIntoDispatchOp(prevOp)) {
if (fusedOpsSet.contains(prevOp)) {
continue;
}
break;
// If this op is not being fused, any operations that defines values
// used by this op cannot be horizontally fused
// Insert all operations into the set that define op's operands or
// define values used inside of op's regions
mlir::visitUsedValuesDefinedAbove(
prevOp->getRegions(), [&](OpOperand *operand) {
if (auto definingOp = operand->get().getDefiningOp()) {
ineligibleRoots.insert(definingOp);
}
});

for (Value val : prevOp->getOperands()) {
if (auto definingOp = val.getDefiningOp()) {
ineligibleRoots.insert(definingOp);
}
}
continue;
}

didHorizontalFusion = true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,3 +157,70 @@ util.func public @interleaving(
// CHECK-SAME: outs(%[[EMPTY1]] :
// CHECK: flow.return %[[GENERIC3]], %[[GENERIC2]]
// CHECK: util.return %[[DISPATCH1]]#0, %[[DISPATCH1]]#1

// -----

#map = affine_map<() -> ()>
util.func public @clonable_op_in_chain(%arg0: i32, %arg1: i1, %arg2: tensor<i32>, %arg7: i32) -> (tensor<i32>, tensor<i32>) {
%0 = linalg.generic {indexing_maps = [#map], iterator_types = []} outs(%arg2 : tensor<i32>) {
^bb0(%out: i32):
%2 = arith.select %arg1, %arg7, %arg0 : i32
linalg.yield %2 : i32
} -> tensor<i32>
%extracted = tensor.extract %0[] : tensor<i32>
%1 = linalg.generic {indexing_maps = [#map], iterator_types = []} outs(%arg2 : tensor<i32>) {
^bb0(%out: i32):
%2 = arith.select %arg1, %arg7, %extracted : i32
linalg.yield %2 : i32
} -> tensor<i32>
util.return %0, %1 : tensor<i32>, tensor<i32>
}

// CHECK-LABEL: util.func public @clonable_op_in_chain(
// CHECK: %[[DISPATCH0:.+]] = flow.dispatch.region
// CHECK: %[[DISPATCH1:.+]] = flow.dispatch.region

// -----

#map = affine_map<() -> ()>
util.func public @clonable_op_used_after(%arg0: i32, %arg1: i1, %arg2: tensor<i32>, %arg7: i32) -> (tensor<i32>, tensor<i32>, i32) {
%0 = linalg.generic {indexing_maps = [#map], iterator_types = []} outs(%arg2 : tensor<i32>) {
^bb0(%out: i32):
%2 = arith.select %arg1, %arg7, %arg0 : i32
linalg.yield %2 : i32
} -> tensor<i32>
%extracted = tensor.extract %0[] : tensor<i32>
%1 = linalg.generic {indexing_maps = [#map], iterator_types = []} outs(%arg2 : tensor<i32>) {
^bb0(%out: i32):
%2 = arith.select %arg1, %arg7, %extracted : i32
linalg.yield %2 : i32
} -> tensor<i32>
util.return %0, %1, %extracted : tensor<i32>, tensor<i32>, i32
}

// CHECK-LABEL: util.func public @clonable_op_used_after
// CHECK: %[[DISPATCH0:.+]] = flow.dispatch.region
// CHECK: %[[DISPATCH1:.+]] = flow.dispatch.region

// -----

#map = affine_map<() -> ()>
util.func public @clonable_op_only_used_after(%arg0: i32, %arg1: i1, %arg2: tensor<i32>, %arg7: i32) -> (tensor<i32>, tensor<i32>, i32) {
%0 = linalg.generic {indexing_maps = [#map], iterator_types = []} outs(%arg2 : tensor<i32>) {
^bb0(%out: i32):
%2 = arith.select %arg1, %arg7, %arg0 : i32
linalg.yield %2 : i32
} -> tensor<i32>
%extracted = tensor.extract %0[] : tensor<i32>
%1 = linalg.generic {indexing_maps = [#map], iterator_types = []} outs(%arg2 : tensor<i32>) {
^bb0(%out: i32):
%extracted2 = tensor.extract %arg2[] : tensor<i32>
%2 = arith.select %arg1, %arg7, %extracted2 : i32
linalg.yield %2 : i32
} -> tensor<i32>
util.return %0, %1, %extracted : tensor<i32>, tensor<i32>, i32
}

// CHECK-LABEL: util.func public @clonable_op_only_used_after
// CHECK: %[[DISPATCH0:.+]] = flow.dispatch.region
// CHECK: %[[DISPATCH1:.+]] = flow.dispatch.region

0 comments on commit afea353

Please sign in to comment.