Skip to content

Commit 240b2c1

Browse files
committed
Add optional norm + clean-up code
1 parent 2930668 commit 240b2c1

File tree

4 files changed

+127
-30
lines changed

4 files changed

+127
-30
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2826,12 +2826,25 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
28262826
GGML_ASSERT(unary_ops.size() == num_unary);
28272827
#endif
28282828

2829-
//special case for topk-moe
2830-
if (ops.size() == 5 && ops.begin()[0] == GGML_OP_SOFT_MAX && ops.begin()[1] == GGML_OP_RESHAPE && ops.begin()[2] == GGML_OP_ARGSORT
2831-
&& ops.begin()[3] == GGML_OP_VIEW && ops.begin()[4] == GGML_OP_GET_ROWS) {
2829+
//TODO: remove special case once ggml_can_fuse can handle empty nodes
2830+
std::initializer_list<enum ggml_op> topk_moe_ops = ggml_cuda_topk_moe_ops(false);
2831+
std::initializer_list<enum ggml_op> topk_moe_ops_with_norm = ggml_cuda_topk_moe_ops(true);
28322832

2833-
for (int i = 0; i < 5; i++) {
2834-
if (cgraph->nodes[node_idx + i]->op != ops.begin()[i]) return false;
2833+
if (ops.size() == topk_moe_ops_with_norm.size() && std::equal(ops.begin(), ops.end(), topk_moe_ops_with_norm.begin())) {
2834+
for (size_t i = 0; i < topk_moe_ops_with_norm.size(); i++) {
2835+
if (cgraph->nodes[node_idx + i]->op != topk_moe_ops_with_norm.begin()[i]) return false;
2836+
}
2837+
ggml_tensor * softmax = cgraph->nodes[node_idx];
2838+
ggml_tensor * weights = cgraph->nodes[node_idx+8];
2839+
2840+
if (ggml_cuda_should_use_topk_moe(softmax, weights)) {
2841+
return true;
2842+
}
2843+
}
2844+
2845+
if (ops.size() == topk_moe_ops.size() && std::equal(ops.begin(), ops.end(), topk_moe_ops.begin())) {
2846+
for (size_t i = 0; i < topk_moe_ops.size(); i++) {
2847+
if (cgraph->nodes[node_idx + i]->op != topk_moe_ops.begin()[i]) return false;
28352848
}
28362849

28372850
ggml_tensor * softmax = cgraph->nodes[node_idx];
@@ -2931,11 +2944,19 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
29312944
static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr);
29322945
if (!disable_fusion) {
29332946

2934-
if (ggml_cuda_can_fuse(cgraph, i, {GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT, GGML_OP_VIEW, GGML_OP_GET_ROWS}, {})) {
2947+
if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ true), {})) {
2948+
ggml_tensor * weights = cgraph->nodes[i+8];
2949+
ggml_tensor * selected_experts = cgraph->nodes[i+3];
2950+
ggml_cuda_op_topk_moe(*cuda_ctx, node, weights, selected_experts, /*with norm*/ true);
2951+
i += 8;
2952+
continue;
2953+
}
2954+
2955+
if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ false), {})) {
29352956

29362957
ggml_tensor * weights = cgraph->nodes[i+4];
29372958
ggml_tensor * selected_experts = cgraph->nodes[i+3];
2938-
ggml_cuda_op_topk_moe(*cuda_ctx, node, weights, selected_experts);
2959+
ggml_cuda_op_topk_moe(*cuda_ctx, node, weights, selected_experts, /*with norm*/ false);
29392960
i += 4;
29402961
continue;
29412962
}

ggml/src/ggml-cuda/topk-moe.cu

Lines changed: 75 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,18 @@
22
#include "ggml.h"
33
#include "topk-moe.cuh"
44

5+
#include <initializer_list>
6+
57
/*
68
This kernel does the following:
79
1. softmax over the logits per token [n_experts, n_tokens]
810
2. argmax reduce over the top-k (n_experts_used) logits
911
3. write weights + ids to global memory
12+
4. optionally normalize the weights
1013
1114
It is intended as fusion of softmax->top-k->get_rows pipeline for MoE models
1215
*/
13-
template <size_t n_experts>
16+
template <size_t n_experts, bool with_norm>
1417
__launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * logits,
1518
float * weights,
1619
int32_t * ids,
@@ -68,6 +71,11 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
6871
//we do the argmax reduce over n_expert_used, each time marking
6972
//the expert weight as -inf to exclude from the next iteration
7073

74+
float wt_sum = 0.f;
75+
76+
extern __shared__ float data_topk_shared[];
77+
float * wt_shared_ptr = data_topk_shared + row * n_expert_used;
78+
7179
for (int k = 0; k < n_expert_used; k++) {
7280
float max_val = wt[0];
7381
int max_expert = threadIdx.x;
@@ -94,12 +102,33 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
94102
if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) {
95103
wt[max_expert / WARP_SIZE] = -INFINITY;
96104

97-
weights[k] = max_val;
98-
ids[k] = max_expert;
105+
wt_shared_ptr[k] = max_val;
106+
ids[k] = max_expert;
107+
if constexpr (with_norm) {
108+
wt_sum += max_val;
109+
}
110+
}
111+
}
112+
113+
if constexpr (with_norm) {
114+
wt_sum = warp_reduce_sum(wt_sum);
115+
const float inv_sum = 1.0f / wt_sum;
116+
117+
if (threadIdx.x == 0) {
118+
for (int i = 0; i < n_expert_used; i++) {
119+
wt_shared_ptr[i] = wt_shared_ptr[i] * inv_sum;
120+
}
121+
}
122+
}
123+
124+
if (threadIdx.x == 0) {
125+
for (int i = 0; i < n_expert_used; i++) {
126+
weights[i] = wt_shared_ptr[i];
99127
}
100128
}
101129
}
102130

131+
template <bool with_norm>
103132
static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
104133
const float * logits,
105134
float * weights,
@@ -112,36 +141,48 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
112141
dim3 block_dims(WARP_SIZE, rows_per_block, 1);
113142
cudaStream_t stream = ctx.stream();
114143

144+
const int nbytes_shared = n_expert_used * rows_per_block * sizeof(float);
145+
115146
switch (n_expert) {
116147
case 1:
117-
topk_moe_cuda<1><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
148+
topk_moe_cuda<1, with_norm>
149+
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
118150
break;
119151
case 2:
120-
topk_moe_cuda<2><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
152+
topk_moe_cuda<2, with_norm>
153+
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
121154
break;
122155
case 4:
123-
topk_moe_cuda<4><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
156+
topk_moe_cuda<4, with_norm>
157+
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
124158
break;
125159
case 8:
126-
topk_moe_cuda<8><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
160+
topk_moe_cuda<8, with_norm>
161+
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
127162
break;
128163
case 16:
129-
topk_moe_cuda<16><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
164+
topk_moe_cuda<16, with_norm>
165+
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
130166
break;
131167
case 32:
132-
topk_moe_cuda<32><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
168+
topk_moe_cuda<32, with_norm>
169+
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
133170
break;
134171
case 64:
135-
topk_moe_cuda<64><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
172+
topk_moe_cuda<64, with_norm>
173+
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
136174
break;
137175
case 128:
138-
topk_moe_cuda<128><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
176+
topk_moe_cuda<128, with_norm>
177+
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
139178
break;
140179
case 256:
141-
topk_moe_cuda<256><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
180+
topk_moe_cuda<256, with_norm>
181+
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
142182
break;
143183
case 512:
144-
topk_moe_cuda<512><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used);
184+
topk_moe_cuda<512, with_norm>
185+
<<<grid_dims, block_dims, nbytes_shared, stream>>>(logits, weights, ids, n_rows, n_expert_used);
145186
break;
146187
default:
147188
GGML_ASSERT(false && "fatal error");
@@ -152,7 +193,8 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
152193
void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
153194
const ggml_tensor * logits,
154195
ggml_tensor * weights,
155-
ggml_tensor * ids) {
196+
ggml_tensor * ids,
197+
const bool with_norm) {
156198
GGML_ASSERT(logits->type == GGML_TYPE_F32);
157199
GGML_ASSERT(weights->type == GGML_TYPE_F32);
158200
GGML_ASSERT(ids->type == GGML_TYPE_I32);
@@ -170,7 +212,11 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
170212

171213
const int n_expert_used = weights->ne[1];
172214

173-
launch_topk_moe_cuda(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
215+
if (with_norm) {
216+
launch_topk_moe_cuda<true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
217+
} else {
218+
launch_topk_moe_cuda<false>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
219+
}
174220
}
175221

176222
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights) {
@@ -201,3 +247,17 @@ bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tenso
201247

202248
return true;
203249
}
250+
251+
std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool norm) {
252+
static std::initializer_list<enum ggml_op> norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
253+
GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
254+
GGML_OP_SUM_ROWS, GGML_OP_DIV, GGML_OP_RESHAPE };
255+
256+
static std::initializer_list<enum ggml_op> no_norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
257+
GGML_OP_VIEW, GGML_OP_GET_ROWS };
258+
259+
if (norm) {
260+
return norm_ops;
261+
}
262+
return no_norm_ops;
263+
}

ggml/src/ggml-cuda/topk-moe.cuh

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
11
#include "common.cuh"
2+
#include "ggml.h"
3+
4+
#include <initializer_list>
25

36
void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
47
const ggml_tensor * logits,
58
ggml_tensor * weights,
6-
ggml_tensor * top_k);
9+
ggml_tensor * top_k,
10+
const bool with_norm);
711

812
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights);
13+
14+
std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool with_norm);

tests/test-backend-ops.cpp

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4421,13 +4421,14 @@ struct test_argsort : public test_case {
44214421
struct test_topk_moe: public test_case {
44224422
const std::array<int64_t, 4> ne;
44234423
const int n_expert_used;
4424-
test_topk_moe(std::array<int64_t, 4> ne = {10, 5, 1, 1}, int n_expert_used = 1)
4425-
: ne(ne), n_expert_used(n_expert_used) {
4424+
const bool with_norm;
4425+
test_topk_moe(std::array<int64_t, 4> ne = {10, 5, 1, 1}, int n_expert_used = 1, bool with_norm = false)
4426+
: ne(ne), n_expert_used(n_expert_used), with_norm(with_norm) {
44264427
GGML_ASSERT(n_expert_used <= ne[0]);
44274428
}
44284429

44294430
std::string vars() override {
4430-
return VARS_TO_STR2(ne, n_expert_used);
4431+
return VARS_TO_STR3(ne, n_expert_used, with_norm);
44314432
}
44324433

44334434
std::string op_desc(ggml_tensor * t) override {
@@ -4447,6 +4448,14 @@ struct test_topk_moe: public test_case {
44474448

44484449
ggml_tensor * out = ggml_get_rows(ctx, ggml_reshape_3d(ctx, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
44494450

4451+
if (with_norm) {
4452+
out = ggml_reshape_2d(ctx, out, n_expert_used, n_tokens);
4453+
ggml_tensor * weights_sum = ggml_sum_rows(ctx, out); // [1, n_tokens]
4454+
4455+
out = ggml_div(ctx, out, weights_sum); // [n_expert_used, n_tokens]
4456+
out = ggml_reshape_3d(ctx, out, 1, n_expert_used, n_tokens);
4457+
}
4458+
44504459
ggml_set_name(out, "out");
44514460
return out;
44524461
}
@@ -6622,10 +6631,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
66226631
test_cases.emplace_back(new test_opt_step_adamw(GGML_TYPE_F32, {10, 5, 4, 3}));
66236632
test_cases.emplace_back(new test_opt_step_sgd(GGML_TYPE_F32, {10, 5, 4, 3}));
66246633

6625-
6626-
test_cases.emplace_back(new test_topk_moe({8, 22, 1, 1}, 4));
6627-
test_cases.emplace_back(new test_topk_moe({32, 22, 1, 1}, 8));
6628-
test_cases.emplace_back(new test_topk_moe({128, 1, 1, 1}, 128));
6634+
for (bool with_norm : {false, true}) {
6635+
test_cases.emplace_back(new test_topk_moe({8, 22, 1, 1}, 4, with_norm));
6636+
test_cases.emplace_back(new test_topk_moe({32, 22, 1, 1}, 8, with_norm));
6637+
test_cases.emplace_back(new test_topk_moe({128, 1, 1, 1}, 128, with_norm));
6638+
}
66296639

66306640
#if 0
66316641
// these tests are disabled to save execution time, sbut they can be handy for debugging

0 commit comments

Comments
 (0)