@@ -538,6 +538,7 @@ __global__ void Marlin(
538538    int  prob_n,           //  output dimension n
539539    int  prob_k,           //  reduction dimension k
540540    int * locks,           //  extra global storage for barrier synchronization
541+     bool  use_atomic_add,  //  whether to use atomic add to reduce
541542    bool  use_fp32_reduce  //  whether to use fp32 global reduce
542543) {
543544  //  Each threadblock processes one "stripe" of the B matrix with (roughly) the
@@ -1542,7 +1543,17 @@ __global__ void Marlin(
15421543         i < div_ceil (16  * thread_m_blocks, threads / (2  * thread_n_blocks));
15431544         i++) {
15441545      if  (c_gl_wr < c_gl_wr_end) {
1545-         C[c_gl_wr] = sh_red[c_sh_rd];
1546+         if  (use_atomic_add && slice_count > 1 ) {
1547+           scalar_t2* C_half2 = reinterpret_cast <scalar_t2*>(&C[c_gl_wr]);
1548+           scalar_t2* sh_red_half2 =
1549+               reinterpret_cast <scalar_t2*>(&sh_red[c_sh_rd]);
1550+   #pragma  unroll
1551+           for  (int  a = 0 ; a < 4 ; a++) {
1552+             atomicAdd (&C_half2[a], sh_red_half2[a]);
1553+           }
1554+         } else  {
1555+           C[c_gl_wr] = sh_red[c_sh_rd];
1556+         }
15461557        c_gl_wr += c_gl_wr_delta;
15471558        c_sh_rd += c_sh_rd_delta;
15481559      }
@@ -1644,7 +1655,7 @@ __global__ void Marlin(
16441655          }
16451656          cp_async_fence ();
16461657        } else  {
1647-           if  (last) {
1658+           if  (last || use_atomic_add ) {
16481659            if  (s_sh_wr_pred) {
16491660              cp_async4 (&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);
16501661            }
@@ -1664,7 +1675,7 @@ __global__ void Marlin(
16641675          }
16651676
16661677        } else  {
1667-           if  (last) {
1678+           if  (last || use_atomic_add ) {
16681679            cp_async_wait<0 >();
16691680            __syncthreads ();
16701681            if  (threadIdx .x  / 32  < thread_n_blocks / 4 ) {
@@ -1703,8 +1714,8 @@ __global__ void Marlin(
17031714        }
17041715      }
17051716
1706-       if  (slice_count > 1 ) {   //  only globally reduce if there is more than one 
1707-                                //  block in a slice
1717+       if  (slice_count > 1  && !use_atomic_add) { 
1718+         //  only globally reduce if there is more than one  block in a slice
17081719        barrier_acquire (&locks[slice_col], slice_idx);
17091720        if  (use_fp32_reduce) {
17101721          global_reduce_fp32 (slice_idx == 0 , last);
@@ -1713,7 +1724,8 @@ __global__ void Marlin(
17131724        }
17141725        barrier_release (&locks[slice_col], last);
17151726      }
1716-       if  (last)  //  only the last block in a slice actually writes the result
1727+       if  (last || use_atomic_add)
1728+         //  only the last block in a slice actuallywrites the result
17171729        write_result ();
17181730      slice_row = 0 ;
17191731      slice_col_par++;
@@ -1768,7 +1780,8 @@ __global__ void Marlin(
17681780               HAS_ZP, GROUP_BLOCKS, IS_ZP_FLOAT>                              \
17691781            <<<blocks, NUM_THREADS, max_shared_mem, stream>>> (                 \
17701782                A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr,      \
1771-                 num_groups, prob_m, prob_n, prob_k, locks, use_fp32_reduce);   \
1783+                 num_groups, prob_m, prob_n, prob_k, locks, use_atomic_add,     \
1784+                 use_fp32_reduce);                                              \
17721785      }                                                                        \
17731786    }
17741787
@@ -2062,7 +2075,8 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
20622075               vllm::ScalarType const & q_type, bool  has_act_order,
20632076               bool  is_k_full, bool  has_zp, int  num_groups, int  group_size,
20642077               int  dev, cudaStream_t stream, int  thread_k, int  thread_n,
2065-                int  sms, int  max_par, bool  use_fp32_reduce, bool  is_zp_float) {
2078+                int  sms, int  max_par, bool  use_atomic_add, bool  use_fp32_reduce,
2079+                bool  is_zp_float) {
20662080  if  (has_zp) {
20672081    TORCH_CHECK (
20682082        q_type == vllm::kU4  || q_type == vllm::kU8 ,
@@ -2243,7 +2257,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
22432257                               torch::Tensor& workspace,
22442258                               vllm::ScalarTypeId const & b_q_type_id,
22452259                               int64_t  size_m, int64_t  size_n, int64_t  size_k,
2246-                                bool  is_k_full, bool  has_zp,
2260+                                bool  is_k_full, bool  has_zp,  bool  use_atomic_add, 
22472261                               bool  use_fp32_reduce, bool  is_zp_float) {
22482262  vllm::ScalarType const  b_q_type = vllm::ScalarType::from_id (b_q_type_id);
22492263  if  (has_zp) {
@@ -2306,19 +2320,34 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
23062320  //  Alloc buffers
23072321  const  at::cuda::OptionalCUDAGuard device_guard (device_of (a));
23082322  auto  options = torch::TensorOptions ().dtype (a.dtype ()).device (a.device ());
2309-   torch::Tensor c = torch::empty ({size_m, size_n}, options);
2310-   torch::Tensor a_tmp = torch::empty ({size_m, size_k}, options);
2323+   torch::Tensor c;
2324+   if  (use_atomic_add) {
2325+     c = torch::zeros ({size_m, size_n}, options);
2326+   } else  {
2327+     c = torch::empty ({size_m, size_n}, options);
2328+   }
2329+ 
2330+   torch::Tensor a_tmp;
2331+   bool  has_act_order = g_idx.size (0 ) != 0 ;
2332+   if  (has_act_order) {
2333+     a_tmp = torch::empty ({size_m, size_k}, options);
2334+   } else  {
2335+     a_tmp = torch::empty ({0 }, options);
2336+   }
23112337
23122338  //  Alloc C tmp buffer that is going to be used for the global reduce
2339+   torch::Tensor c_tmp;
23132340  int  reduce_max_m = marlin::determine_reduce_max_m (size_m, marlin::max_par);
23142341  int  reduce_n = size_n;
23152342  auto  options_fp32 =
23162343      torch::TensorOptions ().dtype (at::kFloat ).device (a.device ());
2317-   if  (!use_fp32_reduce) {
2344+   if  (use_fp32_reduce) {
2345+     c_tmp = torch::empty ({reduce_max_m, reduce_n}, options_fp32);
2346+   } else  {
23182347    reduce_max_m = 0 ;
23192348    reduce_n = 0 ;
2349+     c_tmp = torch::empty ({0 }, options_fp32);
23202350  }
2321-   torch::Tensor c_tmp = torch::empty ({reduce_max_m, reduce_n}, options_fp32);
23222351
23232352  //  thread_k: `k` size of a thread_tile in `weights` (can usually be left as
23242353  //  auto -1)
@@ -2339,7 +2368,6 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
23392368  //  Detect groupsize and act_order
23402369  int  num_groups = -1 ;
23412370  int  group_size = -1 ;
2342-   bool  has_act_order = g_idx.size (0 ) != 0 ;
23432371
23442372  int  rank = b_scales.sizes ().size ();
23452373  TORCH_CHECK (rank == 2 , " b_scales rank = " "  is not 2" 
@@ -2407,7 +2435,8 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
24072435        a_tmp.data_ptr <at::Half>(), size_m, size_n, size_k,
24082436        workspace.data_ptr (), b_q_type, has_act_order, is_k_full, has_zp,
24092437        num_groups, group_size, dev, at::cuda::getCurrentCUDAStream (dev),
2410-         thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce, is_zp_float);
2438+         thread_k, thread_n, sms, marlin::max_par, use_atomic_add,
2439+         use_fp32_reduce, is_zp_float);
24112440  } else  if  (a.scalar_type () == at::ScalarType::BFloat16) {
24122441    marlin::marlin_mm<nv_bfloat16>(
24132442        a.data_ptr <at::BFloat16>(), b_q_weight.data_ptr (),
@@ -2416,7 +2445,8 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
24162445        perm.data_ptr (), a_tmp.data_ptr <at::BFloat16>(), size_m, size_n, size_k,
24172446        workspace.data_ptr (), b_q_type, has_act_order, is_k_full, has_zp,
24182447        num_groups, group_size, dev, at::cuda::getCurrentCUDAStream (dev),
2419-         thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce, is_zp_float);
2448+         thread_k, thread_n, sms, marlin::max_par, use_atomic_add,
2449+         use_fp32_reduce, is_zp_float);
24202450  } else  {
24212451    TORCH_CHECK (false , " gpt_marlin_gemm only supports bfloat16 and float16" 
24222452  }
0 commit comments