diff --git a/mlir/lib/Transform/AIRDependencyScheduleOpt.cpp b/mlir/lib/Transform/AIRDependencyScheduleOpt.cpp index 986b8b4eb..c9e2f7aa3 100644 --- a/mlir/lib/Transform/AIRDependencyScheduleOpt.cpp +++ b/mlir/lib/Transform/AIRDependencyScheduleOpt.cpp @@ -3257,6 +3257,20 @@ class AIRFuseChannels return true; return false; } + // Check if two ssa value lists are identical. + bool areTheSameSSAValueLists(SmallVector a, SmallVector b) { + if (a.size() != b.size()) + return false; + for (unsigned i = 0; i < a.size(); i++) { + auto constAElem = getConstantIntValue(a[i]); + auto constBElem = getConstantIntValue(b[i]); + if (constAElem && constBElem) + // Unequal constant values + if (*constAElem != *constBElem) + return false; + } + return true; + } // Check of two air.channels are mergeable in time, by fusing into a shared // scf.for loop. Returns a tuple of bool of whether mergeable, and string of // fusing into for loop lower bound (LB) or upper bound (UB). @@ -3280,26 +3294,56 @@ class AIRFuseChannels return notMergeable; if (a_gets.size() != 1) return notMergeable; - // Check for identical src and dst memref + // Check for identical src and dst memrefs, offset, size and stride lists Value aMemref = a_puts[0].getMemref(); + SmallVector aOffsets = a_puts[0].getOffsets(); + SmallVector aSizes = a_puts[0].getSizes(); + SmallVector aStrides = a_puts[0].getStrides(); for (unsigned i = 1; i < a_puts.size(); i++) - if (aMemref != a_puts[i].getMemref()) + if ((!areTheSameMemref(aMemref, a_puts[i].getMemref())) || + (!areTheSameSSAValueLists(aOffsets, a_puts[i].getOffsets())) || + (!areTheSameSSAValueLists(aSizes, a_puts[i].getSizes())) || + (!areTheSameSSAValueLists(aStrides, a_puts[i].getStrides()))) return notMergeable; // Inconsistent memory use for all puts Value bMemref = b_puts[0].getMemref(); + SmallVector bOffsets = b_puts[0].getOffsets(); + SmallVector bSizes = b_puts[0].getSizes(); + SmallVector bStrides = b_puts[0].getStrides(); for (unsigned i = 1; i < b_puts.size(); i++) - if (bMemref != b_puts[i].getMemref()) + if ((!areTheSameMemref(bMemref, b_puts[i].getMemref())) || + (!areTheSameSSAValueLists(bOffsets, b_puts[i].getOffsets())) || + (!areTheSameSSAValueLists(bSizes, b_puts[i].getSizes())) || + (!areTheSameSSAValueLists(bStrides, b_puts[i].getStrides()))) return notMergeable; // Inconsistent memory use for all puts - if (!areTheSameMemref(aMemref, bMemref)) + if ((!areTheSameMemref(aMemref, bMemref)) || + (!areTheSameSSAValueLists(aOffsets, bOffsets)) || + (!areTheSameSSAValueLists(aSizes, bSizes)) || + (!areTheSameSSAValueLists(aStrides, bStrides))) return notMergeable; aMemref = a_gets[0].getMemref(); + aOffsets = a_gets[0].getOffsets(); + aSizes = a_gets[0].getSizes(); + aStrides = a_gets[0].getStrides(); for (unsigned i = 1; i < a_gets.size(); i++) - if (aMemref != a_gets[i].getMemref()) + if ((!areTheSameMemref(aMemref, a_gets[i].getMemref())) || + (!areTheSameSSAValueLists(aOffsets, a_gets[i].getOffsets())) || + (!areTheSameSSAValueLists(aSizes, a_gets[i].getSizes())) || + (!areTheSameSSAValueLists(aStrides, a_gets[i].getStrides()))) return notMergeable; // Inconsistent memory use for all gets bMemref = b_gets[0].getMemref(); + bOffsets = b_gets[0].getOffsets(); + bSizes = b_gets[0].getSizes(); + bStrides = b_gets[0].getStrides(); for (unsigned i = 1; i < b_gets.size(); i++) - if (bMemref != b_gets[i].getMemref()) + if ((!areTheSameMemref(bMemref, b_gets[i].getMemref())) || + (!areTheSameSSAValueLists(bOffsets, b_gets[i].getOffsets())) || + (!areTheSameSSAValueLists(bSizes, b_gets[i].getSizes())) || + (!areTheSameSSAValueLists(bStrides, b_gets[i].getStrides()))) return notMergeable; // Inconsistent memory use for all gets - if (!areTheSameMemref(aMemref, bMemref)) + if ((!areTheSameMemref(aMemref, bMemref)) || + (!areTheSameSSAValueLists(aOffsets, bOffsets)) || + (!areTheSameSSAValueLists(aSizes, bSizes)) || + (!areTheSameSSAValueLists(aStrides, bStrides))) return notMergeable; for (unsigned i = 0; i < a_puts.size(); i++) { auto a_put_loop_nest = getParentLoopNest(a_puts[i].getOperation()); diff --git a/mlir/test/Transform/AIRDependencyScheduleOpt/fuse_channels.mlir b/mlir/test/Transform/AIRDependencyScheduleOpt/fuse_channels.mlir index 3a1bfe1d6..40a1acc70 100644 --- a/mlir/test/Transform/AIRDependencyScheduleOpt/fuse_channels.mlir +++ b/mlir/test/Transform/AIRDependencyScheduleOpt/fuse_channels.mlir @@ -783,3 +783,120 @@ module { return } } + +// ----- + +// Merging air.channels into both scf.for op's LB and UB (L2->L1, with broadcast). + +// CHECK-LABEL: func7 +// CHECK: air.segment @segment_0 +// CHECK: scf.for %{{.*}} = %c0{{.*}}to %c64{{.*}}step %c1{{.*}}{ +// CHECK-NEXT: air.channel.put{{.*}}@channel_6{{.*}} : (memref<1x1x32x64xi32, 1 : i32>) +// CHECK-NEXT: } +// CHECK: scf.for %{{.*}} = %c0{{.*}}to %c64{{.*}}step %c1{{.*}}{ +// CHECK-NEXT: air.channel.put{{.*}}@channel_7{{.*}} : (memref<1x1x32x64xi32, 1 : i32>) +// CHECK-NEXT: } +// CHECK: air.segment_terminator +// AGGRESSIVE-LABEL: func7 +// AGGRESSIVE: air.segment @segment_0 +// AGGRESSIVE: scf.for %{{.*}} = %c0{{.*}}to %c64{{.*}}step %c1{{.*}}{ +// AGGRESSIVE-NEXT: air.channel.put{{.*}}@channel_6{{.*}} : (memref<1x1x32x64xi32, 1 : i32>) +// AGGRESSIVE-NEXT: } +// AGGRESSIVE: scf.for %{{.*}} = %c0{{.*}}to %c64{{.*}}step %c1{{.*}}{ +// AGGRESSIVE-NEXT: air.channel.put{{.*}}@channel_7{{.*}} : (memref<1x1x32x64xi32, 1 : i32>) +// AGGRESSIVE-NEXT: } +// AGGRESSIVE: air.segment_terminator +// AGGL1-LABEL: func7 +// AGGL1: air.segment @segment_0 +// AGGL1: scf.for %{{.*}} = %c0{{.*}}to %c64{{.*}}step %c1{{.*}}{ +// AGGL1-NEXT: air.channel.put{{.*}}@channel_6{{.*}} : (memref<1x1x32x64xi32, 1 : i32>) +// AGGL1-NEXT: } +// AGGL1: scf.for %{{.*}} = %c0{{.*}}to %c64{{.*}}step %c1{{.*}}{ +// AGGL1-NEXT: air.channel.put{{.*}}@channel_7{{.*}} : (memref<1x1x32x64xi32, 1 : i32>) +// AGGL1-NEXT: } +// AGGL1: air.segment_terminator + +#set = affine_set<()[s0, s1] : (s0 >= 0, -s0 + 1 >= 0, s1 == 0)> +module { + air.channel @channel_11 [1, 1] {broadcast_shape = [2, 1]} + air.channel @channel_10 [1, 1] {broadcast_shape = [2, 1]} + air.channel @channel_7 [1, 1] {broadcast_shape = [2, 1]} + air.channel @channel_6 [1, 1] {broadcast_shape = [2, 1]} + air.channel @channel_3 [1, 1] {broadcast_shape = [2, 1]} + air.channel @channel_2 [1, 1] {broadcast_shape = [2, 1]} + func.func @func7(%arg0: memref<2048x2048xi32>, %arg1: memref<2048x2048xi32>, %arg2: memref<2048x2048xi32>) { + %c32 = arith.constant 32 : index + %0 = air.launch async (%arg3, %arg4) in (%arg5=%c32, %arg6=%c32) attributes {id = 1 : i32} { + %1 = air.segment @segment_0 async attributes {id = 2 : i32} { + %c512 = arith.constant 512 : index + %c8 = arith.constant 8 : index + %c4 = arith.constant 4 : index + %c32_0 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c2048 = arith.constant 2048 : index + %c0 = arith.constant 0 : index + %c2 = arith.constant 2 : index + %c1 = arith.constant 1 : index + %c63 = arith.constant 63 : index + %async_token, %results = air.execute -> (memref<1x1x8x4x8x4xi32, 2 : i32>) { + %alloc = memref.alloc() : memref<1x1x8x4x8x4xi32, 2 : i32> + air.execute_terminator %alloc : memref<1x1x8x4x8x4xi32, 2 : i32> + } + %async_token_1, %results_2 = air.execute -> (memref<1x1x32x64xi32, 1 : i32>) { + %alloc = memref.alloc() : memref<1x1x32x64xi32, 1 : i32> + air.execute_terminator %alloc : memref<1x1x32x64xi32, 1 : i32> + } + %2 = air.channel.put async [%async_token_1] @channel_2[] (%results_2[%c0, %c0, %c0, %c0, %c0, %c0] [%c1, %c1, %c8, %c4, %c8, %c4] [%c2048, %c2048, %c4, %c512, %c64, %c1]) {id = 12 : i32} : (memref<1x1x32x64xi32, 1 : i32>) + %3 = air.channel.put async [%async_token_1] @channel_3[] (%results_2[%c0, %c0, %c0, %c0, %c0, %c32_0] [%c1, %c1, %c8, %c4, %c8, %c4] [%c2048, %c2048, %c4, %c512, %c64, %c1]) {id = 13 : i32} : (memref<1x1x32x64xi32, 1 : i32>) + %4 = air.herd @herd_0 async [%async_token_1] tile (%arg7, %arg8) in (%arg9=%c2, %arg10=%c2) args(%arg11=%results) : memref<1x1x8x4x8x4xi32, 2 : i32> attributes {id = 3 : i32} { + %9 = affine.if #set()[%arg7, %arg8] -> !air.async.token { + %10 = air.channel.get async @channel_2[%arg7, %arg8] (%arg11[] [] []) {id = 16 : i32} : (memref<1x1x8x4x8x4xi32, 2 : i32>) + affine.yield %10 : !air.async.token + } else { + %10 = air.channel.get async @channel_3[%arg7, %arg8] (%arg11[] [] []) {id = 17 : i32} : (memref<1x1x8x4x8x4xi32, 2 : i32>) + affine.yield %10 : !air.async.token + } + air.herd_terminator + } + scf.for %arg7 = %c1 to %c63 step %c1 { + %9 = air.channel.put async [%async_token_1] @channel_6[] (%results_2[%c0, %c0, %c0, %c0, %c0, %c0] [%c1, %c1, %c8, %c4, %c8, %c4] [%c2048, %c2048, %c4, %c512, %c64, %c1]) {id = 22 : i32} : (memref<1x1x32x64xi32, 1 : i32>) + } + scf.for %arg7 = %c1 to %c63 step %c1 { + %9 = air.channel.put async [%async_token_1] @channel_7[] (%results_2[%c0, %c0, %c0, %c0, %c0, %c32_0] [%c1, %c1, %c8, %c4, %c8, %c4] [%c2048, %c2048, %c4, %c512, %c64, %c1]) {id = 23 : i32} : (memref<1x1x32x64xi32, 1 : i32>) + } + %5 = air.herd @herd_0 async [%4] tile (%arg7, %arg8) in (%arg9=%c2, %arg10=%c2) args(%arg11=%results) : memref<1x1x8x4x8x4xi32, 2 : i32> attributes {id = 4 : i32} { + %c1_4 = arith.constant 1 : index + %c63_5 = arith.constant 63 : index + scf.for %arg12 = %c1_4 to %c63_5 step %c1_4 { + %9 = affine.if #set()[%arg7, %arg8] -> !air.async.token { + %10 = air.channel.get async @channel_6[%arg7, %arg8] (%arg11[] [] []) {id = 26 : i32} : (memref<1x1x8x4x8x4xi32, 2 : i32>) + affine.yield %10 : !air.async.token + } else { + %10 = air.channel.get async @channel_7[%arg7, %arg8] (%arg11[] [] []) {id = 27 : i32} : (memref<1x1x8x4x8x4xi32, 2 : i32>) + affine.yield %10 : !air.async.token + } + } + air.herd_terminator + } + %6 = air.channel.put async [%5] @channel_10[] (%results_2[%c0, %c0, %c0, %c0, %c0, %c0] [%c1, %c1, %c8, %c4, %c8, %c4] [%c2048, %c2048, %c4, %c512, %c64, %c1]) {id = 32 : i32} : (memref<1x1x32x64xi32, 1 : i32>) + %7 = air.channel.put async [%5] @channel_11[] (%results_2[%c0, %c0, %c0, %c0, %c0, %c32_0] [%c1, %c1, %c8, %c4, %c8, %c4] [%c2048, %c2048, %c4, %c512, %c64, %c1]) {id = 33 : i32} : (memref<1x1x32x64xi32, 1 : i32>) + %8 = air.herd @herd_0 async tile (%arg7, %arg8) in (%arg9=%c2, %arg10=%c2) args(%arg11=%results) : memref<1x1x8x4x8x4xi32, 2 : i32> attributes {id = 5 : i32} { + %9 = affine.if #set()[%arg7, %arg8] -> !air.async.token { + %10 = air.channel.get async @channel_10[%arg7, %arg8] (%arg11[] [] []) {id = 37 : i32} : (memref<1x1x8x4x8x4xi32, 2 : i32>) + affine.yield %10 : !air.async.token + } else { + %10 = air.channel.get async @channel_11[%arg7, %arg8] (%arg11[] [] []) {id = 38 : i32} : (memref<1x1x8x4x8x4xi32, 2 : i32>) + affine.yield %10 : !air.async.token + } + air.herd_terminator + } + %async_token_3 = air.execute [%8] { + memref.dealloc %results_2 : memref<1x1x32x64xi32, 1 : i32> + } + air.segment_terminator + } + air.launch_terminator + } + return + } +}