@@ -633,31 +633,37 @@ __device__ __forceinline__ float vec_dot_iq1_bn_q8_1(
633633 float scale = d16;
634634 const block_iq1_bn * bq1 = (const block_iq1_bn *)((const char *)vbq + sizeof (d16)) + kbx;
635635
636- static const uint8_t k_mult[5 ] = {81 , 27 , 9 , 3 , 1 };
637-
638636 // iqs is 0 or 1
639637
640638 int sumi = 0 ;
641639#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
640+ uint16_t mult[2 ];
641+ mult[1 ] = iqs == 0 ? 27 : 3 ;
642+ mult[0 ] = mult[1 ] + (mult[1 ] << 1 );
642643 const int * q8 = (const int *)bq8_1[iqs].qs ;
643644 int val[4 ];
644645 for (int l = 0 ; l < 2 ; ++l) {
645646 int8_t * a = (int8_t *)val;
646647 const int i16 = 2 *iqs + l;
647648 for (int k = 0 ; k < 3 ; ++k) {
648- uint8_t q = bq1->ql [3 *i16 +k];
649- for (int j = 0 ; j < 5 ; ++j) {
650- uint8_t v = k_mult[j]*q;
651- int8_t vs = 3 *v >> 8 ; // (v + (v >> 1)) >> 7;
652- *a++ = vs-1 ;
649+ uint16_t q = bq1->ql [3 *i16 +k];
650+ for (int j = 4 ; j >= 0 ; --j) {
651+ uint16_t v = q & 0xff ;
652+ v += v << 1 ;
653+ a[j] = v >> 8 ;
654+ q += q << 1 ;
653655 }
656+ a += 5 ;
654657 }
655- uint8_t v = k_mult[ i16 ]*bq1->extra ;
656- int8_t vs = 3 *v >> 8 ; // (v + (v >> 1)) >> 7 ;
657- *a++ = vs- 1 ;
658+ uint16_t v = (mult[l ]*bq1->extra ) & 0xff ;
659+ v += v << 1 ;
660+ *a = v >> 8 ;
658661 sumi = __dp4a (val[0 ], q8[4 *l+0 ], __dp4a (val[1 ], q8[4 *l+1 ], __dp4a (val[2 ], q8[4 *l+2 ], __dp4a (val[3 ], q8[4 *l+3 ], sumi))));
659662 }
663+ float2 d8 = __half22float2 (bq8_1[iqs].ds );
664+ return scale * (d8.x * sumi - d8.y );
660665#else
666+ static const uint16_t k_mult[5 ] = {81 , 27 , 9 , 3 , 1 };
661667 const int8_t * q8 = bq8_1[iqs].qs ;
662668 for (int l = 0 ; l < 2 ; ++l) {
663669 const int i16 = 2 *iqs + l;
@@ -675,8 +681,8 @@ __device__ __forceinline__ float vec_dot_iq1_bn_q8_1(
675681 sumi += q8[0 ]*(vs - 1 );
676682 q8++;
677683 }
678- #endif
679684 return scale * __low2float (bq8_1[iqs].ds ) * sumi;
685+ #endif
680686}
681687
682688__device__ __forceinline__ float vec_dot_iq2_bn_q8_1 (
@@ -688,20 +694,21 @@ __device__ __forceinline__ float vec_dot_iq2_bn_q8_1(
688694 // iqs is 0 or 1
689695
690696#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
691- auto qs = (const uint16_t *)bq2->qs + 4 *iqs;
697+ auto qs = (const int *)bq2->qs + 2 *iqs;
692698 auto q8l = (const int *)bq8_1[0 ].qs + 2 *iqs;
693699 auto q8h = (const int *)bq8_1[1 ].qs + 2 *iqs;
694700 int sumi1 = 0 , sumi2 = 0 , sumi3 = 0 , sumi4 = 0 ;
695701 for (int j = 0 ; j < 2 ; ++j) {
696- int vl = qs[2 *j+ 0 ] | ( uint32_t (qs[ 2 *j+ 1 ]) << 16 ) ;
697- int vh = vl >> 4 ;
702+ int vl = qs[j] ;
703+ int vh = qs[j] >> 4 ;
698704 sumi1 = __dp4a (vl & 0x03030303 , q8l[j+0 ], sumi1);
699705 sumi2 = __dp4a (vl & 0x0c0c0c0c , q8l[j+4 ], sumi2);
700706 sumi3 = __dp4a (vh & 0x03030303 , q8h[j+0 ], sumi3);
701707 sumi4 = __dp4a (vh & 0x0c0c0c0c , q8h[j+4 ], sumi4);
702708 }
703709 auto d8l = __half22float2 (bq8_1[0 ].ds );
704710 auto d8h = __half22float2 (bq8_1[1 ].ds );
711+ return scale * (d8l.x * (sumi1 + 0 .25f *sumi2) + d8h.x * (sumi3 + 0 .25f * sumi4) - 0 .5f *d8l.y - 0 .5f *d8h.y );
705712#else
706713 int sumi1 = 0 , sumi2 = 0 , sumi3 = 0 , sumi4 = 0 ;
707714 auto q8l = bq8_1[0 ].qs + 8 *iqs;
@@ -717,7 +724,6 @@ __device__ __forceinline__ float vec_dot_iq2_bn_q8_1(
717724 auto d8h = __half22float2 (bq8_1[1 ].ds );
718725 return scale * (d8l.x * (sumi1 + 0 .25f *sumi2) + 0 .0625f * d8h.x *(sumi3 + 0 .25f *sumi4) - 0 .5f *d8l.y - 0 .5f *d8h.y );
719726#endif
720- return scale * (d8l.x * (sumi1 + 0 .25f *sumi2) + d8h.x * (sumi3 + 0 .25f * sumi4) - 0 .5f *d8l.y - 0 .5f *d8h.y );
721727}
722728
723729} // namespace
0 commit comments