|
| 1 | +#include "topk-moe.cuh" |
| 2 | + |
| 3 | +/* |
| 4 | + This kernel does the following: |
| 5 | + 1. softmax over the logits per token [n_experts, n_tokens] |
| 6 | + 2. argmax reduce over the top-k (n_experts_used) logits |
| 7 | + 3. write weights + ids to global memory |
| 8 | +
|
| 9 | + It is intended as fusion of softmax->top-k->get_rows pipeline for MoE models |
| 10 | +*/ |
| 11 | +template <size_t n_experts> |
| 12 | +__global__ void topk_moe_cuda(const float * logits, |
| 13 | + float * weights, |
| 14 | + int32_t * ids, |
| 15 | + const int n_rows, |
| 16 | + const int n_expert_used) { |
| 17 | + const int row = blockIdx.x * blockDim.y + threadIdx.y; |
| 18 | + if (row >= n_rows) { |
| 19 | + return; |
| 20 | + } |
| 21 | + logits += n_experts * row; |
| 22 | + ids += n_experts * row; |
| 23 | + weights += n_expert_used * row; |
| 24 | + |
| 25 | + constexpr int experts_per_thread = (n_experts > 32) ? n_experts / 32 : 1; |
| 26 | + |
| 27 | + const int start_expert = threadIdx.x * experts_per_thread; |
| 28 | + const int end_expert = (threadIdx.x + 1) * experts_per_thread; |
| 29 | + float max_val = -INFINITY; |
| 30 | + |
| 31 | +#pragma unroll |
| 32 | + for (int i = 0; i < experts_per_thread; i++) { |
| 33 | + const int expert = start_expert + i; |
| 34 | + const float val = (expert < n_experts) ? logits[expert] : -INFINITY; |
| 35 | + max_val = max(val, max_val); |
| 36 | + } |
| 37 | + |
| 38 | + max_val = warp_reduce_max(max_val); |
| 39 | + |
| 40 | + float wt[experts_per_thread]; |
| 41 | + float tmp = 0.f; |
| 42 | + |
| 43 | +#pragma unroll |
| 44 | + for (int i = 0; i < experts_per_thread; i++) { |
| 45 | + const int expert = start_expert + i; |
| 46 | + const float val = (expert < n_experts) ? logits[expert] : -INFINITY; |
| 47 | + wt[i] = expf(val - max_val); |
| 48 | + tmp += wt[i]; |
| 49 | + } |
| 50 | + |
| 51 | + tmp = warp_reduce_sum(tmp); |
| 52 | + |
| 53 | + const float inv_sum = 1.0f / tmp; |
| 54 | + |
| 55 | +#pragma unroll |
| 56 | + for (int i = 0; i < experts_per_thread; i++) { |
| 57 | + wt[i] = wt[i] * inv_sum; |
| 58 | + } |
| 59 | + |
| 60 | + //at this point, each thread holds a portion of softmax, |
| 61 | + //we do the argmax reduce over n_expert_used, each time marking |
| 62 | + //the expert weight as -inf to exclude from the next iteration |
| 63 | + |
| 64 | + for (int k = 0; k < n_expert_used; k++) { |
| 65 | + float max_val = wt[0]; |
| 66 | + int max_expert = start_expert; |
| 67 | + |
| 68 | +#pragma unroll |
| 69 | + for (int i = 1; i < experts_per_thread; i++) { |
| 70 | + const int expert = start_expert + i; |
| 71 | + if (wt[i] > max_val) { |
| 72 | + max_val = wt[i]; |
| 73 | + max_expert = expert; |
| 74 | + } |
| 75 | + } |
| 76 | + |
| 77 | +#pragma unroll |
| 78 | + for (int mask = warpSize / 2; mask > 0; mask /= 2) { |
| 79 | + const float val = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, warpSize); |
| 80 | + const int expert = __shfl_xor_sync(0xFFFFFFFF, max_expert, mask, warpSize); |
| 81 | + if (val > max_val) { |
| 82 | + max_val = val; |
| 83 | + max_expert = expert; |
| 84 | + } |
| 85 | + } |
| 86 | + |
| 87 | + if (max_expert >= start_expert && max_expert < end_expert) { |
| 88 | + wt[max_expert - start_expert] = -INFINITY; |
| 89 | + |
| 90 | + weights[k] = max_val; |
| 91 | + ids[k] = max_expert; |
| 92 | + } |
| 93 | + } |
| 94 | +} |
| 95 | + |
| 96 | +static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx, |
| 97 | + const float * logits, |
| 98 | + float * weights, |
| 99 | + int32_t * ids, |
| 100 | + const int n_rows, |
| 101 | + const int n_expert, |
| 102 | + const int n_expert_used) { |
| 103 | + const int rows_per_block = 4; |
| 104 | + dim3 grid_dims((n_rows + rows_per_block - 1) / rows_per_block, 1, 1); |
| 105 | + dim3 block_dims(32, rows_per_block, 1); |
| 106 | + cudaStream_t stream = ctx.stream(); |
| 107 | + |
| 108 | + switch (n_expert) { |
| 109 | + case 1: |
| 110 | + topk_moe_cuda<1><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used); |
| 111 | + break; |
| 112 | + case 2: |
| 113 | + topk_moe_cuda<2><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used); |
| 114 | + break; |
| 115 | + case 4: |
| 116 | + topk_moe_cuda<4><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used); |
| 117 | + break; |
| 118 | + case 8: |
| 119 | + topk_moe_cuda<8><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used); |
| 120 | + break; |
| 121 | + case 16: |
| 122 | + topk_moe_cuda<16><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used); |
| 123 | + break; |
| 124 | + case 32: |
| 125 | + topk_moe_cuda<32><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used); |
| 126 | + break; |
| 127 | + case 64: |
| 128 | + topk_moe_cuda<64><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used); |
| 129 | + break; |
| 130 | + case 128: |
| 131 | + topk_moe_cuda<128><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used); |
| 132 | + break; |
| 133 | + case 256: |
| 134 | + topk_moe_cuda<256><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used); |
| 135 | + break; |
| 136 | + case 512: |
| 137 | + topk_moe_cuda<512><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used); |
| 138 | + break; |
| 139 | + default: |
| 140 | + GGML_ASSERT(false && "fatal error"); |
| 141 | + break; |
| 142 | + } |
| 143 | +} |
| 144 | + |
| 145 | +void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx, |
| 146 | + ggml_tensor * logits, |
| 147 | + ggml_tensor * weights, |
| 148 | + ggml_tensor * ids) { |
| 149 | + GGML_ASSERT(logits->type == GGML_TYPE_F32); |
| 150 | + GGML_ASSERT(weights->type == GGML_TYPE_F32); |
| 151 | + GGML_ASSERT(ids->type == GGML_TYPE_I32); |
| 152 | + |
| 153 | + const float * logits_d = (const float *) logits->src[0]->data; |
| 154 | + float * weights_d = (float *) weights->data; |
| 155 | + int32_t * ids_d = (int32_t *) ids->data; |
| 156 | + |
| 157 | + const int n_experts = logits->ne[0]; |
| 158 | + const int n_rows = logits->ne[1]; |
| 159 | + |
| 160 | + cudaStream_t stream = ctx.stream(); |
| 161 | + |
| 162 | + const int n_expert_used = weights->ne[1]; |
| 163 | + |
| 164 | + launch_topk_moe_cuda(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used); |
| 165 | +} |
0 commit comments