Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: fix fp8 for MLA and support bmm fp8 for DeepSeek V2 #1285

Merged
merged 2 commits into from
Sep 1, 2024
Merged

Conversation

zhyncs
Copy link
Member

@zhyncs zhyncs commented Sep 1, 2024

Motivation

Paired programming with @ispobock , completed this feature, gsm8k fp16 and fp8 are both normal, will continue to use nvtx and nsys for performance analysis and optimization.

python3 -m sglang.launch_server --model deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct --enable-mla --trust-remote-code --disable-radix --mem-frac 0.85
python3 -m sglang.launch_server --model neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8 --enable-mla --trust-remote-code --quantization fp8 --disable-radix --mem-frac 0.85

lm_eval --model local-completions --tasks gsm8k --model_args model=deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct,base_url=http://127.0.0.1:30000/v1/completions,num_concurrent=128,max_retries=3,tokenized_requests=False
lm_eval --model local-completions --tasks gsm8k --model_args model=neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8,base_url=http://127.0.0.1:30000/v1/completions,num_concurrent=128,max_retries=3,tokenized_requests=False

# fp16
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.7741|±  |0.0115|
|     |       |strict-match    |     5|exact_match|↑  |0.7635|±  |0.0117|

# fp8
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.7809|±  |0.0114|
|     |       |strict-match    |     5|exact_match|↑  |0.7672|±  |0.0116|

Modifications

Checklist

  • Format your code according to the Contributor Guide.
  • Add unit tests as outlined in the Contributor Guide.
  • Update documentation as needed, including docstrings or example tutorials.

@zhyncs
Copy link
Member Author

zhyncs commented Sep 1, 2024

ref #1156

@zhyncs zhyncs merged commit 54772f7 into main Sep 1, 2024
3 of 8 checks passed
@zhyncs zhyncs deleted the change_bmm branch September 1, 2024 07:28
@zhyncs
Copy link
Member Author

zhyncs commented Sep 1, 2024

ref flashinfer-ai/flashinfer#469

@zhyncs zhyncs mentioned this pull request Sep 1, 2024
29 tasks
@fengyang95
Copy link

fengyang95 commented Sep 1, 2024

@zhyncs May I ask if this does not support the A100? I encountered the following error when using the A100:

          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/tiger/deepseek_http/vllm_source/sglang/python/sglang/srt/models/deepseek_v2.py", line 576, in forward
    hidden_states = self.self_attn(
                    ^^^^^^^^^^^^^^^
  File "/home/tiger/.pyenv/versions/3.11.2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tiger/.pyenv/versions/3.11.2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/tiger/deepseek_http/vllm_source/sglang/python/sglang/srt/models/deepseek_v2.py", line 453, in forward
    q_nope_out = bmm_fp8(
                 ^^^^^^^^
  File "/home/tiger/.pyenv/versions/3.11.2/lib/python3.11/site-packages/flashinfer/gemm.py", line 266, in bmm_fp8
    _kernels.bmm_fp8(A, B, out, A_scale, B_scale)
RuntimeError: CUDA error: CUBLAS_STATUS_NOT_SUPPORTED when calling `cublasLtMatmulAlgoGetHeuristic( lt_handle, matmul_desp.descriptor(), a_desp.descriptor(), b_desp.descriptor(), d_desp.descriptor(), d_desp.descriptor(), preference.descriptor(), 1, &heuristic_result, &returned_result)`
[rank0]:W0901 22:42:24.599000 139974644569856 torch/_inductor/compile_worker/subproc_pool.py:126] SubprocPool unclean exit
Killed

@zhyncs
Copy link
Member Author

zhyncs commented Sep 1, 2024

@fengyang95 It only supports sm89+, such as 4090, H100. And I only tested on H100.

@zhyncs
Copy link
Member Author

zhyncs commented Sep 1, 2024

In order to support fp8 on A100, I think we should use the torch.bmm with cast @ispobock

@fengyang95
Copy link

fengyang95 commented Sep 1, 2024

@fengyang95 It only supports sm89+, such as 4090, H100. And I only tested on H100.

Is it because of the same reason as this? vllm-project/vllm#7322 @zhyncs

@zhyncs
Copy link
Member Author

zhyncs commented Sep 1, 2024

@fengyang95 Currently, the dependency of fp8 on sm89+ appears in two places: torch.bmm in MLA implementation and FusedMOE. The former can be resolved by converting fp8 to bf16 and multiplying by w_scale. While the latter requires support for the sm80 kernel in FusedMOE, but we have no plans to do so in the short term.

zhyncs added a commit to flashinfer-ai/flashinfer that referenced this pull request Sep 1, 2024
@fengyang95
Copy link

@fengyang95 Currently, the dependency of fp8 on sm89+ appears in two places: torch.bmm in MLA implementation and FusedMOE. The former can be resolved by converting fp8 to bf16 and multiplying by w_scale. While the latter requires support for the sm80 kernel in FusedMOE, but we have no plans to do so in the short term.

So, does this mean that solving the MLA issue still doesn't allow the A100 to run properly?

@zhyncs
Copy link
Member Author

zhyncs commented Sep 1, 2024

Yes we optimized it for H100.

@fengyang95
Copy link

fengyang95 commented Sep 1, 2024

Yes we optimized it for H100.
@zhyncs By the way, may I ask for some advice? Our main requirement now is to run DeepSeek-Coder-V2 236B using L40*8 since we only have ample resources for L40. Do you have any good technical suggestions?
Thank you very much.

@zhyncs
Copy link
Member Author

zhyncs commented Sep 1, 2024

@fengyang95

L40 is sm89, it supports bmm fp8. But we haven't adjusted the Triton kernel block size for the shared memory of L40, so it might not work. The verification and adaptation of sm89 will be done when there is time, but it is not a high priority. Maybe you can try this

git clone https://github.com/sgl-project/sglang.git
cd sglang

pip install --upgrade pip
pip install -e "python[all]"

pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/

# use MLA, W8A8(FP8) and FP8 E5M2 KV Cache
python3 -m sglang.launch_server --model neuralmagic/DeepSeek-Coder-V2-Instruct-FP8 --enable-mla --quantization fp8 --kv-cache-dtype fp8_e5m2 --disable-radix --tp 8 --trust-remote-code

If you encounter any problems, you can directly resolve them and submit a PR. We highly welcome contributions!

Subsequently, we will implement CUDA Graph and Quant MOE on this basis, and the performance will continue to improve. Please stay tuned.

@fengyang95
Copy link

fengyang95 commented Sep 1, 2024

@fengyang95

L40 is sm89, it supports bmm fp8. But we haven't adjusted the Triton kernel block size for the shared memory of L40, so it might not work. The verification and adaptation of sm89 will be done when there is time, but it is not a high priority. Maybe you can try this

git clone https://github.com/sgl-project/sglang.git
cd sglang

pip install --upgrade pip
pip install -e "python[all]"

pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/

# use MLA, W8A8(FP8) and FP8 E5M2 KV Cache
python3 -m sglang.launch_server --model neuralmagic/DeepSeek-Coder-V2-Instruct-FP8 --enable-mla --quantization fp8 --kv-cache-dtype fp8_e5m2 --disable-radix --tp 8 --trust-remote-code

If you encounter any problems, you can directly resolve them and submit a PR. We highly welcome contributions!

Subsequently, we will implement CUDA Graph and Quant MOE on this basis, and the performance will continue to improve. Please stay tuned.

@zhyncs That's great. I ran it and encountered the following error. It looks like there's indeed an issue with Triton. I am not familiar with Triton. Could you provide some directions or suggestions for fix this?

File "/home/tiger/.pyenv/versions/3.11.2/lib/python3.11/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/home/tiger/.pyenv/versions/3.11.2/lib/python3.11/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/sglang/python/sglang/srt/managers/tp_worker.py", line 896, in run_tp_server
    model_server.exposed_step(recv_reqs)
  File "/sglang/python/sglang/srt/managers/tp_worker.py", line 244, in exposed_step
    self.forward_step()
  File "/home/tiger/.pyenv/versions/3.11.2/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/sglang/python/sglang/srt/managers/tp_worker.py", line 260, in forward_step
    self.forward_prefill_batch(new_batch)
  File "/sglang/python/sglang/srt/managers/tp_worker.py", line 507, in forward_prefill_batch
    sample_output, logits_output = self.model_runner.forward(
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sglang/python/sglang/srt/model_executor/model_runner.py", line 584, in forward
    return self.forward_extend(batch)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tiger/.pyenv/versions/3.11.2/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/sglang/python/sglang/srt/model_executor/model_runner.py", line 548, in forward_extend
    return self.model.forward(
           ^^^^^^^^^^^^^^^^^^^
  File "/sglang/python/sglang/srt/models/deepseek_v2.py", line 660, in forward
    hidden_states = self.model(input_ids, positions, input_metadata)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tiger/.pyenv/versions/3.11.2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tiger/.pyenv/versions/3.11.2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sglang/python/sglang/srt/models/deepseek_v2.py", line 629, in forward
    hidden_states, residual = layer(
                              ^^^^^^
  File "/home/tiger/.pyenv/versions/3.11.2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tiger/.pyenv/versions/3.11.2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sglang/python/sglang/srt/models/deepseek_v2.py", line 576, in forward
    hidden_states = self.self_attn(
                    ^^^^^^^^^^^^^^^
  File "/home/tiger/.pyenv/versions/3.11.2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tiger/.pyenv/versions/3.11.2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sglang/python/sglang/srt/models/deepseek_v2.py", line 471, in forward
    attn_output = self.attn(q_input, k_input, v_input, input_metadata)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tiger/.pyenv/versions/3.11.2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tiger/.pyenv/versions/3.11.2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sglang/python/sglang/srt/layers/radix_attention.py", line 201, in forward
    return self.extend_forward(q, k, v, input_metadata)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/sglang/python/sglang/srt/layers/radix_attention.py", line 73, in extend_forward_triton
    extend_attention_fwd(
  File "/sglang/python/sglang/srt/layers/extend_attention.py", line 298, in extend_attention_fwd
    _fwd_kernel[grid](
  File "/home/tiger/.pyenv/versions/3.11.2/lib/python3.11/site-packages/triton/runtime/jit.py", line 345, in <lambda>
    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tiger/.pyenv/versions/3.11.2/lib/python3.11/site-packages/triton/runtime/jit.py", line 691, in run
    kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata,
    ^^^^^^^^^^
  File "/home/tiger/.pyenv/versions/3.11.2/lib/python3.11/site-packages/triton/compiler/compiler.py", line 381, in __getattribute__
    self._init_handles()
  File "/home/tiger/.pyenv/versions/3.11.2/lib/python3.11/site-packages/triton/compiler/compiler.py", line 374, in _init_handles
    raise OutOfResources(self.metadata.shared, max_shared, "shared memory")
triton.runtime.errors.OutOfResources: out of resource: shared memory, Required: 106496, Hardware limit: 101376. Reducing block sizes or `num_stages` may help.
Killed

@zhyncs
Copy link
Member Author

zhyncs commented Sep 1, 2024

Could you provide some directions or suggestions for fix this?

@fengyang95 You can focus on this, but specific modifications and attempts will require you to take action yourself.

if CUDA_CAPABILITY[0] >= 9:
if Lq <= 256:
BLOCK_M, BLOCK_N = (128, 64)
else:
BLOCK_M, BLOCK_N = (32, 64)
elif CUDA_CAPABILITY[0] >= 8:
if Lq <= 128:
BLOCK_M, BLOCK_N = (128, 128)
elif Lq <= 256:
BLOCK_M, BLOCK_N = (64, 64)
else:
BLOCK_M, BLOCK_N = (32, 64)
else:
BLOCK_M, BLOCK_N = (64, 64) if Lq <= 128 else (32, 32)

@fengyang95
Copy link

fengyang95 commented Sep 2, 2024

Could you provide some directions or suggestions for fix this?

@fengyang95 You can focus on this, but specific modifications and attempts will require you to take action yourself.

if CUDA_CAPABILITY[0] >= 9:
if Lq <= 256:
BLOCK_M, BLOCK_N = (128, 64)
else:
BLOCK_M, BLOCK_N = (32, 64)
elif CUDA_CAPABILITY[0] >= 8:
if Lq <= 128:
BLOCK_M, BLOCK_N = (128, 128)
elif Lq <= 256:
BLOCK_M, BLOCK_N = (64, 64)
else:
BLOCK_M, BLOCK_N = (32, 64)
else:
BLOCK_M, BLOCK_N = (64, 64) if Lq <= 128 else (32, 32)

@zhyncs Thank you very much. I will try it asap.

@fengyang95
Copy link

fengyang95 commented Sep 2, 2024

Could you provide some directions or suggestions for fix this?

@fengyang95 You can focus on this, but specific modifications and attempts will require you to take action yourself.

if CUDA_CAPABILITY[0] >= 9:
if Lq <= 256:
BLOCK_M, BLOCK_N = (128, 64)
else:
BLOCK_M, BLOCK_N = (32, 64)
elif CUDA_CAPABILITY[0] >= 8:
if Lq <= 128:
BLOCK_M, BLOCK_N = (128, 128)
elif Lq <= 256:
BLOCK_M, BLOCK_N = (64, 64)
else:
BLOCK_M, BLOCK_N = (32, 64)
else:
BLOCK_M, BLOCK_N = (64, 64) if Lq <= 128 else (32, 32)

@zhyncs
I managed to get Triton to compile by adjusting the settings,

BLOCK_M, BLOCK_N = (32, 32)

but there are issues with fused MoE, so it doesn't seem to run on L40 either.

File "/home/tiger/.pyenv/versions/3.11.2/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/fp8.py", line 499, in apply
   return fused_experts(x,
          ^^^^^^^^^^^^^^^^
 File "/home/tiger/.pyenv/versions/3.11.2/lib/python3.11/site-packages/vllm/model_executor/layers/fused_moe/fused_moe.py", line 544, in fused_experts
   moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'], E))
   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 File "/home/tiger/.pyenv/versions/3.11.2/lib/python3.11/site-packages/vllm/model_executor/layers/fused_moe/fused_moe.py", line 228, in moe_align_block_size
   ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids,
 File "/home/tiger/.pyenv/versions/3.11.2/lib/python3.11/site-packages/vllm/_custom_ops.py", line 28, in wrapper
   return fn(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^
 File "/home/tiger/.pyenv/versions/3.11.2/lib/python3.11/site-packages/vllm/_custom_ops.py", line 485, in moe_align_block_size
   torch.ops._C.moe_align_block_size(topk_ids, num_experts, block_size,
 File "/home/tiger/.pyenv/versions/3.11.2/lib/python3.11/site-packages/torch/_ops.py", line 1061, in __call__
   return self_._op(*args, **(kwargs or {}))
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: CUDA error: invalid argument
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

So for the L40, running with FP8 precision is fine with MLA support because it uses native Triton; however, there are issues with Fused MoE because it uses vLLM's Fused MoE ?

@xiaoqi35
Copy link

xiaoqi35 commented Sep 4, 2024

Is there a plan that sglang supports deepseek-V2 4-bit quant?
The result of my running result is the modeling of deepseek-V2 doesn't support int4 quant. My patch code for supporting awq quant infer successful without --enable-mla. There are some unknown somethings, bug, in attention-MLA I can't handle.

@zhyncs
Copy link
Member Author

zhyncs commented Sep 4, 2024

Can you share the HuggingFace URL of the checkpoint you used? We plan to take a look after supporting CUDA Graph and Quant MoE.

@Xu-Chen
Copy link
Contributor

Xu-Chen commented Sep 7, 2024

@Xu-Chen
Copy link
Contributor

Xu-Chen commented Sep 9, 2024

Is there a plan that sglang supports deepseek-V2 4-bit quant? The result of my running result is the modeling of deepseek-V2 doesn't support int4 quant. My patch code for supporting awq quant infer successful without --enable-mla. There are some unknown somethings, bug, in attention-MLA I can't handle.

As you mentioned in vllm-project/vllm#7494 (comment).
Can you share the code or the HuggingFace URL of the checkpoint you used?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants