@@ -543,12 +543,7 @@ static inline __m256 sum_i16_pairs_float(const __m256i x) {
543
543
return _mm256_cvtepi32_ps (summed_pairs );
544
544
}
545
545
546
- // multiply int8_t, add results pairwise twice and return as float vector
547
- static inline __m256 mul_sum_i8_pairs_float (const __m256i x , const __m256i y ) {
548
- // Get absolute values of x vectors
549
- const __m256i ax = _mm256_sign_epi8 (x , x );
550
- // Sign the values of the y vectors
551
- const __m256i sy = _mm256_sign_epi8 (y , x );
546
+ static inline __m256 mul_sum_us8_pairs_float (const __m256i ax , const __m256i sy ) {
552
547
#if __AVXVNNI__
553
548
const __m256i zero = _mm256_setzero_si256 ();
554
549
const __m256i summed_pairs = _mm256_dpbusd_epi32 (zero , ax , sy );
@@ -560,6 +555,21 @@ static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
560
555
#endif
561
556
}
562
557
558
+ // multiply int8_t, add results pairwise twice and return as float vector
559
+ static inline __m256 mul_sum_i8_pairs_float (const __m256i x , const __m256i y ) {
560
+ #if __AVXVNNIINT8__
561
+ const __m256i zero = _mm256_setzero_si256 ();
562
+ const __m256i summed_pairs = _mm256_dpbssd_epi32 (zero , x , y );
563
+ return _mm256_cvtepi32_ps (summed_pairs );
564
+ #else
565
+ // Get absolute values of x vectors
566
+ const __m256i ax = _mm256_sign_epi8 (x , x );
567
+ // Sign the values of the y vectors
568
+ const __m256i sy = _mm256_sign_epi8 (y , x );
569
+ return mul_sum_us8_pairs_float (ax , sy );
570
+ #endif
571
+ }
572
+
563
573
static inline __m128i packNibbles ( __m256i bytes )
564
574
{
565
575
// Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
@@ -619,6 +629,17 @@ static inline __m256 sum_i16_pairs_float(const __m128i xh, const __m128i xl) {
619
629
return _mm256_cvtepi32_ps (summed_pairs );
620
630
}
621
631
632
+ static inline __m256 mul_sum_us8_pairs_float (const __m256i ax , const __m256i sy ) {
633
+ const __m128i axl = _mm256_castsi256_si128 (ax );
634
+ const __m128i axh = _mm256_extractf128_si256 (ax , 1 );
635
+ const __m128i syl = _mm256_castsi256_si128 (sy );
636
+ const __m128i syh = _mm256_extractf128_si256 (sy , 1 );
637
+ // Perform multiplication and create 16-bit values
638
+ const __m128i dotl = _mm_maddubs_epi16 (axl , syl );
639
+ const __m128i doth = _mm_maddubs_epi16 (axh , syh );
640
+ return sum_i16_pairs_float (doth , dotl );
641
+ }
642
+
622
643
// multiply int8_t, add results pairwise twice and return as float vector
623
644
static inline __m256 mul_sum_i8_pairs_float (const __m256i x , const __m256i y ) {
624
645
const __m128i xl = _mm256_castsi256_si128 (x );
@@ -2434,7 +2455,7 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
2434
2455
const __m256i bx = bytes_from_nibbles_32 (x [i ].qs );
2435
2456
const __m256i by = _mm256_loadu_si256 ( (const __m256i * )y [i ].qs );
2436
2457
2437
- const __m256 xy = mul_sum_i8_pairs_float (bx , by );
2458
+ const __m256 xy = mul_sum_us8_pairs_float (bx , by );
2438
2459
2439
2460
// Accumulate d0*d1*x*y
2440
2461
#if defined(__AVX2__ )
@@ -2906,7 +2927,7 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
2906
2927
const __m256 dy = _mm256_broadcast_ss (& y [i ].d );
2907
2928
const __m256i by = _mm256_loadu_si256 ((const __m256i * )y [i ].qs );
2908
2929
2909
- const __m256 q = mul_sum_i8_pairs_float (bx , by );
2930
+ const __m256 q = mul_sum_us8_pairs_float (bx , by );
2910
2931
2911
2932
acc = _mm256_fmadd_ps (q , _mm256_mul_ps (dx , dy ), acc );
2912
2933
}
@@ -2940,7 +2961,7 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
2940
2961
const __m256 dy = _mm256_broadcast_ss (& y [i ].d );
2941
2962
const __m256i by = _mm256_loadu_si256 ((const __m256i * )y [i ].qs );
2942
2963
2943
- const __m256 q = mul_sum_i8_pairs_float (bx , by );
2964
+ const __m256 q = mul_sum_us8_pairs_float (bx , by );
2944
2965
2945
2966
acc = _mm256_add_ps (_mm256_mul_ps (q , _mm256_mul_ps (dx , dy )), acc );
2946
2967
}
0 commit comments