Skip to content

Commit e04d706

Browse files
houseroadr-barnes
authored andcommitted
Fix CUDA kernel index data type in vllm/csrc/quantization/gptq_marlin/awq_marlin_repack.cu +10 (vllm-project#15160)
Signed-off-by: Lu Fang <lufang@fb.com> Co-authored-by: Richard Barnes <rbarnes@meta.com> Signed-off-by: Louis Ulmer <ulmerlouis@gmail.com>
1 parent 79495c2 commit e04d706

File tree

7 files changed

+73
-73
lines changed

7 files changed

+73
-73
lines changed

csrc/quantization/gptq_marlin/awq_marlin_repack.cu

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ __global__ void awq_marlin_repack_kernel(
1414
int n_tiles = size_n / tile_n_size;
1515
int block_k_tiles = div_ceil(k_tiles, gridDim.x);
1616

17-
int start_k_tile = blockIdx.x * block_k_tiles;
17+
auto start_k_tile = blockIdx.x * block_k_tiles;
1818
if (start_k_tile >= k_tiles) {
1919
return;
2020
}
@@ -51,8 +51,8 @@ __global__ void awq_marlin_repack_kernel(
5151
int4* sh_ptr = sh + stage_size * pipe;
5252

5353
if (threadIdx.x < stage_size) {
54-
int k_id = threadIdx.x / stage_n_threads;
55-
int n_id = threadIdx.x % stage_n_threads;
54+
auto k_id = threadIdx.x / stage_n_threads;
55+
auto n_id = threadIdx.x % stage_n_threads;
5656

5757
int first_k = k_tile_id * tile_k_size;
5858

@@ -70,8 +70,8 @@ __global__ void awq_marlin_repack_kernel(
7070
return;
7171
}
7272

73-
int warp_id = threadIdx.x / 32;
74-
int th_id = threadIdx.x % 32;
73+
auto warp_id = threadIdx.x / 32;
74+
auto th_id = threadIdx.x % 32;
7575

7676
if (warp_id >= 4) {
7777
return;
@@ -265,4 +265,4 @@ TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
265265

266266
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, Meta, m) {
267267
m.impl("awq_marlin_repack", &awq_marlin_repack_meta);
268-
}
268+
}

csrc/quantization/gptq_marlin/gptq_marlin.cu

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -460,7 +460,7 @@ __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
460460
int const* __restrict__ perm_int_ptr,
461461
int4* __restrict__ out_int4_ptr, int size_m,
462462
int size_k, int lda, int block_rows) {
463-
int start_row = block_rows * blockIdx.x;
463+
auto start_row = block_rows * blockIdx.x;
464464
int finish_row = start_row + block_rows;
465465
if (finish_row > size_m) {
466466
finish_row = size_m;
@@ -484,7 +484,7 @@ __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
484484
int base_k = 0;
485485

486486
for (int i = 0; i < iters; i++) {
487-
int cur_k = base_k + threadIdx.x;
487+
auto cur_k = base_k + threadIdx.x;
488488
int src_pos = perm_int_ptr[cur_k];
489489

490490
out_half[cur_k] = a_row_half[src_pos];
@@ -494,7 +494,7 @@ __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
494494

495495
if (rest) {
496496
if (threadIdx.x < rest) {
497-
int cur_k = base_k + threadIdx.x;
497+
auto cur_k = base_k + threadIdx.x;
498498
int src_pos = perm_int_ptr[cur_k];
499499

500500
out_half[cur_k] = a_row_half[src_pos];
@@ -723,8 +723,8 @@ __global__ void Marlin(
723723
(threadIdx.x % b_sh_stride_threads) * b_thread_vecs;
724724
b_gl_rd += b_sh_stride * slice_col;
725725
b_gl_rd += b_gl_rd_delta_o * slice_row;
726-
int b_sh_wr = threadIdx.x * b_thread_vecs;
727-
int b_sh_rd = threadIdx.x * b_thread_vecs;
726+
auto b_sh_wr = threadIdx.x * b_thread_vecs;
727+
auto b_sh_rd = threadIdx.x * b_thread_vecs;
728728

729729
// For act_order
730730
constexpr int k_iter_size = tb_k / b_sh_wr_iters;
@@ -743,7 +743,7 @@ __global__ void Marlin(
743743
s_sh_stride * slice_col + threadIdx.x;
744744
}
745745
}
746-
int s_sh_wr = threadIdx.x;
746+
auto s_sh_wr = threadIdx.x;
747747
bool s_sh_wr_pred = threadIdx.x < s_sh_stride;
748748

749749
// Zero-points
@@ -756,7 +756,7 @@ __global__ void Marlin(
756756
zp_sh_stride * slice_col + threadIdx.x;
757757
}
758758
}
759-
int zp_sh_wr = threadIdx.x;
759+
auto zp_sh_wr = threadIdx.x;
760760
bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride;
761761

762762
// We use a different scale layout for grouped and column-wise quantization as
@@ -1047,7 +1047,7 @@ __global__ void Marlin(
10471047
int4* sh_s_stage = sh_s + s_sh_stage * pipe;
10481048
reinterpret_cast<int4*>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd];
10491049
} else {
1050-
int warp_id = threadIdx.x / 32;
1050+
auto warp_id = threadIdx.x / 32;
10511051
int n_warps = thread_n_blocks / 4;
10521052

10531053
int warp_row = warp_id / n_warps;
@@ -1085,7 +1085,7 @@ __global__ void Marlin(
10851085

10861086
// Determine "position" inside the thread-block (based on warp and
10871087
// thread-id)
1088-
int warp_id = threadIdx.x / 32;
1088+
auto warp_id = threadIdx.x / 32;
10891089
int n_warps =
10901090
thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N
10911091

@@ -1094,7 +1094,7 @@ __global__ void Marlin(
10941094

10951095
cur_k += warp_row * 16;
10961096

1097-
int th_id = threadIdx.x % 32;
1097+
auto th_id = threadIdx.x % 32;
10981098
cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix
10991099

11001100
int s_col_shift =
@@ -1159,7 +1159,7 @@ __global__ void Marlin(
11591159
(reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i];
11601160
}
11611161
} else {
1162-
int warp_id = threadIdx.x / 32;
1162+
auto warp_id = threadIdx.x / 32;
11631163
int n_warps = thread_n_blocks / 4;
11641164

11651165
int warp_row = warp_id / n_warps;
@@ -1197,7 +1197,7 @@ __global__ void Marlin(
11971197
(pipe / (group_blocks / thread_k_blocks)));
11981198
reinterpret_cast<int4*>(&frag_zpf[k % 2])[0] = sh_zp_stage[zp_sh_rd];
11991199
} else {
1200-
int warp_id = threadIdx.x / 32;
1200+
auto warp_id = threadIdx.x / 32;
12011201
int n_warps = thread_n_blocks / 4;
12021202

12031203
int warp_row = warp_id / n_warps;
@@ -1323,7 +1323,7 @@ __global__ void Marlin(
13231323
auto thread_block_reduce = [&]() {
13241324
constexpr int red_off = threads / b_sh_stride_threads / 2;
13251325
if (red_off >= 1) {
1326-
int red_idx = threadIdx.x / b_sh_stride_threads;
1326+
auto red_idx = threadIdx.x / b_sh_stride_threads;
13271327
constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2;
13281328
constexpr int red_sh_delta = b_sh_stride_threads;
13291329
int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) +
@@ -1390,7 +1390,7 @@ __global__ void Marlin(
13901390
4 * (threadIdx.x / 32) + threadIdx.x % 4;
13911391
c_gl_wr += (2 * thread_n_blocks) * slice_col;
13921392
constexpr int c_sh_wr_delta = active_threads;
1393-
int c_sh_wr = threadIdx.x;
1393+
auto c_sh_wr = threadIdx.x;
13941394

13951395
int row = (threadIdx.x % 32) / 4;
13961396

csrc/quantization/gptq_marlin/gptq_marlin_repack.cu

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ __global__ void gptq_marlin_repack_kernel(
1515
int n_tiles = size_n / tile_n_size;
1616
int block_k_tiles = div_ceil(k_tiles, gridDim.x);
1717

18-
int start_k_tile = blockIdx.x * block_k_tiles;
18+
auto start_k_tile = blockIdx.x * block_k_tiles;
1919
if (start_k_tile >= k_tiles) {
2020
return;
2121
}
@@ -71,8 +71,8 @@ __global__ void gptq_marlin_repack_kernel(
7171

7272
if constexpr (has_perm) {
7373
if (threadIdx.x < stage_size) {
74-
int k_id = threadIdx.x / stage_n_threads;
75-
int n_id = threadIdx.x % stage_n_threads;
74+
auto k_id = threadIdx.x / stage_n_threads;
75+
auto n_id = threadIdx.x % stage_n_threads;
7676

7777
uint32_t const* sh_perm_int_ptr =
7878
reinterpret_cast<uint32_t const*>(sh_perm_ptr);
@@ -88,8 +88,8 @@ __global__ void gptq_marlin_repack_kernel(
8888

8989
} else {
9090
if (threadIdx.x < stage_size) {
91-
int k_id = threadIdx.x / stage_n_threads;
92-
int n_id = threadIdx.x % stage_n_threads;
91+
auto k_id = threadIdx.x / stage_n_threads;
92+
auto n_id = threadIdx.x % stage_n_threads;
9393

9494
int first_k = k_tile_id * tile_k_size;
9595
int first_k_packed = first_k / pack_factor;
@@ -109,8 +109,8 @@ __global__ void gptq_marlin_repack_kernel(
109109
return;
110110
}
111111

112-
int warp_id = threadIdx.x / 32;
113-
int th_id = threadIdx.x % 32;
112+
auto warp_id = threadIdx.x / 32;
113+
auto th_id = threadIdx.x % 32;
114114

115115
if (warp_id >= 4) {
116116
return;
@@ -339,4 +339,4 @@ TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
339339

340340
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, Meta, m) {
341341
m.impl("gptq_marlin_repack", &gptq_marlin_repack_meta);
342-
}
342+
}

csrc/quantization/marlin/dense/marlin_cuda_kernel.cu

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -277,12 +277,12 @@ __global__ void Marlin(
277277
b_gl_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride);
278278
b_gl_rd += b_sh_stride * slice_col;
279279
b_gl_rd += b_gl_rd_delta_o * slice_row;
280-
int b_sh_wr = threadIdx.x;
281-
int b_sh_rd = threadIdx.x;
280+
auto b_sh_wr = threadIdx.x;
281+
auto b_sh_rd = threadIdx.x;
282282

283283
int s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
284284
s_sh_stride * slice_col + threadIdx.x;
285-
int s_sh_wr = threadIdx.x;
285+
auto s_sh_wr = threadIdx.x;
286286
int s_sh_rd;
287287
// We use a different scale layout for grouped and column-wise quantization as
288288
// we scale a `half2` tile in column-major layout in the former and in
@@ -455,7 +455,7 @@ __global__ void Marlin(
455455
auto thread_block_reduce = [&]() {
456456
constexpr int red_off = threads / b_sh_stride / 2;
457457
if (red_off >= 1) {
458-
int red_idx = threadIdx.x / b_sh_stride;
458+
auto red_idx = threadIdx.x / b_sh_stride;
459459
constexpr int red_sh_stride = b_sh_stride * 4 * 2;
460460
constexpr int red_sh_delta = b_sh_stride;
461461
int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) +
@@ -522,7 +522,7 @@ __global__ void Marlin(
522522
4 * (threadIdx.x / 32) + threadIdx.x % 4;
523523
c_gl_wr += (2 * thread_n_blocks) * slice_col;
524524
constexpr int c_sh_wr_delta = active_threads;
525-
int c_sh_wr = threadIdx.x;
525+
auto c_sh_wr = threadIdx.x;
526526

527527
int row = (threadIdx.x % 32) / 4;
528528

csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -353,10 +353,10 @@ __global__ void Marlin(
353353
b_gl_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride);
354354
b_gl_rd += b_sh_stride * slice_col;
355355
b_gl_rd += b_gl_rd_delta_o * slice_row;
356-
int b_sh_wr = threadIdx.x;
357-
int b_sh_rd = threadIdx.x;
356+
auto b_sh_wr = threadIdx.x;
357+
auto b_sh_rd = threadIdx.x;
358358

359-
int s_tok_gl_rd = threadIdx.x;
359+
auto s_tok_gl_rd = threadIdx.x;
360360
// NOTE(HandH1998): activation scale s_tok need shuffle to [0, 8, 1, 9, 2, 10,
361361
// 3, 11, 4, 12, 5, 13, 6, 14, 7, 15] for example, 0, 8 row scales serve for
362362
// thread 0, 1, 2, 3. For more details, refer to mma operand A layout as
@@ -368,8 +368,8 @@ __global__ void Marlin(
368368
int s_tok_sh_rd = (threadIdx.x % 32) / 4;
369369
bool s_tok_sh_wr_pred = threadIdx.x < prob_m;
370370

371-
int s_ch_gl_rd = s_ch_sh_stride * slice_col + threadIdx.x;
372-
int s_ch_sh_wr = threadIdx.x;
371+
auto s_ch_gl_rd = s_ch_sh_stride * slice_col + threadIdx.x;
372+
auto s_ch_sh_wr = threadIdx.x;
373373
int s_ch_sh_rd = 16 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
374374
2 * ((threadIdx.x % 32) % 4);
375375
bool s_ch_sh_wr_pred = threadIdx.x < s_ch_sh_stride;
@@ -558,7 +558,7 @@ __global__ void Marlin(
558558
auto thread_block_reduce = [&]() {
559559
constexpr int red_off = threads / b_sh_stride / 2;
560560
if (red_off >= 1) {
561-
int red_idx = threadIdx.x / b_sh_stride;
561+
auto red_idx = threadIdx.x / b_sh_stride;
562562
constexpr int red_sh_stride = b_sh_stride * 4 * 2;
563563
constexpr int red_sh_delta = b_sh_stride;
564564
int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) +
@@ -628,7 +628,7 @@ __global__ void Marlin(
628628
8 * (threadIdx.x / 32) + (threadIdx.x % 4) * 2;
629629
c_gl_wr += (4 * thread_n_blocks) * slice_col;
630630
constexpr int c_sh_wr_delta = active_threads * 2;
631-
int c_sh_wr = 2 * threadIdx.x;
631+
auto c_sh_wr = 2 * threadIdx.x;
632632

633633
int row = (threadIdx.x % 32) / 4;
634634

csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -273,15 +273,15 @@ __global__ void Marlin_24(
273273
(threadIdx.x % b_sh_stride_threads) * b_thread_vecs;
274274
b_gl_rd += b_sh_stride * slice_col;
275275
b_gl_rd += b_gl_rd_delta_o * slice_row;
276-
int b_sh_wr = threadIdx.x * b_thread_vecs;
277-
int b_sh_rd = threadIdx.x * b_thread_vecs;
276+
auto b_sh_wr = threadIdx.x * b_thread_vecs;
277+
auto b_sh_rd = threadIdx.x * b_thread_vecs;
278278

279279
int m_gl_rd = m_gl_stride * (threadIdx.x / (m_sh_stride)) +
280280
(threadIdx.x % (m_sh_stride));
281281
m_gl_rd += (m_sh_stride)*slice_col;
282282
m_gl_rd += m_gl_rd_delta_o * slice_row;
283-
int m_sh_wr = threadIdx.x;
284-
int m_sh_rd = threadIdx.x % 16 + (threadIdx.x / 32) * 16;
283+
auto m_sh_wr = threadIdx.x;
284+
auto m_sh_rd = threadIdx.x % 16 + (threadIdx.x / 32) * 16;
285285

286286
int s_gl_rd;
287287
if constexpr (group_blocks == -1) {
@@ -291,7 +291,7 @@ __global__ void Marlin_24(
291291
s_sh_stride * slice_col + threadIdx.x;
292292
}
293293

294-
int s_sh_wr = threadIdx.x;
294+
auto s_sh_wr = threadIdx.x;
295295
int s_sh_rd;
296296
// We use a different scale layout for grouped and column-wise quantization as
297297
// we scale a `half2` tile in column-major layout in the former and in
@@ -516,7 +516,7 @@ __global__ void Marlin_24(
516516
auto thread_block_reduce = [&]() {
517517
constexpr int red_off = threads / b_sh_stride_threads / 2;
518518
if (red_off >= 1) {
519-
int red_idx = threadIdx.x / b_sh_stride_threads;
519+
auto red_idx = threadIdx.x / b_sh_stride_threads;
520520
constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2;
521521
constexpr int red_sh_delta = b_sh_stride_threads;
522522
int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) +
@@ -583,7 +583,7 @@ __global__ void Marlin_24(
583583
8 * (threadIdx.x / 32) + (threadIdx.x % 32) / 4;
584584
c_gl_wr += (2 * thread_n_blocks) * slice_col;
585585
constexpr int c_sh_wr_delta = active_threads;
586-
int c_sh_wr = threadIdx.x;
586+
auto c_sh_wr = threadIdx.x;
587587

588588
int col = 2 * ((threadIdx.x % 32) % 4);
589589

0 commit comments

Comments
 (0)