@@ -3341,8 +3341,57 @@ void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int6
33413341 }
33423342}
33433343
3344- void quantize_row_q8_K(const float * restrict x, void * restrict y, int64_t k) {
3345- quantize_row_q8_K_reference(x, y, k);
3344+ void quantize_row_q8_K(const float * restrict x, void * restrict vy, int64_t k) {
3345+ #ifdef __AVX2__
3346+ assert(k % QK_K == 0);
3347+ const int nb = k / QK_K;
3348+ block_q8_K * y = vy;
3349+ const __m256 signBit = _mm256_set1_ps( -0.0f );
3350+ const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
3351+ for (int i = 0; i < nb; i++) {
3352+ const float * xb = x + i*QK_K;
3353+ __m256 maxAbs = _mm256_setzero_ps();
3354+ const float * xx = xb;
3355+ for (int ib = 0; ib < QK_K/8; ++ib) {
3356+ const __m256 v = _mm256_loadu_ps(xx); xx += 8;
3357+ maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps(signBit, v));
3358+ }
3359+ __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
3360+ max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
3361+ max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
3362+ const float maxScalar = _mm_cvtss_f32( max4 );
3363+ const float d = maxScalar / 127.f;
3364+ y[i].d = d;
3365+ const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f;
3366+ const __m256 mul = _mm256_set1_ps( id );
3367+ xx = xb;
3368+ int8_t * q8 = y[i].qs;
3369+ for (int ib = 0; ib < QK_K/32; ++ib) {
3370+ __m256 v0 = _mm256_mul_ps(mul, _mm256_loadu_ps(xx)); xx += 8;
3371+ __m256 v1 = _mm256_mul_ps(mul, _mm256_loadu_ps(xx)); xx += 8;
3372+ __m256 v2 = _mm256_mul_ps(mul, _mm256_loadu_ps(xx)); xx += 8;
3373+ __m256 v3 = _mm256_mul_ps(mul, _mm256_loadu_ps(xx)); xx += 8;
3374+ v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
3375+ v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
3376+ v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
3377+ v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
3378+ __m256i i0 = _mm256_cvtps_epi32( v0 );
3379+ __m256i i1 = _mm256_cvtps_epi32( v1 );
3380+ __m256i i2 = _mm256_cvtps_epi32( v2 );
3381+ __m256i i3 = _mm256_cvtps_epi32( v3 );
3382+ y[i].bsums[2*ib+0] = hsum_i32_8(_mm256_add_epi32(i0, i1));
3383+ y[i].bsums[2*ib+1] = hsum_i32_8(_mm256_add_epi32(i2, i3));
3384+ i0 = _mm256_packs_epi32( i0, i1 );
3385+ i2 = _mm256_packs_epi32( i2, i3 );
3386+ i0 = _mm256_packs_epi16( i0, i2 );
3387+ i0 = _mm256_permutevar8x32_epi32( i0, perm );
3388+ _mm256_storeu_si256((__m256i *)q8, i0);
3389+ q8 += 32;
3390+ }
3391+ }
3392+ #else
3393+ quantize_row_q8_K_reference(x, vy, k);
3394+ #endif
33463395}
33473396
33483397//===================================== Dot ptoducts =================================
0 commit comments