Skip to content

Commit e6532f7

Browse files
authored
Faster AVX2 prompt processing for k-quants and IQ4_XS (#394)
1 parent 911d58f commit e6532f7

File tree

3 files changed

+843
-2
lines changed

3 files changed

+843
-2
lines changed

llama.cpp/ggml-quants.inc

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3341,8 +3341,57 @@ void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int6
33413341
}
33423342
}
33433343

3344-
void quantize_row_q8_K(const float * restrict x, void * restrict y, int64_t k) {
3345-
quantize_row_q8_K_reference(x, y, k);
3344+
void quantize_row_q8_K(const float * restrict x, void * restrict vy, int64_t k) {
3345+
#ifdef __AVX2__
3346+
assert(k % QK_K == 0);
3347+
const int nb = k / QK_K;
3348+
block_q8_K * y = vy;
3349+
const __m256 signBit = _mm256_set1_ps( -0.0f );
3350+
const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
3351+
for (int i = 0; i < nb; i++) {
3352+
const float * xb = x + i*QK_K;
3353+
__m256 maxAbs = _mm256_setzero_ps();
3354+
const float * xx = xb;
3355+
for (int ib = 0; ib < QK_K/8; ++ib) {
3356+
const __m256 v = _mm256_loadu_ps(xx); xx += 8;
3357+
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps(signBit, v));
3358+
}
3359+
__m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
3360+
max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
3361+
max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
3362+
const float maxScalar = _mm_cvtss_f32( max4 );
3363+
const float d = maxScalar / 127.f;
3364+
y[i].d = d;
3365+
const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f;
3366+
const __m256 mul = _mm256_set1_ps( id );
3367+
xx = xb;
3368+
int8_t * q8 = y[i].qs;
3369+
for (int ib = 0; ib < QK_K/32; ++ib) {
3370+
__m256 v0 = _mm256_mul_ps(mul, _mm256_loadu_ps(xx)); xx += 8;
3371+
__m256 v1 = _mm256_mul_ps(mul, _mm256_loadu_ps(xx)); xx += 8;
3372+
__m256 v2 = _mm256_mul_ps(mul, _mm256_loadu_ps(xx)); xx += 8;
3373+
__m256 v3 = _mm256_mul_ps(mul, _mm256_loadu_ps(xx)); xx += 8;
3374+
v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
3375+
v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
3376+
v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
3377+
v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
3378+
__m256i i0 = _mm256_cvtps_epi32( v0 );
3379+
__m256i i1 = _mm256_cvtps_epi32( v1 );
3380+
__m256i i2 = _mm256_cvtps_epi32( v2 );
3381+
__m256i i3 = _mm256_cvtps_epi32( v3 );
3382+
y[i].bsums[2*ib+0] = hsum_i32_8(_mm256_add_epi32(i0, i1));
3383+
y[i].bsums[2*ib+1] = hsum_i32_8(_mm256_add_epi32(i2, i3));
3384+
i0 = _mm256_packs_epi32( i0, i1 );
3385+
i2 = _mm256_packs_epi32( i2, i3 );
3386+
i0 = _mm256_packs_epi16( i0, i2 );
3387+
i0 = _mm256_permutevar8x32_epi32( i0, perm );
3388+
_mm256_storeu_si256((__m256i *)q8, i0);
3389+
q8 += 32;
3390+
}
3391+
}
3392+
#else
3393+
quantize_row_q8_K_reference(x, vy, k);
3394+
#endif
33463395
}
33473396

33483397
//===================================== Dot ptoducts =================================

0 commit comments

Comments
 (0)