Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: Failed to generate normal outputs on deepseek-vl2-tiny's MoE LM backbone #12015

Closed
1 task done
Isotr0py opened this issue Jan 13, 2025 · 0 comments · Fixed by #12067
Closed
1 task done

[Bug]: Failed to generate normal outputs on deepseek-vl2-tiny's MoE LM backbone #12015

Isotr0py opened this issue Jan 13, 2025 · 0 comments · Fixed by #12067
Labels
bug Something isn't working

Comments

@Isotr0py
Copy link
Collaborator

Your current environment

The output of `python collect_env.py`
Your output of `python collect_env.py` here

Model Input Dumps

No response

🐛 Describe the bug

When developing support for deepseek-vl2-tiny, I noticed the model was generating gibberish outputs except first prompt in the batch, even if set max_num_seqs=1:

The image features a view of cherry blossoms in the foreground with a prominent tower in the background. The sky is clear and blue, providing a vibrant backdrop to the pink blossoms. The tower appears to be a modern structure, possibly a communications or observation tower.
Theo بتكون antidepressoNames sesu Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginhawaitself-sight
The Tokyo Drama Drama Drama Drama Businesses the Tokyo Tower Records Tower Records Tower Records Businesses the Tokyo Tower of the DramaThe Drama the Drama the Drama the Drama the Drama the Drama the Drama the Drama the Drama the Drama the Drama the Drama the Drama the Drama the Drama the Drama the Drama the Drama the Drama the Drama the Drama
Theo بتكون antidepressoNames sesu Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginh Ginhawaitself-sight

After further investigation, I found that the problem occurred on the output of MoE LM's attention decode. Here is code to reproduce the issue from the extracted MoE LM:

from vllm import LLM, SamplingParams

# Sample prompts.
prompts = [
    "<|begin▁of▁sentence|>The future of AI is",
    "<|begin▁of▁sentence|>The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.0, max_tokens=2)

# Create an LLM.
llm = LLM(
    model="estrogen/DeepSeekMoE-3B",
    dtype="half",
    enforce_eager=True,
    trust_remote_code=True,
    hf_overrides={"architectures": ["DeepseekForCausalLM"]},
)
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

By running this code with xformers backend, in the first self-attention layer during decoding, the q and k input of Attention is same across the batch. However, the decode output became all zeros after first prompt in the batch:

q_after_rope:  (tensor([[ 0.7090, -0.3503, -0.2996,  ..., -0.0652, -1.4297,  1.3496]],
       device='cuda:0', dtype=torch.float16), tensor([[ 0.7090, -0.3503, -0.2996,  ..., -0.0652, -1.4297,  1.3496]],
       device='cuda:0', dtype=torch.float16))
k_after_rope:  (tensor([[ 0.4038, -0.1843, -0.5962,  ..., -0.2625,  1.2324, -0.1292]],
       device='cuda:0', dtype=torch.float16), tensor([[ 0.4038, -0.1843, -0.5962,  ..., -0.2625,  1.2324, -0.1292]],
       device='cuda:0', dtype=torch.float16))
decode_output:  torch.Size([2, 10, 128]) (tensor([[[ 0.0354, -0.0139,  0.0257,  ...,  0.0233, -0.0036, -0.0231],
         [-0.0368,  0.0082, -0.0165,  ...,  0.0092,  0.0364, -0.0365],
         [ 0.1521, -0.0135, -0.0307,  ..., -0.0020,  0.0911,  0.0950],
         ...,
         [-0.2415,  0.0032, -0.0009,  ...,  0.0008, -0.0039, -0.0035],
         [-0.0546, -0.0158,  0.0126,  ..., -0.0031,  0.0112, -0.0668],
         [-0.0004,  0.0056, -0.0066,  ..., -0.0038, -0.0072, -0.0069]]],
       device='cuda:0', dtype=torch.float16), tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]], device='cuda:0',
       dtype=torch.float16))

If using flash-attention backend on V0, an RuntimeError about incorrect k_cache shape will occur:

Error on V0 with FA
[rank0]: Traceback (most recent call last):
[rank0]:   File "/root/autodl-tmp/vllm/vllm/worker/model_runner_base.py", line 115, in _wrapper
[rank0]:     return func(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/autodl-tmp/vllm/vllm/worker/model_runner.py", line 1700, in execute_model
[rank0]:     hidden_or_intermediate_states = model_executable(
[rank0]:                                     ^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/autodl-tmp/vllm/vllm/model_executor/models/deepseek_vl2.py", line 632, in forward
[rank0]:     hidden_states = self.language_model(input_ids,
[rank0]:                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/autodl-tmp/vllm/vllm/model_executor/models/deepseek.py", line 431, in forward
[rank0]:     hidden_states = self.model(input_ids, positions, kv_caches,
[rank0]:                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/autodl-tmp/vllm/vllm/model_executor/models/deepseek.py", line 387, in forward
[rank0]:     hidden_states, residual = layer(positions, hidden_states,
[rank0]:                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/autodl-tmp/vllm/vllm/model_executor/models/deepseek.py", line 321, in forward
[rank0]:     hidden_states = self.self_attn(
[rank0]:                     ^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/autodl-tmp/vllm/vllm/model_executor/models/deepseek.py", line 255, in forward
[rank0]:     attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
[rank0]:                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/autodl-tmp/vllm/vllm/attention/layer.py", line 157, in forward
[rank0]:     torch.ops.vllm.unified_attention_with_output(
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/_ops.py", line 1116, in __call__
[rank0]:     return self._op(*args, **(kwargs or {}))
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/autodl-tmp/vllm/vllm/attention/layer.py", line 279, in unified_attention_with_output
[rank0]:     self.impl.forward(query,
[rank0]:   File "/root/autodl-tmp/vllm/vllm/attention/backends/flash_attn.py", line 810, in forward
[rank0]:     flash_attn_with_kvcache(
[rank0]:   File "/root/autodl-tmp/vllm/vllm/vllm_flash_attn/flash_attn_interface.py", line 411, in flash_attn_with_kvcache
[rank0]:     out, softmax_lse = torch.ops.vllm_flash_attn_c.fwd_kvcache(
[rank0]:                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/_ops.py", line 1116, in __call__
[rank0]:     return self._op(*args, **(kwargs or {}))
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: RuntimeError: kcache must have shape (num_blocks, page_block_size, num_heads_k, head_size_og)

[rank0]: The above exception was the direct cause of the following exception:

[rank0]: Traceback (most recent call last):
[rank0]:   File "/root/autodl-tmp/vllm/examples/offline_inference/vision_language.py", line 706, in <module>
[rank0]:     main(args)
[rank0]:   File "/root/autodl-tmp/vllm/examples/offline_inference/vision_language.py", line 657, in main
[rank0]:     outputs = llm.generate(inputs, sampling_params=sampling_params)
[rank0]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/autodl-tmp/vllm/vllm/utils.py", line 1079, in inner
[rank0]:     return fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/autodl-tmp/vllm/vllm/entrypoints/llm.py", line 455, in generate
[rank0]:     outputs = self._run_engine(use_tqdm=use_tqdm)
[rank0]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/autodl-tmp/vllm/vllm/entrypoints/llm.py", line 1237, in _run_engine
[rank0]:     step_outputs = self.llm_engine.step()
[rank0]:                    ^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/autodl-tmp/vllm/vllm/engine/llm_engine.py", line 1394, in step
[rank0]:     outputs = self.model_executor.execute_model(
[rank0]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/autodl-tmp/vllm/vllm/executor/gpu_executor.py", line 88, in execute_model
[rank0]:     output = self.driver_worker.execute_model(execute_model_req)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/autodl-tmp/vllm/vllm/worker/worker_base.py", line 345, in execute_model
[rank0]:     output = self.model_runner.execute_model(
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/autodl-tmp/vllm/vllm/worker/model_runner_base.py", line 151, in _wrapper
[rank0]:     raise type(err)(
[rank0]: RuntimeError: Error in model execution (input dumped to /tmp/err_execute_model_input_20250114-004323.pkl): kcache must have shape (num_blocks, page_block_size, num_heads_k, head_size_og)

A similar error will also occur if I switch to V1:

Error on V1 with FA
[rank0]: Traceback (most recent call last):
[rank0]:   File "/root/autodl-tmp/vllm/vllm/worker/model_runner_base.py", line 115, in _wrapper
[rank0]:     return func(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/autodl-tmp/vllm/vllm/worker/model_runner.py", line 1700, in execute_model
[rank0]:     hidden_or_intermediate_states = model_executable(
[rank0]:                                     ^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/autodl-tmp/vllm/vllm/model_executor/models/deepseek_vl2.py", line 632, in forward
[rank0]:     hidden_states = self.language_model(input_ids,
[rank0]:                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/autodl-tmp/vllm/vllm/model_executor/models/deepseek.py", line 431, in forward
[rank0]:     hidden_states = self.model(input_ids, positions, kv_caches,
[rank0]:                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/autodl-tmp/vllm/vllm/model_executor/models/deepseek.py", line 387, in forward
[rank0]:     hidden_states, residual = layer(positions, hidden_states,
[rank0]:                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/autodl-tmp/vllm/vllm/model_executor/models/deepseek.py", line 321, in forward
[rank0]:     hidden_states = self.self_attn(
[rank0]:                     ^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/autodl-tmp/vllm/vllm/model_executor/models/deepseek.py", line 255, in forward
[rank0]:     attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
[rank0]:                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/autodl-tmp/vllm/vllm/attention/layer.py", line 157, in forward
[rank0]:     torch.ops.vllm.unified_attention_with_output(
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/_ops.py", line 1116, in __call__
[rank0]:     return self._op(*args, **(kwargs or {}))
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/autodl-tmp/vllm/vllm/attention/layer.py", line 279, in unified_attention_with_output
[rank0]:     self.impl.forward(query,
[rank0]:   File "/root/autodl-tmp/vllm/vllm/attention/backends/flash_attn.py", line 810, in forward
[rank0]:     flash_attn_with_kvcache(
[rank0]:   File "/root/autodl-tmp/vllm/vllm/vllm_flash_attn/flash_attn_interface.py", line 411, in flash_attn_with_kvcache
[rank0]:     out, softmax_lse = torch.ops.vllm_flash_attn_c.fwd_kvcache(
[rank0]:                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/_ops.py", line 1116, in __call__
[rank0]:     return self._op(*args, **(kwargs or {}))
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: RuntimeError: kcache must have shape (num_blocks, page_block_size, num_heads_k, head_size_og)

[rank0]: The above exception was the direct cause of the following exception:

[rank0]: Traceback (most recent call last):
[rank0]:   File "/root/autodl-tmp/vllm/examples/offline_inference/vision_language.py", line 706, in <module>
[rank0]:     main(args)
[rank0]:   File "/root/autodl-tmp/vllm/examples/offline_inference/vision_language.py", line 657, in main
[rank0]:     outputs = llm.generate(inputs, sampling_params=sampling_params)
[rank0]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/autodl-tmp/vllm/vllm/utils.py", line 1079, in inner
[rank0]:     return fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/autodl-tmp/vllm/vllm/entrypoints/llm.py", line 455, in generate
[rank0]:     outputs = self._run_engine(use_tqdm=use_tqdm)
[rank0]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/autodl-tmp/vllm/vllm/entrypoints/llm.py", line 1237, in _run_engine
[rank0]:     step_outputs = self.llm_engine.step()
[rank0]:                    ^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/autodl-tmp/vllm/vllm/engine/llm_engine.py", line 1394, in step
[rank0]:     outputs = self.model_executor.execute_model(
[rank0]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/autodl-tmp/vllm/vllm/executor/gpu_executor.py", line 88, in execute_model
[rank0]:     output = self.driver_worker.execute_model(execute_model_req)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/autodl-tmp/vllm/vllm/worker/worker_base.py", line 345, in execute_model
[rank0]:     output = self.model_runner.execute_model(
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/miniconda3/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/root/autodl-tmp/vllm/vllm/worker/model_runner_base.py", line 151, in _wrapper
[rank0]:     raise type(err)(
[rank0]: RuntimeError: Error in model execution (input dumped to /tmp/err_execute_model_input_20250114-004323.pkl): kcache must have shape (num_blocks, page_block_size, num_heads_k, head_size_og)

Noted that this issue only occurred on deepseek-vl2-tiny's DeepSeek-V1 style MoE LM backbone, while other DeepSeek-V1 checkpoints don't have this issue

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant