@@ -755,7 +755,7 @@ void dequantize_row_q4_0(const void * restrict x, float * restrict y, int k) {
755
755
const uint8_t * restrict pd = ((const uint8_t * )x + 0 * bs );
756
756
const uint8_t * restrict pb = ((const uint8_t * )x + 0 * bs + sizeof (float ));
757
757
758
- #if defined(__AVX2__ ) && QK % 32 == 0
758
+ #if defined(__AVX2__ )
759
759
for (int i = 0 ; i < nb ; i ++ ) {
760
760
// scale factor
761
761
const __m256 d_v = _mm256_broadcast_ss ((const float * ) (pd + i * bs ));
@@ -788,7 +788,59 @@ void dequantize_row_q4_0(const void * restrict x, float * restrict y, int k) {
788
788
}
789
789
}
790
790
}
791
- //#elif defined(__ARM_NEON)
791
+ #elif defined(__ARM_NEON )
792
+ for (int i = 0 ; i < nb ; i ++ ) {
793
+ const float d = * (const float * ) (pd + i * bs );
794
+
795
+ const uint8_t * restrict pp = pb + i * bs ;
796
+
797
+ const float32x4_t vd = vdupq_n_f32 (d );
798
+
799
+ for (int l = 0 ; l < QK ; l += 16 ) {
800
+ // Load 16x4-bit integers into 8x8-bit integers
801
+ const uint8x8_t v8 = vld1_u8 (pp + l /2 );
802
+
803
+ // Expand 4-bit nibbles to 8-bit bytes
804
+ const uint8x8_t v0 = vand_u8 (v8 , vdup_n_u8 (0x0f ));
805
+ const uint8x8_t v1 = vshr_n_u8 (v8 , 4 );
806
+
807
+ // Convert to signed 8-bit integers
808
+ const int8x8_t vs_0 = vreinterpret_s8_u8 (v0 );
809
+ const int8x8_t vs_1 = vreinterpret_s8_u8 (v1 );
810
+
811
+ // Subtract 8 from each byte
812
+ const int8x8_t vb_0 = vsub_s8 (vs_0 , vdup_n_s8 (8 ));
813
+ const int8x8_t vb_1 = vsub_s8 (vs_1 , vdup_n_s8 (8 ));
814
+
815
+ // Interleave and combine
816
+ const int8x8_t vx_0 = vzip1_s8 (vb_0 , vb_1 );
817
+ const int8x8_t vx_1 = vzip2_s8 (vb_0 , vb_1 );
818
+
819
+ const int8x16_t vq = vcombine_s8 (vx_0 , vx_1 );
820
+
821
+ // convert to 2x int16x8_t
822
+ const int16x8_t vi_0 = vmovl_s8 (vget_low_s8 (vq ));
823
+ const int16x8_t vi_1 = vmovl_s8 (vget_high_s8 (vq ));
824
+
825
+ // convert to 4x float32x4_t
826
+ const float32x4_t vf_0 = vcvtq_f32_s32 (vmovl_s16 (vget_low_s16 (vi_0 )));
827
+ const float32x4_t vf_1 = vcvtq_f32_s32 (vmovl_s16 (vget_high_s16 (vi_0 )));
828
+ const float32x4_t vf_2 = vcvtq_f32_s32 (vmovl_s16 (vget_low_s16 (vi_1 )));
829
+ const float32x4_t vf_3 = vcvtq_f32_s32 (vmovl_s16 (vget_high_s16 (vi_1 )));
830
+
831
+ // Multiply by d
832
+ const float32x4_t r0 = vmulq_f32 (vf_0 , vd );
833
+ const float32x4_t r1 = vmulq_f32 (vf_1 , vd );
834
+ const float32x4_t r2 = vmulq_f32 (vf_2 , vd );
835
+ const float32x4_t r3 = vmulq_f32 (vf_3 , vd );
836
+
837
+ // Store
838
+ vst1q_f32 (y + i * QK + l + 0 , r0 );
839
+ vst1q_f32 (y + i * QK + l + 4 , r1 );
840
+ vst1q_f32 (y + i * QK + l + 8 , r2 );
841
+ vst1q_f32 (y + i * QK + l + 12 , r3 );
842
+ }
843
+ }
792
844
#else
793
845
// scalar
794
846
for (int i = 0 ; i < nb ; i ++ ) {
0 commit comments