@@ -460,7 +460,7 @@ __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
460460 int const * __restrict__ perm_int_ptr,
461461 int4 * __restrict__ out_int4_ptr, int size_m,
462462 int size_k, int lda, int block_rows) {
463- int start_row = block_rows * blockIdx .x ;
463+ auto start_row = block_rows * blockIdx .x ;
464464 int finish_row = start_row + block_rows;
465465 if (finish_row > size_m) {
466466 finish_row = size_m;
@@ -484,7 +484,7 @@ __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
484484 int base_k = 0 ;
485485
486486 for (int i = 0 ; i < iters; i++) {
487- int cur_k = base_k + threadIdx .x ;
487+ auto cur_k = base_k + threadIdx .x ;
488488 int src_pos = perm_int_ptr[cur_k];
489489
490490 out_half[cur_k] = a_row_half[src_pos];
@@ -494,7 +494,7 @@ __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
494494
495495 if (rest) {
496496 if (threadIdx .x < rest) {
497- int cur_k = base_k + threadIdx .x ;
497+ auto cur_k = base_k + threadIdx .x ;
498498 int src_pos = perm_int_ptr[cur_k];
499499
500500 out_half[cur_k] = a_row_half[src_pos];
@@ -723,8 +723,8 @@ __global__ void Marlin(
723723 (threadIdx .x % b_sh_stride_threads) * b_thread_vecs;
724724 b_gl_rd += b_sh_stride * slice_col;
725725 b_gl_rd += b_gl_rd_delta_o * slice_row;
726- int b_sh_wr = threadIdx .x * b_thread_vecs;
727- int b_sh_rd = threadIdx .x * b_thread_vecs;
726+ auto b_sh_wr = threadIdx .x * b_thread_vecs;
727+ auto b_sh_rd = threadIdx .x * b_thread_vecs;
728728
729729 // For act_order
730730 constexpr int k_iter_size = tb_k / b_sh_wr_iters;
@@ -743,7 +743,7 @@ __global__ void Marlin(
743743 s_sh_stride * slice_col + threadIdx .x ;
744744 }
745745 }
746- int s_sh_wr = threadIdx .x ;
746+ auto s_sh_wr = threadIdx .x ;
747747 bool s_sh_wr_pred = threadIdx .x < s_sh_stride;
748748
749749 // Zero-points
@@ -756,7 +756,7 @@ __global__ void Marlin(
756756 zp_sh_stride * slice_col + threadIdx .x ;
757757 }
758758 }
759- int zp_sh_wr = threadIdx .x ;
759+ auto zp_sh_wr = threadIdx .x ;
760760 bool zp_sh_wr_pred = threadIdx .x < zp_sh_stride;
761761
762762 // We use a different scale layout for grouped and column-wise quantization as
@@ -1047,7 +1047,7 @@ __global__ void Marlin(
10471047 int4 * sh_s_stage = sh_s + s_sh_stage * pipe;
10481048 reinterpret_cast <int4 *>(&frag_s[k % 2 ])[0 ] = sh_s_stage[s_sh_rd];
10491049 } else {
1050- int warp_id = threadIdx .x / 32 ;
1050+ auto warp_id = threadIdx .x / 32 ;
10511051 int n_warps = thread_n_blocks / 4 ;
10521052
10531053 int warp_row = warp_id / n_warps;
@@ -1085,7 +1085,7 @@ __global__ void Marlin(
10851085
10861086 // Determine "position" inside the thread-block (based on warp and
10871087 // thread-id)
1088- int warp_id = threadIdx .x / 32 ;
1088+ auto warp_id = threadIdx .x / 32 ;
10891089 int n_warps =
10901090 thread_n_blocks / 4 ; // Each warp processes 4 16-size tiles over N
10911091
@@ -1094,7 +1094,7 @@ __global__ void Marlin(
10941094
10951095 cur_k += warp_row * 16 ;
10961096
1097- int th_id = threadIdx .x % 32 ;
1097+ auto th_id = threadIdx .x % 32 ;
10981098 cur_k += (th_id % 4 ) * 2 ; // Due to tensor-core layout for fp16 B matrix
10991099
11001100 int s_col_shift =
@@ -1159,7 +1159,7 @@ __global__ void Marlin(
11591159 (reinterpret_cast <int *>(sh_zp_stage))[zp_sh_rd + i];
11601160 }
11611161 } else {
1162- int warp_id = threadIdx .x / 32 ;
1162+ auto warp_id = threadIdx .x / 32 ;
11631163 int n_warps = thread_n_blocks / 4 ;
11641164
11651165 int warp_row = warp_id / n_warps;
@@ -1197,7 +1197,7 @@ __global__ void Marlin(
11971197 (pipe / (group_blocks / thread_k_blocks)));
11981198 reinterpret_cast <int4 *>(&frag_zpf[k % 2 ])[0 ] = sh_zp_stage[zp_sh_rd];
11991199 } else {
1200- int warp_id = threadIdx .x / 32 ;
1200+ auto warp_id = threadIdx .x / 32 ;
12011201 int n_warps = thread_n_blocks / 4 ;
12021202
12031203 int warp_row = warp_id / n_warps;
@@ -1323,7 +1323,7 @@ __global__ void Marlin(
13231323 auto thread_block_reduce = [&]() {
13241324 constexpr int red_off = threads / b_sh_stride_threads / 2 ;
13251325 if (red_off >= 1 ) {
1326- int red_idx = threadIdx .x / b_sh_stride_threads;
1326+ auto red_idx = threadIdx .x / b_sh_stride_threads;
13271327 constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2 ;
13281328 constexpr int red_sh_delta = b_sh_stride_threads;
13291329 int red_sh_rd = red_sh_stride * (threadIdx .x / b_sh_stride_threads) +
@@ -1390,7 +1390,7 @@ __global__ void Marlin(
13901390 4 * (threadIdx .x / 32 ) + threadIdx .x % 4 ;
13911391 c_gl_wr += (2 * thread_n_blocks) * slice_col;
13921392 constexpr int c_sh_wr_delta = active_threads;
1393- int c_sh_wr = threadIdx .x ;
1393+ auto c_sh_wr = threadIdx .x ;
13941394
13951395 int row = (threadIdx .x % 32 ) / 4 ;
13961396
0 commit comments