@@ -1959,45 +1959,75 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
1959
1959
// Horizontal sum of all lanes of the accumulator
1960
1960
sumf = _mm512_reduce_add_ps ( acc0 ) + _mm512_reduce_add_ps ( acc1 );
1961
1961
#elif defined(__AVX2__ )
1962
+
1962
1963
// Initialize accumulator with zeros
1963
1964
__m256 acc = _mm256_setzero_ps ();
1964
1965
1965
- // Main loop
1966
- // TODO: figure a way to do this in a portable way
1967
- #ifdef __GNUC__
1968
- #pragma GCC unroll 16
1969
- #endif
1970
- for (int i = 0 ; i < nb ; ++ i ) {
1971
- // Compute combined scale for the block
1972
- const __m256 d = _mm256_mul_ps ( _mm256_broadcast_ss ( & x [i ].d ), _mm256_broadcast_ss ( & y [i ].d ) );
1973
-
1974
- // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
1975
- __m256i bx = bytesFromNibbles ( x [i ].qs );
1976
- __m256i by = bytesFromNibbles ( y [i ].qs );
1977
-
1978
- // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
1979
- const __m256i off = _mm256_set1_epi8 ( 8 );
1980
- bx = _mm256_sub_epi8 ( bx , off );
1981
- by = _mm256_sub_epi8 ( by , off );
1982
-
1983
- // Get absolute values of x vectors
1984
- const __m256i ax = _mm256_sign_epi8 (bx , bx );
1985
-
1986
- // Sign the values of the y vectors
1987
- const __m256i sy = _mm256_sign_epi8 (by , bx );
1988
-
1989
- // Perform multiplication and create 16-bit values
1990
- const __m256i dot = _mm256_maddubs_epi16 (ax , sy );
1991
-
1992
- const __m256i ones = _mm256_set1_epi16 (1 );
1993
- const __m256i i32 = _mm256_madd_epi16 (ones , dot );
1966
+ /* Prepare the constants we will need during execution */
1967
+ const __m256i lowMask = _mm256_set1_epi8 ( 0xF );
1968
+ const __m256i offset_8 = _mm256_set1_epi16 ( 8 );
1994
1969
1995
- // Convert int32_t to float
1996
- const __m256 p = _mm256_cvtepi32_ps ( i32 );
1970
+ #define UNROLL_COUNT 8
1971
+ // make sure we only unroll multiples of the block count
1972
+ assert (nb % UNROLL_COUNT == 0 );
1997
1973
1998
- // Apply the scale, and accumulate
1999
- acc = _mm256_fmadd_ps ( d , p , acc );
2000
- }
1974
+ // Main loop
1975
+ for (int i = 0 ; i < nb ; i += UNROLL_COUNT ) {
1976
+
1977
+ // This loop will be unrolled by the compiler
1978
+ for (int u = 0 ;u < UNROLL_COUNT ;u ++ ) {
1979
+ /* Compute combined scale for the block */
1980
+ const __m256 scale = _mm256_mul_ps (
1981
+ _mm256_broadcast_ss ( & x [i + u ].d ),
1982
+ _mm256_broadcast_ss ( & y [i + u ].d ) );
1983
+
1984
+ /* get input from x
1985
+ Input: 32 Nibbles (16 bytes) at *x[i+u]
1986
+ Output: 2 vectors with 16 values of type int16_t (x_high_q, x_low_q) */
1987
+
1988
+ /* Load 16 bytes from memory */
1989
+ const __m128i tmp_x = _mm_loadu_si128 ( (const __m128i_u * ) x [i + u ].qs );
1990
+ /* Expand bytes into uint16_t values */
1991
+ const __m256i bytes_x = _mm256_cvtepu8_epi16 (tmp_x );
1992
+ /* Unpack values into individual bytes */
1993
+ __m256i x_low_q = _mm256_and_si256 ( lowMask , bytes_x );
1994
+ const __m256i pre_shift_x_high_q = _mm256_andnot_si256 ( lowMask , bytes_x );
1995
+ __m256i x_high_q = _mm256_srli_epi16 ( pre_shift_x_high_q , 4 );
1996
+ /* Now we have two vectors with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. */
1997
+ x_high_q = _mm256_sub_epi16 ( x_high_q , offset_8 );
1998
+ x_low_q = _mm256_sub_epi16 ( x_low_q , offset_8 );
1999
+
2000
+ /* get input from y
2001
+ Input: 32 Nibbles (16 bytes) at *y[i+u]
2002
+ Output: 2 vectors with 16 values of type int16_t (y_high_q, y_low_q) */
2003
+
2004
+ /* Load 16 bytes from memory */
2005
+ const __m128i tmp_y = _mm_loadu_si128 ( (const __m128i_u * ) y [i + u ].qs );
2006
+ /* Expand bytes into uint16_t values */
2007
+ const __m256i bytes_y = _mm256_cvtepu8_epi16 (tmp_y );
2008
+ /* Unpack values into individual bytes */
2009
+ const __m256i pre_shift_y_high_q = _mm256_andnot_si256 ( lowMask , bytes_y );
2010
+ __m256i y_high_q = _mm256_srli_epi16 ( pre_shift_y_high_q , 4 );
2011
+ __m256i y_low_q = _mm256_and_si256 ( lowMask , bytes_y );
2012
+ /* Now we have two vectors with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. */
2013
+ y_high_q = _mm256_sub_epi16 ( y_high_q , offset_8 );
2014
+ y_low_q = _mm256_sub_epi16 ( y_low_q , offset_8 );
2015
+
2016
+ /* Compute products of int16_t integers, add pairwise, store as int32_t */
2017
+ __m256i xy_high_q = _mm256_madd_epi16 ( x_high_q , y_high_q );
2018
+ __m256i xy_low_q = _mm256_madd_epi16 ( x_low_q , y_low_q );
2019
+
2020
+ /* Accumulate the products of int32_t integers -> we now have a vector of 8 int_32t */
2021
+ __m256i xy_q = _mm256_add_epi32 ( xy_high_q , xy_low_q );
2022
+
2023
+ /* Convert to vectore of 8 int32_t to 8 floats */
2024
+ __m256 q = _mm256_cvtepi32_ps ( xy_q );
2025
+
2026
+ /* Multiply q with scale and accumulate */
2027
+ acc = _mm256_fmadd_ps ( scale , q , acc );;
2028
+ }
2029
+
2030
+ }
2001
2031
2002
2032
// Return horizontal sum of the acc vector
2003
2033
__m128 res = _mm256_extractf128_ps ( acc , 1 );
@@ -2026,7 +2056,7 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
2026
2056
bx = _mm_sub_epi8 ( bx , off );
2027
2057
by = _mm_sub_epi8 ( by , off );
2028
2058
2029
- // Get absolute values of x vectors
2059
+ // Get absolute values of x vectors
2030
2060
const __m128i ax = _mm_sign_epi8 (bx , bx );
2031
2061
2032
2062
// Sign the values of the y vectors
0 commit comments