-
-
Notifications
You must be signed in to change notification settings - Fork 11k
[BugFix] bugfix for Flash Attention MLA with full cuda graph IMA following pr-25490 #27128
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BugFix] bugfix for Flash Attention MLA with full cuda graph IMA following pr-25490 #27128
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request aims to fix an illegal memory access error in Flash Attention MLA with full CUDA graph support by ensuring get_scheduler_metadata and FlashAttnMLAMetadata receive the same max_num_splits value. The changes correctly refactor the logic to calculate max_num_splits before it's used. However, I've identified a remaining logic issue where a similar discrepancy can occur when vllm_is_batch_invariant() is true, which could lead to the same bug under different conditions. I've provided a suggestion to fully resolve this.
…25490 Signed-off-by: qqma <qqma@amazon.com>
257c4e8 to
a8cdaba
Compare
Signed-off-by: qqma <qqma@amazon.com>
Signed-off-by: qqma <qqma@amazon.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM; thanks!
|
seems like the failed tests are unrelated, is it fine to still merge it? |
…owing pr-25490 (vllm-project#27128) Signed-off-by: qqma <qqma@amazon.com> Co-authored-by: qqma <qqma@amazon.com>
…owing pr-25490 (vllm-project#27128) Signed-off-by: qqma <qqma@amazon.com> Co-authored-by: qqma <qqma@amazon.com> Signed-off-by: Alberto Perdomo <aperdomo@redhat.com>
…o step_forward * 'step_forward' of https://github.com/raindaywhu/vllm: (148 commits) [Model] Add MoE support for NemotronH (vllm-project#25863) [Metrics] [KVConnector] Add connector prefix cache hit rate stats (vllm-project#26245) [CI] Reorganize entrypoints tests (vllm-project#27403) add SLA information into comparison graph for vLLM Benchmark Suite (vllm-project#25525) [CI/Build] Fix AMD CI: test_cpu_gpu.py (vllm-project#27388) [Bugfix] Fix args settings for guided decoding args (vllm-project#27375) [CI/Build] Fix Prithvi plugin test (vllm-project#27393) [Chore] Remove duplicate `has_` functions in vllm.utils (vllm-project#27372) [Model] Add num_cached_tokens for PoolingRequestOutput (vllm-project#27378) [V1][spec decode] return logprobs for spec decoding (vllm-project#26060) [CORE] Support Prefix Caching with Prompt Embeds (vllm-project#27219) [Bugfix][Core] running queue index leakage exception (vllm-project#26754) [Bugfix] Fix incorrect kv cache metrics in grafana.json (vllm-project#27133) [Bugfix] Fix SLA tuner initialization (vllm-project#27355) [Bugfix] Fix deepseek-ocr multi-image inference and add `merge_by_field_config=True` with tensor schema support (vllm-project#27361) [MLA] Bump FlashMLA (vllm-project#27354) [Chore] Separate out system utilities from vllm.utils (vllm-project#27201) [BugFix] bugfix for Flash Attention MLA with full cuda graph IMA following pr-25490 (vllm-project#27128) [Feature] publisher default set zmq in kv_event config (vllm-project#26915) [Prefix Cache] Use LoRA name for consistent KV-cache block hashing (vllm-project#27211) ...
…owing pr-25490 (vllm-project#27128) Signed-off-by: qqma <qqma@amazon.com> Co-authored-by: qqma <qqma@amazon.com>
…owing pr-25490 (vllm-project#27128) Signed-off-by: qqma <qqma@amazon.com> Co-authored-by: qqma <qqma@amazon.com> Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
…owing pr-25490 (vllm-project#27128) Signed-off-by: qqma <qqma@amazon.com> Co-authored-by: qqma <qqma@amazon.com> Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
Bugfix for Flash Attention MLA with full cuda graph IMA following pr-25490
Run into illegal memory access error when testing some prompts with prefix caching enabled on Flash Attention MLA backend
Log below is generated with CUDA_LAUNCH_BLOCKING=1 which indicating it's flash attn mla.
And realized it's the same root cause as #25490 where
get_scheduler_metadatawas being called with a differentmax_num_splitsthan what was being passed toFlashAttnMLAMetadata.