@@ -538,6 +538,7 @@ __global__ void Marlin(
538538 int prob_n, // output dimension n
539539 int prob_k, // reduction dimension k
540540 int * locks, // extra global storage for barrier synchronization
541+ bool use_atomic_add, // whether to use atomic add to reduce
541542 bool use_fp32_reduce // whether to use fp32 global reduce
542543) {
543544 // Each threadblock processes one "stripe" of the B matrix with (roughly) the
@@ -1542,7 +1543,17 @@ __global__ void Marlin(
15421543 i < div_ceil (16 * thread_m_blocks, threads / (2 * thread_n_blocks));
15431544 i++) {
15441545 if (c_gl_wr < c_gl_wr_end) {
1545- C[c_gl_wr] = sh_red[c_sh_rd];
1546+ if (use_atomic_add && slice_count > 1 ) {
1547+ scalar_t2* C_half2 = reinterpret_cast <scalar_t2*>(&C[c_gl_wr]);
1548+ scalar_t2* sh_red_half2 =
1549+ reinterpret_cast <scalar_t2*>(&sh_red[c_sh_rd]);
1550+ #pragma unroll
1551+ for (int a = 0 ; a < 4 ; a++) {
1552+ atomicAdd (&C_half2[a], sh_red_half2[a]);
1553+ }
1554+ } else {
1555+ C[c_gl_wr] = sh_red[c_sh_rd];
1556+ }
15461557 c_gl_wr += c_gl_wr_delta;
15471558 c_sh_rd += c_sh_rd_delta;
15481559 }
@@ -1644,7 +1655,7 @@ __global__ void Marlin(
16441655 }
16451656 cp_async_fence ();
16461657 } else {
1647- if (last) {
1658+ if (last || use_atomic_add ) {
16481659 if (s_sh_wr_pred) {
16491660 cp_async4 (&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);
16501661 }
@@ -1664,7 +1675,7 @@ __global__ void Marlin(
16641675 }
16651676
16661677 } else {
1667- if (last) {
1678+ if (last || use_atomic_add ) {
16681679 cp_async_wait<0 >();
16691680 __syncthreads ();
16701681 if (threadIdx .x / 32 < thread_n_blocks / 4 ) {
@@ -1703,8 +1714,8 @@ __global__ void Marlin(
17031714 }
17041715 }
17051716
1706- if (slice_count > 1 ) { // only globally reduce if there is more than one
1707- // block in a slice
1717+ if (slice_count > 1 && !use_atomic_add) {
1718+ // only globally reduce if there is more than one block in a slice
17081719 barrier_acquire (&locks[slice_col], slice_idx);
17091720 if (use_fp32_reduce) {
17101721 global_reduce_fp32 (slice_idx == 0 , last);
@@ -1713,7 +1724,8 @@ __global__ void Marlin(
17131724 }
17141725 barrier_release (&locks[slice_col], last);
17151726 }
1716- if (last) // only the last block in a slice actually writes the result
1727+ if (last || use_atomic_add)
1728+ // only the last block in a slice actuallywrites the result
17171729 write_result ();
17181730 slice_row = 0 ;
17191731 slice_col_par++;
@@ -1768,7 +1780,8 @@ __global__ void Marlin(
17681780 HAS_ZP, GROUP_BLOCKS, IS_ZP_FLOAT> \
17691781 <<<blocks, NUM_THREADS, max_shared_mem, stream>>> ( \
17701782 A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr, \
1771- num_groups, prob_m, prob_n, prob_k, locks, use_fp32_reduce); \
1783+ num_groups, prob_m, prob_n, prob_k, locks, use_atomic_add, \
1784+ use_fp32_reduce); \
17721785 } \
17731786 }
17741787
@@ -2062,7 +2075,8 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
20622075 vllm::ScalarType const & q_type, bool has_act_order,
20632076 bool is_k_full, bool has_zp, int num_groups, int group_size,
20642077 int dev, cudaStream_t stream, int thread_k, int thread_n,
2065- int sms, int max_par, bool use_fp32_reduce, bool is_zp_float) {
2078+ int sms, int max_par, bool use_atomic_add, bool use_fp32_reduce,
2079+ bool is_zp_float) {
20662080 if (has_zp) {
20672081 TORCH_CHECK (
20682082 q_type == vllm::kU4 || q_type == vllm::kU8 ,
@@ -2243,7 +2257,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
22432257 torch::Tensor& workspace,
22442258 vllm::ScalarTypeId const & b_q_type_id,
22452259 int64_t size_m, int64_t size_n, int64_t size_k,
2246- bool is_k_full, bool has_zp,
2260+ bool is_k_full, bool has_zp, bool use_atomic_add,
22472261 bool use_fp32_reduce, bool is_zp_float) {
22482262 vllm::ScalarType const b_q_type = vllm::ScalarType::from_id (b_q_type_id);
22492263 if (has_zp) {
@@ -2306,19 +2320,34 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
23062320 // Alloc buffers
23072321 const at::cuda::OptionalCUDAGuard device_guard (device_of (a));
23082322 auto options = torch::TensorOptions ().dtype (a.dtype ()).device (a.device ());
2309- torch::Tensor c = torch::empty ({size_m, size_n}, options);
2310- torch::Tensor a_tmp = torch::empty ({size_m, size_k}, options);
2323+ torch::Tensor c;
2324+ if (use_atomic_add) {
2325+ c = torch::zeros ({size_m, size_n}, options);
2326+ } else {
2327+ c = torch::empty ({size_m, size_n}, options);
2328+ }
2329+
2330+ torch::Tensor a_tmp;
2331+ bool has_act_order = g_idx.size (0 ) != 0 ;
2332+ if (has_act_order) {
2333+ a_tmp = torch::empty ({size_m, size_k}, options);
2334+ } else {
2335+ a_tmp = torch::empty ({0 }, options);
2336+ }
23112337
23122338 // Alloc C tmp buffer that is going to be used for the global reduce
2339+ torch::Tensor c_tmp;
23132340 int reduce_max_m = marlin::determine_reduce_max_m (size_m, marlin::max_par);
23142341 int reduce_n = size_n;
23152342 auto options_fp32 =
23162343 torch::TensorOptions ().dtype (at::kFloat ).device (a.device ());
2317- if (!use_fp32_reduce) {
2344+ if (use_fp32_reduce) {
2345+ c_tmp = torch::empty ({reduce_max_m, reduce_n}, options_fp32);
2346+ } else {
23182347 reduce_max_m = 0 ;
23192348 reduce_n = 0 ;
2349+ c_tmp = torch::empty ({0 }, options_fp32);
23202350 }
2321- torch::Tensor c_tmp = torch::empty ({reduce_max_m, reduce_n}, options_fp32);
23222351
23232352 // thread_k: `k` size of a thread_tile in `weights` (can usually be left as
23242353 // auto -1)
@@ -2339,7 +2368,6 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
23392368 // Detect groupsize and act_order
23402369 int num_groups = -1 ;
23412370 int group_size = -1 ;
2342- bool has_act_order = g_idx.size (0 ) != 0 ;
23432371
23442372 int rank = b_scales.sizes ().size ();
23452373 TORCH_CHECK (rank == 2 , " b_scales rank = " , rank, " is not 2" );
@@ -2407,7 +2435,8 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
24072435 a_tmp.data_ptr <at::Half>(), size_m, size_n, size_k,
24082436 workspace.data_ptr (), b_q_type, has_act_order, is_k_full, has_zp,
24092437 num_groups, group_size, dev, at::cuda::getCurrentCUDAStream (dev),
2410- thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce, is_zp_float);
2438+ thread_k, thread_n, sms, marlin::max_par, use_atomic_add,
2439+ use_fp32_reduce, is_zp_float);
24112440 } else if (a.scalar_type () == at::ScalarType::BFloat16) {
24122441 marlin::marlin_mm<nv_bfloat16>(
24132442 a.data_ptr <at::BFloat16>(), b_q_weight.data_ptr (),
@@ -2416,7 +2445,8 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
24162445 perm.data_ptr (), a_tmp.data_ptr <at::BFloat16>(), size_m, size_n, size_k,
24172446 workspace.data_ptr (), b_q_type, has_act_order, is_k_full, has_zp,
24182447 num_groups, group_size, dev, at::cuda::getCurrentCUDAStream (dev),
2419- thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce, is_zp_float);
2448+ thread_k, thread_n, sms, marlin::max_par, use_atomic_add,
2449+ use_fp32_reduce, is_zp_float);
24202450 } else {
24212451 TORCH_CHECK (false , " gpt_marlin_gemm only supports bfloat16 and float16" );
24222452 }
0 commit comments