@@ -771,6 +771,40 @@ void dequantize_row_q4_0(const void * restrict x, float * restrict y, int k) {
771771 const uint8_t * restrict pd = ((const uint8_t * )x + 0 * bs );
772772 const uint8_t * restrict pb = ((const uint8_t * )x + 0 * bs + sizeof (float ));
773773
774+ #if defined(__AVX2__ ) && QK % 32 == 0
775+ for (int i = 0 ; i < nb ; i ++ ) {
776+ // scale factor
777+ const __m256 d_v = _mm256_broadcast_ss ((const float * ) (pd + i * bs ));
778+
779+ const uint8_t * restrict pp = pb + i * bs ;
780+
781+ for (int l = 0 ; l < QK ; l += 32 ) {
782+ // Load 32x4-bit integers into 32x8-bit integers
783+ __m256i vx8 = bytesFromNibbles (pp + l /2 );
784+
785+ // Subtract 8 from the integers
786+ vx8 = _mm256_sub_epi8 (vx8 , _mm256_set1_epi8 (8 ));
787+
788+ // Convert to 16-bit int
789+ const __m256i vx16_lo = _mm256_cvtepi8_epi16 (_mm256_extracti128_si256 (vx8 , 0 ));
790+ const __m256i vx16_hi = _mm256_cvtepi8_epi16 (_mm256_extracti128_si256 (vx8 , 1 ));
791+
792+ // Convert to 32-bit int -> float 32
793+ const __m256 vf [4 ] = {
794+ _mm256_cvtepi32_ps (_mm256_cvtepi16_epi32 (_mm256_extracti128_si256 (vx16_lo , 0 ))),
795+ _mm256_cvtepi32_ps (_mm256_cvtepi16_epi32 (_mm256_extracti128_si256 (vx16_lo , 1 ))),
796+ _mm256_cvtepi32_ps (_mm256_cvtepi16_epi32 (_mm256_extracti128_si256 (vx16_hi , 0 ))),
797+ _mm256_cvtepi32_ps (_mm256_cvtepi16_epi32 (_mm256_extracti128_si256 (vx16_hi , 1 )))
798+ };
799+
800+ // Scale and store
801+ for (int j = 0 ; j < 4 ; j ++ ) {
802+ __m256 result = _mm256_mul_ps (vf [j ], d_v );
803+ _mm256_storeu_ps (y + i * QK + l + j * 8 , result );
804+ }
805+ }
806+ }
807+ #else
774808 // scalar
775809 for (int i = 0 ; i < nb ; i ++ ) {
776810 const float d = * (const float * ) (pd + i * bs );
@@ -795,6 +829,7 @@ void dequantize_row_q4_0(const void * restrict x, float * restrict y, int k) {
795829 assert (!isnan (y [i * QK + l + 1 ]));
796830 }
797831 }
832+ #endif
798833}
799834
800835void dequantize_row_q4_1 (const void * restrict x , float * restrict y , int k ) {
0 commit comments