Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1412,12 +1412,6 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
recommend_to_remove=False)
return False

# No OTLP observability so far.
if (self.otlp_traces_endpoint or self.collect_detailed_traces):
_raise_or_fallback(feature_name="--otlp-traces-endpoint",
recommend_to_remove=False)
return False

# V1 supports N-gram, Medusa, and Eagle speculative decoding.
is_ngram_enabled = False
is_eagle_enabled = False
Expand Down
9 changes: 8 additions & 1 deletion vllm/v1/engine/async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.tracing import init_tracer
from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value)
from vllm.transformers_utils.tokenizer import AnyTokenizer
Expand Down Expand Up @@ -91,6 +92,7 @@ def __init__(

self.model_config = vllm_config.model_config
self.vllm_config = vllm_config
self.observability_config = vllm_config.observability_config
self.log_requests = log_requests
self.log_stats = log_stats

Expand Down Expand Up @@ -118,6 +120,11 @@ def __init__(
# OutputProcessor (converts EngineCoreOutputs --> RequestOutput).
self.output_processor = OutputProcessor(self.tokenizer,
log_stats=self.log_stats)
if self.observability_config.otlp_traces_endpoint is not None:
tracer = init_tracer(
"vllm.llm_engine",
self.observability_config.otlp_traces_endpoint)
self.output_processor.tracer = tracer

# EngineCore (starts the engine in background process).

Expand Down Expand Up @@ -539,7 +546,7 @@ async def get_tokenizer(
return self.tokenizer.get_lora_tokenizer(lora_request)

async def is_tracing_enabled(self) -> bool:
return False
return self.observability_config.otlp_traces_endpoint is not None

async def do_log_stats(
self,
Expand Down
84 changes: 29 additions & 55 deletions vllm/v1/engine/output_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
from vllm.outputs import (CompletionOutput, PoolingOutput,
PoolingRequestOutput, RequestOutput)
from vllm.sampling_params import RequestOutputKind
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
init_tracer)
from vllm.tracing import (Tracer, SpanAttributes, SpanKind, extract_trace_context)
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason
Expand Down Expand Up @@ -288,13 +287,7 @@ def __init__(self,
self.parent_requests: dict[str, ParentRequest] = {}
self.lora_states = LoRARequestStates()
self.observability_config = observability_config

self.tracer = None
if (self.observability_config is not None
and self.observability_config.otlp_traces_endpoint):
self.tracer = init_tracer(
"vllm.llm_engine",
self.observability_config.otlp_traces_endpoint)
self.tracer: Optional[Tracer] = None

def is_tracing_enabled(self) -> bool:
return self.tracer is not None
Expand Down Expand Up @@ -446,72 +439,53 @@ def process_outputs(
# Track per-request stats
self._update_stats_from_finished(req_state, finish_reason,
iteration_stats)

if self.tracer:
self.do_tracing(engine_core_output, req_state, iteration_stats)
self.lora_states.update_iteration_stats(iteration_stats)

return OutputProcessorOutput(
request_outputs=request_outputs,
reqs_to_abort=reqs_to_abort,
)

def do_tracing(self, engine_core_output: EngineCoreOutput,
req_state: RequestState,
iteration_stats: Optional[IterationStats]):
if (engine_core_output.finish_reason is None or iteration_stats is None
or req_state is None or req_state.stats is None
or self.tracer is None):
return
arrival_time_nano_seconds = int(req_state.stats.arrival_time * 1e9)
def do_tracing(self,
engine_core_output: EngineCoreOutput,
req_state: RequestState,
iteration_stats: Optional[IterationStats]) -> None:
assert req_state.stats is not None
assert iteration_stats is not None

arrival_time_nano_seconds = int(req_state.stats.arrival_time * 1e9)
trace_context = extract_trace_context(engine_core_output.trace_headers)
with self.tracer.start_as_current_span(
"llm_request",
kind=SpanKind.SERVER,
context=trace_context,
start_time=arrival_time_nano_seconds) as span:
metrics = req_state.stats
ttft = metrics.first_token_ts - metrics.arrival_time
e2e_time = time.time() - metrics.arrival_time
# Queued interval is from first QUEUED event to first SCHEDULED
e2e_time = iteration_stats.iteration_timestamp - metrics.arrival_time
queued_time = metrics.scheduled_ts - metrics.queued_ts

# Prefill interval is from first SCHEDULED to first NEW_TOKEN
# Any preemptions during prefill is included in the interval
prefill_time = metrics.first_token_ts - metrics.scheduled_ts

# Decode interval is from first NEW_TOKEN to last NEW_TOKEN
# Any preemptions during decode are included
decode_time = metrics.last_token_ts - metrics.first_token_ts

# Inference interval is from first SCHEDULED to last NEW_TOKEN
# Any preemptions during prefill or decode are included
inference_time = metrics.last_token_ts - metrics.scheduled_ts
span.set_attribute(SpanAttributes.GEN_AI_RESPONSE_MODEL,
self.tokenizer.tokenizer_id)
span.set_attribute(SpanAttributes.GEN_AI_REQUEST_ID,
req_state.request_id)
span.set_attribute(SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS,
req_state.max_tokens_param)
span.set_attribute(SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS,
len(req_state.prompt_token_ids))
span.set_attribute(SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS,
metrics.num_generation_tokens)
span.set_attribute(SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE,
metrics.queued_ts - metrics.arrival_time)
span.set_attribute(
SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN, ttft)
span.set_attribute(SpanAttributes.GEN_AI_LATENCY_TIME_TO_FIRST_TOKEN, metrics.first_token_latency)
span.set_attribute(SpanAttributes.GEN_AI_LATENCY_E2E, e2e_time)
span.set_attribute(SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE,
queued_time)
span.set_attribute(
SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_PREFILL,
prefill_time)
span.set_attribute(
SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_DECODE,
decode_time)
span.set_attribute(
SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_INFERENCE,
inference_time)
span.set_attribute(SpanAttributes.GEN_AI_LATENCY_TIME_IN_QUEUE, queued_time)
span.set_attribute(SpanAttributes.GEN_AI_USAGE_PROMPT_TOKENS, len(req_state.prompt_token_ids))
span.set_attribute(SpanAttributes.GEN_AI_USAGE_COMPLETION_TOKENS, metrics.num_generation_tokens)
span.set_attribute(SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_PREFILL, prefill_time)
span.set_attribute(SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_DECODE, decode_time)
span.set_attribute(SpanAttributes.GEN_AI_LATENCY_TIME_IN_MODEL_INFERENCE, inference_time)

# meta
span.set_attribute(SpanAttributes.GEN_AI_REQUEST_ID, req_state.request_id)
if req_state.parent_req and req_state.parent_req.sampling_params:
span.set_attribute(SpanAttributes.GEN_AI_REQUEST_TOP_P, req_state.parent_req.sampling_params.top_p)
span.set_attribute(SpanAttributes.GEN_AI_REQUEST_MAX_TOKENS,
req_state.parent_req.sampling_params.max_tokens)
span.set_attribute(SpanAttributes.GEN_AI_REQUEST_TEMPERATURE,
req_state.parent_req.sampling_params.temperature)
span.set_attribute(SpanAttributes.GEN_AI_REQUEST_N, req_state.parent_req.sampling_params.n)

def _update_stats_from_output(self, req_state: RequestState,
engine_core_output: EngineCoreOutput,
Expand Down
1 change: 1 addition & 0 deletions vllm/v1/engine/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,7 @@ def process_inputs(
cache_salt=decoder_inputs.get("cache_salt"),
priority=priority,
data_parallel_rank=data_parallel_rank,
trace_headers=trace_headers,
)

def _validate_model_inputs(self,
Expand Down
4 changes: 4 additions & 0 deletions vllm/v1/metrics/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ class RequestStateStats:
first_token_ts: float = 0.0
last_token_ts: float = 0.0

# first token latency
first_token_latency: float = 0.0


@dataclass
class FinishedRequestStats:
Expand Down Expand Up @@ -112,6 +115,7 @@ def update_from_output(self, output: "EngineCoreOutput",

first_token_latency = self._time_since(req_stats.arrival_time)
self.time_to_first_tokens_iter.append(first_token_latency)
req_stats.first_token_latency = first_token_latency

req_stats.num_generation_tokens += num_new_generation_tokens

Expand Down