From 5389c53f2c859965ce0f00716456132399ab0cf1 Mon Sep 17 00:00:00 2001 From: wchen61 Date: Tue, 24 Dec 2024 16:51:06 +0800 Subject: [PATCH 1/5] Resolve race conditions in Marlin kernel Signed-off-by: wchen61 --- csrc/quantization/gptq_marlin/gptq_marlin.cu | 33 ++++++++++---------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/csrc/quantization/gptq_marlin/gptq_marlin.cu b/csrc/quantization/gptq_marlin/gptq_marlin.cu index 0c698ced7713d..2971532706bc5 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin.cu +++ b/csrc/quantization/gptq_marlin/gptq_marlin.cu @@ -834,6 +834,7 @@ __global__ void Marlin( int4* sh_g_idx = sh_b + (stages * b_sh_stage); int4* sh_zp = sh_g_idx + (stages * g_idx_stage); int4* sh_s = sh_zp + (stages * zp_sh_stage); + int4* sh_red = sh_s + (stages * s_sh_stage); // Register storage for double buffer of shared memory reads. FragA frag_a[2][thread_m_blocks]; @@ -932,11 +933,11 @@ __global__ void Marlin( int4* sh_s_stage = sh_s + s_sh_stage * pipe; if constexpr (group_blocks >= thread_k_blocks) { + if (s_sh_wr_pred) { + cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]); + } // Only fetch scales if this tile starts a new group - if (pipe % (group_blocks / thread_k_blocks) == 0) { - if (s_sh_wr_pred) { - cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]); - } + if ((pipe + 1) % (group_blocks / thread_k_blocks) == 0) { s_gl_rd += s_gl_rd_delta; } } else { @@ -1038,9 +1039,7 @@ __global__ void Marlin( // No act-order case if constexpr (group_blocks != -1) { if constexpr (group_blocks >= thread_k_blocks) { - int4* sh_s_stage = - sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * - (pipe / (group_blocks / thread_k_blocks))); + int4* sh_s_stage = sh_s + s_sh_stage * pipe; reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; } else { int warp_id = threadIdx.x / 32; @@ -1340,14 +1339,14 @@ __global__ void Marlin( red_sh_delta * j + (red_sh_rd - red_sh_stride * i); if (i < red_off) { float* c_rd = - reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); - float* c_wr = reinterpret_cast(&sh[red_sh_wr]); + reinterpret_cast(&sh_red[red_sh_delta * j + red_sh_rd]); + float* c_wr = reinterpret_cast(&sh_red[red_sh_wr]); #pragma unroll for (int k = 0; k < 4; k++) reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += c_rd[k] + c_wr[k]; } - sh[red_sh_wr] = + sh_red[red_sh_wr] = reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; } } @@ -1357,7 +1356,7 @@ __global__ void Marlin( #pragma unroll for (int i = 0; i < 4 * 2; i++) { float* c_rd = - reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); + reinterpret_cast(&sh_red[red_sh_delta * i + red_sh_rd]); #pragma unroll for (int j = 0; j < 4; j++) reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += @@ -1397,7 +1396,7 @@ __global__ void Marlin( #pragma unroll for (int i = 0; i < thread_m_blocks * 4; i++) { cp_async4_pred( - &sh[c_sh_wr + c_sh_wr_delta * i], + &sh_red[c_sh_wr + c_sh_wr_delta * i], &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)], i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m); @@ -1410,7 +1409,7 @@ __global__ void Marlin( for (int i = 0; i < thread_m_blocks * 4; i++) { if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) { if (!first) { - int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; + int4 c_red = sh_red[c_sh_wr + i * c_sh_wr_delta]; #pragma unroll for (int j = 0; j < 2 * 4; j++) { reinterpret_cast( @@ -1461,10 +1460,10 @@ __global__ void Marlin( float* frag_c_ptr = reinterpret_cast(&frag_c); #pragma unroll for (int k = 0; k < th_size; k++) { - sh[threadIdx.x] = + sh_red[threadIdx.x] = C_tmp[c_cur_offset + active_threads * k + threadIdx.x]; - float* sh_c_ptr = reinterpret_cast(&sh[threadIdx.x]); + float* sh_c_ptr = reinterpret_cast(&sh_red[threadIdx.x]); #pragma unroll for (int f = 0; f < 4; f++) { frag_c_ptr[k * 4 + f] += sh_c_ptr[f]; @@ -1515,7 +1514,7 @@ __global__ void Marlin( res = __hmul2(res, s[0]); } - ((scalar_t2*)sh)[idx] = res; + ((scalar_t2*)sh_red)[idx] = res; }; if (threadIdx.x / 32 < thread_n_blocks / 4) { @@ -1543,7 +1542,7 @@ __global__ void Marlin( i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); i++) { if (c_gl_wr < c_gl_wr_end) { - C[c_gl_wr] = sh[c_sh_rd]; + C[c_gl_wr] = sh_red[c_sh_rd]; c_gl_wr += c_gl_wr_delta; c_sh_rd += c_sh_rd_delta; } From ab586fbc878dd801929b749134d820c4a566aadd Mon Sep 17 00:00:00 2001 From: wchen61 Date: Thu, 26 Dec 2024 13:55:02 +0800 Subject: [PATCH 2/5] Add reduce size when verify configs Signed-off-by: wchen61 --- csrc/quantization/gptq_marlin/gptq_marlin.cu | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/csrc/quantization/gptq_marlin/gptq_marlin.cu b/csrc/quantization/gptq_marlin/gptq_marlin.cu index 2971532706bc5..cc1787312b393 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin.cu +++ b/csrc/quantization/gptq_marlin/gptq_marlin.cu @@ -1338,8 +1338,8 @@ __global__ void Marlin( int red_sh_wr = red_sh_delta * j + (red_sh_rd - red_sh_stride * i); if (i < red_off) { - float* c_rd = - reinterpret_cast(&sh_red[red_sh_delta * j + red_sh_rd]); + float* c_rd = reinterpret_cast( + &sh_red[red_sh_delta * j + red_sh_rd]); float* c_wr = reinterpret_cast(&sh_red[red_sh_wr]); #pragma unroll for (int k = 0; k < 4; k++) @@ -1864,9 +1864,11 @@ bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks, float pipe_size = (a_size + b_size) * pipe_stages; + float reduce_size = th_config.num_threads * 2 * 4 * 4; + TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity - return pipe_size < 0.95f * (max_shared_mem - scales_cache_size); + return pipe_size + reduce_size < 0.95f * (max_shared_mem - scales_cache_size); } bool is_valid_config(thread_config_t const& th_config, int max_m_blocks, From 3a6e41fa63328bff4aa6330f8981552ccdc96882 Mon Sep 17 00:00:00 2001 From: wchen61 Date: Thu, 2 Jan 2025 10:54:32 +0800 Subject: [PATCH 3/5] Optimize the caculation of reduce shared memory size Signed-off-by: wchen61 --- csrc/quantization/gptq_marlin/gptq_marlin.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/quantization/gptq_marlin/gptq_marlin.cu b/csrc/quantization/gptq_marlin/gptq_marlin.cu index cc1787312b393..95b5362dc8ca6 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin.cu +++ b/csrc/quantization/gptq_marlin/gptq_marlin.cu @@ -1864,7 +1864,7 @@ bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks, float pipe_size = (a_size + b_size) * pipe_stages; - float reduce_size = th_config.num_threads * 2 * 4 * 4; + float reduce_size = max(th_config.num_threads * 2 * 4 * 4, (tb_n / 64) * 32 * (tb_max_m / 16) * 4 * 2 * 4 * 2); TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity From aa2f07ae27390814dfcc2e1a2c48935a605db01f Mon Sep 17 00:00:00 2001 From: wchen61 Date: Thu, 2 Jan 2025 11:10:07 +0800 Subject: [PATCH 4/5] Fix clang format issue Signed-off-by: wchen61 --- csrc/quantization/gptq_marlin/gptq_marlin.cu | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/csrc/quantization/gptq_marlin/gptq_marlin.cu b/csrc/quantization/gptq_marlin/gptq_marlin.cu index 95b5362dc8ca6..d8595d2355fb3 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin.cu +++ b/csrc/quantization/gptq_marlin/gptq_marlin.cu @@ -1864,7 +1864,8 @@ bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks, float pipe_size = (a_size + b_size) * pipe_stages; - float reduce_size = max(th_config.num_threads * 2 * 4 * 4, (tb_n / 64) * 32 * (tb_max_m / 16) * 4 * 2 * 4 * 2); + float reduce_size = max(th_config.num_threads * 2 * 4 * 4, + (tb_n / 64) * 32 * (tb_max_m / 16) * 4 * 2 * 4 * 2); TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity From 476212796872928a9974910072e18e3a224698a3 Mon Sep 17 00:00:00 2001 From: wchen61 Date: Thu, 2 Jan 2025 14:13:55 +0800 Subject: [PATCH 5/5] Optimize caclulation of shared memory size for reduction Signed-off-by: wchen61 --- csrc/quantization/gptq_marlin/gptq_marlin.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/quantization/gptq_marlin/gptq_marlin.cu b/csrc/quantization/gptq_marlin/gptq_marlin.cu index d8595d2355fb3..04ef842fbdf95 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin.cu +++ b/csrc/quantization/gptq_marlin/gptq_marlin.cu @@ -1864,7 +1864,7 @@ bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks, float pipe_size = (a_size + b_size) * pipe_stages; - float reduce_size = max(th_config.num_threads * 2 * 4 * 4, + float reduce_size = max(th_config.num_threads * 32 * 4, (tb_n / 64) * 32 * (tb_max_m / 16) * 4 * 2 * 4 * 2); TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity