Skip to content

Commit 4262742

Browse files
~7% faster Q5_1 AVX2 code (ggml-org#1477)
1 parent 9560655 commit 4262742

File tree

1 file changed

+30
-9
lines changed

1 file changed

+30
-9
lines changed

Diff for: ggml.c

+30-9
Original file line numberDiff line numberDiff line change
@@ -543,12 +543,7 @@ static inline __m256 sum_i16_pairs_float(const __m256i x) {
543543
return _mm256_cvtepi32_ps(summed_pairs);
544544
}
545545

546-
// multiply int8_t, add results pairwise twice and return as float vector
547-
static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
548-
// Get absolute values of x vectors
549-
const __m256i ax = _mm256_sign_epi8(x, x);
550-
// Sign the values of the y vectors
551-
const __m256i sy = _mm256_sign_epi8(y, x);
546+
static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
552547
#if __AVXVNNI__
553548
const __m256i zero = _mm256_setzero_si256();
554549
const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy);
@@ -560,6 +555,21 @@ static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
560555
#endif
561556
}
562557

558+
// multiply int8_t, add results pairwise twice and return as float vector
559+
static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
560+
#if __AVXVNNIINT8__
561+
const __m256i zero = _mm256_setzero_si256();
562+
const __m256i summed_pairs = _mm256_dpbssd_epi32(zero, x, y);
563+
return _mm256_cvtepi32_ps(summed_pairs);
564+
#else
565+
// Get absolute values of x vectors
566+
const __m256i ax = _mm256_sign_epi8(x, x);
567+
// Sign the values of the y vectors
568+
const __m256i sy = _mm256_sign_epi8(y, x);
569+
return mul_sum_us8_pairs_float(ax, sy);
570+
#endif
571+
}
572+
563573
static inline __m128i packNibbles( __m256i bytes )
564574
{
565575
// Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh
@@ -619,6 +629,17 @@ static inline __m256 sum_i16_pairs_float(const __m128i xh, const __m128i xl) {
619629
return _mm256_cvtepi32_ps(summed_pairs);
620630
}
621631

632+
static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
633+
const __m128i axl = _mm256_castsi256_si128(ax);
634+
const __m128i axh = _mm256_extractf128_si256(ax, 1);
635+
const __m128i syl = _mm256_castsi256_si128(sy);
636+
const __m128i syh = _mm256_extractf128_si256(sy, 1);
637+
// Perform multiplication and create 16-bit values
638+
const __m128i dotl = _mm_maddubs_epi16(axl, syl);
639+
const __m128i doth = _mm_maddubs_epi16(axh, syh);
640+
return sum_i16_pairs_float(doth, dotl);
641+
}
642+
622643
// multiply int8_t, add results pairwise twice and return as float vector
623644
static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
624645
const __m128i xl = _mm256_castsi256_si128(x);
@@ -2434,7 +2455,7 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
24342455
const __m256i bx = bytes_from_nibbles_32(x[i].qs);
24352456
const __m256i by = _mm256_loadu_si256( (const __m256i *)y[i].qs );
24362457

2437-
const __m256 xy = mul_sum_i8_pairs_float(bx, by);
2458+
const __m256 xy = mul_sum_us8_pairs_float(bx, by);
24382459

24392460
// Accumulate d0*d1*x*y
24402461
#if defined(__AVX2__)
@@ -2906,7 +2927,7 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
29062927
const __m256 dy = _mm256_broadcast_ss(&y[i].d);
29072928
const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
29082929

2909-
const __m256 q = mul_sum_i8_pairs_float(bx, by);
2930+
const __m256 q = mul_sum_us8_pairs_float(bx, by);
29102931

29112932
acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc);
29122933
}
@@ -2940,7 +2961,7 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
29402961
const __m256 dy = _mm256_broadcast_ss(&y[i].d);
29412962
const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
29422963

2943-
const __m256 q = mul_sum_i8_pairs_float(bx, by);
2964+
const __m256 q = mul_sum_us8_pairs_float(bx, by);
29442965

29452966
acc = _mm256_add_ps(_mm256_mul_ps(q, _mm256_mul_ps(dx, dy)), acc);
29462967
}

0 commit comments

Comments
 (0)