Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Complete v1 logprobs support #9880

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion tests/samplers/test_logprobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest
import torch

from tests.kernels.utils import override_backend_env_variable
from vllm import SamplingParams

from ..conftest import VllmRunner
Expand All @@ -12,10 +13,11 @@

@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype",
["float"]) # needed for comparing logprobs with HF
["half"]) # needed for comparing logprobs with HF
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16, -1])
@pytest.mark.parametrize("num_top_logprobs", [0, 6]) # 32000 == vocab_size
@pytest.mark.parametrize("detokenize", [True, False])
@pytest.mark.parametrize("vllm_use_v1", [True, False])
def test_get_prompt_logprobs(
hf_runner,
vllm_runner,
Expand All @@ -24,8 +26,18 @@ def test_get_prompt_logprobs(
chunked_prefill_token_size: int,
num_top_logprobs: int,
detokenize: bool,
vllm_use_v1: bool,
example_prompts,
monkeypatch,
):
if vllm_use_v1:
# LLM engine v1
monkeypatch.setenv("VLLM_USE_V1", "1")
override_backend_env_variable(monkeypatch, "FLASH_ATTN")
else:
# LLM engine v0
monkeypatch.setenv("VLLM_USE_V1", "0")

max_num_seqs = 256
enable_chunked_prefill = False
max_num_batched_tokens = None
Expand Down
17 changes: 17 additions & 0 deletions vllm/v1/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from vllm.logger import init_logger
from vllm.multimodal import MultiModalDataDict
from vllm.sampling_params import SamplingParams
from vllm.sequence import Logprob
from vllm.v1.core.kv_cache_manager import KVCacheManager
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
Expand Down Expand Up @@ -230,6 +231,12 @@ def update_from_output(
) -> List[Tuple[Request, int]]:
# NOTE(woosuk): This method doesn't consider speculative decoding.
sampled_token_ids = model_runner_output.sampled_token_ids_cpu.tolist()
do_logprobs = not (model_runner_output.logprob_token_ids_cpu is None
or model_runner_output.logprobs_cpu is None)
if do_logprobs:
logprob_token_ids_list = (
model_runner_output.logprob_token_ids_cpu.tolist())
logprob_values_list = (model_runner_output.logprobs_cpu.tolist())
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
new_running: List[Request] = []
# (request, num_sampled_tokens)
Expand All @@ -246,6 +253,16 @@ def update_from_output(
# NOTE(woosuk): Currently, we assume that each request
# generates at most one token at each step.
token_id = sampled_token_ids[req_index]
if do_logprobs and request.max_logprobs is not None:
# Construct logprobs, if requested
logprob_token_ids = logprob_token_ids_list[req_index]
logprob_values = logprob_values_list[req_index]
logprobs = {
lpt: Logprob(lpv, (idx + 1), None)
for idx, (lpv, lpt) in enumerate(
zip(logprob_values, logprob_token_ids))
}
request.logprobs.append(logprobs)
request.output_token_ids.append(token_id)
sampled.append((request, 1))
# TODO: Update the KV cache manager for prefix caching.
Expand Down
18 changes: 15 additions & 3 deletions vllm/v1/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def _add_processed_request(
# TODO(woosuk): Check max_logprobs
# TODO(woosuk): Support encoder-decoder models.
req = Request(request_id, processed_inputs, params, eos_token_id,
arrival_time)
arrival_time, params.logprobs, params.prompt_logprobs)
self.requests[request_id] = req
self.num_lagged_steps[request_id] = 0
self.scheduler.add_request(req)
Expand Down Expand Up @@ -393,14 +393,16 @@ def _make_request_output(
finished: bool,
) -> RequestOutput:
req_output = self.request_outputs.get(request.request_id)
do_logprobs = request.max_logprobs is not None
do_prompt_logprobs = request.max_prompt_logprobs is not None
if req_output is None:
# TODO: Support `n` > 1.
completion_output = CompletionOutput(
index=0,
text="",
token_ids=[],
cumulative_logprob=None,
logprobs=None, # TODO
logprobs=[] if do_logprobs else None,
finish_reason=None,
stop_reason=None,
lora_request=None,
Expand All @@ -409,7 +411,7 @@ def _make_request_output(
request_id=request.request_id,
prompt=request.prompt,
prompt_token_ids=request.prompt_token_ids,
prompt_logprobs=None, # TODO
prompt_logprobs=[] if do_prompt_logprobs else None,
outputs=[completion_output],
finished=False,
metrics=None,
Expand All @@ -424,19 +426,29 @@ def _make_request_output(
completion_output.text += new_output_text
completion_output.token_ids = (
request.output_token_ids[:num_output_tokens])
if do_logprobs:
completion_output.logprobs = (
request.logprobs[:num_output_tokens])
elif request.sampling_params.output_kind == RequestOutputKind.DELTA:
completion_output.text = new_output_text
num_prev_tokens = len(completion_output.token_ids)
completion_output.token_ids = request.output_token_ids[
num_prev_tokens:num_output_tokens]
if do_logprobs:
completion_output.logprobs = (
request.logprobs[num_prev_tokens:num_output_tokens])
elif (request.sampling_params.output_kind ==
RequestOutputKind.FINAL_ONLY):
if finished:
completion_output.text = request.output_text
completion_output.token_ids = request.output_token_ids
if do_logprobs:
completion_output.logprobs = request.logprobs
else:
completion_output.text = ""
completion_output.token_ids = []
if do_logprobs:
completion_output.logprobs = []

if finished:
completion_output.finish_reason = request.get_finished_reason()
Expand Down
10 changes: 9 additions & 1 deletion vllm/v1/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from vllm.lora.request import LoRARequest
from vllm.sampling_params import SamplingParams
from vllm.sequence import RequestMetrics
from vllm.sequence import PromptLogprobs, RequestMetrics, SampleLogprobs

if TYPE_CHECKING:
from vllm.inputs import DecoderOnlyInputs
Expand All @@ -18,6 +18,8 @@ def __init__(
sampling_params: SamplingParams,
eos_token_id: Optional[int],
arrival_time: float,
max_logprobs: Optional[int],
max_prompt_logprobs: Optional[int],
lora_request: Optional[LoRARequest] = None,
) -> None:
self.request_id = request_id
Expand All @@ -42,6 +44,12 @@ def __init__(
self.num_prompt_tokens = len(self.prompt_token_ids)
self.output_token_ids: List[int] = []
self.output_text = ""
self.max_logprobs = max_logprobs
self.max_prompt_logprobs = max_prompt_logprobs
self.logprobs: Optional[SampleLogprobs] = (None if max_logprobs is None
else [])
self.prompt_logprobs: Optional[PromptLogprobs] = (
None if max_prompt_logprobs is None else [])
self.num_computed_tokens = 0

@property
Expand Down