Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions vllm/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,11 @@ class SpanAttributes:
# forward, block/sync across workers, cpu-gpu sync time and sampling time.
GEN_AI_LATENCY_TIME_IN_MODEL_EXECUTE = (
"gen_ai.latency.time_in_model_execute")
GEN_AI_LATENCY_TIME_IN_MODEL_PREFILL = \
"gen_ai.latency.time_in_model_prefill"
GEN_AI_LATENCY_TIME_IN_MODEL_DECODE = "gen_ai.latency.time_in_model_decode"
GEN_AI_LATENCY_TIME_IN_MODEL_INFERENCE = \
"gen_ai.latency.time_in_model_inference"


def contains_trace_headers(headers: Mapping[str, str]) -> bool:
Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -860,9 +860,9 @@ def update_from_output(
stop_reason=request.stop_reason,
events=request.take_events(),
kv_transfer_params=kv_transfer_params,
trace_headers=request.trace_headers,
num_cached_tokens=request.num_cached_tokens,
))

else:
# Invariant: EngineCore returns no partial prefill outputs.
assert not prompt_logprobs_tensors
Expand Down
7 changes: 5 additions & 2 deletions vllm/v1/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import enum
import time
from collections.abc import Sequence
from collections.abc import Mapping, Sequence
from typing import Any, Optional, Union

import msgspec
Expand Down Expand Up @@ -70,6 +70,8 @@ class EngineCoreRequest(
current_wave: int = 0
priority: int = 0

trace_headers: Optional[Mapping[str, str]] = None


class EngineCoreEventType(enum.IntEnum):
"""The type of engine core request event."""
Expand Down Expand Up @@ -115,6 +117,7 @@ class EngineCoreOutput(
events: Optional[list[EngineCoreEvent]] = None
kv_transfer_params: Optional[dict[str, Any]] = None

trace_headers: Optional[Mapping[str, str]] = None
# The number of tokens with prefix cache hits.
num_cached_tokens: int = 0

Expand All @@ -141,7 +144,7 @@ class EngineCoreOutputs(
omit_defaults=True, # type: ignore[call-arg]
gc=False): # type: ignore[call-arg]

#NOTE(Nick): We could consider ways to make this more compact,
# NOTE(Nick): We could consider ways to make this more compact,
# e.g. columnwise layout

engine_index: int = 0
Expand Down
83 changes: 78 additions & 5 deletions vllm/v1/engine/output_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,19 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import asyncio
import time
from collections.abc import Iterable
from dataclasses import dataclass
from typing import Any, Optional, Union, cast

import torch

from vllm.config import ObservabilityConfig
from vllm.outputs import (CompletionOutput, PoolingOutput,
PoolingRequestOutput, RequestOutput)
from vllm.sampling_params import RequestOutputKind
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
init_tracer)
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason
Expand Down Expand Up @@ -274,16 +278,26 @@ def _new_pooling_output(
class OutputProcessor:
"""Process EngineCoreOutputs into RequestOutputs."""

def __init__(
self,
tokenizer: TokenizerGroup,
log_stats: bool,
):
def __init__(self,
tokenizer: TokenizerGroup,
log_stats: bool,
observability_config: Optional[ObservabilityConfig] = None):
self.log_stats = log_stats
self.tokenizer = tokenizer
self.request_states: dict[str, RequestState] = {}
self.parent_requests: dict[str, ParentRequest] = {}
self.lora_states = LoRARequestStates()
self.observability_config = observability_config

self.tracer = None
if (self.observability_config is not None
and self.observability_config.otlp_traces_endpoint):
self.tracer = init_tracer(
"vllm.llm_engine",
self.observability_config.otlp_traces_endpoint)

def is_tracing_enabled(self) -> bool:
return self.tracer is not None

def get_num_unfinished_requests(self):
return len(self.request_states)
Expand Down Expand Up @@ -440,6 +454,65 @@ def process_outputs(
reqs_to_abort=reqs_to_abort,
)

def do_tracing(self, engine_core_output: EngineCoreOutput,
req_state: RequestState,
iteration_stats: Optional[IterationStats]):
if (engine_core_output.finish_reason is None or iteration_stats is None
or req_state is None or req_state.stats is None
or self.tracer is None):
return
arrival_time_nano_seconds = int(req_state.stats.arrival_time * 1e9)

trace_context = extract_trace_context(engine_core_output.trace_headers)
with self.tracer.start_as_current_span(
"llm_request",
kind=SpanKind.SERVER,
context=trace_context,
start_time=arrival_time_nano_seconds) as span:
metrics = req_state.stats
ttft = metrics.first_token_ts - metrics.arrival_time
e2e_time = time.time() - metrics.arrival_time
# Queued interval is from first QUEUED event to first SCHEDULED
queued_time = metrics.scheduled_ts - metrics.queued_ts

# Prefill interval is from first SCHEDULED to first NEW_TOKEN
# Any preemptions during prefill is included in the interval
prefill_time = metrics.first_token_ts - metrics.scheduled_ts

# Decode interval is from first NEW_TOKEN to last NEW_TOKEN
# Any preemptions during decode are included
decode_time = metrics.last_token_ts - metrics.first_token_ts

# Inference interval is from first SCHEDULED to last NEW_TOKEN
# Any preemptions during prefill or decode are included
inference_time = metrics.last_token_ts - metrics.scheduled_ts
span.set_attribute(SpanAttributes.GEN_AI_RESPONSE_MODEL,
self.tokenizer.tokenizer_id)
span.set_attribute(SpanAttributes.GEN_AI_REQUEST_ID,
req_state.request_id)
span.set_attribute(SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS,
req_state.max_tokens_param)
span.set_attribute(SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS,
len(req_state.prompt_token_ids))
span.set_attribute(SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS,
metrics.num_generation_tokens)
span.set_attribute(SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE,
metrics.queued_ts - metrics.arrival_time)
span.set_attribute(
SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN, ttft)
span.set_attribute(SpanAttributes.GEN_AI_LATENCY_E2E, e2e_time)
span.set_attribute(SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE,
queued_time)
span.set_attribute(
SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_PREFILL,
prefill_time)
span.set_attribute(
SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_DECODE,
decode_time)
span.set_attribute(
SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_INFERENCE,
inference_time)

def _update_stats_from_output(self, req_state: RequestState,
engine_core_output: EngineCoreOutput,
engine_core_timestamp: Optional[float],
Expand Down
2 changes: 0 additions & 2 deletions vllm/v1/engine/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,8 +225,6 @@ def process_inputs(
# TODO(woosuk): Support encoder-decoder models.
self._validate_lora(lora_request)
self._validate_params(params, lora_request)
if trace_headers is not None:
raise ValueError("V1 does not support tracing yet.")
if prompt_adapter_request is not None:
raise ValueError("V1 does not support prompt_adapter_request.")

Expand Down
6 changes: 5 additions & 1 deletion vllm/v1/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import enum
import time
from collections.abc import Mapping
from typing import TYPE_CHECKING, Any, Optional, Union

from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
Expand Down Expand Up @@ -36,6 +37,7 @@ def __init__(
structured_output_request: Optional["StructuredOutputRequest"] = None,
cache_salt: Optional[str] = None,
priority: int = 0,
trace_headers: Optional[Mapping[str, str]] = None,
) -> None:
self.request_id = request_id
self.client_index = client_index
Expand Down Expand Up @@ -98,7 +100,8 @@ def __init__(
# they should also be updated simultaneously.
self.output_token_ids = ConstantList(self._output_token_ids)
self.all_token_ids = ConstantList(self._all_token_ids)

# trace_headers
self.trace_headers = trace_headers
# State
# The number of tokens with prefix cache hits.
self.num_cached_tokens = -1
Expand Down Expand Up @@ -131,6 +134,7 @@ def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request":
if request.sampling_params else None,
cache_salt=request.cache_salt,
priority=request.priority,
trace_headers=request.trace_headers,
)

def append_output_token_ids(
Expand Down