Skip to content

Commit

Permalink
ggml : update WASM SIMD
Browse files Browse the repository at this point in the history
  • Loading branch information
ggerganov committed May 20, 2023
1 parent b8ee340 commit fab49c6
Showing 1 changed file with 85 additions and 15 deletions.
100 changes: 85 additions & 15 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -740,19 +740,19 @@ inline static float vaddvq_f32(float32x4_t v) {
return vgetq_lane_f32(v, 0) + vgetq_lane_f32(v, 1) + vgetq_lane_f32(v, 2) + vgetq_lane_f32(v, 3);
}

float vminvq_f32(float32x4_t v) {
inline static float vminvq_f32(float32x4_t v) {
return
MIN(MIN(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)),
MIN(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3)));
}

float vmaxvq_f32(float32x4_t v) {
inline static float vmaxvq_f32(float32x4_t v) {
return
MAX(MAX(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)),
MAX(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3)));
}

int32x4_t vcvtnq_s32_f32(float32x4_t v) {
inline static int32x4_t vcvtnq_s32_f32(float32x4_t v) {
int32x4_t res;

res[0] = roundf(vgetq_lane_f32(v, 0));
Expand All @@ -766,7 +766,6 @@ int32x4_t vcvtnq_s32_f32(float32x4_t v) {
#endif
#endif


#define QK4_0 32
typedef struct {
ggml_fp16_t d; // delta
Expand Down Expand Up @@ -1056,6 +1055,39 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int
y[i].qs[4*j + 3] = vgetq_lane_s32(vi, 3);
}
}
#elif defined(__wasm_simd128__)
for (int i = 0; i < nb; i++) {
v128_t srcv [8];
v128_t asrcv[8];
v128_t amaxv[8];

for (int j = 0; j < 8; j++) srcv[j] = wasm_v128_load(x + i*32 + 4*j);
for (int j = 0; j < 8; j++) asrcv[j] = wasm_f32x4_abs(srcv[j]);

for (int j = 0; j < 4; j++) amaxv[2*j] = wasm_f32x4_max(asrcv[2*j], asrcv[2*j+1]);
for (int j = 0; j < 2; j++) amaxv[4*j] = wasm_f32x4_max(amaxv[4*j], amaxv[4*j+2]);
for (int j = 0; j < 1; j++) amaxv[8*j] = wasm_f32x4_max(amaxv[8*j], amaxv[8*j+4]);

const float amax = MAX(MAX(wasm_f32x4_extract_lane(amaxv[0], 0),
wasm_f32x4_extract_lane(amaxv[0], 1)),
MAX(wasm_f32x4_extract_lane(amaxv[0], 2),
wasm_f32x4_extract_lane(amaxv[0], 3)));

const float d = amax / ((1 << 7) - 1);
const float id = d ? 1.0f/d : 0.0f;

y[i].d = GGML_FP32_TO_FP16(d);

for (int j = 0; j < 8; j++) {
const v128_t v = wasm_f32x4_mul(srcv[j], wasm_f32x4_splat(id));
const v128_t vi = wasm_i32x4_trunc_sat_f32x4(v);

y[i].qs[4*j + 0] = wasm_i32x4_extract_lane(vi, 0);
y[i].qs[4*j + 1] = wasm_i32x4_extract_lane(vi, 1);
y[i].qs[4*j + 2] = wasm_i32x4_extract_lane(vi, 2);
y[i].qs[4*j + 3] = wasm_i32x4_extract_lane(vi, 3);
}
}
#elif defined(__AVX2__) || defined(__AVX__)
for (int i = 0; i < nb; i++) {
// Load elements into 4 AVX vectors
Expand Down Expand Up @@ -1224,6 +1256,48 @@ static void quantize_row_q8_1(const float * restrict x, void * restrict vy, int

y[i].s = d * vaddvq_s32(accv);
}
#elif defined(__wasm_simd128__)
for (int i = 0; i < nb; i++) {
v128_t srcv [8];
v128_t asrcv[8];
v128_t amaxv[8];

for (int j = 0; j < 8; j++) srcv[j] = wasm_v128_load(x + i*32 + 4*j);
for (int j = 0; j < 8; j++) asrcv[j] = wasm_f32x4_abs(srcv[j]);

for (int j = 0; j < 4; j++) amaxv[2*j] = wasm_f32x4_max(asrcv[2*j], asrcv[2*j+1]);
for (int j = 0; j < 2; j++) amaxv[4*j] = wasm_f32x4_max(amaxv[4*j], amaxv[4*j+2]);
for (int j = 0; j < 1; j++) amaxv[8*j] = wasm_f32x4_max(amaxv[8*j], amaxv[8*j+4]);

const float amax = MAX(MAX(wasm_f32x4_extract_lane(amaxv[0], 0),
wasm_f32x4_extract_lane(amaxv[0], 1)),
MAX(wasm_f32x4_extract_lane(amaxv[0], 2),
wasm_f32x4_extract_lane(amaxv[0], 3)));

const float d = amax / ((1 << 7) - 1);
const float id = d ? 1.0f/d : 0.0f;

y[i].d = d;

v128_t accv = wasm_i32x4_splat(0);

for (int j = 0; j < 8; j++) {
const v128_t v = wasm_f32x4_mul(srcv[j], wasm_f32x4_splat(id));
const v128_t vi = wasm_i32x4_trunc_sat_f32x4(v);

y[i].qs[4*j + 0] = wasm_i32x4_extract_lane(vi, 0);
y[i].qs[4*j + 1] = wasm_i32x4_extract_lane(vi, 1);
y[i].qs[4*j + 2] = wasm_i32x4_extract_lane(vi, 2);
y[i].qs[4*j + 3] = wasm_i32x4_extract_lane(vi, 3);

accv = wasm_i32x4_add(accv, vi);
}

y[i].s = d * (wasm_i32x4_extract_lane(accv, 0) +
wasm_i32x4_extract_lane(accv, 1) +
wasm_i32x4_extract_lane(accv, 2) +
wasm_i32x4_extract_lane(accv, 3));
}
#elif defined(__AVX2__) || defined(__AVX__)
for (int i = 0; i < nb; i++) {
// Load elements into 4 AVX vectors
Expand Down Expand Up @@ -2598,7 +2672,6 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
const block_q8_0 * restrict y0 = &y[i];

const v128_t m4b = wasm_i8x16_splat(0x0F);
const v128_t s16b = wasm_i8x16_splat(0x10);

// extract the 5th bit
memcpy(&qh, x0->qh, sizeof(qh));
Expand Down Expand Up @@ -2636,15 +2709,14 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h);
const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h);

const float x0d = GGML_FP16_TO_FP32(x0->d);

// dot product
sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4(
wasm_i32x4_add(
wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll),
wasm_i32x4_dot_i16x8(v0lfh, v1lh)),
wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl),
wasm_i32x4_dot_i16x8(v0hfh, v1hh)))), wasm_f32x4_splat(x0d*y0->d)));
wasm_i32x4_dot_i16x8(v0hfh, v1hh)))),
wasm_f32x4_splat(GGML_FP16_TO_FP32(x0->d) * GGML_FP16_TO_FP32(y0->d))));
}

*s = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
Expand Down Expand Up @@ -2868,8 +2940,6 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
const v128_t v0l = wasm_v128_and (v0, m4b);
const v128_t v0h = wasm_u8x16_shr(v0, 4);

static bool x = true;

// add high bit
const v128_t v0lf = wasm_v128_or(v0l, qhl);
const v128_t v0hf = wasm_v128_or(v0h, qhh);
Expand All @@ -2892,11 +2962,11 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
// dot product
sumv = wasm_f32x4_add(sumv,
wasm_f32x4_mul(wasm_f32x4_convert_i32x4(wasm_i32x4_add(
wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll),
wasm_i32x4_dot_i16x8(v0lfh, v1lh)),
wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl),
wasm_i32x4_dot_i16x8(v0hfh, v1hh)))),
wasm_f32x4_splat(GGML_FP16_TO_FP32(x0->d) * y0->d));
wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll),
wasm_i32x4_dot_i16x8(v0lfh, v1lh)),
wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl),
wasm_i32x4_dot_i16x8(v0hfh, v1hh)))),
wasm_f32x4_splat(GGML_FP16_TO_FP32(x0->d) * y0->d)));
}

*s = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) +
Expand Down

0 comments on commit fab49c6

Please sign in to comment.