@@ -3950,6 +3950,35 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_mul_mat(
39503950 return vec_dot_q6_K_q8_1_impl_mmq (&x_ql[index_x], &y_qs[index_y], sc, x_dmf[i * (WARP_SIZE/QI6_K) + i/QI6_K], &y_df[index_y/QI8_1]);
39513951}
39523952
3953+ static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1 (
3954+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
3955+ #if QK_K == 256
3956+ const block_iq2_xxs * bq2 = (const block_iq2_xxs *) vbq;
3957+
3958+ // iqs is 0...15
3959+ const int ib32 = iqs/2 ;
3960+ const int il = iqs%2 ;
3961+ const uint16_t * q2 = bq2->qs + 4 *ib32;
3962+ const uint8_t * aux8 = (const uint8_t *)q2;
3963+ const uint8_t * grid1 = (const uint8_t *)(kgrid_iq2xxs + aux8[2 *il+0 ]);
3964+ const uint8_t * grid2 = (const uint8_t *)(kgrid_iq2xxs + aux8[2 *il+1 ]);
3965+ const uint32_t aux32 = q2[2 ] | (q2[3 ] << 16 );
3966+ const float d = (float )bq2->d * (0 .5f + (aux32 >> 28 )) * (float )bq8_1[ib32].ds .x * 0 .25f ;
3967+ const uint8_t signs1 = ksigns_iq2xs[(aux32 >> 14 *il) & 127 ];
3968+ const uint8_t signs2 = ksigns_iq2xs[(aux32 >> (14 *il + 7 )) & 127 ];
3969+ const int8_t * q8 = bq8_1[ib32].qs + 16 *il;
3970+ int sumi1 = 0 , sumi2 = 0 ;
3971+ for (int j = 0 ; j < 8 ; ++j) {
3972+ sumi1 += q8[j+0 ] * grid1[j] * (signs1 & kmask_iq2xs[j] ? -1 : 1 );
3973+ sumi2 += q8[j+8 ] * grid2[j] * (signs2 & kmask_iq2xs[j] ? -1 : 1 );
3974+ }
3975+ return d * (sumi1 + sumi2);
3976+ #else
3977+ assert (false );
3978+ return 0 .f ;
3979+ #endif
3980+ }
3981+
39533982template <int qk, int qr, int qi, bool need_sum, typename block_q_t , int mmq_x, int mmq_y, int nwarps,
39543983 allocate_tiles_cuda_t allocate_tiles, load_tiles_cuda_t load_tiles, int vdr, vec_dot_q_mul_mat_cuda_t vec_dot>
39553984static __device__ __forceinline__ void mul_mat_q (
@@ -6044,6 +6073,15 @@ static void mul_mat_vec_q6_K_q8_1_cuda(const void * vx, const void * vy, float *
60446073 <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols, nrows);
60456074}
60466075
6076+ static void mul_mat_vec_iq2_xxs_q8_1_cuda (const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
6077+ GGML_ASSERT (ncols % QK_K == 0 );
6078+ const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1 ) / GGML_CUDA_MMV_Y;
6079+ const dim3 block_nums (block_num_y, 1 , 1 );
6080+ const dim3 block_dims (WARP_SIZE, GGML_CUDA_MMV_Y, 1 );
6081+ mul_mat_vec_q<QK_K, QI2_XXS, block_iq2_xxs, 1 , vec_dot_iq2_xxs_q8_1>
6082+ <<<block_nums, block_dims, 0 , stream>>> (vx, vy, dst, ncols, nrows);
6083+ }
6084+
60476085static void ggml_mul_mat_q4_0_q8_1_cuda (
60486086 const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
60496087 const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
@@ -7608,6 +7646,9 @@ static void ggml_cuda_op_mul_mat_vec_q(
76087646 case GGML_TYPE_Q6_K:
76097647 mul_mat_vec_q6_K_q8_1_cuda (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
76107648 break ;
7649+ case GGML_TYPE_IQ2_XXS:
7650+ mul_mat_vec_iq2_xxs_q8_1_cuda (src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
7651+ break ;
76117652 default :
76127653 GGML_ASSERT (false );
76137654 break ;
0 commit comments