Skip to content

Commit 948d124

Browse files
authored
AVX implementations (#1370)
1 parent d155f0f commit 948d124

File tree

2 files changed

+33
-65
lines changed

2 files changed

+33
-65
lines changed

SHA256SUMS

+4-12
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,19 @@
11
700df0d3013b703a806d2ae7f1bfb8e59814e3d06ae78be0c66368a50059f33d models/7B/consolidated.00.pth
22
666a4bb533b303bdaf89e1b6a3b6f93535d868de31d903afdc20983dc526c847 models/7B/ggml-model-f16.bin
3-
99aeb35f26b577fa2732716cca4d8b5ada39a78ea9b2dca2651fc632b5d101b6 models/7B/ggml-model-q4_0.bin
4-
cc061458339a3eb8bcecbf0a825e9924fb7d1a8150f63cd5d091caa99215aafe models/7B/ggml-model-q4_1.bin
5-
25b050337a87344da687a7f2adddc03bd99b7f6c140450e836649f3585fb6496 models/7B/ggml-model-q4_2.bin
3+
ae89af479ab4d31c4e555ad8cc1dc9bf1f68d617186158cc381cd5a0fccd10bd models/7B/ggml-model-q4_0.bin
4+
862072e2036a1bdb1a01ec2e159381f332a9e2357b886031c075fb7efa86db9b models/7B/ggml-model-q4_1.bin
5+
0bef7cefa880a67a0b6d2a7e4559ded235823535ad616808dd8b5e47ff0a202f models/7B/ggml-model-q5_0.bin
6+
97b9c38b2b8aed0c0aa90e0a975570ce3455c47d62128b382c55acbf6e2035f6 models/7B/ggml-model-q5_1.bin
67
7e89e242ddc0dd6f060b43ca219ce8b3e8f08959a72cb3c0855df8bb04d46265 models/7B/params.json
78
745bf4e29a4dd6f411e72976d92b452da1b49168a4f41c951cfcc8051823cf08 models/13B/consolidated.00.pth
89
d5ccbcc465c71c0de439a5aeffebe8344c68a519bce70bc7f9f92654ee567085 models/13B/consolidated.01.pth
910
2b206e9b21fb1076f11cafc624e2af97c9e48ea09312a0962153acc20d45f808 models/13B/ggml-model-f16.bin
10-
eecb575d325d935157761172e2bf05984dad216eb2b06777b73463cf9b818bab models/13B/ggml-model-q4_0.bin
11-
d9581b5b88e5622532fe897c9f9b0e67a317d22dd27a6f90fa4ab8c6d23ccdbb models/13B/ggml-model-q4_1.bin
12-
75a218a47df03f5f96354656329864613abcb67779412b9bc2282b28c1c3cbaa models/13B/ggml-model-q4_2.bin
1311
4ab77bec4d4405ccb66a97b282574c89a94417e3c32e5f68f37e2876fc21322f models/13B/params.json
1412
e23294a58552d8cdec5b7e8abb87993b97ea6eced4178ff2697c02472539d067 models/30B/consolidated.00.pth
1513
4e077b7136c7ae2302e954860cf64930458d3076fcde9443f4d0e939e95903ff models/30B/consolidated.01.pth
1614
24a87f01028cbd3a12de551dcedb712346c0b5cbdeff1454e0ddf2df9b675378 models/30B/consolidated.02.pth
1715
1adfcef71420886119544949767f6a56cb6339b4d5fcde755d80fe68b49de93b models/30B/consolidated.03.pth
1816
7e1b524061a9f4b27c22a12d6d2a5bf13b8ebbea73e99f218809351ed9cf7d37 models/30B/ggml-model-f16.bin
19-
517b9e525742c42b5478a6280a4b41ec66f46298c57aba7f0453d491682fe42d models/30B/ggml-model-q4_0.bin
20-
7b75ac615fa369ee593493a7e6ef87542bf0350255db928b22c5a24f6d598bcd models/30B/ggml-model-q4_1.bin
21-
aadbc9cf806313a55be570f62884eed289d30c313fac3b7838717e01bd553204 models/30B/ggml-model-q4_2.bin
2217
2c07118ea98d69dbe7810d88520e30288fa994751b337f8fca02b171955f44cb models/30B/params.json
2318
135c563f6b3938114458183afb01adc9a63bef3d8ff7cccc3977e5d3664ecafe models/65B/consolidated.00.pth
2419
9a600b37b19d38c7e43809485f70d17d1dc12206c07efa83bc72bb498a568bde models/65B/consolidated.01.pth
@@ -29,8 +24,5 @@ a287c0dfe49081626567c7fe87f74cce5831f58e459b427b5e05567641f47b78 models/65B/con
2924
72b4eba67a1a3b18cb67a85b70f8f1640caae9b40033ea943fb166bd80a7b36b models/65B/consolidated.06.pth
3025
d27f5b0677d7ff129ceacd73fd461c4d06910ad7787cf217b249948c3f3bc638 models/65B/consolidated.07.pth
3126
60758f2384d74e423dffddfd020ffed9d3bb186ebc54506f9c4a787d0f5367b0 models/65B/ggml-model-f16.bin
32-
01672072136f8be6ca9d7cebe5f86ed316e8b85851b9fe3de951809233cea4f2 models/65B/ggml-model-q4_0.bin
33-
4743a28aac3e5f32a6e838a815f51d3779de44fbbe251d745251e66c23c5950f models/65B/ggml-model-q4_1.bin
34-
1b6f6588d0e2ecfe6c4d849088e48e5e3083466b962daa32e3261363e21fc5e9 models/65B/ggml-model-q4_2.bin
3527
999ed1659b469ccc2a941714c0a9656fa571d17c9f7c8c7589817ca90edef51b models/65B/params.json
3628
9e556afd44213b6bd1be2b850ebbbd98f5481437a8021afaf58ee7fb1818d347 models/tokenizer.model

ggml.c

+29-53
Original file line numberDiff line numberDiff line change
@@ -472,23 +472,16 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
472472
//
473473

474474
#if __AVX__ || __AVX2__ || __AVX512F__
475-
// Unpack 16 4-bit fields into 16 bytes
476-
// The output vector contains 16 bytes, each one in [ 0 .. 15 ] interval
477-
static inline __m128i bytes_from_nibbles_16(const uint8_t * rsi)
478-
{
479-
// Load 8 bytes from memory
480-
__m128i tmp = _mm_loadl_epi64( ( const __m128i* )rsi );
481-
482-
// Expand bytes into uint16_t values
483-
__m128i bytes = _mm_cvtepu8_epi16( tmp );
484-
485-
// Unpack values into individual bytes
486-
const __m128i lowMask = _mm_set1_epi8( 0xF );
487-
__m128i high = _mm_andnot_si128( lowMask, bytes );
488-
__m128i low = _mm_and_si128( lowMask, bytes );
489-
high = _mm_slli_epi16( high, 4 );
490-
bytes = _mm_or_si128( low, high );
491-
return bytes;
475+
// multiply int8_t, add results pairwise twice
476+
static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) {
477+
// Get absolute values of x vectors
478+
const __m128i ax = _mm_sign_epi8(x, x);
479+
// Sign the values of the y vectors
480+
const __m128i sy = _mm_sign_epi8(y, x);
481+
// Perform multiplication and create 16-bit values
482+
const __m128i dot = _mm_maddubs_epi16(ax, sy);
483+
const __m128i ones = _mm_set1_epi16(1);
484+
return _mm_madd_epi16(ones, dot);
492485
}
493486

494487
// horizontally add 8 floats
@@ -535,19 +528,10 @@ static inline __m256i bytes_from_bits_32(const uint8_t * x) {
535528
// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
536529
static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
537530
{
538-
// Load 16 bytes from memory
539-
__m128i tmp = _mm_loadu_si128( ( const __m128i* )rsi );
540-
541-
// Expand bytes into uint16_t values
542-
__m256i bytes = _mm256_cvtepu8_epi16( tmp );
543-
544-
// Unpack values into individual bytes
531+
const __m128i tmp = _mm_loadu_si128((const __m128i *)rsi);
532+
const __m256i bytes = _mm256_set_m128i(_mm_srli_epi16(tmp, 4), tmp);
545533
const __m256i lowMask = _mm256_set1_epi8( 0xF );
546-
__m256i high = _mm256_andnot_si256( lowMask, bytes );
547-
__m256i low = _mm256_and_si256( lowMask, bytes );
548-
high = _mm256_slli_epi16( high, 4 );
549-
bytes = _mm256_or_si256( low, high );
550-
return bytes;
534+
return _mm256_and_si256(lowMask, bytes);
551535
}
552536

553537
// add int16_t pairwise and return as float vector
@@ -2146,31 +2130,23 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
21462130
// Compute combined scale for the block
21472131
const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
21482132

2149-
__m128i i32[2];
2150-
for (int j = 0; j < 2; ++j) {
2151-
// Load 8 bytes, and unpack 4 bit fields into bytes, making 16 bytes
2152-
__m128i bx = bytes_from_nibbles_16(x[i].qs + 8*j);
2153-
__m128i by = _mm_loadu_si128((const __m128i *)(y[i].qs + 16*j));
2154-
2155-
// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
2156-
const __m128i off = _mm_set1_epi8( 8 );
2157-
bx = _mm_sub_epi8( bx, off );
2133+
const __m128i lowMask = _mm_set1_epi8(0xF);
2134+
const __m128i off = _mm_set1_epi8(8);
21582135

2159-
// Get absolute values of x vectors
2160-
const __m128i ax = _mm_sign_epi8(bx, bx);
2136+
const __m128i tmp = _mm_loadu_si128((const __m128i *)x[i].qs);
21612137

2162-
// Sign the values of the y vectors
2163-
const __m128i sy = _mm_sign_epi8(by, bx);
2138+
__m128i bx = _mm_and_si128(lowMask, tmp);
2139+
__m128i by = _mm_loadu_si128((const __m128i *)y[i].qs);
2140+
bx = _mm_sub_epi8(bx, off);
2141+
const __m128i i32_0 = mul_sum_i8_pairs(bx, by);
21642142

2165-
// Perform multiplication and create 16-bit values
2166-
const __m128i dot = _mm_maddubs_epi16(ax, sy);
2167-
2168-
const __m128i ones = _mm_set1_epi16(1);
2169-
i32[j] = _mm_madd_epi16(ones, dot);
2170-
}
2143+
bx = _mm_and_si128(lowMask, _mm_srli_epi64(tmp, 4));
2144+
by = _mm_loadu_si128((const __m128i *)(y[i].qs + 16));
2145+
bx = _mm_sub_epi8(bx, off);
2146+
const __m128i i32_1 = mul_sum_i8_pairs(bx, by);
21712147

21722148
// Convert int32_t to float
2173-
__m256 p = _mm256_cvtepi32_ps( _mm256_set_m128i( i32[0], i32[1] ));
2149+
__m256 p = _mm256_cvtepi32_ps(_mm256_set_m128i(i32_0, i32_1));
21742150
// Apply the scale, and accumulate
21752151
acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc);
21762152
}
@@ -2484,8 +2460,8 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
24842460
int sumi = 0;
24852461

24862462
for (int j = 0; j < qk/2; ++j) {
2487-
const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
2488-
const uint8_t xh_1 = ((qh & (1u << (j + 16))) >> (j + 12));
2463+
const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10;
2464+
const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10;
24892465

24902466
const int32_t x0 = ((x[i].qs[j] & 0x0F) | xh_0) - 16;
24912467
const int32_t x1 = ((x[i].qs[j] >> 4) | xh_1) - 16;
@@ -2673,8 +2649,8 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
26732649
int sumi = 0;
26742650

26752651
for (int j = 0; j < qk/2; ++j) {
2676-
const uint8_t xh_0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4;
2677-
const uint8_t xh_1 = ((qh & (1u << (j + 16))) >> (j + 12));
2652+
const uint8_t xh_0 = ((qh >> (j + 0)) << 4) & 0x10;
2653+
const uint8_t xh_1 = ((qh >> (j + 12)) ) & 0x10;
26782654

26792655
const int32_t x0 = (x[i].qs[j] & 0xF) | xh_0;
26802656
const int32_t x1 = (x[i].qs[j] >> 4) | xh_1;

0 commit comments

Comments
 (0)