Skip to content

Commit 1d08882

Browse files
authored
Optimize AVX2 ggml_vec_dot_q4_0 (#642)
1 parent 02c5b27 commit 1d08882

File tree

1 file changed

+18
-13
lines changed

1 file changed

+18
-13
lines changed

ggml.c

+18-13
Original file line numberDiff line numberDiff line change
@@ -1833,7 +1833,7 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
18331833
const block_q4_0 * restrict x = vx;
18341834
const block_q4_0 * restrict y = vy;
18351835

1836-
ggml_float sumf = 0.0;
1836+
float sumf = 0.0;
18371837

18381838
#if defined(__ARM_NEON)
18391839
float sum0 = 0.0f;
@@ -1928,7 +1928,7 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
19281928
#endif
19291929
}
19301930

1931-
sumf = (ggml_float)(sum0 + sum1);
1931+
sumf = sum0 + sum1;
19321932
#elif defined(__AVX512F__)
19331933
// Initialize accumulator with zeros
19341934
__m512 acc0 = _mm512_setzero_ps();
@@ -1962,6 +1962,10 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
19621962
__m256 acc = _mm256_setzero_ps();
19631963

19641964
// Main loop
1965+
// TODO: figure a way to do this in a portable way
1966+
#ifdef __GNUC__
1967+
#pragma GCC unroll 16
1968+
#endif
19651969
for (int i = 0; i < nb; ++i) {
19661970
// Compute combined scale for the block
19671971
const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
@@ -1975,20 +1979,21 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
19751979
bx = _mm256_sub_epi8( bx, off );
19761980
by = _mm256_sub_epi8( by, off );
19771981

1978-
// Sign-extend first 16 signed bytes into int16_t
1979-
__m256i x16 = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( bx ) );
1980-
__m256i y16 = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( by ) );
1981-
// Compute products of int16_t integers, add pairwise
1982-
__m256i i32 = _mm256_madd_epi16( x16, y16 );
1982+
// Get absolute values of x vectors
1983+
const __m256i ax = _mm256_sign_epi8(bx, bx);
19831984

1984-
// Sign-extend last 16 signed bytes into int16_t vectors
1985-
x16 = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( bx, 1 ) );
1986-
y16 = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( by, 1 ) );
1987-
// Accumulate products of int16_t integers
1988-
i32 = _mm256_add_epi32( i32, _mm256_madd_epi16( x16, y16 ) );
1985+
// Sign the values of the y vectors
1986+
const __m256i sy = _mm256_sign_epi8(by, bx);
1987+
1988+
// Perform multiplication and create 16-bit values
1989+
const __m256i dot = _mm256_maddubs_epi16(ax, sy);
1990+
1991+
const __m256i ones = _mm256_set1_epi16(1);
1992+
const __m256i i32 = _mm256_madd_epi16(ones, dot);
19891993

19901994
// Convert int32_t to float
1991-
__m256 p = _mm256_cvtepi32_ps( i32 );
1995+
const __m256 p = _mm256_cvtepi32_ps( i32 );
1996+
19921997
// Apply the scale, and accumulate
19931998
acc = _mm256_fmadd_ps( d, p, acc );
19941999
}

0 commit comments

Comments
 (0)