@@ -771,6 +771,40 @@ void dequantize_row_q4_0(const void * restrict x, float * restrict y, int k) {
771
771
const uint8_t * restrict pd = ((const uint8_t * )x + 0 * bs );
772
772
const uint8_t * restrict pb = ((const uint8_t * )x + 0 * bs + sizeof (float ));
773
773
774
+ #if defined(__AVX2__ ) && QK % 32 == 0
775
+ for (int i = 0 ; i < nb ; i ++ ) {
776
+ // scale factor
777
+ const __m256 d_v = _mm256_broadcast_ss ((const float * ) (pd + i * bs ));
778
+
779
+ const uint8_t * restrict pp = pb + i * bs ;
780
+
781
+ for (int l = 0 ; l < QK ; l += 32 ) {
782
+ // Load 32x4-bit integers into 32x8-bit integers
783
+ __m256i vx8 = bytesFromNibbles (pp + l /2 );
784
+
785
+ // Subtract 8 from the integers
786
+ vx8 = _mm256_sub_epi8 (vx8 , _mm256_set1_epi8 (8 ));
787
+
788
+ // Convert to 16-bit int
789
+ const __m256i vx16_lo = _mm256_cvtepi8_epi16 (_mm256_extracti128_si256 (vx8 , 0 ));
790
+ const __m256i vx16_hi = _mm256_cvtepi8_epi16 (_mm256_extracti128_si256 (vx8 , 1 ));
791
+
792
+ // Convert to 32-bit int -> float 32
793
+ const __m256 vf [4 ] = {
794
+ _mm256_cvtepi32_ps (_mm256_cvtepi16_epi32 (_mm256_extracti128_si256 (vx16_lo , 0 ))),
795
+ _mm256_cvtepi32_ps (_mm256_cvtepi16_epi32 (_mm256_extracti128_si256 (vx16_lo , 1 ))),
796
+ _mm256_cvtepi32_ps (_mm256_cvtepi16_epi32 (_mm256_extracti128_si256 (vx16_hi , 0 ))),
797
+ _mm256_cvtepi32_ps (_mm256_cvtepi16_epi32 (_mm256_extracti128_si256 (vx16_hi , 1 )))
798
+ };
799
+
800
+ // Scale and store
801
+ for (int j = 0 ; j < 4 ; j ++ ) {
802
+ __m256 result = _mm256_mul_ps (vf [j ], d_v );
803
+ _mm256_storeu_ps (y + i * QK + l + j * 8 , result );
804
+ }
805
+ }
806
+ }
807
+ #else
774
808
// scalar
775
809
for (int i = 0 ; i < nb ; i ++ ) {
776
810
const float d = * (const float * ) (pd + i * bs );
@@ -795,6 +829,7 @@ void dequantize_row_q4_0(const void * restrict x, float * restrict y, int k) {
795
829
assert (!isnan (y [i * QK + l + 1 ]));
796
830
}
797
831
}
832
+ #endif
798
833
}
799
834
800
835
void dequantize_row_q4_1 (const void * restrict x , float * restrict y , int k ) {
0 commit comments