22#include  " ggml.h" 
33#include  " topk-moe.cuh" 
44
5+ #include  < initializer_list> 
6+ 
57/* 
68    This kernel does the following: 
79    1. softmax over the logits per token [n_experts, n_tokens] 
810    2. argmax reduce over the top-k (n_experts_used) logits 
911    3. write weights + ids to global memory 
12+     4. optionally normalize the weights 
1013
1114    It is intended as fusion of softmax->top-k->get_rows pipeline for MoE models 
1215*/ 
13- template  <size_t  n_experts>
16+ template  <size_t  n_experts,  bool  with_norm >
1417__launch_bounds__ (4  * WARP_SIZE, 1 ) __global__ void topk_moe_cuda(const  float  * logits,
1518                                                                  float  *       weights,
1619                                                                  int32_t  *     ids,
@@ -68,6 +71,11 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
6871    // we do the argmax reduce over n_expert_used, each time marking
6972    // the expert weight as -inf to exclude from the next iteration
7073
74+     float  wt_sum = 0 .f ;
75+ 
76+     extern  __shared__  float  data_topk_shared[];
77+     float  *                 wt_shared_ptr = data_topk_shared + row * n_expert_used;
78+ 
7179    for  (int  k = 0 ; k < n_expert_used; k++) {
7280        float  max_val    = wt[0 ];
7381        int    max_expert = threadIdx .x ;
@@ -94,12 +102,33 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
94102        if  ((max_expert & (WARP_SIZE - 1 )) == threadIdx .x ) {
95103            wt[max_expert / WARP_SIZE] = -INFINITY;
96104
97-             weights[k] = max_val;
98-             ids[k]     = max_expert;
105+             wt_shared_ptr[k] = max_val;
106+             ids[k]           = max_expert;
107+             if  constexpr  (with_norm) {
108+                 wt_sum += max_val;
109+             }
110+         }
111+     }
112+ 
113+     if  constexpr  (with_norm) {
114+         wt_sum              = warp_reduce_sum (wt_sum);
115+         const  float  inv_sum = 1 .0f  / wt_sum;
116+ 
117+         if  (threadIdx .x  == 0 ) {
118+             for  (int  i = 0 ; i < n_expert_used; i++) {
119+                 wt_shared_ptr[i] = wt_shared_ptr[i] * inv_sum;
120+             }
121+         }
122+     }
123+ 
124+     if  (threadIdx .x  == 0 ) {
125+         for  (int  i = 0 ; i < n_expert_used; i++) {
126+             weights[i] = wt_shared_ptr[i];
99127        }
100128    }
101129}
102130
131+ template  <bool  with_norm>
103132static  void  launch_topk_moe_cuda (ggml_backend_cuda_context & ctx,
104133                                 const  float  *               logits,
105134                                 float  *                     weights,
@@ -112,36 +141,48 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
112141    dim3          block_dims (WARP_SIZE, rows_per_block, 1 );
113142    cudaStream_t stream = ctx.stream ();
114143
144+     const  int  nbytes_shared = n_expert_used * rows_per_block * sizeof (float );
145+ 
115146    switch  (n_expert) {
116147        case  1 :
117-             topk_moe_cuda<1 ><<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used);
148+             topk_moe_cuda<1 , with_norm>
149+                 <<<grid_dims, block_dims, nbytes_shared, stream>>> (logits, weights, ids, n_rows, n_expert_used);
118150            break ;
119151        case  2 :
120-             topk_moe_cuda<2 ><<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used);
152+             topk_moe_cuda<2 , with_norm>
153+                 <<<grid_dims, block_dims, nbytes_shared, stream>>> (logits, weights, ids, n_rows, n_expert_used);
121154            break ;
122155        case  4 :
123-             topk_moe_cuda<4 ><<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used);
156+             topk_moe_cuda<4 , with_norm>
157+                 <<<grid_dims, block_dims, nbytes_shared, stream>>> (logits, weights, ids, n_rows, n_expert_used);
124158            break ;
125159        case  8 :
126-             topk_moe_cuda<8 ><<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used);
160+             topk_moe_cuda<8 , with_norm>
161+                 <<<grid_dims, block_dims, nbytes_shared, stream>>> (logits, weights, ids, n_rows, n_expert_used);
127162            break ;
128163        case  16 :
129-             topk_moe_cuda<16 ><<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used);
164+             topk_moe_cuda<16 , with_norm>
165+                 <<<grid_dims, block_dims, nbytes_shared, stream>>> (logits, weights, ids, n_rows, n_expert_used);
130166            break ;
131167        case  32 :
132-             topk_moe_cuda<32 ><<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used);
168+             topk_moe_cuda<32 , with_norm>
169+                 <<<grid_dims, block_dims, nbytes_shared, stream>>> (logits, weights, ids, n_rows, n_expert_used);
133170            break ;
134171        case  64 :
135-             topk_moe_cuda<64 ><<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used);
172+             topk_moe_cuda<64 , with_norm>
173+                 <<<grid_dims, block_dims, nbytes_shared, stream>>> (logits, weights, ids, n_rows, n_expert_used);
136174            break ;
137175        case  128 :
138-             topk_moe_cuda<128 ><<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used);
176+             topk_moe_cuda<128 , with_norm>
177+                 <<<grid_dims, block_dims, nbytes_shared, stream>>> (logits, weights, ids, n_rows, n_expert_used);
139178            break ;
140179        case  256 :
141-             topk_moe_cuda<256 ><<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used);
180+             topk_moe_cuda<256 , with_norm>
181+                 <<<grid_dims, block_dims, nbytes_shared, stream>>> (logits, weights, ids, n_rows, n_expert_used);
142182            break ;
143183        case  512 :
144-             topk_moe_cuda<512 ><<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used);
184+             topk_moe_cuda<512 , with_norm>
185+                 <<<grid_dims, block_dims, nbytes_shared, stream>>> (logits, weights, ids, n_rows, n_expert_used);
145186            break ;
146187        default :
147188            GGML_ASSERT (false  && " fatal error" 
@@ -152,7 +193,8 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
152193void  ggml_cuda_op_topk_moe (ggml_backend_cuda_context & ctx,
153194                           const  ggml_tensor *         logits,
154195                           ggml_tensor *               weights,
155-                            ggml_tensor *               ids) {
196+                            ggml_tensor *               ids,
197+                            const  bool                   with_norm) {
156198    GGML_ASSERT (logits->type  == GGML_TYPE_F32);
157199    GGML_ASSERT (weights->type  == GGML_TYPE_F32);
158200    GGML_ASSERT (ids->type  == GGML_TYPE_I32);
@@ -170,7 +212,11 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
170212
171213    const  int  n_expert_used = weights->ne [1 ];
172214
173-     launch_topk_moe_cuda (ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
215+     if  (with_norm) {
216+         launch_topk_moe_cuda<true >(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
217+     } else  {
218+         launch_topk_moe_cuda<false >(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
219+     }
174220}
175221
176222bool  ggml_cuda_should_use_topk_moe (const  ggml_tensor * softmax, const  ggml_tensor * weights) {
@@ -201,3 +247,17 @@ bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tenso
201247
202248    return  true ;
203249}
250+ 
251+ std::initializer_list<enum  ggml_op> ggml_cuda_topk_moe_ops (bool  norm) {
252+     static  std::initializer_list<enum  ggml_op> norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE,  GGML_OP_ARGSORT,
253+                                                             GGML_OP_VIEW,     GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
254+                                                             GGML_OP_SUM_ROWS, GGML_OP_DIV,      GGML_OP_RESHAPE };
255+ 
256+     static  std::initializer_list<enum  ggml_op> no_norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
257+                                                                GGML_OP_VIEW, GGML_OP_GET_ROWS };
258+ 
259+     if  (norm) {
260+         return  norm_ops;
261+     }
262+     return  no_norm_ops;
263+ }
0 commit comments