@@ -472,23 +472,16 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
472
472
//
473
473
474
474
#if __AVX__ || __AVX2__ || __AVX512F__
475
- // Unpack 16 4-bit fields into 16 bytes
476
- // The output vector contains 16 bytes, each one in [ 0 .. 15 ] interval
477
- static inline __m128i bytes_from_nibbles_16 (const uint8_t * rsi )
478
- {
479
- // Load 8 bytes from memory
480
- __m128i tmp = _mm_loadl_epi64 ( ( const __m128i * )rsi );
481
-
482
- // Expand bytes into uint16_t values
483
- __m128i bytes = _mm_cvtepu8_epi16 ( tmp );
484
-
485
- // Unpack values into individual bytes
486
- const __m128i lowMask = _mm_set1_epi8 ( 0xF );
487
- __m128i high = _mm_andnot_si128 ( lowMask , bytes );
488
- __m128i low = _mm_and_si128 ( lowMask , bytes );
489
- high = _mm_slli_epi16 ( high , 4 );
490
- bytes = _mm_or_si128 ( low , high );
491
- return bytes ;
475
+ // multiply int8_t, add results pairwise twice
476
+ static inline __m128i mul_sum_i8_pairs (const __m128i x , const __m128i y ) {
477
+ // Get absolute values of x vectors
478
+ const __m128i ax = _mm_sign_epi8 (x , x );
479
+ // Sign the values of the y vectors
480
+ const __m128i sy = _mm_sign_epi8 (y , x );
481
+ // Perform multiplication and create 16-bit values
482
+ const __m128i dot = _mm_maddubs_epi16 (ax , sy );
483
+ const __m128i ones = _mm_set1_epi16 (1 );
484
+ return _mm_madd_epi16 (ones , dot );
492
485
}
493
486
494
487
// horizontally add 8 floats
@@ -535,19 +528,10 @@ static inline __m256i bytes_from_bits_32(const uint8_t * x) {
535
528
// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
536
529
static inline __m256i bytes_from_nibbles_32 (const uint8_t * rsi )
537
530
{
538
- // Load 16 bytes from memory
539
- __m128i tmp = _mm_loadu_si128 ( ( const __m128i * )rsi );
540
-
541
- // Expand bytes into uint16_t values
542
- __m256i bytes = _mm256_cvtepu8_epi16 ( tmp );
543
-
544
- // Unpack values into individual bytes
531
+ const __m128i tmp = _mm_loadu_si128 ((const __m128i * )rsi );
532
+ const __m256i bytes = _mm256_set_m128i (_mm_srli_epi16 (tmp , 4 ), tmp );
545
533
const __m256i lowMask = _mm256_set1_epi8 ( 0xF );
546
- __m256i high = _mm256_andnot_si256 ( lowMask , bytes );
547
- __m256i low = _mm256_and_si256 ( lowMask , bytes );
548
- high = _mm256_slli_epi16 ( high , 4 );
549
- bytes = _mm256_or_si256 ( low , high );
550
- return bytes ;
534
+ return _mm256_and_si256 (lowMask , bytes );
551
535
}
552
536
553
537
// add int16_t pairwise and return as float vector
@@ -2146,31 +2130,23 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2146
2130
// Compute combined scale for the block
2147
2131
const __m256 d = _mm256_mul_ps ( _mm256_broadcast_ss ( & x [i ].d ), _mm256_broadcast_ss ( & y [i ].d ) );
2148
2132
2149
- __m128i i32 [2 ];
2150
- for (int j = 0 ; j < 2 ; ++ j ) {
2151
- // Load 8 bytes, and unpack 4 bit fields into bytes, making 16 bytes
2152
- __m128i bx = bytes_from_nibbles_16 (x [i ].qs + 8 * j );
2153
- __m128i by = _mm_loadu_si128 ((const __m128i * )(y [i ].qs + 16 * j ));
2154
-
2155
- // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
2156
- const __m128i off = _mm_set1_epi8 ( 8 );
2157
- bx = _mm_sub_epi8 ( bx , off );
2133
+ const __m128i lowMask = _mm_set1_epi8 (0xF );
2134
+ const __m128i off = _mm_set1_epi8 (8 );
2158
2135
2159
- // Get absolute values of x vectors
2160
- const __m128i ax = _mm_sign_epi8 (bx , bx );
2136
+ const __m128i tmp = _mm_loadu_si128 ((const __m128i * )x [i ].qs );
2161
2137
2162
- // Sign the values of the y vectors
2163
- const __m128i sy = _mm_sign_epi8 (by , bx );
2138
+ __m128i bx = _mm_and_si128 (lowMask , tmp );
2139
+ __m128i by = _mm_loadu_si128 ((const __m128i * )y [i ].qs );
2140
+ bx = _mm_sub_epi8 (bx , off );
2141
+ const __m128i i32_0 = mul_sum_i8_pairs (bx , by );
2164
2142
2165
- // Perform multiplication and create 16-bit values
2166
- const __m128i dot = _mm_maddubs_epi16 (ax , sy );
2167
-
2168
- const __m128i ones = _mm_set1_epi16 (1 );
2169
- i32 [j ] = _mm_madd_epi16 (ones , dot );
2170
- }
2143
+ bx = _mm_and_si128 (lowMask , _mm_srli_epi64 (tmp , 4 ));
2144
+ by = _mm_loadu_si128 ((const __m128i * )(y [i ].qs + 16 ));
2145
+ bx = _mm_sub_epi8 (bx , off );
2146
+ const __m128i i32_1 = mul_sum_i8_pairs (bx , by );
2171
2147
2172
2148
// Convert int32_t to float
2173
- __m256 p = _mm256_cvtepi32_ps ( _mm256_set_m128i ( i32 [ 0 ], i32 [ 1 ] ));
2149
+ __m256 p = _mm256_cvtepi32_ps (_mm256_set_m128i (i32_0 , i32_1 ));
2174
2150
// Apply the scale, and accumulate
2175
2151
acc = _mm256_add_ps (_mm256_mul_ps ( d , p ), acc );
2176
2152
}
@@ -2484,8 +2460,8 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
2484
2460
int sumi = 0 ;
2485
2461
2486
2462
for (int j = 0 ; j < qk /2 ; ++ j ) {
2487
- const uint8_t xh_0 = ((qh & ( 1u << ( j + 0 ))) >> (j + 0 )) << 4 ;
2488
- const uint8_t xh_1 = ((qh & ( 1u << ( j + 16 ))) >> ( j + 12 )) ;
2463
+ const uint8_t xh_0 = ((qh >> (j + 0 )) << 4 ) & 0x10 ;
2464
+ const uint8_t xh_1 = ((qh >> ( j + 12 )) ) & 0x10 ;
2489
2465
2490
2466
const int32_t x0 = ((x [i ].qs [j ] & 0x0F ) | xh_0 ) - 16 ;
2491
2467
const int32_t x1 = ((x [i ].qs [j ] >> 4 ) | xh_1 ) - 16 ;
@@ -2673,8 +2649,8 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
2673
2649
int sumi = 0 ;
2674
2650
2675
2651
for (int j = 0 ; j < qk /2 ; ++ j ) {
2676
- const uint8_t xh_0 = ((qh & ( 1u << ( j + 0 ))) >> (j + 0 )) << 4 ;
2677
- const uint8_t xh_1 = ((qh & ( 1u << ( j + 16 ))) >> ( j + 12 )) ;
2652
+ const uint8_t xh_0 = ((qh >> (j + 0 )) << 4 ) & 0x10 ;
2653
+ const uint8_t xh_1 = ((qh >> ( j + 12 )) ) & 0x10 ;
2678
2654
2679
2655
const int32_t x0 = (x [i ].qs [j ] & 0xF ) | xh_0 ;
2680
2656
const int32_t x1 = (x [i ].qs [j ] >> 4 ) | xh_1 ;
0 commit comments