Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit face808

Browse files
authoredMar 25, 2023
SIMD-ify dequantize_row_q4_0() for ARM_NEON (#502)
* Attempt to SIMD-ify dequantize_row_q4_0() for ARM_NEON * Fix dequantization - forgot to interleave the quants
1 parent 1e39d2b commit face808

File tree

1 file changed

+54
-2
lines changed

1 file changed

+54
-2
lines changed
 

‎ggml.c

+54-2
Original file line numberDiff line numberDiff line change
@@ -755,7 +755,7 @@ void dequantize_row_q4_0(const void * restrict x, float * restrict y, int k) {
755755
const uint8_t * restrict pd = ((const uint8_t *)x + 0*bs);
756756
const uint8_t * restrict pb = ((const uint8_t *)x + 0*bs + sizeof(float));
757757

758-
#if defined(__AVX2__) && QK % 32 == 0
758+
#if defined(__AVX2__)
759759
for (int i = 0; i < nb; i++) {
760760
// scale factor
761761
const __m256 d_v = _mm256_broadcast_ss((const float *) (pd + i*bs));
@@ -788,7 +788,59 @@ void dequantize_row_q4_0(const void * restrict x, float * restrict y, int k) {
788788
}
789789
}
790790
}
791-
//#elif defined(__ARM_NEON)
791+
#elif defined(__ARM_NEON)
792+
for (int i = 0; i < nb; i++) {
793+
const float d = *(const float *) (pd + i*bs);
794+
795+
const uint8_t * restrict pp = pb + i*bs;
796+
797+
const float32x4_t vd = vdupq_n_f32(d);
798+
799+
for (int l = 0; l < QK; l += 16) {
800+
// Load 16x4-bit integers into 8x8-bit integers
801+
const uint8x8_t v8 = vld1_u8(pp + l/2);
802+
803+
// Expand 4-bit nibbles to 8-bit bytes
804+
const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0f));
805+
const uint8x8_t v1 = vshr_n_u8(v8, 4);
806+
807+
// Convert to signed 8-bit integers
808+
const int8x8_t vs_0 = vreinterpret_s8_u8(v0);
809+
const int8x8_t vs_1 = vreinterpret_s8_u8(v1);
810+
811+
// Subtract 8 from each byte
812+
const int8x8_t vb_0 = vsub_s8(vs_0, vdup_n_s8(8));
813+
const int8x8_t vb_1 = vsub_s8(vs_1, vdup_n_s8(8));
814+
815+
// Interleave and combine
816+
const int8x8_t vx_0 = vzip1_s8(vb_0, vb_1);
817+
const int8x8_t vx_1 = vzip2_s8(vb_0, vb_1);
818+
819+
const int8x16_t vq = vcombine_s8(vx_0, vx_1);
820+
821+
// convert to 2x int16x8_t
822+
const int16x8_t vi_0 = vmovl_s8(vget_low_s8 (vq));
823+
const int16x8_t vi_1 = vmovl_s8(vget_high_s8(vq));
824+
825+
// convert to 4x float32x4_t
826+
const float32x4_t vf_0 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vi_0)));
827+
const float32x4_t vf_1 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vi_0)));
828+
const float32x4_t vf_2 = vcvtq_f32_s32(vmovl_s16(vget_low_s16 (vi_1)));
829+
const float32x4_t vf_3 = vcvtq_f32_s32(vmovl_s16(vget_high_s16(vi_1)));
830+
831+
// Multiply by d
832+
const float32x4_t r0 = vmulq_f32(vf_0, vd);
833+
const float32x4_t r1 = vmulq_f32(vf_1, vd);
834+
const float32x4_t r2 = vmulq_f32(vf_2, vd);
835+
const float32x4_t r3 = vmulq_f32(vf_3, vd);
836+
837+
// Store
838+
vst1q_f32(y + i*QK + l + 0, r0);
839+
vst1q_f32(y + i*QK + l + 4, r1);
840+
vst1q_f32(y + i*QK + l + 8, r2);
841+
vst1q_f32(y + i*QK + l + 12, r3);
842+
}
843+
}
792844
#else
793845
// scalar
794846
for (int i = 0; i < nb; i++) {

0 commit comments

Comments
 (0)
Please sign in to comment.