@@ -124,11 +124,7 @@ static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding.");
124124 type == GGML_TYPE_Q6_K ? MMQ_MMA_TILE_X_K_Q6_K : \
125125 0
126126
127- static int mmq_get_mma_tile_x_k_host (const ggml_type type) {
128- MMQ_MMA_GET_TILE_X_K_BODY;
129- }
130-
131- static constexpr __device__ int mmq_get_mma_tile_x_k_device (ggml_type type) {
127+ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k (ggml_type type) {
132128 MMQ_MMA_GET_TILE_X_K_BODY;
133129}
134130
@@ -2424,9 +2420,10 @@ struct mmq_args {
24242420 int64_t ne0;
24252421};
24262422
2427- static int mmq_get_shmem (const ggml_type type, const int mmq_x, const int mmq_y, const int cc) {
2423+ template <ggml_type type>
2424+ static int mmq_get_shmem (const int mmq_x, const int mmq_y, const int cc) {
24282425 const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_host (type, mmq_y);
2429- const int mmq_tile_x_k = mmq_get_mma_tile_x_k_host (type);
2426+ const int mmq_tile_x_k = mmq_get_mma_tile_x_k (type);
24302427 const int shmem_x = int8_mma_available (cc) ? mmq_y*mmq_tile_x_k*sizeof (int ) : txs.qs *sizeof (int ) + txs.dm *sizeof (half2) + txs.sc *sizeof (int );
24312428 const int shmem_y = mmq_x*sizeof (block_q8_1_mmq);
24322429 return shmem_x + GGML_PAD (shmem_y, MMQ_NWARPS*WARP_SIZE*sizeof (int ));
@@ -2441,7 +2438,7 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
24412438
24422439 const dim3 block_dims (WARP_SIZE, MMQ_NWARPS, 1 );
24432440
2444- const int shmem = mmq_get_shmem ( type, mmq_x, mmq_y, cc);
2441+ const int shmem = mmq_get_shmem< type>( mmq_x, mmq_y, cc);
24452442
24462443#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
24472444 static bool shmem_limit_raised[GGML_CUDA_MAX_DEVICES] = {false };
@@ -2512,7 +2509,7 @@ void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cuda
25122509 for (int mmq_x = 8 ; mmq_x <= mmq_x_max && nparts_best > 1 ; mmq_x += 8 ) {
25132510 const int granularity = mmq_get_granularity_host (mmq_x, cc);
25142511
2515- if (mmq_x % granularity != 0 || mmq_get_shmem ( type, mmq_x, mmq_y, cc) > smpbo) {
2512+ if (mmq_x % granularity != 0 || mmq_get_shmem< type>( mmq_x, mmq_y, cc) > smpbo) {
25162513 continue ;
25172514 }
25182515
0 commit comments