Skip to content

Latest commit

 

History

History
113 lines (102 loc) · 10.2 KB

README.md

File metadata and controls

113 lines (102 loc) · 10.2 KB

MoE_inference

本项目:

  • 梳理 Qwen1.5-MOE-A2.7B 大模型中 MoE FFN 层的并行逻辑
  • MoE FFN 层 CUDA 算子编写与优化
  • 精度对齐与 profiling

MoE FFN 介绍

对于混合专家模型(MoE),其中每个专家(Expert)为一个 FFN,每个 token 将分配给 topK 个专家进行处理。Qwen 源码中的 FFN 部分按串行逻辑编写,通过循环遍历每个专家索引,从原输入中切片得到对应 token,而后计算FFN。

为并行化,需要对每个 token 按专家复制重排并分组处理。

代码文件说明

  • csrc/moe_kernels.cu:MoE FFN 中的关键 Kernel,此部分修改自 TensorRT-LLM-v0.8.0
  • qwen_moe_block.py:含Qwen推理源码,以及按 CUDA Kernel 逻辑编写的 Pytorch 流程,用于验证逻辑正确性,可参阅说明文档
  • bench.py:用于 CUDA Kernel 精度对齐与 profiling

测试 CUDA 算子

requirements:

  • CUDA >= 11.0
  • CMake >= 3.18
  • Torch >= 2.3
  • Python >= 3.10

编译运行:

# 编译
mkdir build && cd build
cmake .. -DCMAKE_PREFIX_PATH=$CONDA_PREFIX/lib/python3.10/site-packages/torch/share/cmake/Torch
make

# 运行
cd ..
python bench.py

输出结果:

=== profiling qwen moe block ===
------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                    aten::index_add_         0.04%       1.851ms        63.32%        2.795s      43.665ms     409.000us         0.01%        2.791s      43.612ms            64  
                  aten::scatter_add_        63.27%        2.792s        63.27%        2.792s      43.632ms        2.791s        63.13%        2.791s      43.605ms            64  
                        aten::linear         0.15%       6.418ms        28.05%        1.238s       6.413ms       3.811ms         0.09%        1.247s       6.461ms           193  
                        aten::matmul         0.04%       1.928ms        27.81%        1.227s       6.359ms       1.509ms         0.03%        1.240s       6.424ms           193  
                            aten::mm        27.75%        1.225s        27.76%        1.225s       6.347ms        1.233s        27.90%        1.238s       6.417ms           193  
                         aten::index         8.22%     362.845ms         8.26%     364.369ms       2.847ms     362.791ms         8.21%     363.705ms       2.841ms           128  
                  aten::resolve_conj         0.00%       1.000us         0.00%       1.000us       0.003us       5.064ms         0.11%       5.064ms      13.119us           386  
                           aten::mul         0.06%       2.646ms         0.06%       2.646ms      20.672us       4.886ms         0.11%       4.886ms      38.172us           128  
                       aten::softmax         0.00%      21.000us         0.04%       1.909ms       1.909ms     189.000us         0.00%       3.312ms       3.312ms             1  
                             aten::t         0.05%       2.118ms         0.08%       3.700ms      19.171us       1.983ms         0.04%       3.239ms      16.782us           193  
------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 4.413s
Self CUDA time total: 4.420s

=== profiling manual qwen moe block ===
-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                         Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                 aten::matmul         0.12%     802.000us        85.79%     587.121ms       4.482ms       1.044ms         0.15%     592.546ms       4.523ms           131  
                     aten::mm        85.65%     586.158ms        85.65%     586.171ms       4.475ms     590.234ms        84.37%     591.502ms       4.515ms           131  
                  aten::copy_         3.45%      23.577ms         3.45%      23.577ms      14.126us      31.107ms         4.45%      31.107ms      18.638us          1669  
                 aten::select         2.70%      18.510ms         2.88%      19.729ms       3.425us      19.367ms         2.77%      26.617ms       4.621us          5760  
                  aten::zeros         0.01%      57.000us         2.78%      19.001ms       3.800ms      19.000us         0.00%      16.658ms       3.332ms             5  
                  aten::zero_         0.01%      41.000us         2.76%      18.919ms       3.784ms      10.000us         0.00%      16.634ms       3.327ms             5  
                  aten::fill_         2.76%      18.867ms         2.76%      18.867ms       4.717ms      16.624ms         2.38%      16.624ms       4.156ms             4  
                    aten::mul         1.11%       7.609ms         1.11%       7.609ms      14.689us       8.197ms         1.17%       8.197ms      15.824us           518  
             aten::as_strided         0.00%      19.000us         0.00%      19.000us       0.003us       7.536ms         1.08%       7.536ms       1.249us          6036  
                   aten::item         0.68%       4.657ms         0.79%       5.402ms       3.517us       3.271ms         0.47%       7.035ms       4.580us          1536  
-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 684.377ms
Self CUDA time total: 699.556ms

tensor(1.7695e-08, grad_fn=<MaxBackward1>)
=== profiling cuda qwen moe block ===
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                               aten::to         0.02%      21.000us        88.42%     119.210ms      29.802ms      23.000us         0.02%     118.703ms      29.676ms             4  
                                         aten::_to_copy         0.06%      75.000us        88.40%     119.186ms      29.797ms      30.000us         0.02%     118.680ms      29.670ms             4  
                                            aten::copy_         0.03%      38.000us        88.06%     118.719ms      29.680ms     118.517ms        94.61%     118.517ms      29.629ms             4  
                                            aten::zeros         0.01%      15.000us         9.29%      12.522ms      12.522ms       4.000us         0.00%       6.566ms       6.566ms             1  
                                            aten::zero_         0.01%      17.000us         9.23%      12.440ms      12.440ms       3.000us         0.00%       6.561ms       6.561ms             1  
                                            aten::fill_         0.02%      22.000us         9.21%      12.422ms      12.422ms       6.558ms         5.24%       6.558ms       6.558ms             1  
                                    aten::empty_strided         0.04%      56.000us         0.29%     390.000us      97.500us     133.000us         0.11%     133.000us      33.250us             4  
                                            aten::empty         0.01%      20.000us         0.05%      67.000us      67.000us       1.000us         0.00%       1.000us       1.000us             1  
                                        cudaEventRecord         0.01%      12.000us         0.01%      12.000us       0.300us       0.000us         0.00%       0.000us       0.000us            40  
                                  cudaStreamIsCapturing         0.00%       1.000us         0.00%       1.000us       0.250us       0.000us         0.00%       0.000us       0.000us             4  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 134.822ms
Self CUDA time total: 125.269ms

max_diff = 7.636845111846924e-08
python output:
size(torch.Size([128, 2048]))
tensor([[-0.0030,  0.0013,  0.0059,  ..., -0.0081, -0.0178,  0.0036],
        [ 0.0002,  0.0057, -0.0034,  ..., -0.0059, -0.0025,  0.0034],
        [ 0.0081, -0.0054,  0.0090,  ...,  0.0034, -0.0032,  0.0165],
        ...,
        [-0.0002, -0.0139, -0.0023,  ..., -0.0036,  0.0087,  0.0041],
        [-0.0096,  0.0013, -0.0102,  ...,  0.0015,  0.0367, -0.0135],
        [ 0.0061,  0.0124,  0.0078,  ..., -0.0023, -0.0076, -0.0083]])

cuda output:
size(torch.Size([128, 2048]))
tensor([[-0.0030,  0.0013,  0.0059,  ..., -0.0081, -0.0178,  0.0036],
        [ 0.0002,  0.0057, -0.0034,  ..., -0.0059, -0.0025,  0.0034],
        [ 0.0081, -0.0054,  0.0090,  ...,  0.0034, -0.0032,  0.0165],
        ...,
        [-0.0002, -0.0139, -0.0023,  ..., -0.0036,  0.0087,  0.0041],
        [-0.0096,  0.0013, -0.0102,  ...,  0.0015,  0.0367, -0.0135],
        [ 0.0061,  0.0124,  0.0078,  ..., -0.0023, -0.0076, -0.0083]])