Skip to content

Conversation

@LucasWilkinson
Copy link
Collaborator

@LucasWilkinson LucasWilkinson commented Feb 24, 2025

Integrate: https://github.com/deepseek-ai/FlashMLA

currently requires

export VLLM_ATTENTION_BACKEND=FLASHMLA

and

block_size=64

TODO:

  • cuda-graphs are broken
  • future PR: enforce block_size 64 gracefully

Closes #13735

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Comment on lines +167 to +179
elif block_size != 64:
logger.warning(
"FlashMLA backend is not supported for block size %d"
" (currently only supports block size 64).",
block_size)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can update the config here:

def check_and_update_config(cls, vllm_config: VllmConfig) -> None:

check the env var and change block size (with an info level logging message).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, this helps thanks! We will still need to figure out a better solution if we ever wanted to make it default though since this relies on the env var

Copy link
Member

@youkaichao youkaichao Feb 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

theoretically, you can set env var inside this function. and it should be respected later.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for now, maybe using env var is fine, for people to try it out, before turning it on by default.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ya id like to do some cursory benchmarking for a few different workloads before turning it on by default 👍, but I suspect we will ultimately turn it on be default in the next couple days since it should be much faster than triton

@LucasWilkinson LucasWilkinson marked this pull request as ready for review February 25, 2025 05:40
@LucasWilkinson
Copy link
Collaborator Author

LucasWilkinson commented Feb 25, 2025

1xH100

VLLM_ATTENTION_BACKEND=TRITON_MLA python benchmarks/benchmark_throughput.py --input-len 1000 --output-len 1000 --num-prompts 100 --model deeps
eek-ai/DeepSeek-V2-Lite --trust-remote-code

Throughput: 3.54 requests/s, 7078.60 total tokens/s, 3539.30 output tokens/s

VLLM_ATTENTION_BACKEND=FLASHMLA python benchmarks/benchmark_throughput.py --input-len 1000 --output-len 1000 --num-prompts 100 --model deepsee
k-ai/DeepSeek-V2-Lite --trust-remote-code

Throughput: 3.72 requests/s, 7446.49 total tokens/s, 3723.25 output tokens/s

VLLM_ATTENTION_BACKEND=FLASHMLA python benchmarks/benchmark_latency.py --model deepseek-ai/DeepSeek-V2-Lite --trust-remote-code

Avg latency: 1.4260522789011398 seconds
10% percentile latency: 1.394141173362732 seconds
25% percentile latency: 1.4116658233106136 seconds
50% percentile latency: 1.427198606543243 seconds
75% percentile latency: 1.4392954176291823 seconds
90% percentile latency: 1.4475165858864785 seconds
99% percentile latency: 1.4744242035970092 seconds


VLLM_ATTENTION_BACKEND=TRITON_MLA python benchmarks/benchmark_latency.py --model deepseek-ai/DeepSeek-V2-Lite --trust-remote-code

Avg latency: 1.4195543638120094 seconds
10% percentile latency: 1.3920048194006085 seconds
25% percentile latency: 1.4075597953051329 seconds
50% percentile latency: 1.416732276789844 seconds
75% percentile latency: 1.4300497379153967 seconds
90% percentile latency: 1.4401984374970198 seconds
99% percentile latency: 1.4675243362039327 seconds

@Stonesjtu
Copy link

VLLM_ATTENTION_BACKEND=TRITON_MLA python benchmarks/benchmark_throughput.py --input-len 1000 --output-len 1000 --num-prompts 100 --model deeps
eek-ai/DeepSeek-V2-Lite --trust-remote-code

Throughput: 3.54 requests/s, 7078.60 total tokens/s, 3539.30 output tokens/s

VLLM_ATTENTION_BACKEND=FLASHMLA python benchmarks/benchmark_throughput.py --input-len 1000 --output-len 1000 --num-prompts 100 --model deepsee
k-ai/DeepSeek-V2-Lite --trust-remote-code

Throughput: 3.72 requests/s, 7446.49 total tokens/s, 3723.25 output tokens/s

VLLM_ATTENTION_BACKEND=FLASHMLA python benchmarks/benchmark_latency.py --model deepseek-ai/DeepSeek-V2-Lite --trust-remote-code

Avg latency: 1.4260522789011398 seconds
10% percentile latency: 1.394141173362732 seconds
25% percentile latency: 1.4116658233106136 seconds
50% percentile latency: 1.427198606543243 seconds
75% percentile latency: 1.4392954176291823 seconds
90% percentile latency: 1.4475165858864785 seconds
99% percentile latency: 1.4744242035970092 seconds


VLLM_ATTENTION_BACKEND=TRITON_MLA python benchmarks/benchmark_latency.py --model deepseek-ai/DeepSeek-V2-Lite --trust-remote-code

Avg latency: 1.4195543638120094 seconds
10% percentile latency: 1.3920048194006085 seconds
25% percentile latency: 1.4075597953051329 seconds
50% percentile latency: 1.416732276789844 seconds
75% percentile latency: 1.4300497379153967 seconds
90% percentile latency: 1.4401984374970198 seconds
99% percentile latency: 1.4675243362039327 seconds

Looks like the FlashMLA has higher throughput (5%-10%) but trades-off the latency (1%).
Should be a solid performance improvement.

BTW can you post the GPU model used in this test?

@youkaichao
Copy link
Member

Looks like the FlashMLA has higher throughput (5%-10%) but trades-off the latency (1%).
Should be a solid performance improvement.

note that this is a small model deepseek-ai/DeepSeek-V2-Lite . we should wait for h200 * 8 benchmark for r1-sized model.

@leonzy
Copy link

leonzy commented Feb 25, 2025

Looks like the FlashMLA has higher throughput (5%-10%) but trades-off the latency (1%).
Should be a solid performance improvement.

note that this is a small model deepseek-ai/DeepSeek-V2-Lite . we should wait for h200 * 8 benchmark for r1-sized model.

I'd like to know if this patch already work for 2 8xH100? And anybody can do r1 671b benchmark on it?

Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work!! I think the performance benefit should be greatest at very large seq len

@mgoin mgoin moved this to In progress in DeepSeek V3/R1 Feb 25, 2025
@hmellor hmellor moved this from In progress to In review in DeepSeek V3/R1 Feb 25, 2025
@LucasWilkinson
Copy link
Collaborator Author

8xH200

VLLM_ATTENTION_BACKEND=TRITON_MLA python benchmarks/benchmark_throughput.py --model /home/vllm-dev/DeepSeek-R1 --trust-remote-code --input-len 3000 --output-len 1000 --num-prompts 100 -tp 8

Throughput: 0.63 requests/s, 2539.60 total tokens/s, 634.90 output tokens/s

VLLM_ATTENTION_BACKEND=FLASHMLA python benchmarks/benchmark_throughput.py --model /home/vllm-dev/DeepSeek-R1 --trust-remote-code --input-len 3000 --output-len 1000 --num-prompts 100 -tp 8

Throughput: 0.69 requests/s, 2769.37 total tokens/s, 692.34 output tokens/s


python benchmarks/benchmark_serving.py --dataset-name random --random-input 32000 --random-output 100 --request-rate 1 --num-prompt 60

VLLM_ATTENTION_BACKEND=TRITON_MLA vllm serve /home/vllm-dev/DeepSeek-R1/ --trust-remote-code -tp 8


│============ Serving Benchmark Result ============
│Successful requests:                     60        
│Benchmark duration (s):                  278.85    
│Total input tokens:                      1920000   
│Total generated tokens:                  5474      
│Request throughput (req/s):              0.22      
│Output token throughput (tok/s):         19.63     
│Total Token throughput (tok/s):          6904.94   
│---------------Time to First Token----------------
│Mean TTFT (ms):                          114926.74 
│Median TTFT (ms):                        119686.01 
│P99 TTFT (ms):                           216243.30 
│-----Time per Output Token (excl. 1st token)------
│Mean TPOT (ms):                          214.40    
│Median TPOT (ms):                        191.36    
│P99 TPOT (ms):                           328.92    
│---------------Inter-token Latency----------------
│Mean ITL (ms):                           218.03    
│Median ITL (ms):                         133.12    
│P99 ITL (ms):                            2747.58   
│==================================================


VLLM_ATTENTION_BACKEND=FLASHMLA vllm serve /home/vllm-dev/DeepSeek-R1/ --trust-remote-code -tp 8


│============ Serving Benchmark Result ============
│Successful requests:                     60        
│Benchmark duration (s):                  287.80    
│Total input tokens:                      1920000   
│Total generated tokens:                  5734      
│Request throughput (req/s):              0.21      
│Output token throughput (tok/s):         19.92     
│Total Token throughput (tok/s):          6691.14   
│---------------Time to First Token----------------
│Mean TTFT (ms):                          113298.95 
│Median TTFT (ms):                        115367.26 
│P99 TTFT (ms):                           219875.67 
│-----Time per Output Token (excl. 1st token)------
│Mean TPOT (ms):                          296.58    
│Median TPOT (ms):                        210.34    
│P99 TPOT (ms):                           2002.50   
│---------------Inter-token Latency----------------
│Mean ITL (ms):                           227.77    
│Median ITL (ms):                         127.52    
│P99 ITL (ms):                            2765.75   
│==================================================

@LucasWilkinson LucasWilkinson added the ready ONLY add when PR is ready to merge/full CI is needed label Feb 26, 2025
@mergify
Copy link

mergify bot commented Feb 26, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @LucasWilkinson.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
@LucasWilkinson LucasWilkinson force-pushed the lwilkinson/flashmla-integration branch from 2fa62a9 to 02a46a3 Compare February 26, 2025 03:39
@mergify mergify bot removed the needs-rebase label Feb 26, 2025
@billishyahao
Copy link
Contributor

billishyahao commented Feb 26, 2025

Hi @LucasWilkinson Great pr! I am trying to reproduce the number on local environment, but hit this following installation issue:

(VllmWorkerProcess pid=22390) WARNING 02-25 23:14:50 [flashmla.py:17] Failed to import from vllm._flashmla_C with ModuleNotFoundError("No module named 'vllm._flashmla_C'")                                                                                                                         
(VllmWorkerProcess pid=22396) WARNING 02-25 23:14:50 [flashmla.py:17] Failed to import from vllm._flashmla_C with ModuleNotFoundError("No module named 'vllm._flashmla_C'")                                                                                                                         
(VllmWorkerProcess pid=22393) WARNING 02-25 23:14:50 [flashmla.py:17] Failed to import from vllm._flashmla_C with ModuleNotFoundError("No module named 'vllm._flashmla_C'")                                                                                                                         
(VllmWorkerProcess pid=22392) WARNING 02-25 23:14:50 [flashmla.py:17] Failed to import from vllm._flashmla_C with ModuleNotFoundError("No module named 'vllm._flashmla_C'")                                                                                                                         
(VllmWorkerProcess pid=22394) WARNING 02-25 23:14:50 [flashmla.py:17] Failed to import from vllm._flashmla_C with ModuleNotFoundError("No module named 'vllm._flashmla_C'")                                                                                                                         
(VllmWorkerProcess pid=22395) WARNING 02-25 23:14:50 [flashmla.py:17] Failed to import from vllm._flashmla_C with ModuleNotFoundError("No module named 'vllm._flashmla_C'")                                                                                                                         
(VllmWorkerProcess pid=22390) WARNING 02-25 23:14:50 [cuda.py:172] FlashMLA backend is not supported due to vllm._flashmla_C is not available, likely was not compiled due to insufficient nvcc version or a supported archwas not in the list of target arches to compile for.                     
(VllmWorkerProcess pid=22390) INFO 02-25 23:14:50 [cuda.py:184] Using Triton MLA backend. 

I am using image vllm/vllm-openai:v0.7.3 and the command pip install . . Could you shed some lights on this issue?

@LucasWilkinson
Copy link
Collaborator Author

Hi @LucasWilkinson Great pr! I am trying to reproduce the number on local environment, but hit this following installation issue:

(VllmWorkerProcess pid=22390) WARNING 02-25 23:14:50 [flashmla.py:17] Failed to import from vllm._flashmla_C with ModuleNotFoundError("No module named 'vllm._flashmla_C'")                                                                                                                         
(VllmWorkerProcess pid=22396) WARNING 02-25 23:14:50 [flashmla.py:17] Failed to import from vllm._flashmla_C with ModuleNotFoundError("No module named 'vllm._flashmla_C'")                                                                                                                         
(VllmWorkerProcess pid=22393) WARNING 02-25 23:14:50 [flashmla.py:17] Failed to import from vllm._flashmla_C with ModuleNotFoundError("No module named 'vllm._flashmla_C'")                                                                                                                         
(VllmWorkerProcess pid=22392) WARNING 02-25 23:14:50 [flashmla.py:17] Failed to import from vllm._flashmla_C with ModuleNotFoundError("No module named 'vllm._flashmla_C'")                                                                                                                         
(VllmWorkerProcess pid=22394) WARNING 02-25 23:14:50 [flashmla.py:17] Failed to import from vllm._flashmla_C with ModuleNotFoundError("No module named 'vllm._flashmla_C'")                                                                                                                         
(VllmWorkerProcess pid=22395) WARNING 02-25 23:14:50 [flashmla.py:17] Failed to import from vllm._flashmla_C with ModuleNotFoundError("No module named 'vllm._flashmla_C'")                                                                                                                         
(VllmWorkerProcess pid=22390) WARNING 02-25 23:14:50 [cuda.py:172] FlashMLA backend is not supported due to vllm._flashmla_C is not available, likely was not compiled due to insufficient nvcc version or a supported archwas not in the list of target arches to compile for.                     
(VllmWorkerProcess pid=22390) INFO 02-25 23:14:50 [cuda.py:184] Using Triton MLA backend. 

I am using image vllm/vllm-openai:v0.7.3 and the command pip install . . Could you shed some lights on this issue?

Hi @LucasWilkinson Great pr! I am trying to reproduce the number on local environment, but hit this following installation issue:

(VllmWorkerProcess pid=22390) WARNING 02-25 23:14:50 [flashmla.py:17] Failed to import from vllm._flashmla_C with ModuleNotFoundError("No module named 'vllm._flashmla_C'")                                                                                                                         
(VllmWorkerProcess pid=22396) WARNING 02-25 23:14:50 [flashmla.py:17] Failed to import from vllm._flashmla_C with ModuleNotFoundError("No module named 'vllm._flashmla_C'")                                                                                                                         
(VllmWorkerProcess pid=22393) WARNING 02-25 23:14:50 [flashmla.py:17] Failed to import from vllm._flashmla_C with ModuleNotFoundError("No module named 'vllm._flashmla_C'")                                                                                                                         
(VllmWorkerProcess pid=22392) WARNING 02-25 23:14:50 [flashmla.py:17] Failed to import from vllm._flashmla_C with ModuleNotFoundError("No module named 'vllm._flashmla_C'")                                                                                                                         
(VllmWorkerProcess pid=22394) WARNING 02-25 23:14:50 [flashmla.py:17] Failed to import from vllm._flashmla_C with ModuleNotFoundError("No module named 'vllm._flashmla_C'")                                                                                                                         
(VllmWorkerProcess pid=22395) WARNING 02-25 23:14:50 [flashmla.py:17] Failed to import from vllm._flashmla_C with ModuleNotFoundError("No module named 'vllm._flashmla_C'")                                                                                                                         
(VllmWorkerProcess pid=22390) WARNING 02-25 23:14:50 [cuda.py:172] FlashMLA backend is not supported due to vllm._flashmla_C is not available, likely was not compiled due to insufficient nvcc version or a supported archwas not in the list of target arches to compile for.                     
(VllmWorkerProcess pid=22390) INFO 02-25 23:14:50 [cuda.py:184] Using Triton MLA backend. 

I am using image vllm/vllm-openai:v0.7.3 and the command pip install . . Could you shed some lights on this issue?

can you please provide the output of pip install . -v? id be curious to see what the cmake is outputting

Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
@fan-niu
Copy link

fan-niu commented Feb 26, 2025

8xH200

VLLM_ATTENTION_BACKEND=TRITON_MLA python benchmarks/benchmark_throughput.py --model /home/vllm-dev/DeepSeek-R1 --trust-remote-code --input-len 3000 --output-len 1000 --num-prompts 100 -tp 8

Throughput: 0.63 requests/s, 2539.60 total tokens/s, 634.90 output tokens/s

VLLM_ATTENTION_BACKEND=FLASHMLA python benchmarks/benchmark_throughput.py --model /home/vllm-dev/DeepSeek-R1 --trust-remote-code --input-len 3000 --output-len 1000 --num-prompts 100 -tp 8

Throughput: 0.69 requests/s, 2769.37 total tokens/s, 692.34 output tokens/s


python benchmarks/benchmark_serving.py --dataset-name random --random-input 32000 --random-output 100 --request-rate 1 --num-prompt 60

VLLM_ATTENTION_BACKEND=TRITON_MLA vllm serve /home/vllm-dev/DeepSeek-R1/ --trust-remote-code -tp 8


│============ Serving Benchmark Result ============
│Successful requests:                     60        
│Benchmark duration (s):                  278.85    
│Total input tokens:                      1920000   
│Total generated tokens:                  5474      
│Request throughput (req/s):              0.22      
│Output token throughput (tok/s):         19.63     
│Total Token throughput (tok/s):          6904.94   
│---------------Time to First Token----------------
│Mean TTFT (ms):                          114926.74 
│Median TTFT (ms):                        119686.01 
│P99 TTFT (ms):                           216243.30 
│-----Time per Output Token (excl. 1st token)------
│Mean TPOT (ms):                          214.40    
│Median TPOT (ms):                        191.36    
│P99 TPOT (ms):                           328.92    
│---------------Inter-token Latency----------------
│Mean ITL (ms):                           218.03    
│Median ITL (ms):                         133.12    
│P99 ITL (ms):                            2747.58   
│==================================================


VLLM_ATTENTION_BACKEND=FLASHMLA vllm serve /home/vllm-dev/DeepSeek-R1/ --trust-remote-code -tp 8


│============ Serving Benchmark Result ============
│Successful requests:                     60        
│Benchmark duration (s):                  287.80    
│Total input tokens:                      1920000   
│Total generated tokens:                  5734      
│Request throughput (req/s):              0.21      
│Output token throughput (tok/s):         19.92     
│Total Token throughput (tok/s):          6691.14   
│---------------Time to First Token----------------
│Mean TTFT (ms):                          113298.95 
│Median TTFT (ms):                        115367.26 
│P99 TTFT (ms):                           219875.67 
│-----Time per Output Token (excl. 1st token)------
│Mean TPOT (ms):                          296.58    
│Median TPOT (ms):                        210.34    
│P99 TPOT (ms):                           2002.50   
│---------------Inter-token Latency----------------
│Mean ITL (ms):                           227.77    
│Median ITL (ms):                         127.52    
│P99 ITL (ms):                            2765.75   
│==================================================

@LucasWilkinson thanks for great work, I found that flashmla improved by about 10% when doing throughput testing, but why did the Output token throughput of triton_mla and flash_mla not improve when doing latency testing? Thanks a lot!

@billishyahao
Copy link
Contributor

8xH200

VLLM_ATTENTION_BACKEND=TRITON_MLA python benchmarks/benchmark_throughput.py --model /home/vllm-dev/DeepSeek-R1 --trust-remote-code --input-len 3000 --output-len 1000 --num-prompts 100 -tp 8

Throughput: 0.63 requests/s, 2539.60 total tokens/s, 634.90 output tokens/s

VLLM_ATTENTION_BACKEND=FLASHMLA python benchmarks/benchmark_throughput.py --model /home/vllm-dev/DeepSeek-R1 --trust-remote-code --input-len 3000 --output-len 1000 --num-prompts 100 -tp 8

Throughput: 0.69 requests/s, 2769.37 total tokens/s, 692.34 output tokens/s


python benchmarks/benchmark_serving.py --dataset-name random --random-input 32000 --random-output 100 --request-rate 1 --num-prompt 60

VLLM_ATTENTION_BACKEND=TRITON_MLA vllm serve /home/vllm-dev/DeepSeek-R1/ --trust-remote-code -tp 8


│============ Serving Benchmark Result ============
│Successful requests:                     60        
│Benchmark duration (s):                  278.85    
│Total input tokens:                      1920000   
│Total generated tokens:                  5474      
│Request throughput (req/s):              0.22      
│Output token throughput (tok/s):         19.63     
│Total Token throughput (tok/s):          6904.94   
│---------------Time to First Token----------------
│Mean TTFT (ms):                          114926.74 
│Median TTFT (ms):                        119686.01 
│P99 TTFT (ms):                           216243.30 
│-----Time per Output Token (excl. 1st token)------
│Mean TPOT (ms):                          214.40    
│Median TPOT (ms):                        191.36    
│P99 TPOT (ms):                           328.92    
│---------------Inter-token Latency----------------
│Mean ITL (ms):                           218.03    
│Median ITL (ms):                         133.12    
│P99 ITL (ms):                            2747.58   
│==================================================


VLLM_ATTENTION_BACKEND=FLASHMLA vllm serve /home/vllm-dev/DeepSeek-R1/ --trust-remote-code -tp 8


│============ Serving Benchmark Result ============
│Successful requests:                     60        
│Benchmark duration (s):                  287.80    
│Total input tokens:                      1920000   
│Total generated tokens:                  5734      
│Request throughput (req/s):              0.21      
│Output token throughput (tok/s):         19.92     
│Total Token throughput (tok/s):          6691.14   
│---------------Time to First Token----------------
│Mean TTFT (ms):                          113298.95 
│Median TTFT (ms):                        115367.26 
│P99 TTFT (ms):                           219875.67 
│-----Time per Output Token (excl. 1st token)------
│Mean TPOT (ms):                          296.58    
│Median TPOT (ms):                        210.34    
│P99 TPOT (ms):                           2002.50   
│---------------Inter-token Latency----------------
│Mean ITL (ms):                           227.77    
│Median ITL (ms):                         127.52    
│P99 ITL (ms):                            2765.75   
│==================================================

@LucasWilkinson thanks for great work, I found that flashmla improved by about 10% when doing throughput testing, but why did the Output token throughput of triton_mla and flash_mla not improve when doing latency testing? Thanks a lot!

I think there is minor issue in command. We should use --ignore-eos to keep testcase completion. https://github.com/vllm-project/vllm/blob/main/benchmarks/benchmark_serving.py#L1175

@billishyahao
Copy link
Contributor

Hi @LucasWilkinson Great pr! I am trying to reproduce the number on local environment, but hit this following installation issue:

(VllmWorkerProcess pid=22390) WARNING 02-25 23:14:50 [flashmla.py:17] Failed to import from vllm._flashmla_C with ModuleNotFoundError("No module named 'vllm._flashmla_C'")                                                                                                                         
(VllmWorkerProcess pid=22396) WARNING 02-25 23:14:50 [flashmla.py:17] Failed to import from vllm._flashmla_C with ModuleNotFoundError("No module named 'vllm._flashmla_C'")                                                                                                                         
(VllmWorkerProcess pid=22393) WARNING 02-25 23:14:50 [flashmla.py:17] Failed to import from vllm._flashmla_C with ModuleNotFoundError("No module named 'vllm._flashmla_C'")                                                                                                                         
(VllmWorkerProcess pid=22392) WARNING 02-25 23:14:50 [flashmla.py:17] Failed to import from vllm._flashmla_C with ModuleNotFoundError("No module named 'vllm._flashmla_C'")                                                                                                                         
(VllmWorkerProcess pid=22394) WARNING 02-25 23:14:50 [flashmla.py:17] Failed to import from vllm._flashmla_C with ModuleNotFoundError("No module named 'vllm._flashmla_C'")                                                                                                                         
(VllmWorkerProcess pid=22395) WARNING 02-25 23:14:50 [flashmla.py:17] Failed to import from vllm._flashmla_C with ModuleNotFoundError("No module named 'vllm._flashmla_C'")                                                                                                                         
(VllmWorkerProcess pid=22390) WARNING 02-25 23:14:50 [cuda.py:172] FlashMLA backend is not supported due to vllm._flashmla_C is not available, likely was not compiled due to insufficient nvcc version or a supported archwas not in the list of target arches to compile for.                     
(VllmWorkerProcess pid=22390) INFO 02-25 23:14:50 [cuda.py:184] Using Triton MLA backend. 

I am using image vllm/vllm-openai:v0.7.3 and the command pip install . . Could you shed some lights on this issue?

can you please provide the output of pip install . -v? id be curious to see what the cmake is outputting

I observed the error:

Running command Building wheel for vllm (pyproject.toml)                                                                                                                                                                                                                                          
  /tmp/pip-build-env-mjbgpq4o/overlay/local/lib/python3.12/dist-packages/torch/_subclasses/functional_tensor.py:295: UserWarning: Failed to initialize NumPy: No module named 'numpy' (Triggered internally at ../torch/csrc/utils/tensor_numpy.cpp:84.)                                            
    cpu = _conversion_method_template(device=torch.device("cpu"))                                                                                                                                                                                                                                   
  /bin/sh: 1: lsmod: not found                                                                                                                                                                                                                                                                      
  /bin/sh: 1: lsmod: not found                                                                                                                                                                                                                                                                      
  /bin/sh: 1: lsmod: not found                                                                                                                                                                                                                                                                      
  /bin/sh: 1: lsmod: not found                                                                                                                                                                                                                                                                      
  /bin/sh: 1: lsmod: not found 

And then I reset to 145944c

git reset --hard 145944cb94a6fc663c05451763315d45e771a285
HEAD is now at 145944cb Improve pipeline partitioning (#13839)

but still saw this issue. It is not related to this PR. Check similar issue #5587

@zeroorhero
Copy link

@LucasWilkinson hi, When the prefix cache is enabled and send two identical requests, once the cache hits, this error will be reported.
vllm serve /data00/models/DeepSeek-V2-Lite-Chat --trust-remote-code -tp 1 --host 127.0.0.1 --port 9001 --enable-prefix-caching
bug

Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
@LucasWilkinson
Copy link
Collaborator Author

@LucasWilkinson hi, When the prefix cache is enabled and send two identical requests, once the cache hits, this error will be reported. vllm serve /data00/models/DeepSeek-V2-Lite-Chat --trust-remote-code -tp 1 --host 127.0.0.1 --port 9001 --enable-prefix-caching bug

I think thats related to this PR: #12639 not the FlashMLA one, but ill investigate and open a PR to disable prefix caching + MLA in the meantime, thanks for the report!

@zeroorhero
Copy link

@LucasWilkinson hi, When the prefix cache is enabled and send two identical requests, once the cache hits, this error will be reported. vllm serve /data00/models/DeepSeek-V2-Lite-Chat --trust-remote-code -tp 1 --host 127.0.0.1 --port 9001 --enable-prefix-caching bug

I think thats related to this PR: #12639 not the FlashMLA one, but ill investigate and open a PR to disable prefix caching + MLA in the meantime, thanks for the report!

Thanks a lot for your reply! Looking forward to having prefix cache supported in MLA!

@youkaichao youkaichao merged commit f959039 into vllm-project:main Feb 27, 2025
67 of 69 checks passed
@github-project-automation github-project-automation bot moved this from In review to Done in DeepSeek V3/R1 Feb 27, 2025
@simon-mo
Copy link
Collaborator

My result on H200

VLLM_ATTENTION_BACKEND=TRITON_MLA python benchmarks/benchmark_throughput.py --model /home/vllm-dev/DeepSeek-R1 --trust-remote-code --load-format dummy --input-len 2000 --output-len 1000 --num-prompts 60 -tp 8
Throughput: 1.35 requests/s, 4041.21 total tokens/s, 1347.07 output tokens/s

VLLM_ATTENTION_BACKEND=FLASHMLA python benchmarks/benchmark_throughput.py --model /home/vllm-dev/DeepSeek-R1 --trust-remote-code --load-format dummy --input-len 2000 --output-len 1000 --num-prompts 60 -tp 8

Throughput: 1.41 requests/s, 4235.21 total tokens/s, 1411.74 output tokens/s


VLLM_ATTENTION_BACKEND=TRITON_MLA python benchmarks/benchmark_throughput.py --model /home/vllm-dev/DeepSeek-R1 --trust-remote-code --load-format dummy --input-len 5000 --output-len 1000 --num-prompts 60 -tp 8
Throughput: 0.56 requests/s, 3355.99 total tokens/s, 559.33 output tokens/s

VLLM_ATTENTION_BACKEND=FLASHMLA python benchmarks/benchmark_throughput.py --model /home/vllm-dev/DeepSeek-R1 --trust-remote-code --load-format dummy --input-len 5000 --output-len 1000 --num-prompts 60 -tp 8
Throughput: 0.65 requests/s, 3920.04 total tokens/s, 653.34 output tokens/s


VLLM_ATTENTION_BACKEND=TRITON_MLA python benchmarks/benchmark_throughput.py --model /home/vllm-dev/DeepSeek-R1 --trust-remote-code --load-format dummy --input-len 10000 --output-len 1000 --num-prompts 60 -tp 8
Throughput: 0.13 requests/s, 1424.01 total tokens/s, 129.46 output tokens/s

VLLM_ATTENTION_BACKEND=FLASHMLA python benchmarks/benchmark_throughput.py --model /home/vllm-dev/DeepSeek-R1 --trust-remote-code --load-format dummy --input-len 10000 --output-len 1000 --num-prompts 60 -tp 8
Throughput: 0.13 requests/s, 1463.90 total tokens/s, 133.08 output tokens/s

Akshat-Tripathi pushed a commit to krai/vllm that referenced this pull request Mar 3, 2025
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
@ZhongYingMatrix
Copy link
Contributor

I am using image vllm/vllm-openai:v0.7.3 and the command pip install . . Could you shed some lights on this issue?

hi @billishyahao, have you successfully run flashmla in vllm/vllm-openai:v0.7.3? I built from source code within the same image but discovered that _flashmla_C is unavailable. I have noticed that the CUDA version in that image is 12.1 and does not meet the requirement of flashmla 12.3+ (12.8+ better). I am wondering if there is an official docker image where flashmla is available.

@billishyahao
Copy link
Contributor

I am using image vllm/vllm-openai:v0.7.3 and the command pip install . . Could you shed some lights on this issue?

hi @billishyahao, have you successfully run flashmla in vllm/vllm-openai:v0.7.3? I built from source code within the same image but discovered that _flashmla_C is unavailable. I have noticed that the CUDA version in that image is 12.1 and does not meet the requirement of flashmla 12.3+ (12.8+ better). I am wondering if there is an official docker image where flashmla is available.

Hi @ZhongYingMatrix , we observe the same symptom. vllm openai image is buggy. cuda nvcc is unintentionally being downgraded to cuda 12.1 rather than 12.4 (base cuda image). Just for quick workaround, I would recommend you to try lmsysorg/sglang:v0.4.3.post2-cu125

@tlrmchlsmth
Copy link
Member

FYI that vLLM is upgrading to 12.4 as the default in the next release (v0.8.0)
#12098

lulmer pushed a commit to lulmer/vllm that referenced this pull request Apr 7, 2025
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Louis Ulmer <ulmerlouis@gmail.com>
shreyankg pushed a commit to shreyankg/vllm that referenced this pull request May 3, 2025
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build force-merge ready ONLY add when PR is ready to merge/full CI is needed

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

[Feature]: Support for FlashMLA