|
3 | 3 | ``` |
4 | 4 | deepseek_v32/ |
5 | 5 | ├── README.md # This file |
6 | | -├── fp8_mqa_logits.py # FP8 Indexer |
| 6 | +├── figures/ # Figures and diagrams |
| 7 | +├── inference/ # Inference implementation folder |
| 8 | +├── fp8_lighting_indexer.py # FP8 lighting indexer |
7 | 9 | ├── sparse_mla_fwd.py # Sparse MLA forward implementation |
8 | 10 | ├── sparse_mla_fwd_pipelined.py # Pipelined implementation of sparse MLA forward pass |
| 11 | +├── topk_selector.py # Top-k selector implementation |
9 | 12 | ``` |
| 13 | + |
| 14 | +## File Descriptions |
| 15 | + |
| 16 | +### Architecture Overview |
| 17 | + |
| 18 | + |
| 19 | + |
| 20 | +The architecture diagram above highlights three key components (shown in green) that correspond to our kernel implementations: |
| 21 | + |
| 22 | +1. **Lightning Indexer** (`fp8_lighting_indexer.py`) - Efficiently indexes and processes sparse attention patterns using FP8 precision |
| 23 | +2. **Top-k Selector** (`topk_selector.py`) - Selects the top-k most relevant tokens for sparse attention computation |
| 24 | +3. **Multi-Query Attention** (`sparse_mla_fwd.py` and `sparse_mla_fwd_pipelined.py`) - Core attention mechanism implementation with sparse MLA (Multi-Latent Attention) forward pass |
| 25 | + |
| 26 | +### Lightning Indexer |
| 27 | + |
| 28 | +Looking at the architecture diagram, the Lightning Indexer sits at the bottom right. It takes the input hidden states and produces compressed representations `{q^A_{t,i}}`, `{k^R_t}`, and `{w^I_{t,j}}`. These FP8-quantized index vectors are what feed into the top-k selector. |
| 29 | + |
| 30 | +The main kernel `mqa_attn_return_logits_kernel` computes similarity scores between query and key indices: |
| 31 | + |
| 32 | +```python |
| 33 | +T.gemm( |
| 34 | + index_k_shared, |
| 35 | + index_q_shared, |
| 36 | + s, |
| 37 | + transpose_B=True, |
| 38 | + clear_accum=True, |
| 39 | + policy=T.GemmWarpPolicy.FullCol, |
| 40 | +) |
| 41 | +``` |
| 42 | + |
| 43 | +After the matmul, we apply ReLU and aggregate across heads with learned weights: |
| 44 | + |
| 45 | +```python |
| 46 | +for bn_i, bq_i, h_i in T.Parallel(block_N, block_Q, heads): |
| 47 | + s_reshaped[bn_i, bq_i, h_i] = ( |
| 48 | + T.max(s[bn_i, bq_i * heads + h_i], 0) * weights[bq_i, h_i] |
| 49 | + ) * index_k_scale_fragment[bn_i] |
| 50 | + |
| 51 | +T.reduce_sum(s_reshaped, logits, dim=-1, clear=True) |
| 52 | +``` |
| 53 | + |
| 54 | +The result is a `[seq_len, seq_len_kv]` logits matrix. For long sequences, the kernel uses per-token bounds (`CuSeqLenKS`, `CuSeqLenKE`) to skip irrelevant KV positions: |
| 55 | + |
| 56 | +```python |
| 57 | +for bq_i in T.serial(block_Q): |
| 58 | + cu_k_s_min[0] = T.min(cu_k_s_min[0], T.min(CuSeqLenKS[seq_len_i + bq_i], seq_len_kv)) |
| 59 | +for bq_i in T.serial(block_Q): |
| 60 | + cu_k_e_max[0] = T.max(cu_k_e_max[0], T.min(CuSeqLenKE[seq_len_i + bq_i], seq_len_kv)) |
| 61 | +``` |
| 62 | + |
| 63 | +The pipelined loop then only processes keys in the `[cu_k_s_min, cu_k_e_max)` range, which is crucial for handling variable-length sequences in distributed training. |
| 64 | + |
| 65 | +### Top-k Selector |
| 66 | + |
| 67 | +The Top-k Selector takes the logits matrix from the indexer and picks the top-k indices for each query. In the architecture diagram, this sits between the Lightning Indexer and the Multi-Query Attention block. The output indices tell the attention layer which KV tokens to actually load and process. |
| 68 | + |
| 69 | +The implementation uses a radix-sort-based approach that processes floats as unsigned integers. Stage 1 does a quick 8-bit pass over the whole sequence: |
| 70 | + |
| 71 | +```python |
| 72 | +for s in T.serial(T.ceildiv(seq_len, BLOCK_SIZE)): |
| 73 | + input_idx = s*BLOCK_SIZE+tx |
| 74 | + if input_idx < l_end_idx and input_idx >= l_start_idx and input_idx < seq_len: |
| 75 | + inval_int16 = convert_to_uint16(input[bx, input_idx]) |
| 76 | + T.atomic_add(s_histogram[inval_int16], 1) |
| 77 | +``` |
| 78 | + |
| 79 | +The `convert_to_uint16` function maps floats to uint16 such that larger floats map to larger integers. After building a histogram and doing a cumulative sum, we find the threshold bin: |
| 80 | + |
| 81 | +```python |
| 82 | +if s_histogram[tx] > l_new_topk and s_histogram[tx + 1] <= l_new_topk: |
| 83 | + s_threshold_bin_id[0] = tx |
| 84 | +``` |
| 85 | + |
| 86 | +Elements above the threshold go directly to the output. Elements in the threshold bin get collected for further processing: |
| 87 | + |
| 88 | +```python |
| 89 | +if l_bin_id32 > l_threshold_bin_id: |
| 90 | + pos = T.atomic_add(s_histogram[l_bin_id32+1], 1, return_prev=True) |
| 91 | + index[bx, pos] = input_idx |
| 92 | +elif l_bin_id32 == l_threshold_bin_id and l_new_topk > 0: |
| 93 | + pos = T.atomic_add(s_num_input[0], 1, return_prev=True) |
| 94 | + s_input_idx[0, pos] = input_idx |
| 95 | +``` |
| 96 | + |
| 97 | +Stage 2 refines the threshold bin with up to 4 rounds of 8-bit radix sort, processing progressively higher bits. This gives exact top-k selection without sorting the entire sequence. |
| 98 | + |
| 99 | +### Sparse MLA Forward |
| 100 | + |
| 101 | +The Sparse MLA kernel is where the actual attention computation happens. In the architecture diagram, this is the large "Multi-Query Attention (Core Attention)" block at the top. It takes the selected top-k indices and computes attention only over those tokens. |
| 102 | + |
| 103 | +Turning dense MLA into sparse MLA requires surprisingly few changes - essentially just modifying how we iterate and load KV tokens. The key difference from dense MLA (see `../deepseek_mla/example_mla_decode.py`) is the iteration pattern. Dense MLA iterates over all KV positions: |
| 104 | + |
| 105 | +```python |
| 106 | +# Dense MLA: iterate over full sequence |
| 107 | +loop_range = T.ceildiv(seqlen_kv, block_N) |
| 108 | +for k in T.Pipelined(loop_range, num_stages=2): |
| 109 | + T.copy(KV[bid, k * block_N:(k + 1) * block_N, cur_kv_head, :], KV_shared) |
| 110 | + # ... compute attention over this block |
| 111 | +``` |
| 112 | + |
| 113 | +Sparse MLA only loads KV positions selected by the top-k selector: |
| 114 | + |
| 115 | +```python |
| 116 | +# Sparse MLA: iterate over selected indices only |
| 117 | +for i_i in T.Pipelined(NI, num_stages=num_stages): |
| 118 | + for bi_i, d_i in T.Parallel(BI, D): |
| 119 | + KV_shared[bi_i, d_i] = KV[b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i, d_i] |
| 120 | + # ... compute attention over selected tokens |
| 121 | +``` |
| 122 | + |
| 123 | +This reduces compute from O(seq_len * seq_len_kv) to O(seq_len * topk). The causal mask is enforced by checking whether each index position is valid: |
| 124 | + |
| 125 | +```python |
| 126 | +for bi_i in T.Parallel(BI): |
| 127 | + mask[bi_i] = Indices[b_i, s_i, g_i, i_i * BI + bi_i] <= max_kv_i |
| 128 | +``` |
| 129 | + |
| 130 | +Beyond this sparse indexing, the rest of the attention computation (online softmax, output accumulation) follows the same pattern as dense MLA. |
| 131 | + |
| 132 | +### Sparse MLA Forward (Pipelined) |
| 133 | + |
| 134 | +The pipelined version (`sparse_mla_fwd_pipelined.py`) is a manual pipeline implementation designed to match the schedule of [FlashMLA](https://github.com/deepseek-ai/FlashMLA/blob/main/csrc/sm90/prefill/sparse/fwd.cu). It achieves close to 600 TFlops on H800 SXM by carefully orchestrating memory and compute pipelines. |
| 135 | + |
| 136 | +The key difference is splitting the warp groups into specialized roles: |
| 137 | + |
| 138 | +```python |
| 139 | +if tx < 128: |
| 140 | + # Consumer 0: computes left half of output (D//2 dimensions) |
| 141 | + # Handles QK matmul, softmax, and PV for left half |
| 142 | + |
| 143 | +elif tx >= 128 and tx < 256: |
| 144 | + # Consumer 1: computes right half of output (D//2 dimensions) |
| 145 | + # Only does PV matmul for right half |
| 146 | + |
| 147 | +elif tx >= 256: |
| 148 | + # Producer: loads KV data from global memory |
| 149 | + # Uses async copy with barriers to feed consumers |
| 150 | +``` |
| 151 | + |
| 152 | +The producer thread group (tx >= 256) uses double buffering with barriers to keep consumers fed: |
| 153 | + |
| 154 | +```python |
| 155 | +# Producer alternates between two buffers |
| 156 | +for i_i in T.serial(T.ceildiv(NI, 2)): |
| 157 | + # Buffer 0 |
| 158 | + T.barrier_wait(bar_k_0_free[0], ((i_i & 1) ^ 1)) |
| 159 | + # ... load KV into buffer 0 |
| 160 | + T.cp_async_barrier_noinc(bar_k_0_ready[0]) |
| 161 | + |
| 162 | + # Buffer 1 |
| 163 | + T.barrier_wait(bar_k_1_free[0], ((i_i & 1) ^ 1)) |
| 164 | + # ... load KV into buffer 1 |
| 165 | + T.cp_async_barrier_noinc(bar_k_1_ready[0]) |
| 166 | +``` |
| 167 | + |
| 168 | +Consumer threads wait on barriers and process buffers as they become ready. This manual orchestration hides memory latency behind compute, which is why it outperforms the simpler auto-pipelined version. The output dimension is also split in half so that the two consumer groups can work in parallel on different parts of the matmul. |
0 commit comments