Skip to content

Commit 1884411

Browse files
ikawrakowIwan Kawrakow
authored andcommitted
Bitnet CUDA improvements (#109)
* iq1_bn: improve CUDA TG On RTX-3080 TG-128(Bitnet-1.58b-3B) goes from 318 t/s to 340 t/s. I see I have on the front page 301 t/s, so pretty nice improvement since then. * iq2_bn(CUDA): quants are not 4-byte aligned --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
1 parent 704d311 commit 1884411

File tree

1 file changed

+21
-15
lines changed

1 file changed

+21
-15
lines changed

ggml/src/ggml-cuda/iqk_mmvq.cu

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -633,31 +633,37 @@ __device__ __forceinline__ float vec_dot_iq1_bn_q8_1(
633633
float scale = d16;
634634
const block_iq1_bn * bq1 = (const block_iq1_bn *)((const char *)vbq + sizeof(d16)) + kbx;
635635

636-
static const uint8_t k_mult[5] = {81, 27, 9, 3, 1};
637-
638636
// iqs is 0 or 1
639637

640638
int sumi = 0;
641639
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
640+
uint16_t mult[2];
641+
mult[1] = iqs == 0 ? 27 : 3;
642+
mult[0] = mult[1] + (mult[1] << 1);
642643
const int * q8 = (const int *)bq8_1[iqs].qs;
643644
int val[4];
644645
for (int l = 0; l < 2; ++l) {
645646
int8_t * a = (int8_t *)val;
646647
const int i16 = 2*iqs + l;
647648
for (int k = 0; k < 3; ++k) {
648-
uint8_t q = bq1->ql[3*i16+k];
649-
for (int j = 0; j < 5; ++j) {
650-
uint8_t v = k_mult[j]*q;
651-
int8_t vs = 3*v >> 8; //(v + (v >> 1)) >> 7;
652-
*a++ = vs-1;
649+
uint16_t q = bq1->ql[3*i16+k];
650+
for (int j = 4; j >= 0; --j) {
651+
uint16_t v = q & 0xff;
652+
v += v << 1;
653+
a[j] = v >> 8;
654+
q += q << 1;
653655
}
656+
a += 5;
654657
}
655-
uint8_t v = k_mult[i16]*bq1->extra;
656-
int8_t vs = 3*v >> 8; //(v + (v >> 1)) >> 7;
657-
*a++ = vs-1;
658+
uint16_t v = (mult[l]*bq1->extra) & 0xff;
659+
v += v << 1;
660+
*a = v >> 8;
658661
sumi = __dp4a(val[0], q8[4*l+0], __dp4a(val[1], q8[4*l+1], __dp4a(val[2], q8[4*l+2], __dp4a(val[3], q8[4*l+3], sumi))));
659662
}
663+
float2 d8 = __half22float2(bq8_1[iqs].ds);
664+
return scale * (d8.x * sumi - d8.y);
660665
#else
666+
static const uint16_t k_mult[5] = {81, 27, 9, 3, 1};
661667
const int8_t * q8 = bq8_1[iqs].qs;
662668
for (int l = 0; l < 2; ++l) {
663669
const int i16 = 2*iqs + l;
@@ -675,8 +681,8 @@ __device__ __forceinline__ float vec_dot_iq1_bn_q8_1(
675681
sumi += q8[0]*(vs - 1);
676682
q8++;
677683
}
678-
#endif
679684
return scale * __low2float(bq8_1[iqs].ds) * sumi;
685+
#endif
680686
}
681687

682688
__device__ __forceinline__ float vec_dot_iq2_bn_q8_1(
@@ -688,20 +694,21 @@ __device__ __forceinline__ float vec_dot_iq2_bn_q8_1(
688694
// iqs is 0 or 1
689695

690696
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
691-
auto qs = (const uint16_t *)bq2->qs + 4*iqs;
697+
auto qs = (const int *)bq2->qs + 2*iqs;
692698
auto q8l = (const int *)bq8_1[0].qs + 2*iqs;
693699
auto q8h = (const int *)bq8_1[1].qs + 2*iqs;
694700
int sumi1 = 0, sumi2 = 0, sumi3 = 0, sumi4 = 0;
695701
for (int j = 0; j < 2; ++j) {
696-
int vl = qs[2*j+0] | (uint32_t(qs[2*j+1]) << 16);
697-
int vh = vl >> 4;
702+
int vl = qs[j];
703+
int vh = qs[j] >> 4;
698704
sumi1 = __dp4a(vl & 0x03030303, q8l[j+0], sumi1);
699705
sumi2 = __dp4a(vl & 0x0c0c0c0c, q8l[j+4], sumi2);
700706
sumi3 = __dp4a(vh & 0x03030303, q8h[j+0], sumi3);
701707
sumi4 = __dp4a(vh & 0x0c0c0c0c, q8h[j+4], sumi4);
702708
}
703709
auto d8l = __half22float2(bq8_1[0].ds);
704710
auto d8h = __half22float2(bq8_1[1].ds);
711+
return scale * (d8l.x * (sumi1 + 0.25f*sumi2) + d8h.x * (sumi3 + 0.25f * sumi4) - 0.5f*d8l.y - 0.5f*d8h.y);
705712
#else
706713
int sumi1 = 0, sumi2 = 0, sumi3 = 0, sumi4 = 0;
707714
auto q8l = bq8_1[0].qs + 8*iqs;
@@ -717,7 +724,6 @@ __device__ __forceinline__ float vec_dot_iq2_bn_q8_1(
717724
auto d8h = __half22float2(bq8_1[1].ds);
718725
return scale * (d8l.x * (sumi1 + 0.25f*sumi2) + 0.0625f * d8h.x*(sumi3 + 0.25f*sumi4) - 0.5f*d8l.y - 0.5f*d8h.y);
719726
#endif
720-
return scale * (d8l.x * (sumi1 + 0.25f*sumi2) + d8h.x * (sumi3 + 0.25f * sumi4) - 0.5f*d8l.y - 0.5f*d8h.y);
721727
}
722728

723729
} // namespace

0 commit comments

Comments
 (0)