Skip to content

Commit

Permalink
Updated i8 conv test with for loop inside of ukernel (Xilinx#747)
Browse files Browse the repository at this point in the history
  • Loading branch information
erwei-xilinx authored Oct 21, 2024
1 parent f2a164a commit 7372bd7
Show file tree
Hide file tree
Showing 2 changed files with 145 additions and 73 deletions.
90 changes: 42 additions & 48 deletions test/xrt/14_conv2d_i8_extern_vec/aie.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,63 +20,56 @@
################################################

air_tiled_ir_string = """
#map = affine_map<()[s0] -> (s0 * 4)>
#map1 = affine_map<()[s0] -> (s0 * 8)>
module {
func.func @conv_static_dispatch_0_conv_2d_nhwc_hwcf_2x12x12x64x3x3x32_i32(%0 : memref<2x14x14x32xi8>, %1 : memref<3x3x32x64xi8>, %2 : memref<2x12x12x64xi32>) {
%c4 = arith.constant 4 : index
%c2 = arith.constant 2 : index
func.func private @conv(memref<1x3x4x6x8xi8, 2 : i32>, memref<3x3x4x1x8x8xi8, 2 : i32>, memref<1x1x4x1x8xi32, 2 : i32>) attributes {link_with = "conv.o", llvm.emit_c_interface}
func.func @conv_2d_nhwc_hwcf_q_dispatch_0_conv_2d_nhwc_hwcf_2x12x12x64x3x3x32_i8xi8xi32(%0 : memref<2x14x14x32xi8>, %1 : memref<3x3x32x64xi8>, %2 : memref<2x12x12x64xi32>) {
%c8 = arith.constant 8 : index
%c32 = arith.constant 32 : index
%c64 = arith.constant 64 : index
%c12 = arith.constant 12 : index
%c2 = arith.constant 2 : index
%c0 = arith.constant 0 : index
%c0_i32 = arith.constant 0 : i32
%c3 = arith.constant 3 : index
%c4 = arith.constant 4 : index
%c1 = arith.constant 1 : index
%c0_i32 = arith.constant 0 : i32
%c0 = arith.constant 0 : index
scf.parallel (%arg0, %arg1, %arg2, %arg3) = (%c0, %c0, %c0, %c0) to (%c2, %c3, %c3, %c8) step (%c1, %c1, %c1, %c1) {
%3 = affine.apply #map()[%arg1]
%4 = affine.apply #map()[%arg2]
%5 = affine.apply #map1()[%arg3]
%subview = memref.subview %0[%arg0, %3, %4, 0] [1, 6, 6, 32] [1, 1, 1, 1] : memref<2x14x14x32xi8> to memref<1x6x6x32xi8, strided<[6272, 448, 32, 1], offset: ?>>
%subview_0 = memref.subview %1[0, 0, 0, %5] [3, 3, 32, 8] [1, 1, 1, 1] : memref<3x3x32x64xi8> to memref<3x3x32x8xi8, strided<[6144, 2048, 64, 1], offset: ?>>
%subview_1 = memref.subview %2[%arg0, %3, %4, %5] [1, 4, 4, 8] [1, 1, 1, 1] : memref<2x12x12x64xi32> to memref<1x4x4x8xi32, strided<[9216, 768, 64, 1], offset: ?>>
%alloc = memref.alloc() : memref<1x6x6x32xi8, 1>
memref.copy %subview, %alloc : memref<1x6x6x32xi8, strided<[6272, 448, 32, 1], offset: ?>> to memref<1x6x6x32xi8, 1>
%alloc_2 = memref.alloc() : memref<3x3x32x8xi8, 1>
memref.copy %subview_0, %alloc_2 : memref<3x3x32x8xi8, strided<[6144, 2048, 64, 1], offset: ?>> to memref<3x3x32x8xi8, 1>
%alloc_3 = memref.alloc() : memref<1x4x4x8xi32, 1>
%alloc = memref.alloc() : memref<1x1x4x1x8xi32, 2 : i32>
%alloc_0 = memref.alloc() : memref<3x3x4x1x8x8xi8, 2 : i32>
%alloc_1 = memref.alloc() : memref<1x3x4x6x8xi8, 2 : i32>
%alloc_2 = memref.alloc() : memref<1x4x4x8xi32, 1 : i32>
%alloc_3 = memref.alloc() : memref<3x3x32x8xi8, 1 : i32>
%alloc_4 = memref.alloc() : memref<1x6x6x32xi8, 1 : i32>
scf.parallel (%arg0, %arg1, %arg2, %arg3) = (%c0, %c0, %c0, %c0) to (%c2, %c12, %c12, %c64) step (%c1, %c4, %c4, %c8) {
%subview = memref.subview %0[%arg0, %arg1, %arg2, 0] [1, 6, 6, 32] [1, 1, 1, 1] : memref<2x14x14x32xi8> to memref<1x6x6x32xi8, strided<[6272, 448, 32, 1], offset: ?>>
%subview_5 = memref.subview %1[0, 0, 0, %arg3] [3, 3, 32, 8] [1, 1, 1, 1] : memref<3x3x32x64xi8> to memref<3x3x32x8xi8, strided<[6144, 2048, 64, 1], offset: ?>>
%subview_6 = memref.subview %2[%arg0, %arg1, %arg2, %arg3] [1, 4, 4, 8] [1, 1, 1, 1] : memref<2x12x12x64xi32> to memref<1x4x4x8xi32, strided<[9216, 768, 64, 1], offset: ?>>
memref.copy %subview, %alloc_4 : memref<1x6x6x32xi8, strided<[6272, 448, 32, 1], offset: ?>> to memref<1x6x6x32xi8, 1 : i32>
memref.copy %subview_5, %alloc_3 : memref<3x3x32x8xi8, strided<[6144, 2048, 64, 1], offset: ?>> to memref<3x3x32x8xi8, 1 : i32>
scf.parallel (%arg4) = (%c0) to (%c4) step (%c1) {
%subview_4 = memref.subview %alloc[0, %arg4, 0, 0] [1, 3, 6, 32] [1, 1, 1, 1] : memref<1x6x6x32xi8, 1> to memref<1x3x6x32xi8, strided<[1152, 192, 32, 1], offset: ?>, 1>
%subview_5 = memref.subview %alloc_3[0, %arg4, 0, 0] [1, 1, 4, 8] [1, 1, 1, 1] : memref<1x4x4x8xi32, 1> to memref<1x1x4x8xi32, strided<[128, 32, 8, 1], offset: ?>, 1>
%alloc_6 = memref.alloc() : memref<1x1x4x8xi32, 2>
linalg.fill ins(%c0_i32 : i32) outs(%alloc_6 : memref<1x1x4x8xi32, 2>)
%subview_7 = memref.subview %alloc_6[0, 0, 0, 0] [1, 1, 4, 8] [1, 1, 1, 1] : memref<1x1x4x8xi32, 2> to memref<1x4x8xi32, strided<[32, 8, 1]>, 2>
scf.for %arg5 = %c0 to %c3 step %c1 {
scf.for %arg6 = %c0 to %c3 step %c1 {
scf.for %arg7 = %c0 to %c32 step %c8 {
%subview_8 = memref.subview %subview_4[0, %arg5, %arg6, %arg7] [1, 1, 4, 8] [1, 1, 1, 1] : memref<1x3x6x32xi8, strided<[1152, 192, 32, 1], offset: ?>, 1> to memref<1x1x4x8xi8, strided<[1152, 192, 32, 1], offset: ?>, 1>
%subview_9 = memref.subview %alloc_2[%arg5, %arg6, %arg7, 0] [1, 1, 8, 8] [1, 1, 1, 1] : memref<3x3x32x8xi8, 1> to memref<1x1x8x8xi8, strided<[768, 256, 8, 1], offset: ?>, 1>
%subview_10 = memref.subview %subview_8[0, 0, 0, 0] [1, 1, 4, 8] [1, 1, 1, 1] : memref<1x1x4x8xi8, strided<[1152, 192, 32, 1], offset: ?>, 1> to memref<1x4x8xi8, strided<[1152, 32, 1], offset: ?>, 1>
%subview_11 = memref.subview %subview_9[0, 0, 0, 0] [1, 1, 8, 8] [1, 1, 1, 1] : memref<1x1x8x8xi8, strided<[768, 256, 8, 1], offset: ?>, 1> to memref<1x8x8xi8, strided<[768, 8, 1], offset: ?>, 1>
%alloc_12 = memref.alloc() : memref<1x4x8xi8, 2>
memref.copy %subview_10, %alloc_12 : memref<1x4x8xi8, strided<[1152, 32, 1], offset: ?>, 1> to memref<1x4x8xi8, 2>
%alloc_13 = memref.alloc() : memref<1x8x8xi8, 2>
memref.copy %subview_11, %alloc_13 : memref<1x8x8xi8, strided<[768, 8, 1], offset: ?>, 1> to memref<1x8x8xi8, 2>
linalg.conv_1d_nwc_wcf {dilations = dense<1> : vector<1xi64>, strides = dense<1> : vector<1xi64>} ins(%alloc_12, %alloc_13 : memref<1x4x8xi8, 2>, memref<1x8x8xi8, 2>) outs(%subview_7 : memref<1x4x8xi32, strided<[32, 8, 1]>, 2>)
memref.dealloc %alloc_12 : memref<1x4x8xi8, 2>
memref.dealloc %alloc_13 : memref<1x8x8xi8, 2>
}
}
}
memref.copy %alloc_6, %subview_5 : memref<1x1x4x8xi32, 2> to memref<1x1x4x8xi32, strided<[128, 32, 8, 1], offset: ?>, 1>
memref.dealloc %alloc_6 : memref<1x1x4x8xi32, 2>
%subview_7 = memref.subview %alloc_4[0, %arg4, 0, 0] [1, 3, 6, 32] [1, 1, 1, 1] : memref<1x6x6x32xi8, 1 : i32> to memref<1x3x6x32xi8, strided<[1152, 192, 32, 1], offset: ?>, 1 : i32>
%cast = memref.cast %alloc_3 : memref<3x3x32x8xi8, 1 : i32> to memref<3x3x32x8xi8, strided<[768, 256, 8, 1], offset: ?>, 1 : i32>
%subview_8 = memref.subview %alloc_2[0, %arg4, 0, 0] [1, 1, 4, 8] [1, 1, 1, 1] : memref<1x4x4x8xi32, 1 : i32> to memref<1x1x4x8xi32, strided<[128, 32, 8, 1], offset: ?>, 1 : i32>
%expand_shape = memref.expand_shape %subview_7 [[0], [1], [2], [3, 4]] output_shape [1, 3, 6, 4, 8] : memref<1x3x6x32xi8, strided<[1152, 192, 32, 1], offset: ?>, 1 : i32> into memref<1x3x6x4x8xi8, strided<[1152, 192, 32, 8, 1], offset: ?>, 1 : i32>
%transpose = memref.transpose %expand_shape (d0, d1, d2, d3, d4) -> (d0, d1, d3, d2, d4) : memref<1x3x6x4x8xi8, strided<[1152, 192, 32, 8, 1], offset: ?>, 1 : i32> to memref<1x3x4x6x8xi8, strided<[1152, 192, 8, 32, 1], offset: ?>, 1 : i32>
air.dma_memcpy_nd (%alloc_1[] [] [], %transpose[] [] []) : (memref<1x3x4x6x8xi8, 2 : i32>, memref<1x3x4x6x8xi8, strided<[1152, 192, 8, 32, 1], offset: ?>, 1 : i32>)
%expand_shape_9 = memref.expand_shape %cast [[0], [1], [2, 3], [4, 5]] output_shape [3, 3, 4, 8, 1, 8] : memref<3x3x32x8xi8, strided<[768, 256, 8, 1], offset: ?>, 1 : i32> into memref<3x3x4x8x1x8xi8, strided<[768, 256, 64, 8, 8, 1], offset: ?>, 1 : i32>
%transpose_10 = memref.transpose %expand_shape_9 (d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4, d3, d5) : memref<3x3x4x8x1x8xi8, strided<[768, 256, 64, 8, 8, 1], offset: ?>, 1 : i32> to memref<3x3x4x1x8x8xi8, strided<[768, 256, 64, 8, 8, 1], offset: ?>, 1 : i32>
air.dma_memcpy_nd (%alloc_0[] [] [], %transpose_10[] [] []) : (memref<3x3x4x1x8x8xi8, 2 : i32>, memref<3x3x4x1x8x8xi8, strided<[768, 256, 64, 8, 8, 1], offset: ?>, 1 : i32>)
linalg.fill ins(%c0_i32 : i32) outs(%alloc : memref<1x1x4x1x8xi32, 2 : i32>)
func.call @conv(%alloc_1, %alloc_0, %alloc) : (memref<1x3x4x6x8xi8, 2 : i32>, memref<3x3x4x1x8x8xi8, 2 : i32>, memref<1x1x4x1x8xi32, 2 : i32>) -> ()
%subview_11 = memref.subview %alloc[0, 0, 0, 0, 0] [1, 1, 4, 1, 8] [1, 1, 1, 1, 1] : memref<1x1x4x1x8xi32, 2 : i32> to memref<1x1x4x8xi32, 2 : i32>
%transpose_12 = memref.transpose %subview_11 (d0, d1, d2, d3) -> (d0, d1, d2, d3) : memref<1x1x4x8xi32, 2 : i32> to memref<1x1x4x8xi32, strided<[32, 32, 8, 1]>, 2 : i32>
air.dma_memcpy_nd (%subview_8[] [] [], %transpose_12[] [] []) : (memref<1x1x4x8xi32, strided<[128, 32, 8, 1], offset: ?>, 1 : i32>, memref<1x1x4x8xi32, strided<[32, 32, 8, 1]>, 2 : i32>)
scf.reduce
}
memref.copy %alloc_3, %subview_1 : memref<1x4x4x8xi32, 1> to memref<1x4x4x8xi32, strided<[9216, 768, 64, 1], offset: ?>>
memref.dealloc %alloc : memref<1x6x6x32xi8, 1>
memref.dealloc %alloc_2 : memref<3x3x32x8xi8, 1>
memref.dealloc %alloc_3 : memref<1x4x4x8xi32, 1>
memref.copy %alloc_2, %subview_6 : memref<1x4x4x8xi32, 1 : i32> to memref<1x4x4x8xi32, strided<[9216, 768, 64, 1], offset: ?>>
scf.reduce
}
memref.dealloc %alloc_4 : memref<1x6x6x32xi8, 1 : i32>
memref.dealloc %alloc_3 : memref<3x3x32x8xi8, 1 : i32>
memref.dealloc %alloc_2 : memref<1x4x4x8xi32, 1 : i32>
memref.dealloc %alloc_1 : memref<1x3x4x6x8xi8, 2 : i32>
memref.dealloc %alloc_0 : memref<3x3x4x1x8x8xi8, 2 : i32>
memref.dealloc %alloc : memref<1x1x4x1x8xi32, 2 : i32>
return
}
}
Expand Down Expand Up @@ -123,6 +116,7 @@
"air-dependency-canonicalize",
"canonicalize",
"cse",
"func.func(air-split-l2-memref)",
"air-isolate-async-dma-loop-nests",
"func.func(air-loop-fusion)",
"air-label-scf-for-to-ping-pong",
Expand Down
128 changes: 103 additions & 25 deletions test/xrt/14_conv2d_i8_extern_vec/conv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,57 +24,135 @@

#include "zero.cc"

template <typename T_in, typename T_out, unsigned rowA, unsigned colA,
unsigned colB, unsigned r, unsigned s, unsigned t>
template <typename T_in, typename T_out, unsigned kerCol, unsigned kerRow,
unsigned chanTile, unsigned imgWidth, unsigned r, unsigned s,
unsigned t>
void conv_vectorized(const T_in *__restrict pA, const T_in *__restrict pB,
T_out *__restrict pC) {
using MMUL = aie::mmul<r, s, t, T_in, T_in>;

static_assert(r == 4);
static_assert(s == 8);
static_assert(t == 8);
static_assert(MMUL::size_A == 32);
static_assert(MMUL::size_B == 64);

static_assert(imgWidth == r + kerRow - 1); // stride 1

event0();

aie::vector<T_out, MMUL::size_C> acc_C00 = aie::load_v<MMUL::size_C>(pC);
MMUL C00(acc_C00);
const T_in *__restrict pA1 = pA;
const T_in *__restrict pB1 = pB;
for (unsigned i = 0; i < colA; i += 1)
chess_prepare_for_pipelining chess_loop_range(18, ) {
aie::vector<T_in, MMUL::size_A> A0 = aie::load_v<MMUL::size_A>(pA1);
aie::vector<T_in, MMUL::size_B> B0 = aie::load_v<MMUL::size_B>(pB1);
C00.mac(A0, B0);
pA1 += MMUL::size_A;
pB1 += MMUL::size_B;

for (unsigned i = 0; i < kerCol; i++)
chess_loop_range(3, ) {

for (unsigned z = 0; z < chanTile; z += 4)
chess_prepare_for_pipelining chess_loop_range(4, ) {

const T_in *__restrict pA1 =
pA + i * chanTile * imgWidth * s + z * imgWidth * s;
const T_in *__restrict pB1 =
pB + i * MMUL::size_B * chanTile * kerRow + z * MMUL::size_B;

aie::vector<T_in, 64> A0_0 = aie::load_v<64>(pA1);
aie::vector<T_in, 64> A0_1 = aie::load_v<64>(pA1 + 64);
aie::vector<T_in, 64> A0_2 = aie::load_v<64>(pA1 + 128);

// z = 0
aie::vector<T_in, 32> A0 = extract_v32int8(A0_0, 0);
aie::vector<T_in, MMUL::size_B> B0 = aie::load_v<MMUL::size_B>(pB1);
pB1 += MMUL::size_B * chanTile;
C00.mac(A0, B0);

A0 = extract_v32int8(shift(A0_0, A0_1, 8), 0);
B0 = aie::load_v<MMUL::size_B>(pB1);
pB1 += MMUL::size_B * chanTile;
C00.mac(A0, B0);

A0 = extract_v32int8(shift(A0_0, A0_1, 16), 0);
B0 = aie::load_v<MMUL::size_B>(pB1);
pB1 += MMUL::size_B * chanTile;
C00.mac(A0, B0);

// z = 1
pB1 = pB + i * MMUL::size_B * chanTile * kerRow + z * MMUL::size_B +
MMUL::size_B;
A0 = extract_v32int8(shift(A0_0, A0_1, 48), 0);
B0 = aie::load_v<MMUL::size_B>(pB1);
pB1 += MMUL::size_B * chanTile;
C00.mac(A0, B0);

A0 = extract_v32int8(shift(A0_0, A0_1, 56), 0);
B0 = aie::load_v<MMUL::size_B>(pB1);
pB1 += MMUL::size_B * chanTile;
C00.mac(A0, B0);

A0 = extract_v32int8(A0_1, 0);
B0 = aie::load_v<MMUL::size_B>(pB1);
pB1 += MMUL::size_B * chanTile;
C00.mac(A0, B0);

// z = 2
pB1 = pB + i * MMUL::size_B * chanTile * kerRow + z * MMUL::size_B +
2 * MMUL::size_B;
A0 = extract_v32int8(shift(A0_1, A0_2, 32), 0);
B0 = aie::load_v<MMUL::size_B>(pB1);
pB1 += MMUL::size_B * chanTile;
C00.mac(A0, B0);

A0 = extract_v32int8(shift(A0_1, A0_2, 40), 0);
B0 = aie::load_v<MMUL::size_B>(pB1);
pB1 += MMUL::size_B * chanTile;
C00.mac(A0, B0);

A0 = extract_v32int8(shift(A0_1, A0_2, 48), 0);
B0 = aie::load_v<MMUL::size_B>(pB1);
pB1 += MMUL::size_B * chanTile;
C00.mac(A0, B0);

// z = 3
pB1 = pB + i * MMUL::size_B * chanTile * kerRow + z * MMUL::size_B +
3 * MMUL::size_B;
A0 = extract_v32int8(shift(A0_2, undef_v64int8(), 16), 0);
B0 = aie::load_v<MMUL::size_B>(pB1);
pB1 += MMUL::size_B * chanTile;
C00.mac(A0, B0);

A0 = extract_v32int8(shift(A0_2, undef_v64int8(), 24), 0);
B0 = aie::load_v<MMUL::size_B>(pB1);
pB1 += MMUL::size_B * chanTile;
C00.mac(A0, B0);

A0 = extract_v32int8(shift(A0_2, undef_v64int8(), 32), 0);
B0 = aie::load_v<MMUL::size_B>(pB1);
pB1 += MMUL::size_B * chanTile;
C00.mac(A0, B0);
}
}

aie::store_v(pC, C00.template to_vector<T_out>());

event1();
}

template <unsigned m, unsigned k, unsigned n>
void conv_vectorized_4x8x8_i8_i32(const int8 *__restrict pA,
const int8 *__restrict pB,
int *__restrict pC) {
template <unsigned kerCol, unsigned kerRow, unsigned inputChan,
unsigned inputWidth>
void conv_vectorized_3x3x32x6_i8_i32(const int8 *__restrict pA,
const int8 *__restrict pB,
int *__restrict pC) {
constexpr int r = 4;
constexpr int s = 8;
constexpr int t = 8;
static_assert(m % r == 0);
static_assert(k % s == 0);
static_assert(n % t == 0);
return conv_vectorized<int8, int, m / r, k / s, n / t, r, s, t>(pA, pB, pC);
static_assert(inputChan % s == 0);
return conv_vectorized<int8, int, kerCol, kerRow, inputChan / s, inputWidth,
r, s, t>(pA, pB, pC);
}

extern "C" {

void linalg_conv_1d_nwc_wcf_view1x4x8xi8as2_view1x8x8xi8as2_view1x4x8xi32as2(
int8 *a_in, int8 *b_in, int *c_out) {
conv_vectorized_4x8x8_i8_i32<4, 8, 8>(a_in, b_in, c_out);
void conv(int8 *a_in, int8 *b_in, int *c_out) {
conv_vectorized_3x3x32x6_i8_i32<3, 3, 32, 6>(a_in, b_in, c_out);
}
void linalg_fill_i32_view1x1x4x8xi32as2(int *c_out) {
void linalg_fill_i32_view1x1x4x1x8xi32as2(int *c_out) {
zero_vectorized<int, 4, 8, 32>(c_out);
}

Expand Down

0 comments on commit 7372bd7

Please sign in to comment.