@@ -207,33 +207,97 @@ void ggml_vec_dot_f16(int n, float * GGML_RESTRICT s, size_t bs, ggml_fp16_t * G
207207
208208 ggml_float sumf = 0.0 ;
209209
210- #if defined(GGML_SIMD) && !defined(__riscv_v_intrinsic)
211- const int np = (n & ~(GGML_F16_STEP - 1 ));
212210
213- GGML_F16_VEC sum[GGML_F16_ARR] = { GGML_F16_VEC_ZERO };
211+ #if defined(GGML_SIMD) && !defined(__riscv_v_intrinsic)
212+ #if defined(__ARM_FEATURE_SVE)
213+ const int sve_register_length = svcntb () * 8 ; // get vector length
214+ const int ggml_f16_epr = sve_register_length / 16 ; // running when 16
215+ const int ggml_f16_step = 8 * ggml_f16_epr; // choose 8 SVE registers
216+
217+ const int np= (n & ~(ggml_f16_step - 1 ));
218+ svfloat16_t sum1 = svdup_n_f16 (0 .0f );
219+ svfloat16_t sum2 = svdup_n_f16 (0 .0f );
220+ svfloat16_t sum3 = svdup_n_f16 (0 .0f );
221+ svfloat16_t sum4 = svdup_n_f16 (0 .0f );
222+
223+ svfloat16_t ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8;
224+ svfloat16_t ay1, ay2, ay3, ay4, ay5, ay6, ay7, ay8;
225+ for (int i = 0 ; i < np; i += ggml_f16_step) {
226+ ax1 = GGML_F16x_VEC_LOAD (x + i + 0 * ggml_f16_epr, 0 );
227+ ay1 = GGML_F16x_VEC_LOAD (y + i + 0 * ggml_f16_epr, 0 );
228+ sum1 = GGML_F16x_VEC_FMA (sum1, ax1, ay1);
229+
230+ ax2 = GGML_F16x_VEC_LOAD (x + i + 1 * ggml_f16_epr, 1 );
231+ ay2 = GGML_F16x_VEC_LOAD (y + i + 1 * ggml_f16_epr, 1 );
232+ sum2 = GGML_F16x_VEC_FMA (sum2, ax2, ay2);
233+
234+ ax3 = GGML_F16x_VEC_LOAD (x + i + 2 * ggml_f16_epr, 2 );
235+ ay3 = GGML_F16x_VEC_LOAD (y + i + 2 * ggml_f16_epr, 2 );
236+ sum3 = GGML_F16x_VEC_FMA (sum3, ax3, ay3);
237+
238+ ax4 = GGML_F16x_VEC_LOAD (x + i + 3 * ggml_f16_epr, 3 );
239+ ay4 = GGML_F16x_VEC_LOAD (y + i + 3 * ggml_f16_epr, 3 );
240+ sum4 = GGML_F16x_VEC_FMA (sum4, ax4, ay4);
241+
242+ ax5 = GGML_F16x_VEC_LOAD (x + i + 4 * ggml_f16_epr, 4 );
243+ ay5 = GGML_F16x_VEC_LOAD (y + i + 4 * ggml_f16_epr, 4 );
244+ sum1 = GGML_F16x_VEC_FMA (sum1, ax5, ay5);
245+
246+ ax6 = GGML_F16x_VEC_LOAD (x + i + 5 * ggml_f16_epr, 5 );
247+ ay6 = GGML_F16x_VEC_LOAD (y + i + 5 * ggml_f16_epr, 5 );
248+ sum2 = GGML_F16x_VEC_FMA (sum2, ax6, ay6);
249+
250+ ax7 = GGML_F16x_VEC_LOAD (x + i + 6 * ggml_f16_epr, 6 );
251+ ay7 = GGML_F16x_VEC_LOAD (y + i + 6 * ggml_f16_epr, 6 );
252+ sum3 = GGML_F16x_VEC_FMA (sum3, ax7, ay7);
253+
254+ ax8 = GGML_F16x_VEC_LOAD (x + i + 7 * ggml_f16_epr, 7 );
255+ ay8 = GGML_F16x_VEC_LOAD (y + i + 7 * ggml_f16_epr, 7 );
256+ sum4 = GGML_F16x_VEC_FMA (sum4, ax8, ay8);
257+ }
214258
215- GGML_F16_VEC ax[GGML_F16_ARR];
216- GGML_F16_VEC ay[GGML_F16_ARR];
259+ const int np2 = (n & ~(ggml_f16_epr - 1 )); // round down to multiple of 8
260+ for (int k = np; k < np2; k += ggml_f16_epr) {
261+ svfloat16_t rx = GGML_F16x_VEC_LOAD (x + k, 0 );
262+ svfloat16_t ry = GGML_F16x_VEC_LOAD (y + k, 0 );
263+ sum1 = GGML_F16x_VEC_FMA (sum1, rx, ry);
264+ }
217265
218- for ( int i = 0 ; i < np; i += GGML_F16_STEP ) {
219- for ( int j = 0 ; j < GGML_F16_ARR; j++) {
220- ax[j] = GGML_F16_VEC_LOAD (x + i + j*GGML_F16_EPR, j );
221- ay[j] = GGML_F16_VEC_LOAD (y + i + j*GGML_F16_EPR, j );
266+ if (np2 < n ) {
267+ svbool_t pg = svwhilelt_b16 (np2, n);
268+ svfloat16_t hx = svld1_f16 (pg, ( const __fp16 *)(x + np2) );
269+ svfloat16_t hy = svld1_f16 (pg, ( const __fp16 *)(y + np2) );
222270
223- sum[j] = GGML_F16_VEC_FMA (sum[j], ax[j], ay[j] );
271+ sum1 = svmad_f16_x (pg, hx, hy, sum1 );
224272 }
225- }
273+ GGML_F16x_VEC_REDUCE (sumf, sum1, sum2, sum3, sum4);
274+ #else
275+ const int np = (n & ~(GGML_F16_STEP - 1 ));
226276
227- // reduce sum0..sum3 to sum0
228- GGML_F16_VEC_REDUCE (sumf, sum);
277+ GGML_F16_VEC sum[GGML_F16_ARR] = { GGML_F16_VEC_ZERO };
229278
230- // leftovers
231- for (int i = np; i < n; ++i) {
232- sumf += (ggml_float)(GGML_CPU_FP16_TO_FP32 (x[i])*GGML_CPU_FP16_TO_FP32 (y[i]));
233- }
279+ GGML_F16_VEC ax[GGML_F16_ARR];
280+ GGML_F16_VEC ay[GGML_F16_ARR];
234281
235- // if you hit this, you are likely running outside the FP range
236- assert (!isnan (sumf) && !isinf (sumf));
282+ for (int i = 0 ; i < np; i += GGML_F16_STEP) {
283+ for (int j = 0 ; j < GGML_F16_ARR; j++) {
284+ ax[j] = GGML_F16_VEC_LOAD (x + i + j*GGML_F16_EPR, j);
285+ ay[j] = GGML_F16_VEC_LOAD (y + i + j*GGML_F16_EPR, j);
286+
287+ sum[j] = GGML_F16_VEC_FMA (sum[j], ax[j], ay[j]);
288+ }
289+ }
290+
291+ // reduce sum0..sum3 to sum0
292+ GGML_F16_VEC_REDUCE (sumf, sum);
293+
294+ // leftovers
295+ for (int i = np; i < n; ++i) {
296+ sumf += (ggml_float)(GGML_CPU_FP16_TO_FP32 (x[i])*GGML_CPU_FP16_TO_FP32 (y[i]));
297+ }
298+ // if you hit this, you are likely running outside the FP range
299+ assert (!isnan (sumf) && !isinf (sumf));
300+ #endif
237301#else
238302 for (int i = 0 ; i < n; ++i) {
239303 sumf += (ggml_float)(GGML_CPU_FP16_TO_FP32 (x[i])*GGML_CPU_FP16_TO_FP32 (y[i]));
0 commit comments