Skip to content

Commit

Permalink
[BugFix] Fix cuda graph for MLPSpeculator (vllm-project#5875)
Browse files Browse the repository at this point in the history
Co-authored-by: Abhinav Goyal <abhinav.goyal@flipkart.com>
  • Loading branch information
2 people authored and jimpang committed Jul 24, 2024
1 parent 9400cc4 commit d9925c2
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
1 change: 0 additions & 1 deletion examples/offline_inference_mlpspeculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ def time_generation(llm: LLM, prompts: List[str],
speculative_model="ibm-fms/llama-13b-accelerator",
# These are currently required for MLPSpeculator decoding
use_v2_block_manager=True,
enforce_eager=True,
)

print("With speculation")
Expand Down
9 changes: 6 additions & 3 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1020,10 +1020,13 @@ def execute_model(

if self.return_hidden_states:
# we only need to pass hidden states of most recent token
assert model_input.sampling_metadata is not None
indices = model_input.sampling_metadata.selected_token_indices
if model_input.is_prompt:
assert model_input.sampling_metadata is not None
hidden_states = hidden_states.index_select(
0, model_input.sampling_metadata.selected_token_indices)
hidden_states = hidden_states.index_select(0, indices)
elif decode_meta.use_cuda_graph:
hidden_states = hidden_states[:len(indices)]

output.hidden_states = hidden_states

return output
Expand Down

0 comments on commit d9925c2

Please sign in to comment.