Skip to content

Commit 453d442

Browse files
committed
Remove unused fp8_mqa_logits.py file and update README.md to reflect new directory structure and file descriptions for deepseek_v32 example. Added sections for architecture overview, Lightning Indexer, Top-k Selector, and Sparse MLA Forward implementations.
1 parent d19fe1a commit 453d442

File tree

12 files changed

+1926
-1
lines changed

12 files changed

+1926
-1
lines changed

examples/deepseek_v32/README.md

Lines changed: 160 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,166 @@
33
```
44
deepseek_v32/
55
├── README.md # This file
6-
├── fp8_mqa_logits.py # FP8 Indexer
6+
├── figures/ # Figures and diagrams
7+
├── inference/ # Inference implementation folder
8+
├── fp8_lighting_indexer.py # FP8 lighting indexer
79
├── sparse_mla_fwd.py # Sparse MLA forward implementation
810
├── sparse_mla_fwd_pipelined.py # Pipelined implementation of sparse MLA forward pass
11+
├── topk_selector.py # Top-k selector implementation
912
```
13+
14+
## File Descriptions
15+
16+
### Architecture Overview
17+
18+
![DeepSeek V3.2 Architecture](figures/v32_arch.png)
19+
20+
The architecture diagram above highlights three key components (shown in green) that correspond to our kernel implementations:
21+
22+
1. **Lightning Indexer** (`fp8_lighting_indexer.py`) - Efficiently indexes and processes sparse attention patterns using FP8 precision
23+
2. **Top-k Selector** (`topk_selector.py`) - Selects the top-k most relevant tokens for sparse attention computation
24+
3. **Multi-Query Attention** (`sparse_mla_fwd.py` and `sparse_mla_fwd_pipelined.py`) - Core attention mechanism implementation with sparse MLA (Multi-Latent Attention) forward pass
25+
26+
### Lightning Indexer
27+
28+
Looking at the architecture diagram, the Lightning Indexer sits at the bottom right. It takes the input hidden states and produces compressed representations `{q^A_{t,i}}`, `{k^R_t}`, and `{w^I_{t,j}}`. These FP8-quantized index vectors are what feed into the top-k selector.
29+
30+
The main kernel `mqa_attn_return_logits_kernel` computes similarity scores between query and key indices:
31+
32+
```python
33+
T.gemm(
34+
index_k_shared,
35+
index_q_shared,
36+
s,
37+
transpose_B=True,
38+
clear_accum=True,
39+
policy=T.GemmWarpPolicy.FullCol,
40+
)
41+
```
42+
43+
After the matmul, we apply ReLU and aggregate across heads with learned weights:
44+
45+
```python
46+
for bn_i, bq_i, h_i in T.Parallel(block_N, block_Q, heads):
47+
s_reshaped[bn_i, bq_i, h_i] = (
48+
T.max(s[bn_i, bq_i * heads + h_i], 0) * weights[bq_i, h_i]
49+
) * index_k_scale_fragment[bn_i]
50+
51+
T.reduce_sum(s_reshaped, logits, dim=-1, clear=True)
52+
```
53+
54+
The result is a `[seq_len, seq_len_kv]` logits matrix. For long sequences, the kernel uses per-token bounds (`CuSeqLenKS`, `CuSeqLenKE`) to skip irrelevant KV positions:
55+
56+
```python
57+
for bq_i in T.serial(block_Q):
58+
cu_k_s_min[0] = T.min(cu_k_s_min[0], T.min(CuSeqLenKS[seq_len_i + bq_i], seq_len_kv))
59+
for bq_i in T.serial(block_Q):
60+
cu_k_e_max[0] = T.max(cu_k_e_max[0], T.min(CuSeqLenKE[seq_len_i + bq_i], seq_len_kv))
61+
```
62+
63+
The pipelined loop then only processes keys in the `[cu_k_s_min, cu_k_e_max)` range, which is crucial for handling variable-length sequences in distributed training.
64+
65+
### Top-k Selector
66+
67+
The Top-k Selector takes the logits matrix from the indexer and picks the top-k indices for each query. In the architecture diagram, this sits between the Lightning Indexer and the Multi-Query Attention block. The output indices tell the attention layer which KV tokens to actually load and process.
68+
69+
The implementation uses a radix-sort-based approach that processes floats as unsigned integers. Stage 1 does a quick 8-bit pass over the whole sequence:
70+
71+
```python
72+
for s in T.serial(T.ceildiv(seq_len, BLOCK_SIZE)):
73+
input_idx = s*BLOCK_SIZE+tx
74+
if input_idx < l_end_idx and input_idx >= l_start_idx and input_idx < seq_len:
75+
inval_int16 = convert_to_uint16(input[bx, input_idx])
76+
T.atomic_add(s_histogram[inval_int16], 1)
77+
```
78+
79+
The `convert_to_uint16` function maps floats to uint16 such that larger floats map to larger integers. After building a histogram and doing a cumulative sum, we find the threshold bin:
80+
81+
```python
82+
if s_histogram[tx] > l_new_topk and s_histogram[tx + 1] <= l_new_topk:
83+
s_threshold_bin_id[0] = tx
84+
```
85+
86+
Elements above the threshold go directly to the output. Elements in the threshold bin get collected for further processing:
87+
88+
```python
89+
if l_bin_id32 > l_threshold_bin_id:
90+
pos = T.atomic_add(s_histogram[l_bin_id32+1], 1, return_prev=True)
91+
index[bx, pos] = input_idx
92+
elif l_bin_id32 == l_threshold_bin_id and l_new_topk > 0:
93+
pos = T.atomic_add(s_num_input[0], 1, return_prev=True)
94+
s_input_idx[0, pos] = input_idx
95+
```
96+
97+
Stage 2 refines the threshold bin with up to 4 rounds of 8-bit radix sort, processing progressively higher bits. This gives exact top-k selection without sorting the entire sequence.
98+
99+
### Sparse MLA Forward
100+
101+
The Sparse MLA kernel is where the actual attention computation happens. In the architecture diagram, this is the large "Multi-Query Attention (Core Attention)" block at the top. It takes the selected top-k indices and computes attention only over those tokens.
102+
103+
Turning dense MLA into sparse MLA requires surprisingly few changes - essentially just modifying how we iterate and load KV tokens. The key difference from dense MLA (see `../deepseek_mla/example_mla_decode.py`) is the iteration pattern. Dense MLA iterates over all KV positions:
104+
105+
```python
106+
# Dense MLA: iterate over full sequence
107+
loop_range = T.ceildiv(seqlen_kv, block_N)
108+
for k in T.Pipelined(loop_range, num_stages=2):
109+
T.copy(KV[bid, k * block_N:(k + 1) * block_N, cur_kv_head, :], KV_shared)
110+
# ... compute attention over this block
111+
```
112+
113+
Sparse MLA only loads KV positions selected by the top-k selector:
114+
115+
```python
116+
# Sparse MLA: iterate over selected indices only
117+
for i_i in T.Pipelined(NI, num_stages=num_stages):
118+
for bi_i, d_i in T.Parallel(BI, D):
119+
KV_shared[bi_i, d_i] = KV[b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i, d_i]
120+
# ... compute attention over selected tokens
121+
```
122+
123+
This reduces compute from O(seq_len * seq_len_kv) to O(seq_len * topk). The causal mask is enforced by checking whether each index position is valid:
124+
125+
```python
126+
for bi_i in T.Parallel(BI):
127+
mask[bi_i] = Indices[b_i, s_i, g_i, i_i * BI + bi_i] <= max_kv_i
128+
```
129+
130+
Beyond this sparse indexing, the rest of the attention computation (online softmax, output accumulation) follows the same pattern as dense MLA.
131+
132+
### Sparse MLA Forward (Pipelined)
133+
134+
The pipelined version (`sparse_mla_fwd_pipelined.py`) is a manual pipeline implementation designed to match the schedule of [FlashMLA](https://github.com/deepseek-ai/FlashMLA/blob/main/csrc/sm90/prefill/sparse/fwd.cu). It achieves close to 600 TFlops on H800 SXM by carefully orchestrating memory and compute pipelines.
135+
136+
The key difference is splitting the warp groups into specialized roles:
137+
138+
```python
139+
if tx < 128:
140+
# Consumer 0: computes left half of output (D//2 dimensions)
141+
# Handles QK matmul, softmax, and PV for left half
142+
143+
elif tx >= 128 and tx < 256:
144+
# Consumer 1: computes right half of output (D//2 dimensions)
145+
# Only does PV matmul for right half
146+
147+
elif tx >= 256:
148+
# Producer: loads KV data from global memory
149+
# Uses async copy with barriers to feed consumers
150+
```
151+
152+
The producer thread group (tx >= 256) uses double buffering with barriers to keep consumers fed:
153+
154+
```python
155+
# Producer alternates between two buffers
156+
for i_i in T.serial(T.ceildiv(NI, 2)):
157+
# Buffer 0
158+
T.barrier_wait(bar_k_0_free[0], ((i_i & 1) ^ 1))
159+
# ... load KV into buffer 0
160+
T.cp_async_barrier_noinc(bar_k_0_ready[0])
161+
162+
# Buffer 1
163+
T.barrier_wait(bar_k_1_free[0], ((i_i & 1) ^ 1))
164+
# ... load KV into buffer 1
165+
T.cp_async_barrier_noinc(bar_k_1_ready[0])
166+
```
167+
168+
Consumer threads wait on barriers and process buffers as they become ready. This manual orchestration hides memory latency behind compute, which is why it outperforms the simpler auto-pipelined version. The output dimension is also split in half so that the two consumer groups can work in parallel on different parts of the matmul.
241 KB
Loading
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
{
2+
"permissions": {
3+
"allow": [
4+
"Read(//weka-hg/prod/deepseek/permanent/wanglei/tilelang/examples/deepseek_v32/**)",
5+
"Read(//weka-hg/prod/deepseek/permanent/wanglei/tilelang/examples/deepseek_mla/**)"
6+
],
7+
"deny": [],
8+
"ask": []
9+
}
10+
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# DeepSeek V3.2
2+
3+
First convert huggingface model weights to the the format required by our inference demo. Set `MP` to match your available GPU count:
4+
```bash
5+
cd inference
6+
export EXPERTS=256
7+
python convert.py --hf-ckpt-path ${HF_CKPT_PATH} --save-path ${SAVE_PATH} --n-experts ${EXPERTS} --model-parallel ${MP}
8+
```
9+
10+
Launch the interactive chat interface and start exploring DeepSeek's capabilities:
11+
```bash
12+
export CONFIG=config_671B_v3.2.json
13+
torchrun --nproc-per-node ${MP} generate.py --ckpt-path ${SAVE_PATH} --config ${CONFIG} --interactive
14+
```
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
{
2+
"vocab_size": 129280,
3+
"dim": 7168,
4+
"inter_dim": 18432,
5+
"moe_inter_dim": 2048,
6+
"n_layers": 61,
7+
"n_dense_layers": 3,
8+
"n_heads": 128,
9+
"n_routed_experts": 256,
10+
"n_shared_experts": 1,
11+
"n_activated_experts": 8,
12+
"n_expert_groups": 8,
13+
"n_limited_groups": 4,
14+
"route_scale": 2.5,
15+
"score_func": "sigmoid",
16+
"q_lora_rank": 1536,
17+
"kv_lora_rank": 512,
18+
"qk_nope_head_dim": 128,
19+
"qk_rope_head_dim": 64,
20+
"v_head_dim": 128,
21+
"dtype": "fp8",
22+
"scale_fmt": "ue8m0",
23+
"index_n_heads": 64,
24+
"index_head_dim": 128,
25+
"index_topk": 2048
26+
}
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
import os
2+
import shutil
3+
from argparse import ArgumentParser
4+
from glob import glob
5+
from tqdm import tqdm, trange
6+
7+
import torch
8+
from safetensors.torch import safe_open, save_file
9+
10+
11+
mapping = {
12+
"embed_tokens": ("embed", 0),
13+
"input_layernorm": ("attn_norm", None),
14+
"post_attention_layernorm": ("ffn_norm", None),
15+
"q_proj": ("wq", 0),
16+
"q_a_proj": ("wq_a", None),
17+
"q_a_layernorm": ("q_norm", None),
18+
"q_b_proj": ("wq_b", 0),
19+
"kv_a_proj_with_mqa": ("wkv_a", None),
20+
"kv_a_layernorm": ("kv_norm", None),
21+
"kv_b_proj": ("wkv_b", 0),
22+
"o_proj": ("wo", 1),
23+
"gate": ("gate", None),
24+
"gate_proj": ("w1", 0),
25+
"down_proj": ("w2", 1),
26+
"up_proj": ("w3", 0),
27+
"norm": ("norm", None),
28+
"lm_head": ("head", 0),
29+
"scale": ("scale", None),
30+
"wq_b": ("wq_b", None),
31+
"wk": ("wk", None),
32+
"k_norm": ("k_norm", None),
33+
"weights_proj": ("weights_proj", None),
34+
}
35+
36+
37+
def main(hf_ckpt_path, save_path, n_experts, mp):
38+
"""
39+
Converts and saves model checkpoint files into a specified format.
40+
41+
Args:
42+
hf_ckpt_path (str): Path to the directory containing the input checkpoint files.
43+
save_path (str): Path to the directory where the converted checkpoint files will be saved.
44+
n_experts (int): Total number of experts in the model.
45+
mp (int): Model parallelism factor.
46+
47+
Returns:
48+
None
49+
"""
50+
torch.set_num_threads(8)
51+
n_local_experts = n_experts // mp
52+
state_dicts = [{} for _ in range(mp)]
53+
54+
for file_path in tqdm(glob(os.path.join(hf_ckpt_path, "*.safetensors"))):
55+
with safe_open(file_path, framework="pt", device="cpu") as f:
56+
for name in f.keys():
57+
if "model.layers.61" in name:
58+
continue
59+
param: torch.Tensor = f.get_tensor(name)
60+
if name.startswith("model."):
61+
name = name[len("model."):]
62+
name = name.replace("self_attn", "attn")
63+
name = name.replace("mlp", "ffn")
64+
name = name.replace("weight_scale_inv", "scale")
65+
name = name.replace("e_score_correction_bias", "bias")
66+
key = name.split(".")[-2]
67+
assert key in mapping, f"Key {key} not found in mapping"
68+
new_key, dim = mapping[key]
69+
name = name.replace(key, new_key)
70+
for i in range(mp):
71+
new_param = param
72+
if "experts" in name and "shared_experts" not in name:
73+
idx = int(name.split(".")[-3])
74+
if idx < i * n_local_experts or idx >= (i + 1) * n_local_experts:
75+
continue
76+
elif dim is not None:
77+
assert param.size(dim) % mp == 0, f"Dimension {dim} must be divisible by {mp}"
78+
shard_size = param.size(dim) // mp
79+
new_param = param.narrow(dim, i * shard_size, shard_size).contiguous()
80+
state_dicts[i][name] = new_param
81+
82+
os.makedirs(save_path, exist_ok=True)
83+
84+
for i in trange(mp):
85+
save_file(state_dicts[i], os.path.join(save_path, f"model{i}-mp{mp}.safetensors"))
86+
87+
for file_path in glob(os.path.join(hf_ckpt_path, "*token*")):
88+
new_file_path = os.path.join(save_path, os.path.basename(file_path))
89+
shutil.copyfile(file_path, new_file_path)
90+
91+
92+
if __name__ == "__main__":
93+
parser = ArgumentParser()
94+
parser.add_argument("--hf-ckpt-path", type=str, required=True)
95+
parser.add_argument("--save-path", type=str, required=True)
96+
parser.add_argument("--n-experts", type=int, required=True)
97+
parser.add_argument("--model-parallel", type=int, required=True)
98+
args = parser.parse_args()
99+
assert args.n_experts % args.model_parallel == 0, "Number of experts must be divisible by model parallelism"
100+
main(args.hf_ckpt_path, args.save_path, args.n_experts, args.model_parallel)

0 commit comments

Comments
 (0)