Skip to content

Commit

Permalink
Make loop splitting respect dependency that goes through air.wait_all (
Browse files Browse the repository at this point in the history
  • Loading branch information
erwei-xilinx authored Nov 22, 2024
1 parent 13fd3c6 commit 8ed48b9
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 0 deletions.
5 changes: 5 additions & 0 deletions mlir/lib/Util/Dependency.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -614,6 +614,11 @@ bool areAsyncDependent(Operation *a, Operation *b) {
for (auto dep : dep_b)
if (dep == token_a)
return true;
// Deep async dependency tracing through air.wait_all.
if (isAsyncDependent(a, b))
return true;
if (isAsyncDependent(b, a))
return true;

auto chanA = dyn_cast<air::ChannelInterface>(a);
auto chanB = dyn_cast<air::ChannelInterface>(b);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -616,3 +616,65 @@ module {
return
}
}

// -----

// Deep dependency tracing through air.wait_all.

// CHECK: scf.for
// CHECK: air.channel.get
// CHECK: air.channel.get
// CHECK: scf.for
// CHECK: air.channel.put
// CHECK: air.channel.put
// CHECK: scf.yield
// CHECK: scf.yield

#map = affine_map<()[s0] -> (s0 * 96)>
#map1 = affine_map<()[s0] -> (s0 * 3)>
module {
air.channel @channel_0 []
air.channel @channel_1 []
func.func @func7() {
%c1 = arith.constant 1 : index
%0 = air.launch async (%arg5, %arg6) in (%arg7=%c1, %arg8=%c1) attributes {id = 2 : i32} {
%1 = air.segment @segment_0 async attributes {id = 1 : i32} {
%c96 = arith.constant 96 : index
%c1_0 = arith.constant 1 : index
%c3 = arith.constant 3 : index
%c0 = arith.constant 0 : index
%async_token, %results = air.execute -> (memref<288xi8, 1 : i32>) {
%alloc = memref.alloc() : memref<288xi8, 1 : i32>
air.execute_terminator %alloc : memref<288xi8, 1 : i32>
} {id = 1 : i32}
%async_token_1, %results_2 = air.execute -> (memref<9xf32, 1 : i32>) {
%alloc = memref.alloc() : memref<9xf32, 1 : i32>
air.execute_terminator %alloc : memref<9xf32, 1 : i32>
} {id = 2 : i32}
%2 = air.wait_all async [%async_token, %async_token_1] {id = 4 : i32}
%3 = scf.for %arg9 = %c0 to %c3 step %c1_0 iter_args(%arg10 = %2) -> (!air.async.token) {
%4 = air.channel.get async [%arg10] @channel_0[] (%results[] [] []) {id = 1 : i32} : (memref<288xi8, 1 : i32>)
%5 = air.channel.get async [%arg10] @channel_0[] (%results_2[] [] []) {id = 2 : i32} : (memref<9xf32, 1 : i32>)
%6 = air.wait_all async [%4, %5] {id = 2 : i32}
%7 = scf.for %arg11 = %c0 to %c3 step %c1_0 iter_args(%arg12 = %6) -> (!air.async.token) {
%async_token_3, %results_4 = air.execute [%arg12] -> (index) {
%12 = affine.apply #map()[%arg11]
air.execute_terminator %12 : index
} {id = 5 : i32}
%9 = air.channel.put async [%async_token_3] @channel_1[] (%results[%results_4] [%c96] [%c1_0]) {id = 3 : i32} : (memref<288xi8, 1 : i32>)
%async_token_5, %results_6 = air.execute [%arg12] -> (index) {
%12 = affine.apply #map1()[%arg11]
air.execute_terminator %12 : index
} {id = 6 : i32}
%10 = air.channel.put async [%async_token_5] @channel_1[] (%results_2[%results_6] [%c3] [%c1_0]) {id = 4 : i32} : (memref<9xf32, 1 : i32>)
%11 = air.wait_all async [%arg12, %9, %10] {id = 1 : i32}
scf.yield %11 : !air.async.token
}
%8 = air.wait_all async [%arg10, %7] {id = 3 : i32}
scf.yield %8 : !air.async.token
}
}
}
return
}
}

0 comments on commit 8ed48b9

Please sign in to comment.