Skip to content

Commit 161248c

Browse files
committed
CUDA: support topk_moe with weight clamp
1 parent 32a3f26 commit 161248c

File tree

4 files changed

+44
-27
lines changed

4 files changed

+44
-27
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2828,7 +2828,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
28282828
if (ops.size() == topk_moe_ops_with_norm.size() &&
28292829
ggml_can_fuse_subgraph(cgraph, node_idx, topk_moe_ops_with_norm, { node_idx + 3, node_idx + 8 })) {
28302830
ggml_tensor * softmax = cgraph->nodes[node_idx];
2831-
ggml_tensor * weights = cgraph->nodes[node_idx+8];
2831+
ggml_tensor * weights = cgraph->nodes[node_idx + 9];
28322832

28332833
if (ggml_cuda_should_use_topk_moe(softmax, weights)) {
28342834
return true;
@@ -2838,7 +2838,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
28382838
if (ops.size() == topk_moe_ops.size() &&
28392839
ggml_can_fuse_subgraph(cgraph, node_idx, topk_moe_ops, { node_idx + 3, node_idx + 4 })) {
28402840
ggml_tensor * softmax = cgraph->nodes[node_idx];
2841-
ggml_tensor * weights = cgraph->nodes[node_idx+4];
2841+
ggml_tensor * weights = cgraph->nodes[node_idx + 4];
28422842
if (ggml_cuda_should_use_topk_moe(softmax, weights)) {
28432843
return true;
28442844
}
@@ -2945,17 +2945,18 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
29452945
if (!disable_fusion) {
29462946

29472947
if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ true), {})) {
2948-
ggml_tensor * weights = cgraph->nodes[i+8];
2949-
ggml_tensor * selected_experts = cgraph->nodes[i+3];
2948+
ggml_tensor * weights = cgraph->nodes[i + 9];
2949+
ggml_tensor * selected_experts = cgraph->nodes[i + 3];
2950+
ggml_tensor * clamp = cgraph->nodes[i + 7];
29502951
ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, selected_experts, /*with norm*/ true,
2951-
/*delayed softmax*/ false);
2952-
i += 8;
2952+
/*delayed softmax*/ false, clamp);
2953+
i += 9;
29532954
continue;
29542955
}
29552956

29562957
if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ false), {})) {
2957-
ggml_tensor * weights = cgraph->nodes[i+4];
2958-
ggml_tensor * selected_experts = cgraph->nodes[i+3];
2958+
ggml_tensor * weights = cgraph->nodes[i + 4];
2959+
ggml_tensor * selected_experts = cgraph->nodes[i + 3];
29592960
ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, selected_experts, /*with norm*/ false,
29602961
/*delayed softmax*/ false);
29612962
i += 4;

ggml/src/ggml-cuda/topk-moe.cu

Lines changed: 32 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,8 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
6363
float * weights,
6464
int32_t * ids,
6565
const int n_rows,
66-
const int n_expert_used) {
66+
const int n_expert_used,
67+
const float clamp_val) {
6768
const int row = blockIdx.x * blockDim.y + threadIdx.y;
6869
if (row >= n_rows) {
6970
return;
@@ -139,6 +140,7 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
139140

140141
if constexpr (with_norm) {
141142
wt_sum = warp_reduce_sum(wt_sum);
143+
wt_sum = max(wt_sum, clamp_val);
142144
const float inv_sum = 1.0f / wt_sum;
143145

144146
for (int i = threadIdx.x; i < n_expert_used; i += WARP_SIZE) {
@@ -157,6 +159,10 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
157159
weights[idx] = output_weights[i];
158160
}
159161
}
162+
163+
if (!with_norm) {
164+
GGML_UNUSED(clamp_val);
165+
}
160166
}
161167

162168
template <bool with_norm, bool delayed_softmax = false>
@@ -166,9 +172,9 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
166172
int32_t * ids,
167173
const int n_rows,
168174
const int n_expert,
169-
const int n_expert_used) {
175+
const int n_expert_used,
176+
const float clamp_val) {
170177
static_assert(!(with_norm && delayed_softmax), "delayed softmax is not supported with weight normalization");
171-
172178
const int rows_per_block = 4;
173179
dim3 grid_dims((n_rows + rows_per_block - 1) / rows_per_block, 1, 1);
174180
dim3 block_dims(WARP_SIZE, rows_per_block, 1);
@@ -177,43 +183,43 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
177183
switch (n_expert) {
178184
case 1:
179185
topk_moe_cuda<1, with_norm, delayed_softmax>
180-
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
186+
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
181187
break;
182188
case 2:
183189
topk_moe_cuda<2, with_norm, delayed_softmax>
184-
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
190+
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
185191
break;
186192
case 4:
187193
topk_moe_cuda<4, with_norm, delayed_softmax>
188-
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
194+
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
189195
break;
190196
case 8:
191197
topk_moe_cuda<8, with_norm, delayed_softmax>
192-
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
198+
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
193199
break;
194200
case 16:
195201
topk_moe_cuda<16, with_norm, delayed_softmax>
196-
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
202+
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
197203
break;
198204
case 32:
199205
topk_moe_cuda<32, with_norm, delayed_softmax>
200-
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
206+
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
201207
break;
202208
case 64:
203209
topk_moe_cuda<64, with_norm, delayed_softmax>
204-
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
210+
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
205211
break;
206212
case 128:
207213
topk_moe_cuda<128, with_norm, delayed_softmax>
208-
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
214+
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
209215
break;
210216
case 256:
211217
topk_moe_cuda<256, with_norm, delayed_softmax>
212-
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
218+
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
213219
break;
214220
case 512:
215221
topk_moe_cuda<512, with_norm, delayed_softmax>
216-
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
222+
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
217223
break;
218224
default:
219225
GGML_ASSERT(false && "fatal error");
@@ -226,7 +232,8 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
226232
ggml_tensor * weights,
227233
ggml_tensor * ids,
228234
const bool with_norm,
229-
const bool delayed_softmax) {
235+
const bool delayed_softmax,
236+
ggml_tensor * clamp) {
230237
GGML_ASSERT(logits->type == GGML_TYPE_F32);
231238
GGML_ASSERT(weights->type == GGML_TYPE_F32);
232239
GGML_ASSERT(ids->type == GGML_TYPE_I32);
@@ -242,13 +249,19 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
242249

243250
const int n_expert_used = weights->ne[1];
244251

252+
float clamp_val = 0.0f;
245253
if (with_norm) {
246-
launch_topk_moe_cuda<true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
254+
if (clamp) {
255+
clamp_val = ggml_get_op_params_f32(clamp, 0);
256+
}
257+
launch_topk_moe_cuda<true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used, clamp_val);
247258
} else {
248259
if (delayed_softmax) {
249-
launch_topk_moe_cuda<false, true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
260+
launch_topk_moe_cuda<false, true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used,
261+
clamp_val);
250262
} else {
251-
launch_topk_moe_cuda<false, false>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
263+
launch_topk_moe_cuda<false, false>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used,
264+
clamp_val);
252265
}
253266
}
254267
}
@@ -285,7 +298,8 @@ bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tenso
285298
std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool norm, bool delayed_softmax) {
286299
static std::initializer_list<enum ggml_op> norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
287300
GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
288-
GGML_OP_SUM_ROWS, GGML_OP_DIV, GGML_OP_RESHAPE };
301+
GGML_OP_SUM_ROWS, GGML_OP_CLAMP, GGML_OP_DIV,
302+
GGML_OP_RESHAPE };
289303

290304
static std::initializer_list<enum ggml_op> no_norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
291305
GGML_OP_VIEW, GGML_OP_GET_ROWS };

ggml/src/ggml-cuda/topk-moe.cuh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
88
ggml_tensor * weights,
99
ggml_tensor * ids,
1010
const bool with_norm,
11-
const bool delayed_softmax = false);
11+
const bool delayed_softmax = false,
12+
ggml_tensor * weight_clamp = nullptr);
1213

1314
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights);
1415

tests/test-backend-ops.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4712,6 +4712,7 @@ struct test_topk_moe: public test_case {
47124712
out = ggml_reshape_2d(ctx, out, n_expert_used, n_tokens);
47134713
ggml_tensor * weights_sum = ggml_sum_rows(ctx, out); // [1, n_tokens]
47144714

4715+
weights_sum = ggml_clamp(ctx, weights_sum, 6.103515625e-5, INFINITY);
47154716
out = ggml_div(ctx, out, weights_sum); // [n_expert_used, n_tokens]
47164717
out = ggml_reshape_3d(ctx, out, 1, n_expert_used, n_tokens);
47174718
}

0 commit comments

Comments
 (0)