From e0bb71615008849424bc8a649363c74f88c1b74c Mon Sep 17 00:00:00 2001 From: Mu Huai Date: Wed, 2 Jul 2025 20:19:27 +0800 Subject: [PATCH] feat:trace v1 Signed-off-by: Mu Huai --- vllm/tracing.py | 5 ++ vllm/v1/core/sched/scheduler.py | 2 +- vllm/v1/engine/__init__.py | 7 ++- vllm/v1/engine/output_processor.py | 83 ++++++++++++++++++++++++++++-- vllm/v1/engine/processor.py | 2 - vllm/v1/request.py | 6 ++- 6 files changed, 94 insertions(+), 11 deletions(-) diff --git a/vllm/tracing.py b/vllm/tracing.py index 6a287d82be5f..7537e9901a04 100644 --- a/vllm/tracing.py +++ b/vllm/tracing.py @@ -119,6 +119,11 @@ class SpanAttributes: # forward, block/sync across workers, cpu-gpu sync time and sampling time. GEN_AI_LATENCY_TIME_IN_MODEL_EXECUTE = ( "gen_ai.latency.time_in_model_execute") + GEN_AI_LATENCY_TIME_IN_MODEL_PREFILL = \ + "gen_ai.latency.time_in_model_prefill" + GEN_AI_LATENCY_TIME_IN_MODEL_DECODE = "gen_ai.latency.time_in_model_decode" + GEN_AI_LATENCY_TIME_IN_MODEL_INFERENCE = \ + "gen_ai.latency.time_in_model_inference" def contains_trace_headers(headers: Mapping[str, str]) -> bool: diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index fe552db74e2f..23b3ace73c7b 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -860,9 +860,9 @@ def update_from_output( stop_reason=request.stop_reason, events=request.take_events(), kv_transfer_params=kv_transfer_params, + trace_headers=request.trace_headers, num_cached_tokens=request.num_cached_tokens, )) - else: # Invariant: EngineCore returns no partial prefill outputs. assert not prompt_logprobs_tensors diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index 921ccd708cdd..58aca430e7ee 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -3,7 +3,7 @@ import enum import time -from collections.abc import Sequence +from collections.abc import Mapping, Sequence from typing import Any, Optional, Union import msgspec @@ -70,6 +70,8 @@ class EngineCoreRequest( current_wave: int = 0 priority: int = 0 + trace_headers: Optional[Mapping[str, str]] = None + class EngineCoreEventType(enum.IntEnum): """The type of engine core request event.""" @@ -115,6 +117,7 @@ class EngineCoreOutput( events: Optional[list[EngineCoreEvent]] = None kv_transfer_params: Optional[dict[str, Any]] = None + trace_headers: Optional[Mapping[str, str]] = None # The number of tokens with prefix cache hits. num_cached_tokens: int = 0 @@ -141,7 +144,7 @@ class EngineCoreOutputs( omit_defaults=True, # type: ignore[call-arg] gc=False): # type: ignore[call-arg] - #NOTE(Nick): We could consider ways to make this more compact, + # NOTE(Nick): We could consider ways to make this more compact, # e.g. columnwise layout engine_index: int = 0 diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 2bcd61d1f0aa..b8cfc2c133d8 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -2,15 +2,19 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import asyncio +import time from collections.abc import Iterable from dataclasses import dataclass from typing import Any, Optional, Union, cast import torch +from vllm.config import ObservabilityConfig from vllm.outputs import (CompletionOutput, PoolingOutput, PoolingRequestOutput, RequestOutput) from vllm.sampling_params import RequestOutputKind +from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context, + init_tracer) from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer_group import TokenizerGroup from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason @@ -274,16 +278,26 @@ def _new_pooling_output( class OutputProcessor: """Process EngineCoreOutputs into RequestOutputs.""" - def __init__( - self, - tokenizer: TokenizerGroup, - log_stats: bool, - ): + def __init__(self, + tokenizer: TokenizerGroup, + log_stats: bool, + observability_config: Optional[ObservabilityConfig] = None): self.log_stats = log_stats self.tokenizer = tokenizer self.request_states: dict[str, RequestState] = {} self.parent_requests: dict[str, ParentRequest] = {} self.lora_states = LoRARequestStates() + self.observability_config = observability_config + + self.tracer = None + if (self.observability_config is not None + and self.observability_config.otlp_traces_endpoint): + self.tracer = init_tracer( + "vllm.llm_engine", + self.observability_config.otlp_traces_endpoint) + + def is_tracing_enabled(self) -> bool: + return self.tracer is not None def get_num_unfinished_requests(self): return len(self.request_states) @@ -440,6 +454,65 @@ def process_outputs( reqs_to_abort=reqs_to_abort, ) + def do_tracing(self, engine_core_output: EngineCoreOutput, + req_state: RequestState, + iteration_stats: Optional[IterationStats]): + if (engine_core_output.finish_reason is None or iteration_stats is None + or req_state is None or req_state.stats is None + or self.tracer is None): + return + arrival_time_nano_seconds = int(req_state.stats.arrival_time * 1e9) + + trace_context = extract_trace_context(engine_core_output.trace_headers) + with self.tracer.start_as_current_span( + "llm_request", + kind=SpanKind.SERVER, + context=trace_context, + start_time=arrival_time_nano_seconds) as span: + metrics = req_state.stats + ttft = metrics.first_token_ts - metrics.arrival_time + e2e_time = time.time() - metrics.arrival_time + # Queued interval is from first QUEUED event to first SCHEDULED + queued_time = metrics.scheduled_ts - metrics.queued_ts + + # Prefill interval is from first SCHEDULED to first NEW_TOKEN + # Any preemptions during prefill is included in the interval + prefill_time = metrics.first_token_ts - metrics.scheduled_ts + + # Decode interval is from first NEW_TOKEN to last NEW_TOKEN + # Any preemptions during decode are included + decode_time = metrics.last_token_ts - metrics.first_token_ts + + # Inference interval is from first SCHEDULED to last NEW_TOKEN + # Any preemptions during prefill or decode are included + inference_time = metrics.last_token_ts - metrics.scheduled_ts + span.set_attribute(SpanAttributes.GEN_AI_RESPONSE_MODEL, + self.tokenizer.tokenizer_id) + span.set_attribute(SpanAttributes.GEN_AI_REQUEST_ID, + req_state.request_id) + span.set_attribute(SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS, + req_state.max_tokens_param) + span.set_attribute(SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS, + len(req_state.prompt_token_ids)) + span.set_attribute(SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS, + metrics.num_generation_tokens) + span.set_attribute(SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE, + metrics.queued_ts - metrics.arrival_time) + span.set_attribute( + SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN, ttft) + span.set_attribute(SpanAttributes.GEN_AI_LATENCY_E2E, e2e_time) + span.set_attribute(SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE, + queued_time) + span.set_attribute( + SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_PREFILL, + prefill_time) + span.set_attribute( + SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_DECODE, + decode_time) + span.set_attribute( + SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_INFERENCE, + inference_time) + def _update_stats_from_output(self, req_state: RequestState, engine_core_output: EngineCoreOutput, engine_core_timestamp: Optional[float], diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 7e7703df2cf1..dafb4bc4a953 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -225,8 +225,6 @@ def process_inputs( # TODO(woosuk): Support encoder-decoder models. self._validate_lora(lora_request) self._validate_params(params, lora_request) - if trace_headers is not None: - raise ValueError("V1 does not support tracing yet.") if prompt_adapter_request is not None: raise ValueError("V1 does not support prompt_adapter_request.") diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 9b96f4599f92..a78099e3bf66 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -3,6 +3,7 @@ import enum import time +from collections.abc import Mapping from typing import TYPE_CHECKING, Any, Optional, Union from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange @@ -36,6 +37,7 @@ def __init__( structured_output_request: Optional["StructuredOutputRequest"] = None, cache_salt: Optional[str] = None, priority: int = 0, + trace_headers: Optional[Mapping[str, str]] = None, ) -> None: self.request_id = request_id self.client_index = client_index @@ -98,7 +100,8 @@ def __init__( # they should also be updated simultaneously. self.output_token_ids = ConstantList(self._output_token_ids) self.all_token_ids = ConstantList(self._all_token_ids) - + # trace_headers + self.trace_headers = trace_headers # State # The number of tokens with prefix cache hits. self.num_cached_tokens = -1 @@ -131,6 +134,7 @@ def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request": if request.sampling_params else None, cache_salt=request.cache_salt, priority=request.priority, + trace_headers=request.trace_headers, ) def append_output_token_ids(