Skip to content

Commit 69ef03d

Browse files
committed
Performance improvement of AVX2 code
1 parent d8d4e86 commit 69ef03d

File tree

1 file changed

+65
-35
lines changed

1 file changed

+65
-35
lines changed

ggml.c

+65-35
Original file line numberDiff line numberDiff line change
@@ -1959,45 +1959,75 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
19591959
// Horizontal sum of all lanes of the accumulator
19601960
sumf = _mm512_reduce_add_ps( acc0 ) + _mm512_reduce_add_ps( acc1 );
19611961
#elif defined(__AVX2__)
1962+
19621963
// Initialize accumulator with zeros
19631964
__m256 acc = _mm256_setzero_ps();
19641965

1965-
// Main loop
1966-
// TODO: figure a way to do this in a portable way
1967-
#ifdef __GNUC__
1968-
#pragma GCC unroll 16
1969-
#endif
1970-
for (int i = 0; i < nb; ++i) {
1971-
// Compute combined scale for the block
1972-
const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
1973-
1974-
// Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
1975-
__m256i bx = bytesFromNibbles( x[i].qs );
1976-
__m256i by = bytesFromNibbles( y[i].qs );
1977-
1978-
// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
1979-
const __m256i off = _mm256_set1_epi8( 8 );
1980-
bx = _mm256_sub_epi8( bx, off );
1981-
by = _mm256_sub_epi8( by, off );
1982-
1983-
// Get absolute values of x vectors
1984-
const __m256i ax = _mm256_sign_epi8(bx, bx);
1985-
1986-
// Sign the values of the y vectors
1987-
const __m256i sy = _mm256_sign_epi8(by, bx);
1988-
1989-
// Perform multiplication and create 16-bit values
1990-
const __m256i dot = _mm256_maddubs_epi16(ax, sy);
1991-
1992-
const __m256i ones = _mm256_set1_epi16(1);
1993-
const __m256i i32 = _mm256_madd_epi16(ones, dot);
1966+
/* Prepare the constants we will need during execution */
1967+
const __m256i lowMask = _mm256_set1_epi8( 0xF );
1968+
const __m256i offset_8 = _mm256_set1_epi16( 8 );
19941969

1995-
// Convert int32_t to float
1996-
const __m256 p = _mm256_cvtepi32_ps( i32 );
1970+
#define UNROLL_COUNT 8
1971+
// make sure we only unroll multiples of the block count
1972+
assert(nb % UNROLL_COUNT == 0);
19971973

1998-
// Apply the scale, and accumulate
1999-
acc = _mm256_fmadd_ps( d, p, acc );
2000-
}
1974+
// Main loop
1975+
for (int i = 0; i < nb; i+=UNROLL_COUNT) {
1976+
1977+
// This loop will be unrolled by the compiler
1978+
for (int u=0;u<UNROLL_COUNT;u++) {
1979+
/* Compute combined scale for the block */
1980+
const __m256 scale = _mm256_mul_ps(
1981+
_mm256_broadcast_ss( &x[i+u].d ),
1982+
_mm256_broadcast_ss( &y[i+u].d ) );
1983+
1984+
/* get input from x
1985+
Input: 32 Nibbles (16 bytes) at *x[i+u]
1986+
Output: 2 vectors with 16 values of type int16_t (x_high_q, x_low_q) */
1987+
1988+
/* Load 16 bytes from memory */
1989+
const __m128i tmp_x = _mm_loadu_si128( (const __m128i_u *) x[i+u].qs);
1990+
/* Expand bytes into uint16_t values */
1991+
const __m256i bytes_x = _mm256_cvtepu8_epi16(tmp_x);
1992+
/* Unpack values into individual bytes */
1993+
__m256i x_low_q = _mm256_and_si256( lowMask, bytes_x );
1994+
const __m256i pre_shift_x_high_q = _mm256_andnot_si256( lowMask, bytes_x );
1995+
__m256i x_high_q = _mm256_srli_epi16( pre_shift_x_high_q, 4 );
1996+
/* Now we have two vectors with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. */
1997+
x_high_q = _mm256_sub_epi16( x_high_q, offset_8 );
1998+
x_low_q = _mm256_sub_epi16( x_low_q, offset_8 );
1999+
2000+
/* get input from y
2001+
Input: 32 Nibbles (16 bytes) at *y[i+u]
2002+
Output: 2 vectors with 16 values of type int16_t (y_high_q, y_low_q) */
2003+
2004+
/* Load 16 bytes from memory */
2005+
const __m128i tmp_y = _mm_loadu_si128( (const __m128i_u *) y[i+u].qs);
2006+
/* Expand bytes into uint16_t values */
2007+
const __m256i bytes_y = _mm256_cvtepu8_epi16(tmp_y);
2008+
/* Unpack values into individual bytes */
2009+
const __m256i pre_shift_y_high_q = _mm256_andnot_si256( lowMask, bytes_y );
2010+
__m256i y_high_q = _mm256_srli_epi16( pre_shift_y_high_q, 4 );
2011+
__m256i y_low_q = _mm256_and_si256( lowMask, bytes_y );
2012+
/* Now we have two vectors with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. */
2013+
y_high_q = _mm256_sub_epi16( y_high_q, offset_8 );
2014+
y_low_q = _mm256_sub_epi16( y_low_q, offset_8 );
2015+
2016+
/* Compute products of int16_t integers, add pairwise, store as int32_t */
2017+
__m256i xy_high_q = _mm256_madd_epi16( x_high_q, y_high_q );
2018+
__m256i xy_low_q = _mm256_madd_epi16( x_low_q, y_low_q );
2019+
2020+
/* Accumulate the products of int32_t integers -> we now have a vector of 8 int_32t */
2021+
__m256i xy_q = _mm256_add_epi32( xy_high_q, xy_low_q );
2022+
2023+
/* Convert to vectore of 8 int32_t to 8 floats */
2024+
__m256 q = _mm256_cvtepi32_ps( xy_q );
2025+
2026+
/* Multiply q with scale and accumulate */
2027+
acc = _mm256_fmadd_ps( scale, q, acc );;
2028+
}
2029+
2030+
}
20012031

20022032
// Return horizontal sum of the acc vector
20032033
__m128 res = _mm256_extractf128_ps( acc, 1 );
@@ -2026,7 +2056,7 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
20262056
bx = _mm_sub_epi8( bx, off );
20272057
by = _mm_sub_epi8( by, off );
20282058

2029-
// Get absolute values of x vectors
2059+
// Get absolute values of x vectors
20302060
const __m128i ax = _mm_sign_epi8(bx, bx);
20312061

20322062
// Sign the values of the y vectors

0 commit comments

Comments
 (0)