@@ -137,6 +137,7 @@ static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int
137137 return 1 ;
138138}
139139
140+ // tell the compiler to use as many registers as it wants, see nwarps definition below
140141template <ggml_type type, int ncols_dst, bool has_fusion>
141142__launch_bounds__ (calc_nwarps(ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
142143static __global__ void mul_mat_vec_q(
@@ -198,14 +199,17 @@ static __global__ void mul_mat_vec_q(
198199 }
199200 }
200201
202+ // partial sum for each thread
201203 float tmp[ncols_dst][rows_per_cuda_block] = {{0 .0f }};
202204 float tmp_gate[ncols_dst][rows_per_cuda_block] = {{0 .0f }};
203205
204206 const block_q8_1 * y = ((const block_q8_1 *) vy) + sample_y*stride_sample_y + channel_y*stride_channel_y;
205207 const int kbx_offset = sample_x*stride_sample_x + channel_x*stride_channel_x + row0*stride_row_x;
206208
207209 for (int kbx = tid / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) {
208- const int kby = kbx * (qk/QK8_1);
210+ const int kby = kbx * (qk/QK8_1); // y block index that aligns with kbx
211+
212+ // x block quant index when casting the quants to int
209213 const int kqs = vdr * (tid % (qi/vdr));
210214
211215#pragma unroll
@@ -253,6 +257,7 @@ static __global__ void mul_mat_vec_q(
253257
254258 dst += sample_dst*stride_sample_dst + channel_dst*stride_channel_dst + row0;
255259
260+ // sum up partial sums and write back result
256261#pragma unroll
257262 for (int j = 0 ; j < ncols_dst; ++j) {
258263#pragma unroll
@@ -307,7 +312,7 @@ static __global__ void mul_mat_vec_q(
307312 }
308313}
309314
310- static inline std::pair<dim3 , dim3 > calc_launch_params (
315+ static std::pair<dim3 , dim3 > calc_launch_params (
311316 const int ncols_dst, const int nrows_x, const int nchannels_y, const int nsamples_y,
312317 const int warp_size, const mmvq_parameter_table_id table_id) {
313318 const int64_t nblocks = (nrows_x + calc_rows_per_block (ncols_dst, table_id) - 1 ) / calc_rows_per_block (ncols_dst, table_id);
@@ -626,6 +631,7 @@ void ggml_cuda_mul_mat_vec_q(
626631 fusion_local.glu_op = fusion->glu_op ;
627632 }
628633
634+ // If src0 is a temporary compute buffer, clear any potential padding.
629635 if (ggml_backend_buffer_get_usage (src0->buffer ) == GGML_BACKEND_BUFFER_USAGE_COMPUTE) {
630636 const size_t size_data = ggml_nbytes (src0);
631637 const size_t size_alloc = ggml_backend_buffer_get_alloc_size (src0->buffer , src0);
@@ -656,6 +662,7 @@ void ggml_cuda_mul_mat_vec_q(
656662 const int64_t s12 = ne11*s11;
657663 const int64_t s13 = ne12*s12;
658664
665+ // For MUL_MAT_ID the memory layout is different than for MUL_MAT:
659666 const int64_t ncols_dst = ids ? ne2 : ne1;
660667 const int64_t nchannels_y = ids ? ne11 : ne12;
661668 const int64_t nchannels_dst = ids ? ne1 : ne2;
@@ -687,6 +694,8 @@ void ggml_cuda_op_mul_mat_vec_q(
687694
688695 int id = ggml_cuda_get_device ();
689696
697+ // the main device has a larger memory buffer to hold the results from all GPUs
698+ // nrows_dst == nrows of the matrix that the kernel writes into
690699 const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff;
691700
692701 const int stride_row_x = ne00 / ggml_blck_size (src0->type );
0 commit comments