Skip to content

Commit cab5981

Browse files
only a single get_mma_tile_x_k function
1 parent db6dae7 commit cab5981

File tree

1 file changed

+6
-9
lines changed

1 file changed

+6
-9
lines changed

ggml-cuda/mmq.cuh

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -124,11 +124,7 @@ static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding.");
124124
type == GGML_TYPE_Q6_K ? MMQ_MMA_TILE_X_K_Q6_K : \
125125
0
126126

127-
static int mmq_get_mma_tile_x_k_host(const ggml_type type) {
128-
MMQ_MMA_GET_TILE_X_K_BODY;
129-
}
130-
131-
static constexpr __device__ int mmq_get_mma_tile_x_k_device(ggml_type type) {
127+
static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
132128
MMQ_MMA_GET_TILE_X_K_BODY;
133129
}
134130

@@ -2424,9 +2420,10 @@ struct mmq_args {
24242420
int64_t ne0;
24252421
};
24262422

2427-
static int mmq_get_shmem(const ggml_type type, const int mmq_x, const int mmq_y, const int cc) {
2423+
template<ggml_type type>
2424+
static int mmq_get_shmem(const int mmq_x, const int mmq_y, const int cc) {
24282425
const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes_host(type, mmq_y);
2429-
const int mmq_tile_x_k = mmq_get_mma_tile_x_k_host(type);
2426+
const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type);
24302427
const int shmem_x = int8_mma_available(cc) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
24312428
const int shmem_y = mmq_x*sizeof(block_q8_1_mmq);
24322429
return shmem_x + GGML_PAD(shmem_y, MMQ_NWARPS*WARP_SIZE*sizeof(int));
@@ -2441,7 +2438,7 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
24412438

24422439
const dim3 block_dims(WARP_SIZE, MMQ_NWARPS, 1);
24432440

2444-
const int shmem = mmq_get_shmem(type, mmq_x, mmq_y, cc);
2441+
const int shmem = mmq_get_shmem<type>(mmq_x, mmq_y, cc);
24452442

24462443
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
24472444
static bool shmem_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
@@ -2512,7 +2509,7 @@ void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cuda
25122509
for (int mmq_x = 8; mmq_x <= mmq_x_max && nparts_best > 1; mmq_x += 8) {
25132510
const int granularity = mmq_get_granularity_host(mmq_x, cc);
25142511

2515-
if (mmq_x % granularity != 0 || mmq_get_shmem(type, mmq_x, mmq_y, cc) > smpbo) {
2512+
if (mmq_x % granularity != 0 || mmq_get_shmem<type>(mmq_x, mmq_y, cc) > smpbo) {
25162513
continue;
25172514
}
25182515

0 commit comments

Comments
 (0)