@@ -1833,7 +1833,7 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
1833
1833
const block_q4_0 * restrict x = vx ;
1834
1834
const block_q4_0 * restrict y = vy ;
1835
1835
1836
- ggml_float sumf = 0.0 ;
1836
+ float sumf = 0.0 ;
1837
1837
1838
1838
#if defined(__ARM_NEON )
1839
1839
float sum0 = 0.0f ;
@@ -1928,7 +1928,7 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
1928
1928
#endif
1929
1929
}
1930
1930
1931
- sumf = ( ggml_float )( sum0 + sum1 ) ;
1931
+ sumf = sum0 + sum1 ;
1932
1932
#elif defined(__AVX512F__ )
1933
1933
// Initialize accumulator with zeros
1934
1934
__m512 acc0 = _mm512_setzero_ps ();
@@ -1962,6 +1962,10 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
1962
1962
__m256 acc = _mm256_setzero_ps ();
1963
1963
1964
1964
// Main loop
1965
+ // TODO: figure a way to do this in a portable way
1966
+ #ifdef __GNUC__
1967
+ #pragma GCC unroll 16
1968
+ #endif
1965
1969
for (int i = 0 ; i < nb ; ++ i ) {
1966
1970
// Compute combined scale for the block
1967
1971
const __m256 d = _mm256_mul_ps ( _mm256_broadcast_ss ( & x [i ].d ), _mm256_broadcast_ss ( & y [i ].d ) );
@@ -1975,20 +1979,21 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
1975
1979
bx = _mm256_sub_epi8 ( bx , off );
1976
1980
by = _mm256_sub_epi8 ( by , off );
1977
1981
1978
- // Sign-extend first 16 signed bytes into int16_t
1979
- __m256i x16 = _mm256_cvtepi8_epi16 ( _mm256_castsi256_si128 ( bx ) );
1980
- __m256i y16 = _mm256_cvtepi8_epi16 ( _mm256_castsi256_si128 ( by ) );
1981
- // Compute products of int16_t integers, add pairwise
1982
- __m256i i32 = _mm256_madd_epi16 ( x16 , y16 );
1982
+ // Get absolute values of x vectors
1983
+ const __m256i ax = _mm256_sign_epi8 (bx , bx );
1983
1984
1984
- // Sign-extend last 16 signed bytes into int16_t vectors
1985
- x16 = _mm256_cvtepi8_epi16 ( _mm256_extracti128_si256 ( bx , 1 ) );
1986
- y16 = _mm256_cvtepi8_epi16 ( _mm256_extracti128_si256 ( by , 1 ) );
1987
- // Accumulate products of int16_t integers
1988
- i32 = _mm256_add_epi32 ( i32 , _mm256_madd_epi16 ( x16 , y16 ) );
1985
+ // Sign the values of the y vectors
1986
+ const __m256i sy = _mm256_sign_epi8 (by , bx );
1987
+
1988
+ // Perform multiplication and create 16-bit values
1989
+ const __m256i dot = _mm256_maddubs_epi16 (ax , sy );
1990
+
1991
+ const __m256i ones = _mm256_set1_epi16 (1 );
1992
+ const __m256i i32 = _mm256_madd_epi16 (ones , dot );
1989
1993
1990
1994
// Convert int32_t to float
1991
- __m256 p = _mm256_cvtepi32_ps ( i32 );
1995
+ const __m256 p = _mm256_cvtepi32_ps ( i32 );
1996
+
1992
1997
// Apply the scale, and accumulate
1993
1998
acc = _mm256_fmadd_ps ( d , p , acc );
1994
1999
}
0 commit comments