Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: Error in how HiddenStates are handled for speculative decoding #7505

Closed
abhigoyal1997 opened this issue Aug 14, 2024 · 0 comments · Fixed by #7508
Closed

[Bug]: Error in how HiddenStates are handled for speculative decoding #7505

abhigoyal1997 opened this issue Aug 14, 2024 · 0 comments · Fixed by #7508
Labels
bug Something isn't working

Comments

@abhigoyal1997
Copy link
Contributor

abhigoyal1997 commented Aug 14, 2024

Your current environment

The output of `python collect_env.py`
Your output of `python collect_env.py` here

🐛 Describe the bug

In draft models like Medusa, MLPSpeculator etc., when spec. decode is disabled (e.g. when the num_tokens + spec_tokens > max_len of the model) HiddenStates are not handled properly which causes an invalid shape error.

How to reproduce?

Code:

from vllm import LLM, SamplingParams

llm = LLM(
    model="JackFram/llama-160m",
    speculative_model="ibm-fms/llama-160m-accelerator",
    num_speculative_tokens=3,
    use_v2_block_manager=True,
    enforce_eager=True,
)

prompt = "The president of the United States is"

output = llm.generate(prompt, SamplingParams(max_tokens=2048, ignore_eos=True))

Output:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
[<ipython-input-1-dfa52d56a4c5>](https://localhost:8080/#) in <cell line: 12>()
     10 prompt = "The president of the United States is"
     11 
---> 12 output = llm.generate(prompt, SamplingParams(max_tokens=2048, ignore_eos=True))

10 frames
[/usr/local/lib/python3.10/dist-packages/vllm/spec_decode/spec_decode_worker.py](https://localhost:8080/#) in _verify_tokens(self, seq_group_metadata_list, proposal_scores, proposals, max_proposal_len)
    645             # Contract hidden states based on accepted tokens
    646             hs_size = hidden_states.shape[1]
--> 647             hidden_states = hidden_states.reshape(-1, max_proposal_len + 1,
    648                                                   hs_size)
    649             accepted_index = accepted_token_ids + 1  # Convert -1 to 0

RuntimeError: shape '[-1, 4, 768]' is invalid for input of size 768

Code:

from vllm import LLM, SamplingParams

llm = LLM(
    model="JackFram/llama-68m",
    speculative_model="abhigoyal/vllm-medusa-llama-68m-random",
    num_speculative_tokens=3,
    use_v2_block_manager=True,
    enforce_eager=True,
)

prompt = "The president of the United States is"

output = llm.generate(prompt, SamplingParams(max_tokens=2048, ignore_eos=True))

Output:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
[<ipython-input-1-415db326cfe4>](https://localhost:8080/#) in <cell line: 12>()
     10 prompt = "The president of the United States is"
     11 
---> 12 output = llm.generate(prompt, SamplingParams(max_tokens=2048, ignore_eos=True))

10 frames
[/usr/local/lib/python3.10/dist-packages/vllm/spec_decode/spec_decode_worker.py](https://localhost:8080/#) in _verify_tokens(self, seq_group_metadata_list, proposal_scores, proposals, max_proposal_len)
    645             # Contract hidden states based on accepted tokens
    646             hs_size = hidden_states.shape[1]
--> 647             hidden_states = hidden_states.reshape(-1, max_proposal_len + 1,
    648                                                   hs_size)
    649             accepted_index = accepted_token_ids + 1  # Convert -1 to 0

RuntimeError: shape '[-1, 4, 768]' is invalid for input of size 768
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant