@@ -63,7 +63,8 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
6363 float * weights,
6464 int32_t * ids,
6565 const int n_rows,
66- const int n_expert_used) {
66+ const int n_expert_used,
67+ const float clamp_val) {
6768 const int row = blockIdx .x * blockDim .y + threadIdx .y ;
6869 if (row >= n_rows) {
6970 return ;
@@ -139,6 +140,7 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
139140
140141 if constexpr (with_norm) {
141142 wt_sum = warp_reduce_sum (wt_sum);
143+ wt_sum = max (wt_sum, clamp_val);
142144 const float inv_sum = 1 .0f / wt_sum;
143145
144146 for (int i = threadIdx .x ; i < n_expert_used; i += WARP_SIZE) {
@@ -157,6 +159,10 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
157159 weights[idx] = output_weights[i];
158160 }
159161 }
162+
163+ if (!with_norm) {
164+ GGML_UNUSED (clamp_val);
165+ }
160166}
161167
162168template <bool with_norm, bool delayed_softmax = false >
@@ -166,9 +172,9 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
166172 int32_t * ids,
167173 const int n_rows,
168174 const int n_expert,
169- const int n_expert_used) {
175+ const int n_expert_used,
176+ const float clamp_val) {
170177 static_assert (!(with_norm && delayed_softmax), " delayed softmax is not supported with weight normalization" );
171-
172178 const int rows_per_block = 4 ;
173179 dim3 grid_dims ((n_rows + rows_per_block - 1 ) / rows_per_block, 1 , 1 );
174180 dim3 block_dims (WARP_SIZE, rows_per_block, 1 );
@@ -177,43 +183,43 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
177183 switch (n_expert) {
178184 case 1 :
179185 topk_moe_cuda<1 , with_norm, delayed_softmax>
180- <<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used);
186+ <<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used, clamp_val );
181187 break ;
182188 case 2 :
183189 topk_moe_cuda<2 , with_norm, delayed_softmax>
184- <<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used);
190+ <<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used, clamp_val );
185191 break ;
186192 case 4 :
187193 topk_moe_cuda<4 , with_norm, delayed_softmax>
188- <<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used);
194+ <<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used, clamp_val );
189195 break ;
190196 case 8 :
191197 topk_moe_cuda<8 , with_norm, delayed_softmax>
192- <<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used);
198+ <<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used, clamp_val );
193199 break ;
194200 case 16 :
195201 topk_moe_cuda<16 , with_norm, delayed_softmax>
196- <<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used);
202+ <<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used, clamp_val );
197203 break ;
198204 case 32 :
199205 topk_moe_cuda<32 , with_norm, delayed_softmax>
200- <<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used);
206+ <<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used, clamp_val );
201207 break ;
202208 case 64 :
203209 topk_moe_cuda<64 , with_norm, delayed_softmax>
204- <<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used);
210+ <<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used, clamp_val );
205211 break ;
206212 case 128 :
207213 topk_moe_cuda<128 , with_norm, delayed_softmax>
208- <<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used);
214+ <<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used, clamp_val );
209215 break ;
210216 case 256 :
211217 topk_moe_cuda<256 , with_norm, delayed_softmax>
212- <<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used);
218+ <<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used, clamp_val );
213219 break ;
214220 case 512 :
215221 topk_moe_cuda<512 , with_norm, delayed_softmax>
216- <<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used);
222+ <<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used, clamp_val );
217223 break ;
218224 default :
219225 GGML_ASSERT (false && " fatal error" );
@@ -226,7 +232,8 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
226232 ggml_tensor * weights,
227233 ggml_tensor * ids,
228234 const bool with_norm,
229- const bool delayed_softmax) {
235+ const bool delayed_softmax,
236+ ggml_tensor * clamp) {
230237 GGML_ASSERT (logits->type == GGML_TYPE_F32);
231238 GGML_ASSERT (weights->type == GGML_TYPE_F32);
232239 GGML_ASSERT (ids->type == GGML_TYPE_I32);
@@ -242,13 +249,19 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
242249
243250 const int n_expert_used = weights->ne [1 ];
244251
252+ float clamp_val = 0 .0f ;
245253 if (with_norm) {
246- launch_topk_moe_cuda<true >(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
254+ if (clamp) {
255+ clamp_val = ggml_get_op_params_f32 (clamp, 0 );
256+ }
257+ launch_topk_moe_cuda<true >(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used, clamp_val);
247258 } else {
248259 if (delayed_softmax) {
249- launch_topk_moe_cuda<false , true >(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
260+ launch_topk_moe_cuda<false , true >(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used,
261+ clamp_val);
250262 } else {
251- launch_topk_moe_cuda<false , false >(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
263+ launch_topk_moe_cuda<false , false >(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used,
264+ clamp_val);
252265 }
253266 }
254267}
@@ -285,7 +298,8 @@ bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tenso
285298std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops (bool norm, bool delayed_softmax) {
286299 static std::initializer_list<enum ggml_op> norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
287300 GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
288- GGML_OP_SUM_ROWS, GGML_OP_DIV, GGML_OP_RESHAPE };
301+ GGML_OP_SUM_ROWS, GGML_OP_CLAMP, GGML_OP_DIV,
302+ GGML_OP_RESHAPE };
289303
290304 static std::initializer_list<enum ggml_op> no_norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
291305 GGML_OP_VIEW, GGML_OP_GET_ROWS };
0 commit comments