Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WAsm SIMD QS8 GEMM/IGEMM microkernels using ExtMul and ExtAddPair instructions #1271

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -1378,6 +1378,14 @@ WASMSIMD_UKERNELS = [
"src/qs8-gavgpool/gen/7x-minmax-wasmsimd-c8-acc2.c",
"src/qs8-gavgpool/gen/7x-minmax-wasmsimd-c16-acc2.c",
"src/qs8-gavgpool/gen/7x-minmax-wasmsimd-c24-acc2.c",
"src/qs8-gemm/gen/1x4c8-minmax-wasmsimd-extmul-widen.c",
"src/qs8-gemm/gen/2x4c8-minmax-wasmsimd-extmul-widen.c",
"src/qs8-gemm/gen/3x4c8-minmax-wasmsimd-extmul-widen.c",
"src/qs8-gemm/gen/4x4c8-minmax-wasmsimd-extmul-widen.c",
"src/qs8-gemm/gen/1x4c8-minmax-wasmsimd-extmul-extaddpair.c",
"src/qs8-gemm/gen/2x4c8-minmax-wasmsimd-extmul-extaddpair.c",
"src/qs8-gemm/gen/3x4c8-minmax-wasmsimd-extmul-extaddpair.c",
"src/qs8-gemm/gen/4x4c8-minmax-wasmsimd-extmul-extaddpair.c",
"src/qs8-gemm/gen/1x4c8-minmax-wasmsimd-ld64.c",
"src/qs8-gemm/gen/1x4c8-minmax-wasmsimd-ld128.c",
"src/qs8-gemm/gen/1x4c8-xw-minmax-wasmsimd.c",
Expand All @@ -1387,6 +1395,14 @@ WASMSIMD_UKERNELS = [
"src/qs8-gemm/gen/3x4c8-minmax-wasmsimd-ld64.c",
"src/qs8-gemm/gen/3x4c8-minmax-wasmsimd-ld128.c",
"src/qs8-gemm/gen/3x4c8-xw-minmax-wasmsimd.c",
"src/qs8-igemm/gen/1x4c8-minmax-wasmsimd-extmul-widen.c",
"src/qs8-igemm/gen/2x4c8-minmax-wasmsimd-extmul-widen.c",
"src/qs8-igemm/gen/3x4c8-minmax-wasmsimd-extmul-widen.c",
"src/qs8-igemm/gen/4x4c8-minmax-wasmsimd-extmul-widen.c",
"src/qs8-igemm/gen/1x4c8-minmax-wasmsimd-extmul-extaddpair.c",
"src/qs8-igemm/gen/2x4c8-minmax-wasmsimd-extmul-extaddpair.c",
"src/qs8-igemm/gen/3x4c8-minmax-wasmsimd-extmul-extaddpair.c",
"src/qs8-igemm/gen/4x4c8-minmax-wasmsimd-extmul-extaddpair.c",
"src/qs8-igemm/gen/1x4c8-minmax-wasmsimd-ld64.c",
"src/qs8-igemm/gen/1x4c8-minmax-wasmsimd-ld128.c",
"src/qs8-igemm/gen/2x4c8-minmax-wasmsimd-ld64.c",
Expand Down
12 changes: 12 additions & 0 deletions scripts/generate-qs8-gemm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,18 @@ tools/xngen src/qs8-gemm/MRx4c8-wasmsimd.c.in -D MR=1 -D VARIANT=EXTENDED -o src
tools/xngen src/qs8-gemm/MRx4c8-wasmsimd.c.in -D MR=2 -D VARIANT=EXTENDED -o src/qs8-gemm/gen/2x4c8-xw-minmax-wasmsimd.c
tools/xngen src/qs8-gemm/MRx4c8-wasmsimd.c.in -D MR=3 -D VARIANT=EXTENDED -o src/qs8-gemm/gen/3x4c8-xw-minmax-wasmsimd.c

### C8 ExtMul+Widen micro-kernels
tools/xngen src/qs8-gemm/MRx4c8-wasmsimd-extmul-widen.c.in -D MR=1 -o src/qs8-gemm/gen/1x4c8-minmax-wasmsimd-extmul-widen.c
tools/xngen src/qs8-gemm/MRx4c8-wasmsimd-extmul-widen.c.in -D MR=2 -o src/qs8-gemm/gen/2x4c8-minmax-wasmsimd-extmul-widen.c
tools/xngen src/qs8-gemm/MRx4c8-wasmsimd-extmul-widen.c.in -D MR=3 -o src/qs8-gemm/gen/3x4c8-minmax-wasmsimd-extmul-widen.c
tools/xngen src/qs8-gemm/MRx4c8-wasmsimd-extmul-widen.c.in -D MR=4 -o src/qs8-gemm/gen/4x4c8-minmax-wasmsimd-extmul-widen.c

### C8 ExtMul+ExtAddPair micro-kernels
tools/xngen src/qs8-gemm/MRx4c8-wasmsimd-extmul-extaddpair.c.in -D MR=1 -o src/qs8-gemm/gen/1x4c8-minmax-wasmsimd-extmul-extaddpair.c
tools/xngen src/qs8-gemm/MRx4c8-wasmsimd-extmul-extaddpair.c.in -D MR=2 -o src/qs8-gemm/gen/2x4c8-minmax-wasmsimd-extmul-extaddpair.c
tools/xngen src/qs8-gemm/MRx4c8-wasmsimd-extmul-extaddpair.c.in -D MR=3 -o src/qs8-gemm/gen/3x4c8-minmax-wasmsimd-extmul-extaddpair.c
tools/xngen src/qs8-gemm/MRx4c8-wasmsimd-extmul-extaddpair.c.in -D MR=4 -o src/qs8-gemm/gen/4x4c8-minmax-wasmsimd-extmul-extaddpair.c

################################### ARM NEON ##################################
tools/xngen src/qs8-gemm/neon-mlal-lane.c.in -D MR=1 -D NR=8 -o src/qs8-gemm/gen/1x8-minmax-neon-mlal-lane.c
tools/xngen src/qs8-gemm/neon-mlal-lane.c.in -D MR=2 -D NR=8 -o src/qs8-gemm/gen/2x8-minmax-neon-mlal-lane.c
Expand Down
12 changes: 12 additions & 0 deletions scripts/generate-qs8-igemm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,18 @@ tools/xngen src/qs8-igemm/MRx4c8-wasmsimd.c.in -D MR=1 -D VARIANT=LD128 -o src/q
tools/xngen src/qs8-igemm/MRx4c8-wasmsimd.c.in -D MR=2 -D VARIANT=LD128 -o src/qs8-igemm/gen/2x4c8-minmax-wasmsimd-ld128.c
tools/xngen src/qs8-igemm/MRx4c8-wasmsimd.c.in -D MR=3 -D VARIANT=LD128 -o src/qs8-igemm/gen/3x4c8-minmax-wasmsimd-ld128.c

### C8 ExtMul+Widen micro-kernels
tools/xngen src/qs8-igemm/MRx4c8-wasmsimd-extmul-widen.c.in -D MR=1 -o src/qs8-igemm/gen/1x4c8-minmax-wasmsimd-extmul-widen.c
tools/xngen src/qs8-igemm/MRx4c8-wasmsimd-extmul-widen.c.in -D MR=2 -o src/qs8-igemm/gen/2x4c8-minmax-wasmsimd-extmul-widen.c
tools/xngen src/qs8-igemm/MRx4c8-wasmsimd-extmul-widen.c.in -D MR=3 -o src/qs8-igemm/gen/3x4c8-minmax-wasmsimd-extmul-widen.c
tools/xngen src/qs8-igemm/MRx4c8-wasmsimd-extmul-widen.c.in -D MR=4 -o src/qs8-igemm/gen/4x4c8-minmax-wasmsimd-extmul-widen.c

### C8 ExtMul+ExtAddPair micro-kernels
tools/xngen src/qs8-igemm/MRx4c8-wasmsimd-extmul-extaddpair.c.in -D MR=1 -o src/qs8-igemm/gen/1x4c8-minmax-wasmsimd-extmul-extaddpair.c
tools/xngen src/qs8-igemm/MRx4c8-wasmsimd-extmul-extaddpair.c.in -D MR=2 -o src/qs8-igemm/gen/2x4c8-minmax-wasmsimd-extmul-extaddpair.c
tools/xngen src/qs8-igemm/MRx4c8-wasmsimd-extmul-extaddpair.c.in -D MR=3 -o src/qs8-igemm/gen/3x4c8-minmax-wasmsimd-extmul-extaddpair.c
tools/xngen src/qs8-igemm/MRx4c8-wasmsimd-extmul-extaddpair.c.in -D MR=4 -o src/qs8-igemm/gen/4x4c8-minmax-wasmsimd-extmul-extaddpair.c

################################### ARM NEON ##################################
tools/xngen src/qs8-igemm/neon-mlal-lane.c.in -D MR=1 -D NR=8 -o src/qs8-igemm/gen/1x8-minmax-neon-mlal-lane.c
tools/xngen src/qs8-igemm/neon-mlal-lane.c.in -D MR=2 -D NR=8 -o src/qs8-igemm/gen/2x8-minmax-neon-mlal-lane.c
Expand Down
10 changes: 5 additions & 5 deletions src/init.c
Original file line number Diff line number Diff line change
Expand Up @@ -2097,11 +2097,11 @@ static void init(void) {
#ifndef XNN_NO_QS8_OPERATORS
init_flags |= XNN_INIT_FLAG_QS8;

xnn_params.qs8.gemm.minmax.gemm = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_qs8_gemm_minmax_ukernel_3x4c8__wasmsimd_ld64);
xnn_params.qs8.gemm.minmax.igemm = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_qs8_igemm_minmax_ukernel_3x4c8__wasmsimd_ld64);
xnn_params.qs8.gemm.minmax.gemm1 = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_qs8_gemm_minmax_ukernel_1x4c8__wasmsimd_ld64);
xnn_params.qs8.gemm.minmax.igemm1 = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_qs8_igemm_minmax_ukernel_1x4c8__wasmsimd_ld64);
xnn_params.qs8.gemm.mr = 3;
xnn_params.qs8.gemm.minmax.gemm = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_qs8_gemm_minmax_ukernel_4x4c8__wasmsimd_extmul_extaddpair);
xnn_params.qs8.gemm.minmax.igemm = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_qs8_igemm_minmax_ukernel_4x4c8__wasmsimd_extmul_extaddpair);
xnn_params.qs8.gemm.minmax.gemm1 = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_qs8_gemm_minmax_ukernel_1x4c8__wasmsimd_extmul_extaddpair);
xnn_params.qs8.gemm.minmax.igemm1 = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_qs8_igemm_minmax_ukernel_1x4c8__wasmsimd_extmul_extaddpair);
xnn_params.qs8.gemm.mr = 4;
xnn_params.qs8.gemm.nr = 4;
xnn_params.qs8.gemm.log2_kr = 3;

Expand Down
160 changes: 160 additions & 0 deletions src/qs8-gemm/MRx4c8-wasmsimd-extmul-extaddpair.c.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
// Copyright 2020 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

$assert MR <= 4
#include <assert.h>

#include <wasm_simd128.h>

#include <xnnpack/gemm.h>


void xnn_qs8_gemm_minmax_ukernel_${MR}x4c8__wasmsimd_extmul_extaddpair(
size_t mr,
size_t nc,
size_t kc,
const int8_t* restrict a,
size_t a_stride,
const void* restrict w,
int8_t* restrict c,
size_t cm_stride,
size_t cn_stride,
const union xnn_qs8_gemm_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_DISABLE_TSAN
{
assert(mr != 0);
assert(mr <= ${MR});
assert(nc != 0);
assert(kc != 0);
assert(kc % sizeof(int8_t) == 0);
assert(a != NULL);
assert(w != NULL);
assert(c != NULL);

const int8_t* a0 = a;
int8_t* c0 = c;
$for M in range(1, MR):
const int8_t* a${M} = (const int8_t*) ((uintptr_t) a${M-1} + a_stride);
int8_t* c${M} = (int8_t*) ((uintptr_t) c${M-1} + cm_stride);
$if M % 2 == 0:
if XNN_UNPREDICTABLE(mr <= ${M}) {
a${M} = a${M-1};
c${M} = c${M-1};
}
$elif M + 1 == MR:
if XNN_UNPREDICTABLE(mr != ${M+1}) {
a${M} = a${M-1};
c${M} = c${M-1};
}
$else:
if XNN_UNPREDICTABLE(mr < ${M+1}) {
a${M} = a${M-1};
c${M} = c${M-1};
}

const v128_t vzero = wasm_f64x2_splat(0.0);
do {
$for N in range(4):
v128_t vacc0x${N} = wasm_f32x4_replace_lane(vzero, 0, ((const float*) w)[${N}]);
$for M in range(1, MR):
$for N in range(4):
v128_t vacc${M}x${N} = vacc0x${N};
w = (const void*) ((uintptr_t) w + 4 * sizeof(int32_t));

size_t k = 0;
while (k < kc) {
$for M in range(MR):
const v128_t vxa${M} = __builtin_wasm_load64_zero((long long*) a${M});
a${M} += 8;

$for N in range(4):
$if N == 0:
const v128_t vxb${N} = __builtin_wasm_load64_zero((long long*) w);
$else:
const v128_t vxb${N} = __builtin_wasm_load64_zero((long long*) ((uintptr_t) w + ${N * 8} * sizeof(int8_t)));

$for M in range(MR):
const v128_t vprod${M}x${N} = __builtin_wasm_extmul_low_i8x16_s_i16x8(vxa${M}, vxb${N});
vacc${M}x${N} = wasm_i32x4_add(vacc${M}x${N}, __builtin_wasm_extadd_pairwise_i16x8_s_i32x4(vprod${M}x${N}));

w = (const void*) ((uintptr_t) w + 32 * sizeof(int8_t));
k += 8 * sizeof(int8_t);
}

$for M in range(MR):
const v128_t vacc${M}x02 = wasm_i32x4_add(wasm_v32x4_shuffle(vacc${M}x0, vacc${M}x2, 0, 4, 1, 5), wasm_v32x4_shuffle(vacc${M}x0, vacc${M}x2, 2, 6, 3, 7));
const v128_t vacc${M}x13 = wasm_i32x4_add(wasm_v32x4_shuffle(vacc${M}x1, vacc${M}x3, 0, 4, 1, 5), wasm_v32x4_shuffle(vacc${M}x1, vacc${M}x3, 2, 6, 3, 7));

$for M in range(MR):
v128_t vacc${M}x0123 = wasm_i32x4_add(wasm_v32x4_shuffle(vacc${M}x02, vacc${M}x13, 0, 4, 1, 5), wasm_v32x4_shuffle(vacc${M}x02, vacc${M}x13, 2, 6, 3, 7));

$for M in range(MR):
const v128_t vsign${M}x0123 = wasm_i32x4_lt(vacc${M}x0123, vzero);

$for M in range(MR):
const v128_t vacc${M}x01 = wasm_v32x4_shuffle(vacc${M}x0123, vsign${M}x0123, 0, 4, 1, 5);

const v128_t vmultiplier = wasm_v128_load(params->wasmsimd.multiplier);
const v128_t vrounding = wasm_v128_load(params->wasmsimd.rounding);
$for M in range(MR):
const v128_t vprod${M}x01 = wasm_i64x2_add(wasm_i64x2_mul(vacc${M}x01, vmultiplier), vrounding);
const v128_t vacc${M}x23 = wasm_v32x4_shuffle(vacc${M}x0123, vsign${M}x0123, 2, 6, 3, 7);

$for M in range(MR):
const v128_t vprod${M}x23 = wasm_i64x2_add(wasm_i64x2_mul(vacc${M}x23, vmultiplier), vrounding);

$for M in range(MR):
const v128_t vq31prod${M}x0123 = wasm_v32x4_shuffle(vprod${M}x01, vprod${M}x23, 1, 3, 5, 7);

const v128_t vremainder_mask = wasm_v128_load(params->wasmsimd.remainder_mask);
$for M in range(MR):
const v128_t vrem${M}x0123 = wasm_i32x4_add(wasm_v128_and(vq31prod${M}x0123, vremainder_mask), wasm_i32x4_lt(vq31prod${M}x0123, vzero));

const v128_t vthreshold = wasm_v128_load(params->wasmsimd.remainder_threshold);
const int32_t vshift = params->wasmsimd.shift;
$for M in range(MR):
vacc${M}x0123 = wasm_i32x4_sub(wasm_i32x4_shr(vq31prod${M}x0123, vshift), wasm_i32x4_gt(vrem${M}x0123, vthreshold));

const v128_t voutput_zero_point = wasm_v128_load(params->wasmsimd.output_zero_point);
$for M in range(0, MR, 2):
v128_t vacc${M}${min(M+1, MR-1)}x0123 = wasm_i16x8_add_saturate(wasm_i16x8_narrow_i32x4(vacc${M}x0123, vacc${min(M+1, MR-1)}x0123), voutput_zero_point);

$if MR > 2:
v128_t vout = wasm_i8x16_narrow_i16x8(vacc0${min(1, MR-1)}x0123, vacc${min(2, MR-1)}${min(3, MR-1)}x0123);
$else:
v128_t vout = wasm_i8x16_narrow_i16x8(vacc0${min(1, MR-1)}x0123, vacc0${min(1, MR-1)}x0123);

const v128_t voutput_min = wasm_v128_load(params->wasmsimd.output_min);
vout = wasm_i8x16_max(vout, voutput_min);

const v128_t voutput_max = wasm_v128_load(params->wasmsimd.output_max);
vout = wasm_i8x16_min(vout, voutput_max);

if (nc >= 4) {
$for M in range(MR):
*((float*) c${M}) = (float) wasm_f32x4_extract_lane(vout, ${M});

$for M in range(MR):
a${M} = (const int8_t*) ((uintptr_t) a${M} - k);

$for M in range(MR):
c${M} = (int8_t*) ((uintptr_t) c${M} + cn_stride);

nc -= 4;
} else {
if (nc & 2) {
$for M in range(MR):
*((uint16_t*) c${M}) = (uint16_t) wasm_i16x8_extract_lane(vout, ${M * 2});
c${M} += 2;
vout = wasm_u32x4_shr(vout, 16);
}
if (nc & 1) {
$for M in range(MR):
*c${M} = (int8_t) wasm_i8x16_extract_lane(vout, ${M * 4});
}

nc = 0;
}
} while (nc != 0);
}
Loading