Skip to content

Commit

Permalink
Disable fusion in dispatch formation
Browse files Browse the repository at this point in the history
Signed-off-by: Ian Wood <ianwood2024@u.northwestern.edu>
  • Loading branch information
IanWood1 committed Dec 28, 2024
1 parent a5190f8 commit da5434d
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 7 deletions.
10 changes: 4 additions & 6 deletions compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -258,12 +258,10 @@ FailureOr<SmallVector<int64_t>> ScatterOp::getStaticLoopRanges() {
}

SmallVector<AffineMap> ScatterOp::getIndexingMapsForOperands() {
// TODO: Enable once backends can support scatter fusion.
// Builder builder(getContext());
// return {builder.getMultiDimIdentityMap(getUpdateType().getRank()),
// builder.getMultiDimIdentityMap(getIndicesType().getRank()),
// /*output=*/AffineMap(nullptr)};
return SmallVector<AffineMap>(3);
Builder builder(getContext());
return {builder.getMultiDimIdentityMap(getUpdateType().getRank()),
builder.getMultiDimIdentityMap(getIndicesType().getRank()),
/*output=*/AffineMap(nullptr)};
}

SmallVector<AffineMap> ScatterOp::getIndexingMapsForResults() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -651,7 +651,8 @@ isFusableWithProducer(OpOperand &operand,
}

// Don't fuse attention with it's producer
if (isa<IREE::LinalgExt::AttentionOp>(consumer)) {
// TODO: Enable scatter fusion when supported by backends.
if (isa<IREE::LinalgExt::AttentionOp, IREE::LinalgExt::ScatterOp>(consumer)) {
return false;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,85 @@
// RUN: iree-opt --split-input-file --verify-diagnostics --pass-pipeline="builtin.module(util.func(iree-dispatch-creation-form-dispatch-regions{aggressive-fusion=true}, iree-dispatch-creation-clone-producers-into-dispatch-regions), cse, canonicalize, cse)" %s | FileCheck %s

#map = affine_map<(d0, d1) -> (d0, d1)>
#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
util.func public @linalgext_scatter_dispatch() -> tensor<8192x16x8x128xf32> {
%0 = tensor.empty() : tensor<4x1xi32>
%1 = tensor.empty() : tensor<4x1xi64>
%2 = tensor.empty() : tensor<4x1x16x8x128xf32>
%3 = tensor.empty() : tensor<4x1x16x8x128xf32>
%4 = tensor.empty() : tensor<8192x16x8x128xf32>
%5 = tensor.empty() : tensor<8192x16x8x128xf32>
%6 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%1 : tensor<4x1xi64>) outs(%0 : tensor<4x1xi32>) {
^bb0(%in: i64, %out: i32):
%10 = arith.trunci %in : i64 to i32
linalg.yield %10 : i32
} -> tensor<4x1xi32>

%7 = linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} ins(%2 : tensor<4x1x16x8x128xf32>) outs(%3 : tensor<4x1x16x8x128xf32>) {
^bb0(%in: f32, %out: f32):
%10 = arith.addf %in, %out : f32
linalg.yield %10 : f32
} -> tensor<4x1x16x8x128xf32>

%8 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(false) ins(%7, %6 : tensor<4x1x16x8x128xf32>, tensor<4x1xi32>) outs(%4 : tensor<8192x16x8x128xf32>) {
^bb0(%arg0: f32, %arg1: f32):
iree_linalg_ext.yield %arg0 : f32
} -> tensor<8192x16x8x128xf32>

// Dont fuse with scatter's consumer
%9 = linalg.generic {indexing_maps = [#map2, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%8 : tensor<8192x16x8x128xf32>) outs(%5 : tensor<8192x16x8x128xf32>) {
^bb0(%in: f32, %out: f32):
%10 = arith.addf %in, %out : f32
linalg.yield %10 : f32
} -> tensor<8192x16x8x128xf32>
util.return %9 : tensor<8192x16x8x128xf32>
}

// CHECK-LABEL: util.func public @linalgext_scatter_dispatch
// CHECK-DAG: %[[INDICES:.+]] = flow.dispatch.region
// CHECK-DAG: %[[UPDATE:.+]] = flow.dispatch.region
// CHECK: %[[RESULT:.+]] = flow.dispatch.region
// CHECK: %[[SCATTER_RESULT:.+]] = iree_linalg_ext.scatter
// CHECK-SAME: ins(%[[UPDATE]], %[[INDICES]] : tensor<4x1x16x8x128xf32>, tensor<4x1xi32>)
// CHECK: flow.return %[[SCATTER_RESULT]]
// CHECK: flow.dispatch.region
// CHECK: %[[GEN2:.+]] = linalg.generic
// CHECK-SAME: ins(%[[INPUT:.+]] : tensor<8192x16x8x128xf32>)
// CHECK: flow.return %[[GEN2]]

// -----

#map = affine_map<(d0, d1) -> (d0, d1)>
#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
util.func public @linalgext_scatter_clone() -> tensor<8192x16x8x128xf32> {
%6 = tensor.empty() : tensor<4x1xi32>
%2 = tensor.empty() : tensor<4x1x16x8x128xf32>
%4 = tensor.empty() : tensor<10x8192x16x8x128xf32>

%outs = tensor.extract_slice %4[0, 0, 0, 0, 0][1, 8192, 16, 8, 128][1, 1, 1, 1, 1] :
tensor<10x8192x16x8x128xf32> to tensor<8192x16x8x128xf32>

%8 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(false)
ins(%2, %6 : tensor<4x1x16x8x128xf32>, tensor<4x1xi32>)
outs(%outs : tensor<8192x16x8x128xf32>) {
^bb0(%arg0: f32, %arg1: f32):
iree_linalg_ext.yield %arg0 : f32
} -> tensor<8192x16x8x128xf32>

util.return %8 : tensor<8192x16x8x128xf32>
}

// CHECK-LABEL: util.func public @linalgext_scatter_clone
// CHECK: %[[RESULT:.+]] = flow.dispatch.region
// CHECK: %[[OUTS:.+]] = tensor.extract_slice
// CHECK: %[[SCATTER_RESULT:.+]] = iree_linalg_ext.scatter
// CHECK-SAME: outs(%[[OUTS]] : tensor<8192x16x8x128xf32>)
// CHECK: flow.return %[[SCATTER_RESULT]]

// -----

util.func public @attention_dispatch(%arg0: tensor<?x?x?xf16>, %arg1: tensor<?x?x?xf16>, %arg2: tensor<?x?x?xf16>, %arg3: f16, %arg4: tensor<?x?x?xf16>, %arg5: tensor<?x?x?xf16>, %arg6: tensor<?x?x?xf16>) -> tensor<?x?x?xf16> {
%0 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0 : tensor<?x?x?xf16>) outs(%arg4 : tensor<?x?x?xf16>) {
^bb0(%in: f16, %out: f16):
Expand Down

0 comments on commit da5434d

Please sign in to comment.