Skip to content

Commit

Permalink
Add more strict conditions when checking for channel fusion (Xilinx#586)
Browse files Browse the repository at this point in the history
* Add more strict conditions when checking for channel fusion

* Test
  • Loading branch information
erwei-xilinx authored May 28, 2024
1 parent e844dc2 commit c5c9de8
Show file tree
Hide file tree
Showing 2 changed files with 168 additions and 7 deletions.
58 changes: 51 additions & 7 deletions mlir/lib/Transform/AIRDependencyScheduleOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3257,6 +3257,20 @@ class AIRFuseChannels
return true;
return false;
}
// Check if two ssa value lists are identical.
bool areTheSameSSAValueLists(SmallVector<Value> a, SmallVector<Value> b) {
if (a.size() != b.size())
return false;
for (unsigned i = 0; i < a.size(); i++) {
auto constAElem = getConstantIntValue(a[i]);
auto constBElem = getConstantIntValue(b[i]);
if (constAElem && constBElem)
// Unequal constant values
if (*constAElem != *constBElem)
return false;
}
return true;
}
// Check of two air.channels are mergeable in time, by fusing into a shared
// scf.for loop. Returns a tuple of bool of whether mergeable, and string of
// fusing into for loop lower bound (LB) or upper bound (UB).
Expand All @@ -3280,26 +3294,56 @@ class AIRFuseChannels
return notMergeable;
if (a_gets.size() != 1)
return notMergeable;
// Check for identical src and dst memref
// Check for identical src and dst memrefs, offset, size and stride lists
Value aMemref = a_puts[0].getMemref();
SmallVector<Value> aOffsets = a_puts[0].getOffsets();
SmallVector<Value> aSizes = a_puts[0].getSizes();
SmallVector<Value> aStrides = a_puts[0].getStrides();
for (unsigned i = 1; i < a_puts.size(); i++)
if (aMemref != a_puts[i].getMemref())
if ((!areTheSameMemref(aMemref, a_puts[i].getMemref())) ||
(!areTheSameSSAValueLists(aOffsets, a_puts[i].getOffsets())) ||
(!areTheSameSSAValueLists(aSizes, a_puts[i].getSizes())) ||
(!areTheSameSSAValueLists(aStrides, a_puts[i].getStrides())))
return notMergeable; // Inconsistent memory use for all puts
Value bMemref = b_puts[0].getMemref();
SmallVector<Value> bOffsets = b_puts[0].getOffsets();
SmallVector<Value> bSizes = b_puts[0].getSizes();
SmallVector<Value> bStrides = b_puts[0].getStrides();
for (unsigned i = 1; i < b_puts.size(); i++)
if (bMemref != b_puts[i].getMemref())
if ((!areTheSameMemref(bMemref, b_puts[i].getMemref())) ||
(!areTheSameSSAValueLists(bOffsets, b_puts[i].getOffsets())) ||
(!areTheSameSSAValueLists(bSizes, b_puts[i].getSizes())) ||
(!areTheSameSSAValueLists(bStrides, b_puts[i].getStrides())))
return notMergeable; // Inconsistent memory use for all puts
if (!areTheSameMemref(aMemref, bMemref))
if ((!areTheSameMemref(aMemref, bMemref)) ||
(!areTheSameSSAValueLists(aOffsets, bOffsets)) ||
(!areTheSameSSAValueLists(aSizes, bSizes)) ||
(!areTheSameSSAValueLists(aStrides, bStrides)))
return notMergeable;
aMemref = a_gets[0].getMemref();
aOffsets = a_gets[0].getOffsets();
aSizes = a_gets[0].getSizes();
aStrides = a_gets[0].getStrides();
for (unsigned i = 1; i < a_gets.size(); i++)
if (aMemref != a_gets[i].getMemref())
if ((!areTheSameMemref(aMemref, a_gets[i].getMemref())) ||
(!areTheSameSSAValueLists(aOffsets, a_gets[i].getOffsets())) ||
(!areTheSameSSAValueLists(aSizes, a_gets[i].getSizes())) ||
(!areTheSameSSAValueLists(aStrides, a_gets[i].getStrides())))
return notMergeable; // Inconsistent memory use for all gets
bMemref = b_gets[0].getMemref();
bOffsets = b_gets[0].getOffsets();
bSizes = b_gets[0].getSizes();
bStrides = b_gets[0].getStrides();
for (unsigned i = 1; i < b_gets.size(); i++)
if (bMemref != b_gets[i].getMemref())
if ((!areTheSameMemref(bMemref, b_gets[i].getMemref())) ||
(!areTheSameSSAValueLists(bOffsets, b_gets[i].getOffsets())) ||
(!areTheSameSSAValueLists(bSizes, b_gets[i].getSizes())) ||
(!areTheSameSSAValueLists(bStrides, b_gets[i].getStrides())))
return notMergeable; // Inconsistent memory use for all gets
if (!areTheSameMemref(aMemref, bMemref))
if ((!areTheSameMemref(aMemref, bMemref)) ||
(!areTheSameSSAValueLists(aOffsets, bOffsets)) ||
(!areTheSameSSAValueLists(aSizes, bSizes)) ||
(!areTheSameSSAValueLists(aStrides, bStrides)))
return notMergeable;
for (unsigned i = 0; i < a_puts.size(); i++) {
auto a_put_loop_nest = getParentLoopNest(a_puts[i].getOperation());
Expand Down
117 changes: 117 additions & 0 deletions mlir/test/Transform/AIRDependencyScheduleOpt/fuse_channels.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -783,3 +783,120 @@ module {
return
}
}

// -----

// Merging air.channels into both scf.for op's LB and UB (L2->L1, with broadcast).

// CHECK-LABEL: func7
// CHECK: air.segment @segment_0
// CHECK: scf.for %{{.*}} = %c0{{.*}}to %c64{{.*}}step %c1{{.*}}{
// CHECK-NEXT: air.channel.put{{.*}}@channel_6{{.*}} : (memref<1x1x32x64xi32, 1 : i32>)
// CHECK-NEXT: }
// CHECK: scf.for %{{.*}} = %c0{{.*}}to %c64{{.*}}step %c1{{.*}}{
// CHECK-NEXT: air.channel.put{{.*}}@channel_7{{.*}} : (memref<1x1x32x64xi32, 1 : i32>)
// CHECK-NEXT: }
// CHECK: air.segment_terminator
// AGGRESSIVE-LABEL: func7
// AGGRESSIVE: air.segment @segment_0
// AGGRESSIVE: scf.for %{{.*}} = %c0{{.*}}to %c64{{.*}}step %c1{{.*}}{
// AGGRESSIVE-NEXT: air.channel.put{{.*}}@channel_6{{.*}} : (memref<1x1x32x64xi32, 1 : i32>)
// AGGRESSIVE-NEXT: }
// AGGRESSIVE: scf.for %{{.*}} = %c0{{.*}}to %c64{{.*}}step %c1{{.*}}{
// AGGRESSIVE-NEXT: air.channel.put{{.*}}@channel_7{{.*}} : (memref<1x1x32x64xi32, 1 : i32>)
// AGGRESSIVE-NEXT: }
// AGGRESSIVE: air.segment_terminator
// AGGL1-LABEL: func7
// AGGL1: air.segment @segment_0
// AGGL1: scf.for %{{.*}} = %c0{{.*}}to %c64{{.*}}step %c1{{.*}}{
// AGGL1-NEXT: air.channel.put{{.*}}@channel_6{{.*}} : (memref<1x1x32x64xi32, 1 : i32>)
// AGGL1-NEXT: }
// AGGL1: scf.for %{{.*}} = %c0{{.*}}to %c64{{.*}}step %c1{{.*}}{
// AGGL1-NEXT: air.channel.put{{.*}}@channel_7{{.*}} : (memref<1x1x32x64xi32, 1 : i32>)
// AGGL1-NEXT: }
// AGGL1: air.segment_terminator

#set = affine_set<()[s0, s1] : (s0 >= 0, -s0 + 1 >= 0, s1 == 0)>
module {
air.channel @channel_11 [1, 1] {broadcast_shape = [2, 1]}
air.channel @channel_10 [1, 1] {broadcast_shape = [2, 1]}
air.channel @channel_7 [1, 1] {broadcast_shape = [2, 1]}
air.channel @channel_6 [1, 1] {broadcast_shape = [2, 1]}
air.channel @channel_3 [1, 1] {broadcast_shape = [2, 1]}
air.channel @channel_2 [1, 1] {broadcast_shape = [2, 1]}
func.func @func7(%arg0: memref<2048x2048xi32>, %arg1: memref<2048x2048xi32>, %arg2: memref<2048x2048xi32>) {
%c32 = arith.constant 32 : index
%0 = air.launch async (%arg3, %arg4) in (%arg5=%c32, %arg6=%c32) attributes {id = 1 : i32} {
%1 = air.segment @segment_0 async attributes {id = 2 : i32} {
%c512 = arith.constant 512 : index
%c8 = arith.constant 8 : index
%c4 = arith.constant 4 : index
%c32_0 = arith.constant 32 : index
%c64 = arith.constant 64 : index
%c2048 = arith.constant 2048 : index
%c0 = arith.constant 0 : index
%c2 = arith.constant 2 : index
%c1 = arith.constant 1 : index
%c63 = arith.constant 63 : index
%async_token, %results = air.execute -> (memref<1x1x8x4x8x4xi32, 2 : i32>) {
%alloc = memref.alloc() : memref<1x1x8x4x8x4xi32, 2 : i32>
air.execute_terminator %alloc : memref<1x1x8x4x8x4xi32, 2 : i32>
}
%async_token_1, %results_2 = air.execute -> (memref<1x1x32x64xi32, 1 : i32>) {
%alloc = memref.alloc() : memref<1x1x32x64xi32, 1 : i32>
air.execute_terminator %alloc : memref<1x1x32x64xi32, 1 : i32>
}
%2 = air.channel.put async [%async_token_1] @channel_2[] (%results_2[%c0, %c0, %c0, %c0, %c0, %c0] [%c1, %c1, %c8, %c4, %c8, %c4] [%c2048, %c2048, %c4, %c512, %c64, %c1]) {id = 12 : i32} : (memref<1x1x32x64xi32, 1 : i32>)
%3 = air.channel.put async [%async_token_1] @channel_3[] (%results_2[%c0, %c0, %c0, %c0, %c0, %c32_0] [%c1, %c1, %c8, %c4, %c8, %c4] [%c2048, %c2048, %c4, %c512, %c64, %c1]) {id = 13 : i32} : (memref<1x1x32x64xi32, 1 : i32>)
%4 = air.herd @herd_0 async [%async_token_1] tile (%arg7, %arg8) in (%arg9=%c2, %arg10=%c2) args(%arg11=%results) : memref<1x1x8x4x8x4xi32, 2 : i32> attributes {id = 3 : i32} {
%9 = affine.if #set()[%arg7, %arg8] -> !air.async.token {
%10 = air.channel.get async @channel_2[%arg7, %arg8] (%arg11[] [] []) {id = 16 : i32} : (memref<1x1x8x4x8x4xi32, 2 : i32>)
affine.yield %10 : !air.async.token
} else {
%10 = air.channel.get async @channel_3[%arg7, %arg8] (%arg11[] [] []) {id = 17 : i32} : (memref<1x1x8x4x8x4xi32, 2 : i32>)
affine.yield %10 : !air.async.token
}
air.herd_terminator
}
scf.for %arg7 = %c1 to %c63 step %c1 {
%9 = air.channel.put async [%async_token_1] @channel_6[] (%results_2[%c0, %c0, %c0, %c0, %c0, %c0] [%c1, %c1, %c8, %c4, %c8, %c4] [%c2048, %c2048, %c4, %c512, %c64, %c1]) {id = 22 : i32} : (memref<1x1x32x64xi32, 1 : i32>)
}
scf.for %arg7 = %c1 to %c63 step %c1 {
%9 = air.channel.put async [%async_token_1] @channel_7[] (%results_2[%c0, %c0, %c0, %c0, %c0, %c32_0] [%c1, %c1, %c8, %c4, %c8, %c4] [%c2048, %c2048, %c4, %c512, %c64, %c1]) {id = 23 : i32} : (memref<1x1x32x64xi32, 1 : i32>)
}
%5 = air.herd @herd_0 async [%4] tile (%arg7, %arg8) in (%arg9=%c2, %arg10=%c2) args(%arg11=%results) : memref<1x1x8x4x8x4xi32, 2 : i32> attributes {id = 4 : i32} {
%c1_4 = arith.constant 1 : index
%c63_5 = arith.constant 63 : index
scf.for %arg12 = %c1_4 to %c63_5 step %c1_4 {
%9 = affine.if #set()[%arg7, %arg8] -> !air.async.token {
%10 = air.channel.get async @channel_6[%arg7, %arg8] (%arg11[] [] []) {id = 26 : i32} : (memref<1x1x8x4x8x4xi32, 2 : i32>)
affine.yield %10 : !air.async.token
} else {
%10 = air.channel.get async @channel_7[%arg7, %arg8] (%arg11[] [] []) {id = 27 : i32} : (memref<1x1x8x4x8x4xi32, 2 : i32>)
affine.yield %10 : !air.async.token
}
}
air.herd_terminator
}
%6 = air.channel.put async [%5] @channel_10[] (%results_2[%c0, %c0, %c0, %c0, %c0, %c0] [%c1, %c1, %c8, %c4, %c8, %c4] [%c2048, %c2048, %c4, %c512, %c64, %c1]) {id = 32 : i32} : (memref<1x1x32x64xi32, 1 : i32>)
%7 = air.channel.put async [%5] @channel_11[] (%results_2[%c0, %c0, %c0, %c0, %c0, %c32_0] [%c1, %c1, %c8, %c4, %c8, %c4] [%c2048, %c2048, %c4, %c512, %c64, %c1]) {id = 33 : i32} : (memref<1x1x32x64xi32, 1 : i32>)
%8 = air.herd @herd_0 async tile (%arg7, %arg8) in (%arg9=%c2, %arg10=%c2) args(%arg11=%results) : memref<1x1x8x4x8x4xi32, 2 : i32> attributes {id = 5 : i32} {
%9 = affine.if #set()[%arg7, %arg8] -> !air.async.token {
%10 = air.channel.get async @channel_10[%arg7, %arg8] (%arg11[] [] []) {id = 37 : i32} : (memref<1x1x8x4x8x4xi32, 2 : i32>)
affine.yield %10 : !air.async.token
} else {
%10 = air.channel.get async @channel_11[%arg7, %arg8] (%arg11[] [] []) {id = 38 : i32} : (memref<1x1x8x4x8x4xi32, 2 : i32>)
affine.yield %10 : !air.async.token
}
air.herd_terminator
}
%async_token_3 = air.execute [%8] {
memref.dealloc %results_2 : memref<1x1x32x64xi32, 1 : i32>
}
air.segment_terminator
}
air.launch_terminator
}
return
}
}

0 comments on commit c5c9de8

Please sign in to comment.