Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 88 additions & 1 deletion tests/v1/engine/test_output_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@
STOP_STRINGS,
DummyOutputProcessorTestVectors,
MockEngineCore)
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.sequence import PromptLogprobs, SampleLogprobs
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.output_processor import OutputProcessor
from vllm.v1.engine.output_processor import (OutputProcessor,
RequestOutputCollector)
from vllm.v1.metrics.stats import IterationStats


Expand Down Expand Up @@ -834,3 +836,88 @@ def test_iteration_stats(dummy_test_vectors):

assert iteration_stats.num_prompt_tokens == 0
assert iteration_stats.num_generation_tokens == num_active


@pytest.mark.asyncio
async def test_request_output_collector():
NUM_REQS = 3
TEXT = "a"

def make_outputs() -> list[RequestOutput]:
return [
RequestOutput(
request_id="my-request-id",
prompt=None,
prompt_token_ids=[1, 2, 3],
prompt_logprobs=None,
outputs=[
CompletionOutput(
index=0,
text=TEXT,
token_ids=[idx],
cumulative_logprob=(idx + 1 * 1.0),
logprobs=[{
"a": idx,
"b": idx
}],
finish_reason="length" if
(idx == NUM_REQS - 1) else None,
)
],
finished=(idx == NUM_REQS - 1),
) for idx in range(NUM_REQS)
]

collector = RequestOutputCollector(RequestOutputKind.DELTA)

# CASE 1: Put then get.
outputs = make_outputs()
collector.put(outputs[0])
output = await collector.get()
assert not collector.ready.is_set()
assert collector.output is None
assert output.outputs[0].text == "a"
assert output.outputs[0].token_ids == [0]

# CASE 2: 2 puts then get.
num_to_put = 2
outputs = make_outputs()
for i in range(num_to_put):
collector.put(outputs[i])
output = await collector.get()
assert not collector.ready.is_set()
assert collector.output is None

assert not output.finished
# Text, token_ids, and logprobs should get merged.
assert output.outputs[0].text == TEXT * num_to_put
for tok_0, tok_1 in zip(output.outputs[0].token_ids,
list(range(num_to_put))):
assert tok_0 == tok_1
assert len(output.outputs[0].logprobs) == num_to_put

# Cumulative logprobs should be the last one.
cumulative_logprob_expected = 1.0 * num_to_put
assert output.outputs[0].cumulative_logprob == cumulative_logprob_expected

# CASE 3: Put all 3 (including a finished).
num_to_put = 3
outputs = make_outputs()
for i in range(num_to_put):
collector.put(outputs[i])
output = await collector.get()
assert not collector.ready.is_set()
assert collector.output is None

assert output.finished
assert output.outputs[0].finish_reason == "length"
# Text, token_ids, and logprobs should get merged.
assert output.outputs[0].text == TEXT * num_to_put
for tok_0, tok_1 in zip(output.outputs[0].token_ids,
list(range(num_to_put))):
assert tok_0 == tok_1
assert len(output.outputs[0].logprobs) == num_to_put

# Cumulative logprobs should be the last one.
cumulative_logprob_expected = 1.0 * num_to_put
assert output.outputs[0].cumulative_logprob == cumulative_logprob_expected
34 changes: 14 additions & 20 deletions vllm/v1/engine/async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,15 @@
from vllm.outputs import RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.usage.usage_lib import UsageContext
from vllm.utils import Device, cdiv, kill_process_tree
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.engine.output_processor import OutputProcessor
from vllm.v1.engine.output_processor import (OutputProcessor,
RequestOutputCollector)
from vllm.v1.engine.parallel_sampling import ParentRequest
from vllm.v1.engine.processor import Processor
from vllm.v1.executor.abstract import Executor
Expand Down Expand Up @@ -176,11 +177,14 @@ async def add_request(
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> asyncio.Queue[RequestOutput]:
) -> RequestOutputCollector:
"""Add new request to the AsyncLLM."""

# Create a new output queue for the request.
queue: asyncio.Queue[RequestOutput] = asyncio.Queue()
assert isinstance(params, SamplingParams), \
"Pooling is not supported in V1"

# Create a new output collector for the request.
queue = RequestOutputCollector(output_kind=params.output_kind)

# Convert Input --> Request.
request = self.processor.process_inputs(request_id, prompt, params,
Expand All @@ -189,25 +193,23 @@ async def add_request(
prompt_adapter_request,
priority)

n = params.n if isinstance(params, SamplingParams) else 1

if n == 1:
if params.n == 1:
await self._add_request(request, None, 0, queue)
return queue

# Fan out child requests (for n>1).
parent_request = ParentRequest(request_id, params)
for idx in range(n):
for idx in range(params.n):
request_id, params = parent_request.get_child_info(idx)
child_request = request if idx == n - 1 else copy(request)
child_request = request if idx == params.n - 1 else copy(request)
child_request.request_id = request_id
child_request.sampling_params = params
await self._add_request(child_request, parent_request, idx, queue)
return queue

async def _add_request(self, request: EngineCoreRequest,
parent_req: Optional[ParentRequest], index: int,
queue: asyncio.Queue[RequestOutput]):
queue: RequestOutputCollector):

# Add the request to OutputProcessor (this process).
self.output_processor.add_request(request, parent_req, index, queue)
Expand Down Expand Up @@ -272,15 +274,7 @@ async def generate(
while not finished:
# Note: drain queue without await if possible (avoids
# task switching under load which helps performance).
out = q.get_nowait() if not q.empty() else await q.get()

# Coalesce any additional queued outputs
while not q.empty():
next_out = q.get_nowait()
if sampling_params.output_kind == RequestOutputKind.DELTA:
out.add(next_out)
else:
out = next_out
out = q.get_nowait() or await q.get()

# Note: both OutputProcessor and EngineCore handle their
# own request cleanup based on finished.
Expand Down
48 changes: 44 additions & 4 deletions vllm/v1/engine/output_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,46 @@
RequestStateStats)


class RequestOutputCollector:
"""
Collects streamed RequestOutputs per individual request,
for hand-off to the consuming asyncio generate task.

When streaming deltas, RequestOutputs are merged if the
producer gets ahead of the consumer.
"""

def __init__(self, output_kind: RequestOutputKind):
self.aggregate = output_kind == RequestOutputKind.DELTA
self.output: Optional[RequestOutput] = None
self.ready = asyncio.Event()

def put(self, output: RequestOutput) -> None:
if self.output is None:
self.output = output
self.ready.set()
elif self.aggregate:
# Coalesce the outputs in delta case.
self.output.add(output)
else:
# Just replace latest in non-delta case.
self.output = output

async def get(self) -> RequestOutput:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think we should have an invariant that output is not None if self.ready.wait() is true?

Copy link
Member Author

@njhill njhill Mar 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is the case but I'm not sure what you're suggesting to add here? self.ready.wait() just waits for the condition to be set, it can only ever return True (not even sure why it returns that rather than None). And then we immediately check self.output again before continuing.

while (output := self.output) is None:
await self.ready.wait()
self.output = None
self.ready.clear()
return output

def get_nowait(self) -> Optional[RequestOutput]:
output = self.output
if output is not None:
self.output = None
self.ready.clear()
return output


@dataclass
class OutputProcessorOutput:

Expand All @@ -39,7 +79,7 @@ def __init__(
detokenizer: IncrementalDetokenizer,
max_tokens_param: Optional[int],
arrival_time: float,
queue: Optional[asyncio.Queue[RequestOutput]],
queue: Optional[RequestOutputCollector],
log_stats: bool,
):
self.request_id = request_id
Expand All @@ -66,7 +106,7 @@ def from_new_request(
request: EngineCoreRequest,
parent_req: Optional[ParentRequest],
request_index: int,
queue: Optional[asyncio.Queue[RequestOutput]],
queue: Optional[RequestOutputCollector],
log_stats: bool,
) -> "RequestState":
if not request.sampling_params.detokenize:
Expand Down Expand Up @@ -217,7 +257,7 @@ def add_request(
request: EngineCoreRequest,
parent_req: Optional[ParentRequest] = None,
request_index: int = 0,
queue: Optional[asyncio.Queue[RequestOutput]] = None,
queue: Optional[RequestOutputCollector] = None,
) -> None:
request_id = request.request_id
if request_id in self.request_states:
Expand Down Expand Up @@ -300,7 +340,7 @@ def process_outputs(
new_token_ids, finish_reason, stop_reason):
if req_state.queue is not None:
# AsyncLLM: put into queue for handling by generate().
req_state.queue.put_nowait(request_output)
req_state.queue.put(request_output)
else:
# LLMEngine: return list of RequestOutputs.
request_outputs.append(request_output)
Expand Down
Loading