Skip to content

Commit 5d66d32

Browse files
authored
[Example] Add topk into sparse mla example and append some docs (#901)
* Remove unused `fp8_mqa_logits.py` file and update README.md to reflect new directory structure and file descriptions for deepseek_v32 example. Added sections for architecture overview, Lightning Indexer, Top-k Selector, and Sparse MLA Forward implementations. * Update linting configurations and improve code formatting in deepseek_v32 example scripts - Added per-file ignores for the inference directory in `pyproject.toml`. - Refactored code in `topk_selector.py`, `convert.py`, `generate.py`, `kernel.py`, and `model.py` to enhance readability by adjusting spacing and line breaks. - Ensured consistent formatting across function definitions and assertions for better clarity. * Refactor test functions in deepseek_v32 example scripts for improved clarity and consistency - Updated `fp8_lighting_indexer.py` to define a dedicated test function for the lighting indexer. - Refactored `sparse_mla_fwd_pipelined.py` and `sparse_mla_fwd.py` to standardize test function parameters and improve readability. - Enhanced `topk_selector.py` by introducing a test function with parameters for batch size and sequence length. - Ensured all test functions are invoked correctly in the main execution block. * Enhance test functions in deepseek_v32 example scripts with CUDA requirements and parameterization - Added CUDA requirements decorators to `test_example_sparse_mla_fwd` and `test_example_sparse_mla_fwd_pipelined`. - Parameterized test functions to use specific small shapes for testing, improving test coverage and clarity. * lint fix * Update README.md to correct image path for DeepSeek V3.2 architecture diagram
1 parent 1656115 commit 5d66d32

16 files changed

+2070
-31
lines changed

examples/deepseek_v32/README.md

Lines changed: 160 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,166 @@
33
```
44
deepseek_v32/
55
├── README.md # This file
6-
├── fp8_mqa_logits.py # FP8 Indexer
6+
├── figures/ # Figures and diagrams
7+
├── inference/ # Inference implementation folder
8+
├── fp8_lighting_indexer.py # FP8 lighting indexer
79
├── sparse_mla_fwd.py # Sparse MLA forward implementation
810
├── sparse_mla_fwd_pipelined.py # Pipelined implementation of sparse MLA forward pass
11+
├── topk_selector.py # Top-k selector implementation
912
```
13+
14+
## File Descriptions
15+
16+
### Architecture Overview
17+
18+
![DeepSeek V3.2 Architecture](./figures/v32_arch.png)
19+
20+
The architecture diagram above highlights three key components (shown in green) that correspond to our kernel implementations:
21+
22+
1. **Lightning Indexer** (`fp8_lighting_indexer.py`) - Efficiently indexes and processes sparse attention patterns using FP8 precision
23+
2. **Top-k Selector** (`topk_selector.py`) - Selects the top-k most relevant tokens for sparse attention computation
24+
3. **Multi-Query Attention** (`sparse_mla_fwd.py` and `sparse_mla_fwd_pipelined.py`) - Core attention mechanism implementation with sparse MLA (Multi-Latent Attention) forward pass
25+
26+
### Lightning Indexer
27+
28+
Looking at the architecture diagram, the Lightning Indexer sits at the bottom right. It takes the input hidden states and produces compressed representations `{q^A_{t,i}}`, `{k^R_t}`, and `{w^I_{t,j}}`. These FP8-quantized index vectors are what feed into the top-k selector.
29+
30+
The main kernel `mqa_attn_return_logits_kernel` computes similarity scores between query and key indices:
31+
32+
```python
33+
T.gemm(
34+
index_k_shared,
35+
index_q_shared,
36+
s,
37+
transpose_B=True,
38+
clear_accum=True,
39+
policy=T.GemmWarpPolicy.FullCol,
40+
)
41+
```
42+
43+
After the matmul, we apply ReLU and aggregate across heads with learned weights:
44+
45+
```python
46+
for bn_i, bq_i, h_i in T.Parallel(block_N, block_Q, heads):
47+
s_reshaped[bn_i, bq_i, h_i] = (
48+
T.max(s[bn_i, bq_i * heads + h_i], 0) * weights[bq_i, h_i]
49+
) * index_k_scale_fragment[bn_i]
50+
51+
T.reduce_sum(s_reshaped, logits, dim=-1, clear=True)
52+
```
53+
54+
The result is a `[seq_len, seq_len_kv]` logits matrix. For long sequences, the kernel uses per-token bounds (`CuSeqLenKS`, `CuSeqLenKE`) to skip irrelevant KV positions:
55+
56+
```python
57+
for bq_i in T.serial(block_Q):
58+
cu_k_s_min[0] = T.min(cu_k_s_min[0], T.min(CuSeqLenKS[seq_len_i + bq_i], seq_len_kv))
59+
for bq_i in T.serial(block_Q):
60+
cu_k_e_max[0] = T.max(cu_k_e_max[0], T.min(CuSeqLenKE[seq_len_i + bq_i], seq_len_kv))
61+
```
62+
63+
The pipelined loop then only processes keys in the `[cu_k_s_min, cu_k_e_max)` range, which is crucial for handling variable-length sequences in distributed training.
64+
65+
### Top-k Selector
66+
67+
The Top-k Selector takes the logits matrix from the indexer and picks the top-k indices for each query. In the architecture diagram, this sits between the Lightning Indexer and the Multi-Query Attention block. The output indices tell the attention layer which KV tokens to actually load and process.
68+
69+
The implementation uses a radix-sort-based approach that processes floats as unsigned integers. Stage 1 does a quick 8-bit pass over the whole sequence:
70+
71+
```python
72+
for s in T.serial(T.ceildiv(seq_len, BLOCK_SIZE)):
73+
input_idx = s*BLOCK_SIZE+tx
74+
if input_idx < l_end_idx and input_idx >= l_start_idx and input_idx < seq_len:
75+
inval_int16 = convert_to_uint16(input[bx, input_idx])
76+
T.atomic_add(s_histogram[inval_int16], 1)
77+
```
78+
79+
The `convert_to_uint16` function maps floats to uint16 such that larger floats map to larger integers. After building a histogram and doing a cumulative sum, we find the threshold bin:
80+
81+
```python
82+
if s_histogram[tx] > l_new_topk and s_histogram[tx + 1] <= l_new_topk:
83+
s_threshold_bin_id[0] = tx
84+
```
85+
86+
Elements above the threshold go directly to the output. Elements in the threshold bin get collected for further processing:
87+
88+
```python
89+
if l_bin_id32 > l_threshold_bin_id:
90+
pos = T.atomic_add(s_histogram[l_bin_id32+1], 1, return_prev=True)
91+
index[bx, pos] = input_idx
92+
elif l_bin_id32 == l_threshold_bin_id and l_new_topk > 0:
93+
pos = T.atomic_add(s_num_input[0], 1, return_prev=True)
94+
s_input_idx[0, pos] = input_idx
95+
```
96+
97+
Stage 2 refines the threshold bin with up to 4 rounds of 8-bit radix sort, processing progressively higher bits. This gives exact top-k selection without sorting the entire sequence.
98+
99+
### Sparse MLA Forward
100+
101+
The Sparse MLA kernel is where the actual attention computation happens. In the architecture diagram, this is the large "Multi-Query Attention (Core Attention)" block at the top. It takes the selected top-k indices and computes attention only over those tokens.
102+
103+
Turning dense MLA into sparse MLA requires surprisingly few changes - essentially just modifying how we iterate and load KV tokens. The key difference from dense MLA (see `../deepseek_mla/example_mla_decode.py`) is the iteration pattern. Dense MLA iterates over all KV positions:
104+
105+
```python
106+
# Dense MLA: iterate over full sequence
107+
loop_range = T.ceildiv(seqlen_kv, block_N)
108+
for k in T.Pipelined(loop_range, num_stages=2):
109+
T.copy(KV[bid, k * block_N:(k + 1) * block_N, cur_kv_head, :], KV_shared)
110+
# ... compute attention over this block
111+
```
112+
113+
Sparse MLA only loads KV positions selected by the top-k selector:
114+
115+
```python
116+
# Sparse MLA: iterate over selected indices only
117+
for i_i in T.Pipelined(NI, num_stages=num_stages):
118+
for bi_i, d_i in T.Parallel(BI, D):
119+
KV_shared[bi_i, d_i] = KV[b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i, d_i]
120+
# ... compute attention over selected tokens
121+
```
122+
123+
This reduces compute from O(seq_len * seq_len_kv) to O(seq_len * topk). The causal mask is enforced by checking whether each index position is valid:
124+
125+
```python
126+
for bi_i in T.Parallel(BI):
127+
mask[bi_i] = Indices[b_i, s_i, g_i, i_i * BI + bi_i] <= max_kv_i
128+
```
129+
130+
Beyond this sparse indexing, the rest of the attention computation (online softmax, output accumulation) follows the same pattern as dense MLA.
131+
132+
### Sparse MLA Forward (Pipelined)
133+
134+
The pipelined version (`sparse_mla_fwd_pipelined.py`) is a manual pipeline implementation designed to match the schedule of [FlashMLA](https://github.com/deepseek-ai/FlashMLA/blob/main/csrc/sm90/prefill/sparse/fwd.cu). It achieves close to 600 TFlops on H800 SXM by carefully orchestrating memory and compute pipelines.
135+
136+
The key difference is splitting the warp groups into specialized roles:
137+
138+
```python
139+
if tx < 128:
140+
# Consumer 0: computes left half of output (D//2 dimensions)
141+
# Handles QK matmul, softmax, and PV for left half
142+
143+
elif tx >= 128 and tx < 256:
144+
# Consumer 1: computes right half of output (D//2 dimensions)
145+
# Only does PV matmul for right half
146+
147+
elif tx >= 256:
148+
# Producer: loads KV data from global memory
149+
# Uses async copy with barriers to feed consumers
150+
```
151+
152+
The producer thread group (tx >= 256) uses double buffering with barriers to keep consumers fed:
153+
154+
```python
155+
# Producer alternates between two buffers
156+
for i_i in T.serial(T.ceildiv(NI, 2)):
157+
# Buffer 0
158+
T.barrier_wait(bar_k_0_free[0], ((i_i & 1) ^ 1))
159+
# ... load KV into buffer 0
160+
T.cp_async_barrier_noinc(bar_k_0_ready[0])
161+
162+
# Buffer 1
163+
T.barrier_wait(bar_k_1_free[0], ((i_i & 1) ^ 1))
164+
# ... load KV into buffer 1
165+
T.cp_async_barrier_noinc(bar_k_1_ready[0])
166+
```
167+
168+
Consumer threads wait on barriers and process buffers as they become ready. This manual orchestration hides memory latency behind compute, which is why it outperforms the simpler auto-pipelined version. The output dimension is also split in half so that the two consumer groups can work in parallel on different parts of the matmul.
241 KB
Loading

examples/deepseek_v32/fp8_mqa_logits.py renamed to examples/deepseek_v32/fp8_lighting_indexer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -258,10 +258,7 @@ def ref_fp8_mqa_logits(q: torch.Tensor, kv: torch.Tensor, weights: torch.Tensor,
258258
cost = mask.sum()
259259
return logits, cost
260260

261-
262-
if __name__ == "__main__":
263-
torch.manual_seed(0)
264-
S, SKV, H, HKV, D, kv_stride = 4096, 8192, 32, 1, 64, 1
261+
def test_fp8_lighting_indexer(S=4096, SKV=8192, H=32, HKV=1, D=64, kv_stride=1):
265262
q = torch.randn(S, H, D, device="cuda", dtype=torch.bfloat16).to(torch.bfloat16)
266263
kv = torch.randn(SKV, D, device="cuda", dtype=torch.bfloat16).to(torch.bfloat16)
267264
weights = torch.randn(S, H, device="cuda", dtype=torch.float32)
@@ -304,3 +301,6 @@ def logits_fn():
304301
logits_tflops = logits_flops / (logits_ms * 1e-3) / 1e12
305302
print(f"logits_tflops: {logits_tflops}, logits_ms: {logits_ms}")
306303
print(f"cost_ref: {cost_ref}")
304+
305+
if __name__ == "__main__":
306+
test_fp8_lighting_indexer()
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
{
2+
"permissions": {
3+
"allow": [
4+
"Read(//weka-hg/prod/deepseek/permanent/wanglei/tilelang/examples/deepseek_v32/**)",
5+
"Read(//weka-hg/prod/deepseek/permanent/wanglei/tilelang/examples/deepseek_mla/**)"
6+
],
7+
"deny": [],
8+
"ask": []
9+
}
10+
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# DeepSeek V3.2
2+
3+
First convert huggingface model weights to the the format required by our inference demo. Set `MP` to match your available GPU count:
4+
```bash
5+
cd inference
6+
export EXPERTS=256
7+
python convert.py --hf-ckpt-path ${HF_CKPT_PATH} --save-path ${SAVE_PATH} --n-experts ${EXPERTS} --model-parallel ${MP}
8+
```
9+
10+
Launch the interactive chat interface and start exploring DeepSeek's capabilities:
11+
```bash
12+
export CONFIG=config_671B_v3.2.json
13+
torchrun --nproc-per-node ${MP} generate.py --ckpt-path ${SAVE_PATH} --config ${CONFIG} --interactive
14+
```
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
{
2+
"vocab_size": 129280,
3+
"dim": 7168,
4+
"inter_dim": 18432,
5+
"moe_inter_dim": 2048,
6+
"n_layers": 61,
7+
"n_dense_layers": 3,
8+
"n_heads": 128,
9+
"n_routed_experts": 256,
10+
"n_shared_experts": 1,
11+
"n_activated_experts": 8,
12+
"n_expert_groups": 8,
13+
"n_limited_groups": 4,
14+
"route_scale": 2.5,
15+
"score_func": "sigmoid",
16+
"q_lora_rank": 1536,
17+
"kv_lora_rank": 512,
18+
"qk_nope_head_dim": 128,
19+
"qk_rope_head_dim": 64,
20+
"v_head_dim": 128,
21+
"dtype": "fp8",
22+
"scale_fmt": "ue8m0",
23+
"index_n_heads": 64,
24+
"index_head_dim": 128,
25+
"index_topk": 2048
26+
}
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
import os
2+
import shutil
3+
from argparse import ArgumentParser
4+
from glob import glob
5+
from tqdm import tqdm, trange
6+
7+
import torch
8+
from safetensors.torch import safe_open, save_file
9+
10+
mapping = {
11+
"embed_tokens": ("embed", 0),
12+
"input_layernorm": ("attn_norm", None),
13+
"post_attention_layernorm": ("ffn_norm", None),
14+
"q_proj": ("wq", 0),
15+
"q_a_proj": ("wq_a", None),
16+
"q_a_layernorm": ("q_norm", None),
17+
"q_b_proj": ("wq_b", 0),
18+
"kv_a_proj_with_mqa": ("wkv_a", None),
19+
"kv_a_layernorm": ("kv_norm", None),
20+
"kv_b_proj": ("wkv_b", 0),
21+
"o_proj": ("wo", 1),
22+
"gate": ("gate", None),
23+
"gate_proj": ("w1", 0),
24+
"down_proj": ("w2", 1),
25+
"up_proj": ("w3", 0),
26+
"norm": ("norm", None),
27+
"lm_head": ("head", 0),
28+
"scale": ("scale", None),
29+
"wq_b": ("wq_b", None),
30+
"wk": ("wk", None),
31+
"k_norm": ("k_norm", None),
32+
"weights_proj": ("weights_proj", None),
33+
}
34+
35+
36+
def main(hf_ckpt_path, save_path, n_experts, mp):
37+
"""
38+
Converts and saves model checkpoint files into a specified format.
39+
40+
Args:
41+
hf_ckpt_path (str): Path to the directory containing the input checkpoint files.
42+
save_path (str): Path to the directory where the converted checkpoint files will be saved.
43+
n_experts (int): Total number of experts in the model.
44+
mp (int): Model parallelism factor.
45+
46+
Returns:
47+
None
48+
"""
49+
torch.set_num_threads(8)
50+
n_local_experts = n_experts // mp
51+
state_dicts = [{} for _ in range(mp)]
52+
53+
for file_path in tqdm(glob(os.path.join(hf_ckpt_path, "*.safetensors"))):
54+
with safe_open(file_path, framework="pt", device="cpu") as f:
55+
for name in f.keys():
56+
if "model.layers.61" in name:
57+
continue
58+
param: torch.Tensor = f.get_tensor(name)
59+
if name.startswith("model."):
60+
name = name[len("model."):]
61+
name = name.replace("self_attn", "attn")
62+
name = name.replace("mlp", "ffn")
63+
name = name.replace("weight_scale_inv", "scale")
64+
name = name.replace("e_score_correction_bias", "bias")
65+
key = name.split(".")[-2]
66+
assert key in mapping, f"Key {key} not found in mapping"
67+
new_key, dim = mapping[key]
68+
name = name.replace(key, new_key)
69+
for i in range(mp):
70+
new_param = param
71+
if "experts" in name and "shared_experts" not in name:
72+
idx = int(name.split(".")[-3])
73+
if idx < i * n_local_experts or idx >= (i + 1) * n_local_experts:
74+
continue
75+
elif dim is not None:
76+
assert param.size(
77+
dim) % mp == 0, f"Dimension {dim} must be divisible by {mp}"
78+
shard_size = param.size(dim) // mp
79+
new_param = param.narrow(dim, i * shard_size, shard_size).contiguous()
80+
state_dicts[i][name] = new_param
81+
82+
os.makedirs(save_path, exist_ok=True)
83+
84+
for i in trange(mp):
85+
save_file(state_dicts[i], os.path.join(save_path, f"model{i}-mp{mp}.safetensors"))
86+
87+
for file_path in glob(os.path.join(hf_ckpt_path, "*token*")):
88+
new_file_path = os.path.join(save_path, os.path.basename(file_path))
89+
shutil.copyfile(file_path, new_file_path)
90+
91+
92+
if __name__ == "__main__":
93+
parser = ArgumentParser()
94+
parser.add_argument("--hf-ckpt-path", type=str, required=True)
95+
parser.add_argument("--save-path", type=str, required=True)
96+
parser.add_argument("--n-experts", type=int, required=True)
97+
parser.add_argument("--model-parallel", type=int, required=True)
98+
args = parser.parse_args()
99+
assert args.n_experts % args.model_parallel == 0, "Number of experts must be divisible by model parallelism"
100+
main(args.hf_ckpt_path, args.save_path, args.n_experts, args.model_parallel)

0 commit comments

Comments
 (0)