本项目:
- 梳理 Qwen1.5-MOE-A2.7B 大模型中 MoE FFN 层的并行逻辑
- MoE FFN 层 CUDA 算子编写与优化
- 精度对齐与 profiling
对于混合专家模型(MoE),其中每个专家(Expert)为一个 FFN,每个 token 将分配给 topK 个专家进行处理。Qwen 源码中的 FFN 部分按串行逻辑编写,通过循环遍历每个专家索引,从原输入中切片得到对应 token,而后计算FFN。
为并行化,需要对每个 token 按专家复制重排并分组处理。
- csrc/moe_kernels.cu:MoE FFN 中的关键 Kernel,此部分修改自 TensorRT-LLM-v0.8.0
- qwen_moe_block.py:含Qwen推理源码,以及按 CUDA Kernel 逻辑编写的 Pytorch 流程,用于验证逻辑正确性,可参阅说明文档
- bench.py:用于 CUDA Kernel 精度对齐与 profiling
requirements:
- CUDA >= 11.0
- CMake >= 3.18
- Torch >= 2.3
- Python >= 3.10
编译运行:
# 编译
mkdir build && cd build
cmake .. -DCMAKE_PREFIX_PATH=$CONDA_PREFIX/lib/python3.10/site-packages/torch/share/cmake/Torch
make
# 运行
cd ..
python bench.py
输出结果:
=== profiling qwen moe block ===
------------------------------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls
------------------------------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
aten::index_add_ 0.04% 1.851ms 63.32% 2.795s 43.665ms 409.000us 0.01% 2.791s 43.612ms 64
aten::scatter_add_ 63.27% 2.792s 63.27% 2.792s 43.632ms 2.791s 63.13% 2.791s 43.605ms 64
aten::linear 0.15% 6.418ms 28.05% 1.238s 6.413ms 3.811ms 0.09% 1.247s 6.461ms 193
aten::matmul 0.04% 1.928ms 27.81% 1.227s 6.359ms 1.509ms 0.03% 1.240s 6.424ms 193
aten::mm 27.75% 1.225s 27.76% 1.225s 6.347ms 1.233s 27.90% 1.238s 6.417ms 193
aten::index 8.22% 362.845ms 8.26% 364.369ms 2.847ms 362.791ms 8.21% 363.705ms 2.841ms 128
aten::resolve_conj 0.00% 1.000us 0.00% 1.000us 0.003us 5.064ms 0.11% 5.064ms 13.119us 386
aten::mul 0.06% 2.646ms 0.06% 2.646ms 20.672us 4.886ms 0.11% 4.886ms 38.172us 128
aten::softmax 0.00% 21.000us 0.04% 1.909ms 1.909ms 189.000us 0.00% 3.312ms 3.312ms 1
aten::t 0.05% 2.118ms 0.08% 3.700ms 19.171us 1.983ms 0.04% 3.239ms 16.782us 193
------------------------------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 4.413s
Self CUDA time total: 4.420s
=== profiling manual qwen moe block ===
----------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls
----------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
aten::matmul 0.12% 802.000us 85.79% 587.121ms 4.482ms 1.044ms 0.15% 592.546ms 4.523ms 131
aten::mm 85.65% 586.158ms 85.65% 586.171ms 4.475ms 590.234ms 84.37% 591.502ms 4.515ms 131
aten::copy_ 3.45% 23.577ms 3.45% 23.577ms 14.126us 31.107ms 4.45% 31.107ms 18.638us 1669
aten::select 2.70% 18.510ms 2.88% 19.729ms 3.425us 19.367ms 2.77% 26.617ms 4.621us 5760
aten::zeros 0.01% 57.000us 2.78% 19.001ms 3.800ms 19.000us 0.00% 16.658ms 3.332ms 5
aten::zero_ 0.01% 41.000us 2.76% 18.919ms 3.784ms 10.000us 0.00% 16.634ms 3.327ms 5
aten::fill_ 2.76% 18.867ms 2.76% 18.867ms 4.717ms 16.624ms 2.38% 16.624ms 4.156ms 4
aten::mul 1.11% 7.609ms 1.11% 7.609ms 14.689us 8.197ms 1.17% 8.197ms 15.824us 518
aten::as_strided 0.00% 19.000us 0.00% 19.000us 0.003us 7.536ms 1.08% 7.536ms 1.249us 6036
aten::item 0.68% 4.657ms 0.79% 5.402ms 3.517us 3.271ms 0.47% 7.035ms 4.580us 1536
----------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 684.377ms
Self CUDA time total: 699.556ms
tensor(1.7695e-08, grad_fn=<MaxBackward1>)
=== profiling cuda qwen moe block ===
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
aten::to 0.02% 21.000us 88.42% 119.210ms 29.802ms 23.000us 0.02% 118.703ms 29.676ms 4
aten::_to_copy 0.06% 75.000us 88.40% 119.186ms 29.797ms 30.000us 0.02% 118.680ms 29.670ms 4
aten::copy_ 0.03% 38.000us 88.06% 118.719ms 29.680ms 118.517ms 94.61% 118.517ms 29.629ms 4
aten::zeros 0.01% 15.000us 9.29% 12.522ms 12.522ms 4.000us 0.00% 6.566ms 6.566ms 1
aten::zero_ 0.01% 17.000us 9.23% 12.440ms 12.440ms 3.000us 0.00% 6.561ms 6.561ms 1
aten::fill_ 0.02% 22.000us 9.21% 12.422ms 12.422ms 6.558ms 5.24% 6.558ms 6.558ms 1
aten::empty_strided 0.04% 56.000us 0.29% 390.000us 97.500us 133.000us 0.11% 133.000us 33.250us 4
aten::empty 0.01% 20.000us 0.05% 67.000us 67.000us 1.000us 0.00% 1.000us 1.000us 1
cudaEventRecord 0.01% 12.000us 0.01% 12.000us 0.300us 0.000us 0.00% 0.000us 0.000us 40
cudaStreamIsCapturing 0.00% 1.000us 0.00% 1.000us 0.250us 0.000us 0.00% 0.000us 0.000us 4
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 134.822ms
Self CUDA time total: 125.269ms
max_diff = 7.636845111846924e-08
python output:
size(torch.Size([128, 2048]))
tensor([[-0.0030, 0.0013, 0.0059, ..., -0.0081, -0.0178, 0.0036],
[ 0.0002, 0.0057, -0.0034, ..., -0.0059, -0.0025, 0.0034],
[ 0.0081, -0.0054, 0.0090, ..., 0.0034, -0.0032, 0.0165],
...,
[-0.0002, -0.0139, -0.0023, ..., -0.0036, 0.0087, 0.0041],
[-0.0096, 0.0013, -0.0102, ..., 0.0015, 0.0367, -0.0135],
[ 0.0061, 0.0124, 0.0078, ..., -0.0023, -0.0076, -0.0083]])
cuda output:
size(torch.Size([128, 2048]))
tensor([[-0.0030, 0.0013, 0.0059, ..., -0.0081, -0.0178, 0.0036],
[ 0.0002, 0.0057, -0.0034, ..., -0.0059, -0.0025, 0.0034],
[ 0.0081, -0.0054, 0.0090, ..., 0.0034, -0.0032, 0.0165],
...,
[-0.0002, -0.0139, -0.0023, ..., -0.0036, 0.0087, 0.0041],
[-0.0096, 0.0013, -0.0102, ..., 0.0015, 0.0367, -0.0135],
[ 0.0061, 0.0124, 0.0078, ..., -0.0023, -0.0076, -0.0083]])