Skip to content

Commit

Permalink
clean up and bash format
Browse files Browse the repository at this point in the history
Signed-off-by: NickLucche <nlucches@redhat.com>
  • Loading branch information
NickLucche committed Nov 27, 2024
1 parent ade42dc commit 88bba37
Show file tree
Hide file tree
Showing 12 changed files with 182 additions and 172 deletions.
2 changes: 1 addition & 1 deletion tests/spec_decode/e2e/conftest.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from itertools import cycle
from typing import List, Optional, Sequence, Tuple, Union
import torch

import pytest
import torch

from vllm import LLM, SamplingParams
from vllm.distributed import cleanup_dist_env_and_memory
Expand Down
7 changes: 2 additions & 5 deletions tests/spec_decode/e2e/test_logprobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,11 @@
"speculative_model": "JackFram/llama-160m",
"num_speculative_tokens": 3,
"disable_logprobs_during_spec_decoding": False,
},
{
# TODO HERE fails when disabling logprobs and still requesting them
}, {
"speculative_model": "JackFram/llama-160m",
"num_speculative_tokens": 3,
"disable_logprobs_during_spec_decoding": True,
}
])
}])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize(
"output_len",
Expand Down
26 changes: 11 additions & 15 deletions tests/spec_decode/e2e/test_mlp_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@

from vllm.model_executor.layers.vocab_parallel_embedding import pad_vocab_size

from .conftest import run_equality_correctness_test
from ..utils import maybe_enable_chunked_prefill
from .conftest import run_equality_correctness_test

# main model
MAIN_MODEL = "JackFram/llama-160m"
Expand Down Expand Up @@ -74,7 +74,7 @@ def test_mlp_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int,
seed: int, prefill_chunk_size:int):
seed: int, prefill_chunk_size: int):
"""Verify greedy equality with different batch size."""
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
run_equality_correctness_test(vllm_runner,
Expand Down Expand Up @@ -130,7 +130,7 @@ def test_mlp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
# NOTE Test is sensitive enough st if we don't enable chunked prefill
# scheduling on baseline too, we get slightly different logprobs, ending
# up sampling different tokens at the tail (ie top tokens don't change).
# TL;DR: sd+cp == org+cp but sd+cp != org..is this expected?
# TL;DR: sd+cp == org+cp but sd+cp != org..is this expected?
maybe_enable_chunked_prefill(prefill_chunk_size, baseline_llm_kwargs)
run_equality_correctness_test(vllm_runner,
common_llm_kwargs,
Expand Down Expand Up @@ -176,7 +176,7 @@ def test_mlp_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
def test_mlp_e2e_acceptance_rate(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int,
batch_size: int, output_len: int,
prefill_chunk_size: int, seed: int):
"""Verify acceptance rate with different batch size and large output
length."""
Expand Down Expand Up @@ -223,7 +223,7 @@ def test_mlp_e2e_seeded_correctness(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs,
batch_size: int, output_len: int,
temperature: float,
temperature: float,
prefill_chunk_size: int, seed: int):
"""Verify seeded runs produce the same output."""
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
Expand Down Expand Up @@ -288,8 +288,7 @@ def test_mlp_e2e_seeded_correctness(vllm_runner, common_llm_kwargs,
def test_mlp_e2e_greedy_correctness_with_preemption(
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
prefill_chunk_size: int,
seed: int):
prefill_chunk_size: int, seed: int):
"""Verify greedy equality, even when some sequences are preempted mid-
generation.
"""
Expand Down Expand Up @@ -341,8 +340,7 @@ def test_mlp_e2e_greedy_correctness_with_preemption(
def test_mlp_e2e_greedy_correctness_with_padding(
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
prefill_chunk_size: int,
seed: int):
prefill_chunk_size: int, seed: int):
"""Verify greedy equality when the vocab dimension is padded
"""
maybe_enable_chunked_prefill(prefill_chunk_size, test_llm_kwargs)
Expand Down Expand Up @@ -400,9 +398,8 @@ def patched_pad_vocab_size(vocab_size, pad_to=None):
@pytest.mark.parametrize("seed", [1])
def test_mlp_different_k(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, baseline_llm_kwargs,
test_llm_kwargs, batch_size: int, prefill_chunk_size: int,
seed: int,
output_len: int):
test_llm_kwargs, batch_size: int,
prefill_chunk_size: int, seed: int, output_len: int):
"""Verify that mlp speculative decoding produces exact equality
to without spec decode with different values of num_speculative_tokens.
"""
Expand Down Expand Up @@ -449,9 +446,8 @@ def test_mlp_different_k(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("seed", [1])
def test_mlp_disable_queue(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, baseline_llm_kwargs,
test_llm_kwargs, batch_size: int,
prefill_chunk_size: int,
seed: int,
test_llm_kwargs, batch_size: int,
prefill_chunk_size: int, seed: int,
output_len: int):
"""Verify that mlp speculative decoding produces exact equality
to without spec decode when speculation is disabled for large
Expand Down
30 changes: 14 additions & 16 deletions tests/spec_decode/e2e/test_multistep_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,22 +149,20 @@ def test_spec_decode_e2e_with_detokenization(test_llm_generator,
},
])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"enable_chunked_prefill": False,
"disable_logprobs_during_spec_decoding": False
},
{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 3,
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4,
"disable_logprobs_during_spec_decoding": False
}
])
@pytest.mark.parametrize("test_llm_kwargs",
[{
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"enable_chunked_prefill": False,
"disable_logprobs_during_spec_decoding": False
}, {
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 3,
"enable_chunked_prefill": True,
"max_num_batched_tokens": 4,
"max_num_seqs": 4,
"disable_logprobs_during_spec_decoding": False
}])
@pytest.mark.parametrize(
"output_len",
[
Expand Down
3 changes: 2 additions & 1 deletion tests/spec_decode/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,7 @@ def create_batch(batch_size,
prev_output_tokens, seq_ids)
return seq_group_metadata_list, prompts, prev_output_tokens


def maybe_enable_chunked_prefill(prefill_chunk_size, llm_kwargs):
if prefill_chunk_size > 0:
llm_kwargs.update(
Expand All @@ -284,4 +285,4 @@ def maybe_enable_chunked_prefill(prefill_chunk_size, llm_kwargs):
"max_num_seqs": prefill_chunk_size
})
else:
llm_kwargs["enable_chunked_prefill"] = False
llm_kwargs["enable_chunked_prefill"] = False
6 changes: 0 additions & 6 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1397,12 +1397,6 @@ def maybe_create_spec_config(
f"num_speculative_tokens={n_predict}, but "
f"{num_speculative_tokens=} was provided.")

if enable_chunked_prefill and draft_hf_config.model_type in (
"medusa", "mlp_speculator", "eagle"):
raise ValueError(
"Chunked prefill and hidden-state based draft models are "
"not compatible.")

speculative_draft_tensor_parallel_size = \
SpeculativeConfig._verify_and_get_draft_model_tensor_parallel_size(
target_parallel_config,
Expand Down
8 changes: 5 additions & 3 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1092,18 +1092,20 @@ def _process_model_outputs(self,
# Organize outputs by [step][sequence group] instead of
# [sequence group][step].
# TODO Review when implementing Chunked Prefill+MultiStep
# Decodes are multi-steps while prefills are not, outputting at
# Decodes are multi-steps while prefills are not, outputting at
# most 1 token. Separate them so that we can trigger regular chunk
# processing without having to pad or copy over prompts K times to
# match decodes structure.
num_prefills = sum(sg.is_prompt for sg in seq_group_metadata_list)
prefills, decodes = outputs[:num_prefills], outputs[num_prefills:]
outputs_by_sequence_group = create_output_by_sequence_group(
decodes, num_seq_groups=len(seq_group_metadata_list)-num_prefills)
decodes,
num_seq_groups=len(seq_group_metadata_list) - num_prefills)
# We have outputs for multiple steps submitted in a single burst,
# so invalidate is_first_step_output.
is_first_step_output = None
outputs_by_sequence_group=[p.outputs for p in prefills] + outputs_by_sequence_group
outputs_by_sequence_group = [p.outputs for p in prefills
] + outputs_by_sequence_group
else:
outputs_by_sequence_group = outputs

Expand Down
73 changes: 40 additions & 33 deletions vllm/spec_decode/batch_expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import (VLLM_INVALID_TOKEN_ID, VLLM_TOKEN_ID_ARRAY_TYPE,
ExecuteModelRequest, SequenceData,
SequenceGroupMetadata, get_all_seq_ids, PromptLogprobs)
SequenceGroupMetadata, get_all_seq_ids)
from vllm.spec_decode.interfaces import (SpeculativeProposals,
SpeculativeScorer, SpeculativeScores)
from vllm.spec_decode.util import nvtx_range, split_batch_by_proposal_len
Expand Down Expand Up @@ -136,11 +136,12 @@ def _expand_batch(
num_scoring_tokens)

def _contract_batch(
self, contracted_seq_group_metadata_list: List[SequenceGroupMetadata],
target_sampler_output: SamplerOutput, proposals: SpeculativeProposals,
num_scoring_tokens: int, non_spec_indices: List[int],
spec_indices: List[int], k: int
) -> SpeculativeScores:
self,
contracted_seq_group_metadata_list: List[SequenceGroupMetadata],
target_sampler_output: SamplerOutput,
proposals: SpeculativeProposals, num_scoring_tokens: int,
non_spec_indices: List[int], spec_indices: List[int],
k: int) -> SpeculativeScores:
"""Contract the expanded batch back into its original size.
This maps the scores of speculative tokens back to their original
sequences.
Expand Down Expand Up @@ -191,25 +192,33 @@ def _contract_batch(
idx for idx in non_spec_indices
if contracted_seq_group_metadata_list[idx].do_sample
]
has_prompt_log = any(((sg.sampling_params.prompt_logprobs and sg.sampling_params.prompt_logprobs>0) for sg in contracted_seq_group_metadata_list))
# When prompt logprobs is enabled, lens of returned tensors go from
has_prompt_log = any((sg.sampling_params.prompt_logprobs
and sg.sampling_params.prompt_logprobs > 0)
for sg in contracted_seq_group_metadata_list)
# When prompt logprobs is enabled, lens of returned tensors go from
# n_sampled (requests with do_sample=True) to n_prompt+n_prefills.
# We adjust stride accordingly to get the generated tokens and
# We adjust stride accordingly to get the generated tokens and
# their probs, but pass on prompt_logprobs as is.
if not has_prompt_log:
# When promptlogs are not be returned, we can ignore non-terminal chunks.
# When prompt logprobs are not to be returned,
# we can ignore non-terminal chunks.
non_spec_indices = regular_sampling_indices

if len(non_spec_indices):
if has_prompt_log:
# Add all terminal chunks sizes as well as decodes with no
# Add all terminal chunks sizes as well as decodes with no
# speculation to get out tokens and skip over prompt ones.
seq_meta = contracted_seq_group_metadata_list
nospec_sizes = torch.tensor([seq_meta[i].token_chunk_size if seq_meta[i].is_prompt else 1 for i in non_spec_indices])
nospec_sampled_token_idxs = torch.cumsum(nospec_sizes, 0).add_(-1)
nospec_sizes = torch.tensor([
seq_meta[i].token_chunk_size
if seq_meta[i].is_prompt else 1 for i in non_spec_indices
])
nospec_sampled_token_idxs = torch.cumsum(nospec_sizes,
0).add_(-1)
else:
# In this case only sampled tokens are returned, select all.
nospec_sampled_token_idxs = list(range(len(non_spec_target_token_ids)))
nospec_sampled_token_idxs = list(
range(len(non_spec_target_token_ids)))

all_tokens[non_spec_indices, :1] = \
non_spec_target_token_ids[nospec_sampled_token_idxs].unsqueeze(1)
Expand All @@ -221,10 +230,13 @@ def _contract_batch(
assert non_spec_target_hidden_states is not None
all_hidden_states[non_spec_indices, :1, :] = \
non_spec_target_hidden_states[nospec_sampled_token_idxs].unsqueeze(1)

prompt_logprobs = None
if not self._scorer_worker.model_runner.disable_logprobs and has_prompt_log:
prompt_logprobs = [o.prompt_logprobs for o in target_sampler_output.outputs]
if (not self._scorer_worker.model_runner.disable_logprobs\
and has_prompt_log):
prompt_logprobs = [
o.prompt_logprobs for o in target_sampler_output.outputs
]

if spec_indices:
all_tokens[spec_indices] = target_token_ids
Expand All @@ -233,20 +245,17 @@ def _contract_batch(
if all_hidden_states is not None:
all_hidden_states[spec_indices] = target_hidden_states

return SpeculativeScores(
probs=all_probs,
token_ids=all_tokens,
logprobs=all_logprobs,
hidden_states=all_hidden_states,
prompt_logprobs=prompt_logprobs
)
return SpeculativeScores(probs=all_probs,
token_ids=all_tokens,
logprobs=all_logprobs,
hidden_states=all_hidden_states,
prompt_logprobs=prompt_logprobs)

def _contract_batch_all_spec(
self,
target_sampler_output: SamplerOutput,
proposals: SpeculativeProposals,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor,
Optional[torch.Tensor]]:
) -> SpeculativeScores:
"""Contract the expanded batch back into its original size.
This maps the scores of speculative tokens back to their original
sequences.
Expand All @@ -270,13 +279,11 @@ def _contract_batch_all_spec(
target_hidden_states = target_hidden_states.reshape(
*target_token_ids.shape, target_hidden_states.shape[-1])

return SpeculativeScores(
probs=target_probs,
token_ids=target_token_ids,
logprobs=target_logprobs,
hidden_states=target_hidden_states,
prompt_logprobs=None
)
return SpeculativeScores(probs=target_probs,
token_ids=target_token_ids,
logprobs=target_logprobs,
hidden_states=target_hidden_states,
prompt_logprobs=None)

def _create_scoring_model_input(
self,
Expand Down
5 changes: 2 additions & 3 deletions vllm/spec_decode/interfaces.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Optional, Set, List
from typing import List, Optional, Set

import torch

from vllm.sequence import ExecuteModelRequest
from vllm.sequence import ExecuteModelRequest, PromptLogprobs
from vllm.worker.worker_base import WorkerBase
from vllm.sequence import PromptLogprobs


@dataclass
Expand Down
Loading

0 comments on commit 88bba37

Please sign in to comment.