Skip to content

Commit

Permalink
ggml : add WASM SIMD for Q4_0
Browse files Browse the repository at this point in the history
  • Loading branch information
ggerganov committed Feb 27, 2023
1 parent 24330b8 commit 705bc79
Showing 1 changed file with 118 additions and 24 deletions.
142 changes: 118 additions & 24 deletions src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,45 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
#else
#error "not implemented for QK"
#endif
#elif defined(__wasm_simd128__)
#if QK == 32
for (int i = 0; i < nb; i++) {
float amax = 0.0f; // absolute max

v128_t srcv [8];
v128_t asrcv[8];
v128_t amaxv[8];

for (int l = 0; l < 8; l++) srcv[l] = wasm_v128_load(x + i*32 + 4*l);
for (int l = 0; l < 8; l++) asrcv[l] = wasm_f32x4_abs(srcv[l]);

for (int l = 0; l < 4; l++) amaxv[2*l] = wasm_f32x4_max(asrcv[2*l], asrcv[2*l+1]);
for (int l = 0; l < 2; l++) amaxv[4*l] = wasm_f32x4_max(amaxv[4*l], amaxv[4*l+2]);
for (int l = 0; l < 1; l++) amaxv[8*l] = wasm_f32x4_max(amaxv[8*l], amaxv[8*l+4]);

amax = MAX(
MAX(wasm_f32x4_extract_lane(amaxv[0], 0), wasm_f32x4_extract_lane(amaxv[0], 1)),
MAX(wasm_f32x4_extract_lane(amaxv[0], 2), wasm_f32x4_extract_lane(amaxv[0], 3)));

const float d = amax / ((1 << 3) - 1);
const float id = d ? 1.0/d : 0.0;

pd[i] = d;

for (int l = 0; l < 8; l++) {
const v128_t v = wasm_f32x4_mul(srcv[l], wasm_f32x4_splat(id));
const v128_t vf = wasm_f32x4_add(v, wasm_f32x4_splat(8.5f));
const v128_t vi = wasm_i32x4_trunc_sat_f32x4(vf);

pp[2*l + 0] = wasm_i32x4_extract_lane(vi, 0) | (wasm_i32x4_extract_lane(vi, 1) << 4);
pp[2*l + 1] = wasm_i32x4_extract_lane(vi, 2) | (wasm_i32x4_extract_lane(vi, 3) << 4);
}

memcpy(pb + i*16, pp, sizeof(pp));
}
#else
#error "not implemented for QK"
#endif
#else
// scalar
for (int i = 0; i < nb; i++) {
Expand Down Expand Up @@ -1216,9 +1255,6 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
const int16x8_t p_0 = vaddq_s16(pl_0, ph_0);
const int16x8_t p_1 = vaddq_s16(pl_1, ph_1);

//printf("p_0: %d %d %d %d %d %d %d %d\n", vgetq_lane_s16(p_0, 0), vgetq_lane_s16(p_0, 1), vgetq_lane_s16(p_0, 2), vgetq_lane_s16(p_0, 3), vgetq_lane_s16(p_0, 4), vgetq_lane_s16(p_0, 5), vgetq_lane_s16(p_0, 6), vgetq_lane_s16(p_0, 7));
//printf("p_1: %d %d %d %d %d %d %d %d\n", vgetq_lane_s16(p_1, 0), vgetq_lane_s16(p_1, 1), vgetq_lane_s16(p_1, 2), vgetq_lane_s16(p_1, 3), vgetq_lane_s16(p_1, 4), vgetq_lane_s16(p_1, 5), vgetq_lane_s16(p_1, 6), vgetq_lane_s16(p_1, 7));

// scalar
#if defined(__ARM_FEATURE_QRDMX)
sum0 += d0_0*d1_0*vaddvq_s16(p_0);
Expand All @@ -1230,35 +1266,93 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void
}

sumf = sum0 + sum1;
#else
#error "not implemented for QK"
#endif
#elif defined(__wasm_simd128__)
#if QK == 32
// wasm simd
float sum0 = 0.0f;
float sum1 = 0.0f;

//printf("sumf SIMD = %f\n", sumf);
for (int i = 0; i < nb; i += 2) {
const float d0_0 = pd0[i + 0];
const float d0_1 = pd0[i + 1];
const float d1_0 = pd1[i + 0];
const float d1_1 = pd1[i + 1];

//// scalar
//sumf = 0.0f;
//for (int i = 0; i < nb; i++) {
// const float d0 = pd0[i];
// const float d1 = pd1[i];
const uint8_t * restrict p0 = pb0 + i*16;
const uint8_t * restrict p1 = pb1 + i*16;

// const uint8_t * restrict p0 = pb0 + i*QK/2;
// const uint8_t * restrict p1 = pb1 + i*QK/2;
const v128_t m4b = wasm_u8x16_splat(0xf);
const v128_t s8b = wasm_i8x16_splat(0x8);

// for (int j = 0; j < QK/2; j++) {
// const uint8_t v0 = p0[j];
// const uint8_t v1 = p1[j];
const v128_t v0_0 = wasm_v128_load(p0);
const v128_t v0_1 = wasm_v128_load(p0 + 16);
const v128_t v1_0 = wasm_v128_load(p1);
const v128_t v1_1 = wasm_v128_load(p1 + 16);

// const float f0 = d0*((int8_t) (v0 & 0xf) - 8);
// const float f1 = d0*((int8_t) (v0 >> 4) - 8);
// 4-bit -> 8-bit
const v128_t v0_0l = wasm_v128_and(v0_0, m4b);
const v128_t v1_0l = wasm_v128_and(v1_0, m4b);

// const float f2 = d1*((int8_t) (v1 & 0xf) - 8);
// const float f3 = d1*((int8_t) (v1 >> 4) - 8);
const v128_t v0_0h = wasm_u8x16_shr(v0_0, 4);
const v128_t v1_0h = wasm_u8x16_shr(v1_0, 4);

// sumf += f0*f2 + f1*f3;
// }
//}
//printf("sumf scalar = %f\n", sumf);
//printf("--------\n");
const v128_t v0_1l = wasm_v128_and(v0_1, m4b);
const v128_t v1_1l = wasm_v128_and(v1_1, m4b);

const v128_t v0_1h = wasm_u8x16_shr(v0_1, 4);
const v128_t v1_1h = wasm_u8x16_shr(v1_1, 4);

// sub 8
const v128_t v0_0ls = wasm_i8x16_sub(v0_0l, s8b);
const v128_t v1_0ls = wasm_i8x16_sub(v1_0l, s8b);

//exit(0);
const v128_t v0_0hs = wasm_i8x16_sub(v0_0h, s8b);
const v128_t v1_0hs = wasm_i8x16_sub(v1_0h, s8b);

const v128_t v0_1ls = wasm_i8x16_sub(v0_1l, s8b);
const v128_t v1_1ls = wasm_i8x16_sub(v1_1l, s8b);

const v128_t v0_1hs = wasm_i8x16_sub(v0_1h, s8b);
const v128_t v1_1hs = wasm_i8x16_sub(v1_1h, s8b);

// dot product into int16x8_t
const v128_t pl0l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_0ls), wasm_i16x8_extend_low_i8x16(v1_0ls));
const v128_t pl0h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_0ls), wasm_i16x8_extend_high_i8x16(v1_0ls));

const v128_t ph0l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_0hs), wasm_i16x8_extend_low_i8x16(v1_0hs));
const v128_t ph0h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_0hs), wasm_i16x8_extend_high_i8x16(v1_0hs));

const v128_t pl1l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_1ls), wasm_i16x8_extend_low_i8x16(v1_1ls));
const v128_t pl1h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_1ls), wasm_i16x8_extend_high_i8x16(v1_1ls));

const v128_t ph1l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_1hs), wasm_i16x8_extend_low_i8x16(v1_1hs));
const v128_t ph1h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_1hs), wasm_i16x8_extend_high_i8x16(v1_1hs));

const v128_t pl_0 = wasm_i16x8_add(pl0l, pl0h);
const v128_t ph_0 = wasm_i16x8_add(ph0l, ph0h);

const v128_t pl_1 = wasm_i16x8_add(pl1l, pl1h);
const v128_t ph_1 = wasm_i16x8_add(ph1l, ph1h);

const v128_t p_0 = wasm_i16x8_add(pl_0, ph_0);
const v128_t p_1 = wasm_i16x8_add(pl_1, ph_1);

sum0 += d0_0*d1_0*(
wasm_i16x8_extract_lane(p_0, 0) + wasm_i16x8_extract_lane(p_0, 1) +
wasm_i16x8_extract_lane(p_0, 2) + wasm_i16x8_extract_lane(p_0, 3) +
wasm_i16x8_extract_lane(p_0, 4) + wasm_i16x8_extract_lane(p_0, 5) +
wasm_i16x8_extract_lane(p_0, 6) + wasm_i16x8_extract_lane(p_0, 7));
sum1 += d0_1*d1_1*(
wasm_i16x8_extract_lane(p_1, 0) + wasm_i16x8_extract_lane(p_1, 1) +
wasm_i16x8_extract_lane(p_1, 2) + wasm_i16x8_extract_lane(p_1, 3) +
wasm_i16x8_extract_lane(p_1, 4) + wasm_i16x8_extract_lane(p_1, 5) +
wasm_i16x8_extract_lane(p_1, 6) + wasm_i16x8_extract_lane(p_1, 7));
}

sumf = sum0 + sum1;
#else
#error "not implemented for QK"
#endif
Expand Down

0 comments on commit 705bc79

Please sign in to comment.