-
Notifications
You must be signed in to change notification settings - Fork 622
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: fix fp8 for MLA and support bmm fp8 for DeepSeek V2 #1285
Conversation
ref #1156 |
@zhyncs May I ask if this does not support the A100? I encountered the following error when using the A100: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/tiger/deepseek_http/vllm_source/sglang/python/sglang/srt/models/deepseek_v2.py", line 576, in forward
hidden_states = self.self_attn(
^^^^^^^^^^^^^^^
File "/home/tiger/.pyenv/versions/3.11.2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/tiger/.pyenv/versions/3.11.2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/tiger/deepseek_http/vllm_source/sglang/python/sglang/srt/models/deepseek_v2.py", line 453, in forward
q_nope_out = bmm_fp8(
^^^^^^^^
File "/home/tiger/.pyenv/versions/3.11.2/lib/python3.11/site-packages/flashinfer/gemm.py", line 266, in bmm_fp8
_kernels.bmm_fp8(A, B, out, A_scale, B_scale)
RuntimeError: CUDA error: CUBLAS_STATUS_NOT_SUPPORTED when calling `cublasLtMatmulAlgoGetHeuristic( lt_handle, matmul_desp.descriptor(), a_desp.descriptor(), b_desp.descriptor(), d_desp.descriptor(), d_desp.descriptor(), preference.descriptor(), 1, &heuristic_result, &returned_result)`
[rank0]:W0901 22:42:24.599000 139974644569856 torch/_inductor/compile_worker/subproc_pool.py:126] SubprocPool unclean exit
Killed |
@fengyang95 It only supports sm89+, such as 4090, H100. And I only tested on H100. |
In order to support fp8 on A100, I think we should use the |
Is it because of the same reason as this? vllm-project/vllm#7322 @zhyncs |
@fengyang95 Currently, the dependency of fp8 on sm89+ appears in two places: torch.bmm in MLA implementation and FusedMOE. The former can be resolved by converting fp8 to bf16 and multiplying by w_scale. While the latter requires support for the sm80 kernel in FusedMOE, but we have no plans to do so in the short term. |
fp8 scale per tensor ref sgl-project/sglang#1285
So, does this mean that solving the MLA issue still doesn't allow the A100 to run properly? |
Yes we optimized it for H100. |
|
L40 is sm89, it supports bmm fp8. But we haven't adjusted the Triton kernel block size for the shared memory of L40, so it might not work. The verification and adaptation of sm89 will be done when there is time, but it is not a high priority. Maybe you can try this
If you encounter any problems, you can directly resolve them and submit a PR. We highly welcome contributions! Subsequently, we will implement CUDA Graph and Quant MOE on this basis, and the performance will continue to improve. Please stay tuned. |
@zhyncs That's great. I ran it and encountered the following error. It looks like there's indeed an issue with Triton. I am not familiar with Triton. Could you provide some directions or suggestions for fix this? File "/home/tiger/.pyenv/versions/3.11.2/lib/python3.11/multiprocessing/process.py", line 314, in _bootstrap
self.run()
File "/home/tiger/.pyenv/versions/3.11.2/lib/python3.11/multiprocessing/process.py", line 108, in run
self._target(*self._args, **self._kwargs)
File "/sglang/python/sglang/srt/managers/tp_worker.py", line 896, in run_tp_server
model_server.exposed_step(recv_reqs)
File "/sglang/python/sglang/srt/managers/tp_worker.py", line 244, in exposed_step
self.forward_step()
File "/home/tiger/.pyenv/versions/3.11.2/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/sglang/python/sglang/srt/managers/tp_worker.py", line 260, in forward_step
self.forward_prefill_batch(new_batch)
File "/sglang/python/sglang/srt/managers/tp_worker.py", line 507, in forward_prefill_batch
sample_output, logits_output = self.model_runner.forward(
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/sglang/python/sglang/srt/model_executor/model_runner.py", line 584, in forward
return self.forward_extend(batch)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/tiger/.pyenv/versions/3.11.2/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/sglang/python/sglang/srt/model_executor/model_runner.py", line 548, in forward_extend
return self.model.forward(
^^^^^^^^^^^^^^^^^^^
File "/sglang/python/sglang/srt/models/deepseek_v2.py", line 660, in forward
hidden_states = self.model(input_ids, positions, input_metadata)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/tiger/.pyenv/versions/3.11.2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/tiger/.pyenv/versions/3.11.2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/sglang/python/sglang/srt/models/deepseek_v2.py", line 629, in forward
hidden_states, residual = layer(
^^^^^^
File "/home/tiger/.pyenv/versions/3.11.2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/tiger/.pyenv/versions/3.11.2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/sglang/python/sglang/srt/models/deepseek_v2.py", line 576, in forward
hidden_states = self.self_attn(
^^^^^^^^^^^^^^^
File "/home/tiger/.pyenv/versions/3.11.2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/tiger/.pyenv/versions/3.11.2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/sglang/python/sglang/srt/models/deepseek_v2.py", line 471, in forward
attn_output = self.attn(q_input, k_input, v_input, input_metadata)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/tiger/.pyenv/versions/3.11.2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/tiger/.pyenv/versions/3.11.2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/sglang/python/sglang/srt/layers/radix_attention.py", line 201, in forward
return self.extend_forward(q, k, v, input_metadata)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/sglang/python/sglang/srt/layers/radix_attention.py", line 73, in extend_forward_triton
extend_attention_fwd(
File "/sglang/python/sglang/srt/layers/extend_attention.py", line 298, in extend_attention_fwd
_fwd_kernel[grid](
File "/home/tiger/.pyenv/versions/3.11.2/lib/python3.11/site-packages/triton/runtime/jit.py", line 345, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/tiger/.pyenv/versions/3.11.2/lib/python3.11/site-packages/triton/runtime/jit.py", line 691, in run
kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata,
^^^^^^^^^^
File "/home/tiger/.pyenv/versions/3.11.2/lib/python3.11/site-packages/triton/compiler/compiler.py", line 381, in __getattribute__
self._init_handles()
File "/home/tiger/.pyenv/versions/3.11.2/lib/python3.11/site-packages/triton/compiler/compiler.py", line 374, in _init_handles
raise OutOfResources(self.metadata.shared, max_shared, "shared memory")
triton.runtime.errors.OutOfResources: out of resource: shared memory, Required: 106496, Hardware limit: 101376. Reducing block sizes or `num_stages` may help.
Killed |
@fengyang95 You can focus on this, but specific modifications and attempts will require you to take action yourself. sglang/python/sglang/srt/layers/extend_attention.py Lines 275 to 288 in 6487ef6
|
@zhyncs Thank you very much. I will try it asap. |
but there are issues with fused MoE, so it doesn't seem to run on L40 either. File "/home/tiger/.pyenv/versions/3.11.2/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/fp8.py", line 499, in apply
return fused_experts(x,
^^^^^^^^^^^^^^^^
File "/home/tiger/.pyenv/versions/3.11.2/lib/python3.11/site-packages/vllm/model_executor/layers/fused_moe/fused_moe.py", line 544, in fused_experts
moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'], E))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/tiger/.pyenv/versions/3.11.2/lib/python3.11/site-packages/vllm/model_executor/layers/fused_moe/fused_moe.py", line 228, in moe_align_block_size
ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids,
File "/home/tiger/.pyenv/versions/3.11.2/lib/python3.11/site-packages/vllm/_custom_ops.py", line 28, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/tiger/.pyenv/versions/3.11.2/lib/python3.11/site-packages/vllm/_custom_ops.py", line 485, in moe_align_block_size
torch.ops._C.moe_align_block_size(topk_ids, num_experts, block_size,
File "/home/tiger/.pyenv/versions/3.11.2/lib/python3.11/site-packages/torch/_ops.py", line 1061, in __call__
return self_._op(*args, **(kwargs or {}))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: CUDA error: invalid argument
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions. So for the L40, running with FP8 precision is fine with MLA support because it uses native Triton; however, there are issues with Fused MoE because it uses vLLM's Fused MoE ? |
Is there a plan that sglang supports deepseek-V2 4-bit quant? |
Can you share the HuggingFace URL of the checkpoint you used? We plan to take a look after supporting CUDA Graph and Quant MoE. |
As you mentioned in vllm-project/vllm#7494 (comment). |
Motivation
Paired programming with @ispobock , completed this feature, gsm8k fp16 and fp8 are both normal, will continue to use nvtx and nsys for performance analysis and optimization.
Modifications
Checklist