Skip to content

Conversation

@GITHUBear
Copy link

@GITHUBear GITHUBear commented Jun 3, 2025

Env Info

  • Python: Python 3.11.11
  • vLLM: 0.9.1
  • cuda: cuda_12.8
  • GPU: NVIDIA L20
  • torch: 2.7.0+cu126

Problem Description

DualChunkFlashAttention fails to handle short prompts correctly. This issue can be reproduced by modifying the qwen_1m.py as follows:

def main():
    llm = initialize_engine()
    process_requests(llm, ["Hello, world!"])
    prompt = load_prompt()
    process_requests(llm, [prompt])

It failed to handle the simple prompt Hello, world!, the following assertion error is raised during execution:

(VllmWorkerProcess pid=3970554) ERROR 06-03 20:09:38 [multiproc_worker_utils.py:238]   File "/data/shanhaikang.shk/vllm/vllm/attention/backends/dual_chunk_flash_attn.py", line 493, in forward
(VllmWorkerProcess pid=3970554) ERROR 06-03 20:09:38 [multiproc_worker_utils.py:238]     self._dual_chunk_flash_attn_prefill(
(VllmWorkerProcess pid=3970554) ERROR 06-03 20:09:38 [multiproc_worker_utils.py:238]   File "/data/shanhaikang.shk/vllm/vllm/attention/backends/dual_chunk_flash_attn.py", line 673, in _dual_chunk_flash_attn_prefill
(VllmWorkerProcess pid=3970554) ERROR 06-03 20:09:38 [multiproc_worker_utils.py:238]     current_out = self._dual_chunk_flash_attn_prefill_func(
(VllmWorkerProcess pid=3970554) ERROR 06-03 20:09:38 [multiproc_worker_utils.py:238]                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=3970554) ERROR 06-03 20:09:38 [multiproc_worker_utils.py:238]   File "/data/shanhaikang.shk/vllm/vllm/attention/backends/dual_chunk_flash_attn.py", line 1055, in _dual_chunk_flash_attn_prefill_func
(VllmWorkerProcess pid=3970554) ERROR 06-03 20:09:38 [multiproc_worker_utils.py:238]     flash_result = self._do_flash_attn(
(VllmWorkerProcess pid=3970554) ERROR 06-03 20:09:38 [multiproc_worker_utils.py:238]                    ^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=3970554) ERROR 06-03 20:09:38 [multiproc_worker_utils.py:238]   File "/data/shanhaikang.shk/vllm/vllm/attention/backends/dual_chunk_flash_attn.py", line 1207, in _do_flash_attn
(VllmWorkerProcess pid=3970554) ERROR 06-03 20:09:38 [multiproc_worker_utils.py:238]     output, softmax_lse = flash_attn_varlen_func(
(VllmWorkerProcess pid=3970554) ERROR 06-03 20:09:38 [multiproc_worker_utils.py:238]                           ^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=3970554) ERROR 06-03 20:09:38 [multiproc_worker_utils.py:238]   File "/data/shanhaikang.shk/vllm/vllm/vllm_flash_attn/flash_attn_interface.py", line 204, in flash_attn_varlen_func
(VllmWorkerProcess pid=3970554) ERROR 06-03 20:09:38 [multiproc_worker_utils.py:238]     assert block_table is None or seqused_k is not None, \
(VllmWorkerProcess pid=3970554) ERROR 06-03 20:09:38 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=3970554) ERROR 06-03 20:09:38 [multiproc_worker_utils.py:238] AssertionError: seqused_k must be provided if block_table is provided

Proposed Fix

This issue is introduced in #11844 .

Since the key_states and value_states are directly retrieved from the KV cache through the block_table, setting block_table is both wrong and unnecessary.

Signed-off-by: sa-buc <shanhaikang.shk@oceanbase.com>
@mergify mergify bot added the documentation Improvements or additions to documentation label Jun 3, 2025
Signed-off-by: sa-buc <shanhaikang.shk@oceanbase.com>
@github-actions
Copy link

github-actions bot commented Jun 3, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Signed-off-by: sa-buc <shanhaikang.shk@oceanbase.com>
@ExtReMLapin
Copy link
Contributor

ExtReMLapin commented Jun 4, 2025

With this fix i get another error when i use fp8 quantization (--quantization fp8)

ERROR 06-04 14:07:19 [engine.py:164] RuntimeError('query and key must have the same dtype')
ERROR 06-04 14:07:19 [engine.py:164] Traceback (most recent call last):
ERROR 06-04 14:07:19 [engine.py:164]   File "/home/pierre/idextend/venv/lib/python3.10/site-packages/vllm/engine/multiprocessing/engine.py", line 162, in start
ERROR 06-04 14:07:19 [engine.py:164]     self.run_engine_loop()
ERROR 06-04 14:07:19 [engine.py:164]   File "/home/pierre/idextend/venv/lib/python3.10/site-packages/vllm/engine/multiprocessing/engine.py", line 225, in run_engine_loop
ERROR 06-04 14:07:19 [engine.py:164]     request_outputs = self.engine_step()
ERROR 06-04 14:07:19 [engine.py:164]   File "/home/pierre/idextend/venv/lib/python3.10/site-packages/vllm/engine/multiprocessing/engine.py", line 251, in engine_step
ERROR 06-04 14:07:19 [engine.py:164]     raise e
ERROR 06-04 14:07:19 [engine.py:164]   File "/home/pierre/idextend/venv/lib/python3.10/site-packages/vllm/engine/multiprocessing/engine.py", line 234, in engine_step
ERROR 06-04 14:07:19 [engine.py:164]     return self.engine.step()
ERROR 06-04 14:07:19 [engine.py:164]   File "/home/pierre/idextend/venv/lib/python3.10/site-packages/vllm/engine/llm_engine.py", line 1393, in step
ERROR 06-04 14:07:19 [engine.py:164]     outputs = self.model_executor.execute_model(
ERROR 06-04 14:07:19 [engine.py:164]   File "/home/pierre/idextend/venv/lib/python3.10/site-packages/vllm/executor/executor_base.py", line 140, in execute_model
ERROR 06-04 14:07:19 [engine.py:164]     output = self.collective_rpc("execute_model",
ERROR 06-04 14:07:19 [engine.py:164]   File "/home/pierre/idextend/venv/lib/python3.10/site-packages/vllm/executor/uniproc_executor.py", line 56, in collective_rpc
ERROR 06-04 14:07:19 [engine.py:164]     answer = run_method(self.driver_worker, method, args, kwargs)
ERROR 06-04 14:07:19 [engine.py:164]   File "/home/pierre/idextend/venv/lib/python3.10/site-packages/vllm/utils.py", line 2605, in run_method
ERROR 06-04 14:07:19 [engine.py:164]     return func(*args, **kwargs)
ERROR 06-04 14:07:19 [engine.py:164]   File "/home/pierre/idextend/venv/lib/python3.10/site-packages/vllm/worker/worker_base.py", line 420, in execute_model
ERROR 06-04 14:07:19 [engine.py:164]     output = self.model_runner.execute_model(
ERROR 06-04 14:07:19 [engine.py:164]   File "/home/pierre/idextend/venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
ERROR 06-04 14:07:19 [engine.py:164]     return func(*args, **kwargs)
ERROR 06-04 14:07:19 [engine.py:164]   File "/home/pierre/idextend/venv/lib/python3.10/site-packages/vllm/worker/model_runner.py", line 1843, in execute_model
ERROR 06-04 14:07:19 [engine.py:164]     hidden_or_intermediate_states = model_executable(
ERROR 06-04 14:07:19 [engine.py:164]   File "/home/pierre/idextend/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
ERROR 06-04 14:07:19 [engine.py:164]     return self._call_impl(*args, **kwargs)
ERROR 06-04 14:07:19 [engine.py:164]   File "/home/pierre/idextend/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
ERROR 06-04 14:07:19 [engine.py:164]     return forward_call(*args, **kwargs)
ERROR 06-04 14:07:19 [engine.py:164]   File "/home/pierre/idextend/venv/lib/python3.10/site-packages/vllm/model_executor/models/qwen2.py", line 481, in forward
ERROR 06-04 14:07:19 [engine.py:164]     hidden_states = self.model(input_ids, positions, intermediate_tensors,
ERROR 06-04 14:07:19 [engine.py:164]   File "/home/pierre/idextend/venv/lib/python3.10/site-packages/vllm/compilation/decorators.py", line 172, in __call__
ERROR 06-04 14:07:19 [engine.py:164]     return self.forward(*args, **kwargs)
ERROR 06-04 14:07:19 [engine.py:164]   File "/home/pierre/idextend/venv/lib/python3.10/site-packages/vllm/model_executor/models/qwen2.py", line 358, in forward
ERROR 06-04 14:07:19 [engine.py:164]     hidden_states, residual = layer(
ERROR 06-04 14:07:19 [engine.py:164]   File "/home/pierre/idextend/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
ERROR 06-04 14:07:19 [engine.py:164]     return self._call_impl(*args, **kwargs)
ERROR 06-04 14:07:19 [engine.py:164]   File "/home/pierre/idextend/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
ERROR 06-04 14:07:19 [engine.py:164]     return forward_call(*args, **kwargs)
ERROR 06-04 14:07:19 [engine.py:164]   File "/home/pierre/idextend/venv/lib/python3.10/site-packages/vllm/model_executor/models/qwen2.py", line 257, in forward
ERROR 06-04 14:07:19 [engine.py:164]     hidden_states = self.self_attn(
ERROR 06-04 14:07:19 [engine.py:164]   File "/home/pierre/idextend/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
ERROR 06-04 14:07:19 [engine.py:164]     return self._call_impl(*args, **kwargs)
ERROR 06-04 14:07:19 [engine.py:164]   File "/home/pierre/idextend/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
ERROR 06-04 14:07:19 [engine.py:164]     return forward_call(*args, **kwargs)
ERROR 06-04 14:07:19 [engine.py:164]   File "/home/pierre/idextend/venv/lib/python3.10/site-packages/vllm/model_executor/models/qwen2.py", line 187, in forward
ERROR 06-04 14:07:19 [engine.py:164]     attn_output = self.attn(q, k, v)
ERROR 06-04 14:07:19 [engine.py:164]   File "/home/pierre/idextend/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
ERROR 06-04 14:07:19 [engine.py:164]     return self._call_impl(*args, **kwargs)
ERROR 06-04 14:07:19 [engine.py:164]   File "/home/pierre/idextend/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
ERROR 06-04 14:07:19 [engine.py:164]     return forward_call(*args, **kwargs)
ERROR 06-04 14:07:19 [engine.py:164]   File "/home/pierre/idextend/venv/lib/python3.10/site-packages/vllm/attention/layer.py", line 237, in forward
ERROR 06-04 14:07:19 [engine.py:164]     return torch.ops.vllm.unified_attention(
ERROR 06-04 14:07:19 [engine.py:164]   File "/home/pierre/idextend/venv/lib/python3.10/site-packages/torch/_ops.py", line 1158, in __call__
ERROR 06-04 14:07:19 [engine.py:164]     return self._op(*args, **(kwargs or {}))
ERROR 06-04 14:07:19 [engine.py:164]   File "/home/pierre/idextend/venv/lib/python3.10/site-packages/vllm/attention/layer.py", line 386, in unified_attention
ERROR 06-04 14:07:19 [engine.py:164]     output = self.impl.forward(self, query, key, value, kv_cache,
ERROR 06-04 14:07:19 [engine.py:164]   File "/home/pierre/idextend/venv/lib/python3.10/site-packages/vllm/attention/backends/dual_chunk_flash_attn.py", line 493, in forward
ERROR 06-04 14:07:19 [engine.py:164]     self._dual_chunk_flash_attn_prefill(
ERROR 06-04 14:07:19 [engine.py:164]   File "/home/pierre/idextend/venv/lib/python3.10/site-packages/vllm/attention/backends/dual_chunk_flash_attn.py", line 673, in _dual_chunk_flash_attn_prefill
ERROR 06-04 14:07:19 [engine.py:164]     current_out = self._dual_chunk_flash_attn_prefill_func(
ERROR 06-04 14:07:19 [engine.py:164]   File "/home/pierre/idextend/venv/lib/python3.10/site-packages/vllm/attention/backends/dual_chunk_flash_attn.py", line 1055, in _dual_chunk_flash_attn_prefill_func
ERROR 06-04 14:07:19 [engine.py:164]     flash_result = self._do_flash_attn(
ERROR 06-04 14:07:19 [engine.py:164]   File "/home/pierre/idextend/venv/lib/python3.10/site-packages/vllm/attention/backends/dual_chunk_flash_attn.py", line 1207, in _do_flash_attn
ERROR 06-04 14:07:19 [engine.py:164]     output, softmax_lse = flash_attn_varlen_func(
ERROR 06-04 14:07:19 [engine.py:164]   File "/home/pierre/idextend/venv/lib/python3.10/site-packages/vllm/vllm_flash_attn/flash_attn_interface.py", line 227, in flash_attn_varlen_func
ERROR 06-04 14:07:19 [engine.py:164]     out, softmax_lse = torch.ops._vllm_fa2_C.varlen_fwd(
ERROR 06-04 14:07:19 [engine.py:164]   File "/home/pierre/idextend/venv/lib/python3.10/site-packages/torch/_ops.py", line 1158, in __call__
ERROR 06-04 14:07:19 [engine.py:164]     return self._op(*args, **(kwargs or {}))
ERROR 06-04 14:07:19 [engine.py:164] RuntimeError: query and key must have the same dtype

@GITHUBear
Copy link
Author

With this fix i get another error when i use fp8 quantization (--quantization fp8)

ERROR 06-04 14:07:19 [engine.py:164] RuntimeError('query and key must have the same dtype')
ERROR 06-04 14:07:19 [engine.py:164] Traceback (most recent call last):
ERROR 06-04 14:07:19 [engine.py:164]   File "/home/pierre/idextend/venv/lib/python3.10/site-packages/vllm/engine/multiprocessing/engine.py", line 162, in start
ERROR 06-04 14:07:19 [engine.py:164]     self.run_engine_loop()
ERROR 06-04 14:07:19 [engine.py:164]   File "/home/pierre/idextend/venv/lib/python3.10/site-packages/vllm/engine/multiprocessing/engine.py", line 225, in run_engine_loop
ERROR 06-04 14:07:19 [engine.py:164]     request_outputs = self.engine_step()
ERROR 06-04 14:07:19 [engine.py:164]   File "/home/pierre/idextend/venv/lib/python3.10/site-packages/vllm/engine/multiprocessing/engine.py", line 251, in engine_step
ERROR 06-04 14:07:19 [engine.py:164]     raise e
ERROR 06-04 14:07:19 [engine.py:164]   File "/home/pierre/idextend/venv/lib/python3.10/site-packages/vllm/engine/multiprocessing/engine.py", line 234, in engine_step
ERROR 06-04 14:07:19 [engine.py:164]     return self.engine.step()
ERROR 06-04 14:07:19 [engine.py:164]   File "/home/pierre/idextend/venv/lib/python3.10/site-packages/vllm/engine/llm_engine.py", line 1393, in step
ERROR 06-04 14:07:19 [engine.py:164]     outputs = self.model_executor.execute_model(
ERROR 06-04 14:07:19 [engine.py:164]   File "/home/pierre/idextend/venv/lib/python3.10/site-packages/vllm/executor/executor_base.py", line 140, in execute_model
ERROR 06-04 14:07:19 [engine.py:164]     output = self.collective_rpc("execute_model",
ERROR 06-04 14:07:19 [engine.py:164]   File "/home/pierre/idextend/venv/lib/python3.10/site-packages/vllm/executor/uniproc_executor.py", line 56, in collective_rpc
ERROR 06-04 14:07:19 [engine.py:164]     answer = run_method(self.driver_worker, method, args, kwargs)
ERROR 06-04 14:07:19 [engine.py:164]   File "/home/pierre/idextend/venv/lib/python3.10/site-packages/vllm/utils.py", line 2605, in run_method
ERROR 06-04 14:07:19 [engine.py:164]     return func(*args, **kwargs)
ERROR 06-04 14:07:19 [engine.py:164]   File "/home/pierre/idextend/venv/lib/python3.10/site-packages/vllm/worker/worker_base.py", line 420, in execute_model
ERROR 06-04 14:07:19 [engine.py:164]     output = self.model_runner.execute_model(
ERROR 06-04 14:07:19 [engine.py:164]   File "/home/pierre/idextend/venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
ERROR 06-04 14:07:19 [engine.py:164]     return func(*args, **kwargs)
ERROR 06-04 14:07:19 [engine.py:164]   File "/home/pierre/idextend/venv/lib/python3.10/site-packages/vllm/worker/model_runner.py", line 1843, in execute_model
ERROR 06-04 14:07:19 [engine.py:164]     hidden_or_intermediate_states = model_executable(
ERROR 06-04 14:07:19 [engine.py:164]   File "/home/pierre/idextend/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
ERROR 06-04 14:07:19 [engine.py:164]     return self._call_impl(*args, **kwargs)
ERROR 06-04 14:07:19 [engine.py:164]   File "/home/pierre/idextend/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
ERROR 06-04 14:07:19 [engine.py:164]     return forward_call(*args, **kwargs)
ERROR 06-04 14:07:19 [engine.py:164]   File "/home/pierre/idextend/venv/lib/python3.10/site-packages/vllm/model_executor/models/qwen2.py", line 481, in forward
ERROR 06-04 14:07:19 [engine.py:164]     hidden_states = self.model(input_ids, positions, intermediate_tensors,
ERROR 06-04 14:07:19 [engine.py:164]   File "/home/pierre/idextend/venv/lib/python3.10/site-packages/vllm/compilation/decorators.py", line 172, in __call__
ERROR 06-04 14:07:19 [engine.py:164]     return self.forward(*args, **kwargs)
ERROR 06-04 14:07:19 [engine.py:164]   File "/home/pierre/idextend/venv/lib/python3.10/site-packages/vllm/model_executor/models/qwen2.py", line 358, in forward
ERROR 06-04 14:07:19 [engine.py:164]     hidden_states, residual = layer(
ERROR 06-04 14:07:19 [engine.py:164]   File "/home/pierre/idextend/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
ERROR 06-04 14:07:19 [engine.py:164]     return self._call_impl(*args, **kwargs)
ERROR 06-04 14:07:19 [engine.py:164]   File "/home/pierre/idextend/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
ERROR 06-04 14:07:19 [engine.py:164]     return forward_call(*args, **kwargs)
ERROR 06-04 14:07:19 [engine.py:164]   File "/home/pierre/idextend/venv/lib/python3.10/site-packages/vllm/model_executor/models/qwen2.py", line 257, in forward
ERROR 06-04 14:07:19 [engine.py:164]     hidden_states = self.self_attn(
ERROR 06-04 14:07:19 [engine.py:164]   File "/home/pierre/idextend/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
ERROR 06-04 14:07:19 [engine.py:164]     return self._call_impl(*args, **kwargs)
ERROR 06-04 14:07:19 [engine.py:164]   File "/home/pierre/idextend/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
ERROR 06-04 14:07:19 [engine.py:164]     return forward_call(*args, **kwargs)
ERROR 06-04 14:07:19 [engine.py:164]   File "/home/pierre/idextend/venv/lib/python3.10/site-packages/vllm/model_executor/models/qwen2.py", line 187, in forward
ERROR 06-04 14:07:19 [engine.py:164]     attn_output = self.attn(q, k, v)
ERROR 06-04 14:07:19 [engine.py:164]   File "/home/pierre/idextend/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
ERROR 06-04 14:07:19 [engine.py:164]     return self._call_impl(*args, **kwargs)
ERROR 06-04 14:07:19 [engine.py:164]   File "/home/pierre/idextend/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
ERROR 06-04 14:07:19 [engine.py:164]     return forward_call(*args, **kwargs)
ERROR 06-04 14:07:19 [engine.py:164]   File "/home/pierre/idextend/venv/lib/python3.10/site-packages/vllm/attention/layer.py", line 237, in forward
ERROR 06-04 14:07:19 [engine.py:164]     return torch.ops.vllm.unified_attention(
ERROR 06-04 14:07:19 [engine.py:164]   File "/home/pierre/idextend/venv/lib/python3.10/site-packages/torch/_ops.py", line 1158, in __call__
ERROR 06-04 14:07:19 [engine.py:164]     return self._op(*args, **(kwargs or {}))
ERROR 06-04 14:07:19 [engine.py:164]   File "/home/pierre/idextend/venv/lib/python3.10/site-packages/vllm/attention/layer.py", line 386, in unified_attention
ERROR 06-04 14:07:19 [engine.py:164]     output = self.impl.forward(self, query, key, value, kv_cache,
ERROR 06-04 14:07:19 [engine.py:164]   File "/home/pierre/idextend/venv/lib/python3.10/site-packages/vllm/attention/backends/dual_chunk_flash_attn.py", line 493, in forward
ERROR 06-04 14:07:19 [engine.py:164]     self._dual_chunk_flash_attn_prefill(
ERROR 06-04 14:07:19 [engine.py:164]   File "/home/pierre/idextend/venv/lib/python3.10/site-packages/vllm/attention/backends/dual_chunk_flash_attn.py", line 673, in _dual_chunk_flash_attn_prefill
ERROR 06-04 14:07:19 [engine.py:164]     current_out = self._dual_chunk_flash_attn_prefill_func(
ERROR 06-04 14:07:19 [engine.py:164]   File "/home/pierre/idextend/venv/lib/python3.10/site-packages/vllm/attention/backends/dual_chunk_flash_attn.py", line 1055, in _dual_chunk_flash_attn_prefill_func
ERROR 06-04 14:07:19 [engine.py:164]     flash_result = self._do_flash_attn(
ERROR 06-04 14:07:19 [engine.py:164]   File "/home/pierre/idextend/venv/lib/python3.10/site-packages/vllm/attention/backends/dual_chunk_flash_attn.py", line 1207, in _do_flash_attn
ERROR 06-04 14:07:19 [engine.py:164]     output, softmax_lse = flash_attn_varlen_func(
ERROR 06-04 14:07:19 [engine.py:164]   File "/home/pierre/idextend/venv/lib/python3.10/site-packages/vllm/vllm_flash_attn/flash_attn_interface.py", line 227, in flash_attn_varlen_func
ERROR 06-04 14:07:19 [engine.py:164]     out, softmax_lse = torch.ops._vllm_fa2_C.varlen_fwd(
ERROR 06-04 14:07:19 [engine.py:164]   File "/home/pierre/idextend/venv/lib/python3.10/site-packages/torch/_ops.py", line 1158, in __call__
ERROR 06-04 14:07:19 [engine.py:164]     return self._op(*args, **(kwargs or {}))
ERROR 06-04 14:07:19 [engine.py:164] RuntimeError: query and key must have the same dtype

This issue is likely caused by the fact that DCA is implemented based on FlashAttention, which currently lacks support for FP8 quantization in its v2 version.

@LucasWilkinson
Copy link
Collaborator

@sighingnow can you please review this PR since you are the most familiar with DCA?

@GITHUBear
Copy link
Author

Hello, @sighingnow ! This is a small fix. Could you take a look when you have time?

@mergify mergify bot added the qwen Related to Qwen models label Jun 19, 2025
@ExtReMLapin
Copy link
Contributor

It's a simple fix and original author is not responding, could it be merged ?

@ExtReMLapin
Copy link
Contributor

ref #21364

@microcomet
Copy link

microcomet commented Jul 28, 2025

ref #21364

I'm using this fixed version on an RTX 5090 with CUDA 12.8, but token generation is very slow—about 12 tokens/s. It's even slower than the same model running on an RTX 4090. Is there something wrong? Everything else works fine, but the speed is extremely slow.

command:

VLLM_ATTENTION_BACKEND=DUAL_CHUNK_FLASH_ATTN CUDA_VISIBLE_DEVICES=1 vllm serve /data1/fanwei/models/qwen2.5-14b-instruct-1m-gptq-int4
--served-model-name qwen2.5-14b-instruct-1m-gptq-int4
--host 0.0.0.0
--port 9998
--trust-remote-code
--max-model-len 32768
--max-num-batched-tokens 32768
--max-num-seqs 15
--gpu-memory-utilization 0.8
--enable-prefix-caching
--quantization gptq
--enforce-eager

env info:

Python 3.12.3
vllm 0.9.2.dev210+gb82e0f82c.d20250624
torch 2.8.0.dev20250622+cu128
gpu RTX5090
cuda 12.8.93

@ExtReMLapin
Copy link
Contributor

ExtReMLapin commented Jul 28, 2025

Please don’t comment in this outdated PR, go on the main one

Vllm flash attn need a patch to be working on rtx 5090 , did you apply the sm120 patch cited on the original pr ?

@sighingnow
Copy link
Collaborator

Env Info

  • Python: Python 3.11.11
  • vLLM: 0.9.1
  • cuda: cuda_12.8
  • GPU: NVIDIA L20
  • torch: 2.7.0+cu126

Problem Description

DualChunkFlashAttention fails to handle short prompts correctly. This issue can be reproduced by modifying the qwen_1m.py as follows:

def main():
    llm = initialize_engine()
    process_requests(llm, ["Hello, world!"])
    prompt = load_prompt()
    process_requests(llm, [prompt])

It failed to handle the simple prompt Hello, world!, the following assertion error is raised during execution:

(VllmWorkerProcess pid=3970554) ERROR 06-03 20:09:38 [multiproc_worker_utils.py:238]   File "/data/shanhaikang.shk/vllm/vllm/attention/backends/dual_chunk_flash_attn.py", line 493, in forward
(VllmWorkerProcess pid=3970554) ERROR 06-03 20:09:38 [multiproc_worker_utils.py:238]     self._dual_chunk_flash_attn_prefill(
(VllmWorkerProcess pid=3970554) ERROR 06-03 20:09:38 [multiproc_worker_utils.py:238]   File "/data/shanhaikang.shk/vllm/vllm/attention/backends/dual_chunk_flash_attn.py", line 673, in _dual_chunk_flash_attn_prefill
(VllmWorkerProcess pid=3970554) ERROR 06-03 20:09:38 [multiproc_worker_utils.py:238]     current_out = self._dual_chunk_flash_attn_prefill_func(
(VllmWorkerProcess pid=3970554) ERROR 06-03 20:09:38 [multiproc_worker_utils.py:238]                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=3970554) ERROR 06-03 20:09:38 [multiproc_worker_utils.py:238]   File "/data/shanhaikang.shk/vllm/vllm/attention/backends/dual_chunk_flash_attn.py", line 1055, in _dual_chunk_flash_attn_prefill_func
(VllmWorkerProcess pid=3970554) ERROR 06-03 20:09:38 [multiproc_worker_utils.py:238]     flash_result = self._do_flash_attn(
(VllmWorkerProcess pid=3970554) ERROR 06-03 20:09:38 [multiproc_worker_utils.py:238]                    ^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=3970554) ERROR 06-03 20:09:38 [multiproc_worker_utils.py:238]   File "/data/shanhaikang.shk/vllm/vllm/attention/backends/dual_chunk_flash_attn.py", line 1207, in _do_flash_attn
(VllmWorkerProcess pid=3970554) ERROR 06-03 20:09:38 [multiproc_worker_utils.py:238]     output, softmax_lse = flash_attn_varlen_func(
(VllmWorkerProcess pid=3970554) ERROR 06-03 20:09:38 [multiproc_worker_utils.py:238]                           ^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=3970554) ERROR 06-03 20:09:38 [multiproc_worker_utils.py:238]   File "/data/shanhaikang.shk/vllm/vllm/vllm_flash_attn/flash_attn_interface.py", line 204, in flash_attn_varlen_func
(VllmWorkerProcess pid=3970554) ERROR 06-03 20:09:38 [multiproc_worker_utils.py:238]     assert block_table is None or seqused_k is not None, \
(VllmWorkerProcess pid=3970554) ERROR 06-03 20:09:38 [multiproc_worker_utils.py:238]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=3970554) ERROR 06-03 20:09:38 [multiproc_worker_utils.py:238] AssertionError: seqused_k must be provided if block_table is provided

Proposed Fix

This issue is introduced in #11844 .

Since the key_states and value_states are directly retrieved from the KV cache through the block_table, setting block_table is both wrong and unnecessary.

This issue has already been fixed on main, by #21364, and I can confirm that the original cases works on main, closing.

@sighingnow sighingnow closed this Jul 30, 2025
@sighingnow sighingnow added bug Something isn't working and removed documentation Improvements or additions to documentation labels Jul 30, 2025
@mergify mergify bot added the documentation Improvements or additions to documentation label Jul 30, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working documentation Improvements or additions to documentation qwen Related to Qwen models

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants