Skip to content

Commit 09aecbf

Browse files
authored
Add AVX2 implementation of dequantize_row_q4_0 (#467)
1 parent 4640eff commit 09aecbf

File tree

1 file changed

+35
-0
lines changed

1 file changed

+35
-0
lines changed

Diff for: ggml.c

+35
Original file line numberDiff line numberDiff line change
@@ -771,6 +771,40 @@ void dequantize_row_q4_0(const void * restrict x, float * restrict y, int k) {
771771
const uint8_t * restrict pd = ((const uint8_t *)x + 0*bs);
772772
const uint8_t * restrict pb = ((const uint8_t *)x + 0*bs + sizeof(float));
773773

774+
#if defined(__AVX2__) && QK % 32 == 0
775+
for (int i = 0; i < nb; i++) {
776+
// scale factor
777+
const __m256 d_v = _mm256_broadcast_ss((const float *) (pd + i*bs));
778+
779+
const uint8_t * restrict pp = pb + i*bs;
780+
781+
for (int l = 0; l < QK; l += 32) {
782+
// Load 32x4-bit integers into 32x8-bit integers
783+
__m256i vx8 = bytesFromNibbles(pp+l/2);
784+
785+
// Subtract 8 from the integers
786+
vx8 = _mm256_sub_epi8(vx8, _mm256_set1_epi8(8));
787+
788+
// Convert to 16-bit int
789+
const __m256i vx16_lo = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 0));
790+
const __m256i vx16_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 1));
791+
792+
// Convert to 32-bit int -> float 32
793+
const __m256 vf[4] = {
794+
_mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_lo, 0))),
795+
_mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_lo, 1))),
796+
_mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_hi, 0))),
797+
_mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_hi, 1)))
798+
};
799+
800+
// Scale and store
801+
for (int j = 0; j < 4; j++) {
802+
__m256 result = _mm256_mul_ps(vf[j], d_v);
803+
_mm256_storeu_ps(y + i * QK + l + j*8, result);
804+
}
805+
}
806+
}
807+
#else
774808
// scalar
775809
for (int i = 0; i < nb; i++) {
776810
const float d = *(const float *) (pd + i*bs);
@@ -795,6 +829,7 @@ void dequantize_row_q4_0(const void * restrict x, float * restrict y, int k) {
795829
assert(!isnan(y[i*QK + l + 1]));
796830
}
797831
}
832+
#endif
798833
}
799834

800835
void dequantize_row_q4_1(const void * restrict x, float * restrict y, int k) {

0 commit comments

Comments
 (0)