Skip to content

Commit e772b28

Browse files
committed
Use smem for final write
1 parent 240b2c1 commit e772b28

File tree

3 files changed

+5
-5
lines changed

3 files changed

+5
-5
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2953,7 +2953,6 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
29532953
}
29542954

29552955
if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ false), {})) {
2956-
29572956
ggml_tensor * weights = cgraph->nodes[i+4];
29582957
ggml_tensor * selected_experts = cgraph->nodes[i+3];
29592958
ggml_cuda_op_topk_moe(*cuda_ctx, node, weights, selected_experts, /*with norm*/ false);

ggml/src/ggml-cuda/topk-moe.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
7474
float wt_sum = 0.f;
7575

7676
extern __shared__ float data_topk_shared[];
77-
float * wt_shared_ptr = data_topk_shared + row * n_expert_used;
77+
float * wt_shared_ptr = data_topk_shared + threadIdx.y * n_expert_used;
7878

7979
for (int k = 0; k < n_expert_used; k++) {
8080
float max_val = wt[0];
@@ -83,7 +83,7 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
8383
#pragma unroll
8484
for (int i = 1; i < experts_per_thread; i++) {
8585
const int expert = threadIdx.x + i * WARP_SIZE;
86-
if (expert < n_experts && wt[i] > max_val) {
86+
if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && wt[i] > max_val) {
8787
max_val = wt[i];
8888
max_expert = expert;
8989
}

src/llama-graph.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -929,8 +929,6 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
929929
ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
930930
cb(weights, "ffn_moe_weights", il);
931931

932-
//call early so that softmax->topk->get_rows can be fused
933-
ggml_build_forward_expand(gf, weights);
934932

935933
if (gating_op == LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT) {
936934
weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens);
@@ -955,6 +953,9 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
955953
cb(weights, "ffn_moe_weights_scaled", il);
956954
}
957955

956+
//call early so that topk-moe can be used
957+
ggml_build_forward_expand(gf, weights);
958+
958959
cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens);
959960

960961
if (weight_before_ffn) {

0 commit comments

Comments
 (0)