diff --git a/mlir/lib/Util/Dependency.cpp b/mlir/lib/Util/Dependency.cpp index 3d238ea50..ddee2f971 100644 --- a/mlir/lib/Util/Dependency.cpp +++ b/mlir/lib/Util/Dependency.cpp @@ -614,6 +614,11 @@ bool areAsyncDependent(Operation *a, Operation *b) { for (auto dep : dep_b) if (dep == token_a) return true; + // Deep async dependency tracing through air.wait_all. + if (isAsyncDependent(a, b)) + return true; + if (isAsyncDependent(b, a)) + return true; auto chanA = dyn_cast(a); auto chanB = dyn_cast(b); diff --git a/mlir/test/Transform/AIRDependencyScheduleOpt/isolate_async_dma_loop_nest.mlir b/mlir/test/Transform/AIRDependencyScheduleOpt/isolate_async_dma_loop_nest.mlir index 9e872da99..b66c9309e 100644 --- a/mlir/test/Transform/AIRDependencyScheduleOpt/isolate_async_dma_loop_nest.mlir +++ b/mlir/test/Transform/AIRDependencyScheduleOpt/isolate_async_dma_loop_nest.mlir @@ -616,3 +616,65 @@ module { return } } + +// ----- + +// Deep dependency tracing through air.wait_all. + +// CHECK: scf.for +// CHECK: air.channel.get +// CHECK: air.channel.get +// CHECK: scf.for +// CHECK: air.channel.put +// CHECK: air.channel.put +// CHECK: scf.yield +// CHECK: scf.yield + +#map = affine_map<()[s0] -> (s0 * 96)> +#map1 = affine_map<()[s0] -> (s0 * 3)> +module { + air.channel @channel_0 [] + air.channel @channel_1 [] + func.func @func7() { + %c1 = arith.constant 1 : index + %0 = air.launch async (%arg5, %arg6) in (%arg7=%c1, %arg8=%c1) attributes {id = 2 : i32} { + %1 = air.segment @segment_0 async attributes {id = 1 : i32} { + %c96 = arith.constant 96 : index + %c1_0 = arith.constant 1 : index + %c3 = arith.constant 3 : index + %c0 = arith.constant 0 : index + %async_token, %results = air.execute -> (memref<288xi8, 1 : i32>) { + %alloc = memref.alloc() : memref<288xi8, 1 : i32> + air.execute_terminator %alloc : memref<288xi8, 1 : i32> + } {id = 1 : i32} + %async_token_1, %results_2 = air.execute -> (memref<9xf32, 1 : i32>) { + %alloc = memref.alloc() : memref<9xf32, 1 : i32> + air.execute_terminator %alloc : memref<9xf32, 1 : i32> + } {id = 2 : i32} + %2 = air.wait_all async [%async_token, %async_token_1] {id = 4 : i32} + %3 = scf.for %arg9 = %c0 to %c3 step %c1_0 iter_args(%arg10 = %2) -> (!air.async.token) { + %4 = air.channel.get async [%arg10] @channel_0[] (%results[] [] []) {id = 1 : i32} : (memref<288xi8, 1 : i32>) + %5 = air.channel.get async [%arg10] @channel_0[] (%results_2[] [] []) {id = 2 : i32} : (memref<9xf32, 1 : i32>) + %6 = air.wait_all async [%4, %5] {id = 2 : i32} + %7 = scf.for %arg11 = %c0 to %c3 step %c1_0 iter_args(%arg12 = %6) -> (!air.async.token) { + %async_token_3, %results_4 = air.execute [%arg12] -> (index) { + %12 = affine.apply #map()[%arg11] + air.execute_terminator %12 : index + } {id = 5 : i32} + %9 = air.channel.put async [%async_token_3] @channel_1[] (%results[%results_4] [%c96] [%c1_0]) {id = 3 : i32} : (memref<288xi8, 1 : i32>) + %async_token_5, %results_6 = air.execute [%arg12] -> (index) { + %12 = affine.apply #map1()[%arg11] + air.execute_terminator %12 : index + } {id = 6 : i32} + %10 = air.channel.put async [%async_token_5] @channel_1[] (%results_2[%results_6] [%c3] [%c1_0]) {id = 4 : i32} : (memref<9xf32, 1 : i32>) + %11 = air.wait_all async [%arg12, %9, %10] {id = 1 : i32} + scf.yield %11 : !air.async.token + } + %8 = air.wait_all async [%arg10, %7] {id = 3 : i32} + scf.yield %8 : !air.async.token + } + } + } + return + } +}