From 925cf5ac90e2289ebfa86335667ebd31a3a96cbd Mon Sep 17 00:00:00 2001 From: Zijing Liu Date: Wed, 12 Mar 2025 00:03:48 -0700 Subject: [PATCH 1/4] [V1][Metrics] Allow AsyncLLM engine to use custom stat logger Signed-off-by: Zijing Liu Signed-off-by: Mark McLoughlin Signed-off-by: Zijing Liu --- tests/v1/engine/test_async_llm.py | 33 ++++++++++++++++++++ vllm/v1/engine/async_llm.py | 51 +++++++++++++++++++------------ vllm/v1/engine/llm_engine.py | 16 +++++----- vllm/v1/metrics/loggers.py | 41 ++++++++++++++++++++++++- 4 files changed, 113 insertions(+), 28 deletions(-) diff --git a/tests/v1/engine/test_async_llm.py b/tests/v1/engine/test_async_llm.py index da0639678af8..5d52ad5f5328 100644 --- a/tests/v1/engine/test_async_llm.py +++ b/tests/v1/engine/test_async_llm.py @@ -3,16 +3,19 @@ import asyncio from contextlib import ExitStack from typing import Optional +from unittest.mock import MagicMock import pytest from vllm import SamplingParams from vllm.assets.image import ImageAsset +from vllm.config import VllmConfig from vllm.engine.arg_utils import AsyncEngineArgs from vllm.inputs import PromptType from vllm.platforms import current_platform from vllm.sampling_params import RequestOutputKind from vllm.v1.engine.async_llm import AsyncLLM +from vllm.v1.metrics.loggers import LoggingStatLogger if not current_platform.is_cuda(): pytest.skip(reason="V1 currently only supported on CUDA.", @@ -216,3 +219,33 @@ async def test_finished_flag(monkeypatch: pytest.MonkeyPatch, n: int, # Assert only the last output has the finished flag set assert all(not out.finished for out in outputs[:-1]) assert outputs[-1].finished + + +class MockLoggingStatLogger(LoggingStatLogger): + + def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): + super().__init__(vllm_config, engine_index) + self.log = MagicMock() + + +@pytest.mark.asyncio +async def test_customize_loggers(monkeypatch): + """Test that we can customize the loggers. + If a customized logger is provided at the init, it should + be used directly. + """ + + with monkeypatch.context() as m, ExitStack() as after: + m.setenv("VLLM_USE_V1", "1") + + engine = AsyncLLM.from_engine_args( + TEXT_ENGINE_ARGS, + stat_loggers=[MockLoggingStatLogger], + ) + after.callback(engine.shutdown) + + await engine.do_log_stats() + + assert len(engine.stat_loggers) == 1 + assert len(engine.stat_loggers[0]) == 1 + engine.stat_loggers[0][0].log.assert_called_once() diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index c33535b3d360..e643a22fefff 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 import asyncio -import logging from collections.abc import AsyncGenerator, Mapping from copy import copy from typing import Optional, Union @@ -33,8 +32,7 @@ from vllm.v1.engine.parallel_sampling import ParentRequest from vllm.v1.engine.processor import Processor from vllm.v1.executor.abstract import Executor -from vllm.v1.metrics.loggers import (LoggingStatLogger, PrometheusStatLogger, - StatLoggerBase) +from vllm.v1.metrics.loggers import StatLoggerBase, setup_default_loggers from vllm.v1.metrics.stats import IterationStats, SchedulerStats logger = init_logger(__name__) @@ -52,7 +50,28 @@ def __init__( use_cached_outputs: bool = False, log_requests: bool = True, start_engine_loop: bool = True, + stat_loggers: Optional[list[type[StatLoggerBase]]] = None, ) -> None: + """ + Create an AsyncLLM. + + Args: + vllm_config: global configuration. + executor_class: an Executor impl, e.g. MultiprocExecutor. + log_stats: Whether to log stats. + usage_context: Usage context of the LLM. + mm_registry: Multi-modal registry. + use_cached_outputs: Whether to use cached outputs. + log_requests: Whether to log requests. + start_engine_loop: Whether to start the engine loop. + stat_loggers: customized stat loggers for the engine. + If not provided, default stat loggers will be used. + PLEASE BE AWARE THAT STAT LOGGER IS NOT STABLE + IN V1, AND ITS BASE CLASS INTERFACE MIGHT CHANGE. + + Returns: + None + """ if not envs.VLLM_USE_V1: raise ValueError( "Using V1 AsyncLLMEngine, but envs.VLLM_USE_V1=False. " @@ -66,15 +85,12 @@ def __init__( self.log_stats = log_stats # Set up stat loggers; independent set for each DP rank. - self.stat_loggers: list[list[StatLoggerBase]] = [] - if self.log_stats: - for i in range(vllm_config.parallel_config.data_parallel_size): - loggers: list[StatLoggerBase] = [] - if logger.isEnabledFor(logging.INFO): - loggers.append(LoggingStatLogger(engine_index=i)) - loggers.append( - PrometheusStatLogger(vllm_config, engine_index=i)) - self.stat_loggers.append(loggers) + self.stat_loggers: list[list[StatLoggerBase]] = setup_default_loggers( + vllm_config=vllm_config, + log_stats=self.log_stats, + engine_num=vllm_config.parallel_config.data_parallel_size, + custom_stat_loggers=stat_loggers, + ) # Tokenizer (+ ensure liveness if running in another process). self.tokenizer = init_tokenizer_from_configs( @@ -118,7 +134,7 @@ def from_vllm_config( vllm_config: VllmConfig, start_engine_loop: bool = True, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, - stat_loggers: Optional[dict[str, StatLoggerBase]] = None, + stat_loggers: Optional[list[type[StatLoggerBase]]] = None, disable_log_requests: bool = False, disable_log_stats: bool = False, ) -> "AsyncLLM": @@ -129,17 +145,12 @@ def from_vllm_config( "AsyncLLMEngine.from_vllm_config(...) or explicitly set " "VLLM_USE_V1=0 or 1 and report this issue on Github.") - # FIXME(rob): refactor VllmConfig to include the StatLoggers - # include StatLogger in the Oracle decision. - if stat_loggers is not None: - raise ValueError("Custom StatLoggers are not yet supported on V1. " - "Explicitly set VLLM_USE_V1=0 to disable V1.") - # Create the LLMEngine. return cls( vllm_config=vllm_config, executor_class=Executor.get_class(vllm_config), start_engine_loop=start_engine_loop, + stat_loggers=stat_loggers, log_requests=not disable_log_requests, log_stats=not disable_log_stats, usage_context=usage_context, @@ -151,6 +162,7 @@ def from_engine_args( engine_args: AsyncEngineArgs, start_engine_loop: bool = True, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, + stat_loggers: Optional[list[type[StatLoggerBase]]] = None, ) -> "AsyncLLM": """Create an AsyncLLM from the EngineArgs.""" @@ -166,6 +178,7 @@ def from_engine_args( log_stats=not engine_args.disable_log_stats, start_engine_loop=start_engine_loop, usage_context=usage_context, + stat_loggers=stat_loggers, ) def __del__(self): diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index a07595a552af..74811c4d547d 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -43,7 +43,7 @@ def __init__( executor_class: type[Executor], log_stats: bool, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, - stat_loggers: Optional[dict[str, StatLoggerBase]] = None, + stat_loggers: Optional[list[type[StatLoggerBase]]] = None, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, use_cached_outputs: bool = False, multiprocess_mode: bool = False, @@ -55,6 +55,11 @@ def __init__( "LLMEngine.from_vllm_config(...) or explicitly set " "VLLM_USE_V1=0 or 1 and report this issue on Github.") + if stat_loggers is not None: + raise NotImplementedError( + "Passing StatLoggers to LLMEngine in V1 is not yet supported. " + "Set VLLM_USE_V1=0 and file and issue on Github.") + self.vllm_config = vllm_config self.model_config = vllm_config.model_config self.cache_config = vllm_config.cache_config @@ -101,14 +106,9 @@ def from_vllm_config( cls, vllm_config: VllmConfig, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, - stat_loggers: Optional[dict[str, StatLoggerBase]] = None, + stat_loggers: Optional[list[type[StatLoggerBase]]] = None, disable_log_stats: bool = False, ) -> "LLMEngine": - if stat_loggers is not None: - raise NotImplementedError( - "Passing StatLoggers to V1 is not yet supported. " - "Set VLLM_USE_V1=0 and file and issue on Github.") - return cls(vllm_config=vllm_config, executor_class=Executor.get_class(vllm_config), log_stats=(not disable_log_stats), @@ -121,7 +121,7 @@ def from_engine_args( cls, engine_args: EngineArgs, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, - stat_loggers: Optional[dict[str, StatLoggerBase]] = None, + stat_loggers: Optional[list[type[StatLoggerBase]]] = None, enable_multiprocessing: bool = False, ) -> "LLMEngine": """Creates an LLM engine from the engine arguments.""" diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py index 547e60467632..471995d996d7 100644 --- a/vllm/v1/metrics/loggers.py +++ b/vllm/v1/metrics/loggers.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 +import logging import time from abc import ABC, abstractmethod from typing import Optional @@ -20,6 +21,16 @@ class StatLoggerBase(ABC): + """Interface for logging metrics. + + API users may define custom loggers that implement this interface. + However, note that the `SchedulerStats` and `IterationStats` classes + are not considered stable interfaces and may change in future versions. + """ + + @abstractmethod + def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): + ... @abstractmethod def record(self, scheduler_stats: SchedulerStats, @@ -32,7 +43,7 @@ def log(self): # noqa class LoggingStatLogger(StatLoggerBase): - def __init__(self, engine_index: int = 0): + def __init__(self, vllm_config: VllmConfig, engine_index: int = 0): self.engine_index = engine_index self._reset(time.monotonic()) self.last_scheduler_stats = SchedulerStats() @@ -462,3 +473,31 @@ def build_cudagraph_buckets(vllm_config: VllmConfig) -> list[int]: return buckets else: return [1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8096] + + +def setup_default_loggers( + vllm_config: VllmConfig, + log_stats: bool, + engine_num: int, + custom_stat_loggers: Optional[list[type[StatLoggerBase]]] = None, +) -> list[list[StatLoggerBase]]: + """Setup logging and prometheus metrics.""" + if not log_stats: + return [] + + stat_loggers: list[list[StatLoggerBase]] = [] + for i in range(engine_num): + per_engine_stat_loggers: list[StatLoggerBase] = [] + if custom_stat_loggers is not None: + for logger_cls in custom_stat_loggers: + per_engine_stat_loggers.append( + logger_cls(vllm_config, engine_index=i)) + else: + per_engine_stat_loggers.append( + PrometheusStatLogger(vllm_config, engine_index=i)) + if logger.isEnabledFor(logging.INFO): + per_engine_stat_loggers.append( + LoggingStatLogger(vllm_config, engine_index=i)) + stat_loggers.append(per_engine_stat_loggers) + + return stat_loggers From 1aa7283886adb8cc43ed5cae0bf857c39a3c396b Mon Sep 17 00:00:00 2001 From: Mark McLoughlin Date: Wed, 23 Apr 2025 09:42:08 -0400 Subject: [PATCH 2/4] [V1][Metrics] Define StatLoggerFactory type More general than passing sub-class types, requiring sub-class constructors to conform to a protocol. Signed-off-by: Mark McLoughlin Signed-off-by: Zijing Liu --- vllm/v1/engine/async_llm.py | 9 +++++---- vllm/v1/engine/llm_engine.py | 8 ++++---- vllm/v1/metrics/loggers.py | 28 ++++++++++++++++------------ 3 files changed, 25 insertions(+), 20 deletions(-) diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index e643a22fefff..a1eb5c8ba185 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -32,7 +32,8 @@ from vllm.v1.engine.parallel_sampling import ParentRequest from vllm.v1.engine.processor import Processor from vllm.v1.executor.abstract import Executor -from vllm.v1.metrics.loggers import StatLoggerBase, setup_default_loggers +from vllm.v1.metrics.loggers import (StatLoggerBase, StatLoggerFactory, + setup_default_loggers) from vllm.v1.metrics.stats import IterationStats, SchedulerStats logger = init_logger(__name__) @@ -50,7 +51,7 @@ def __init__( use_cached_outputs: bool = False, log_requests: bool = True, start_engine_loop: bool = True, - stat_loggers: Optional[list[type[StatLoggerBase]]] = None, + stat_loggers: Optional[list[StatLoggerFactory]] = None, ) -> None: """ Create an AsyncLLM. @@ -134,7 +135,7 @@ def from_vllm_config( vllm_config: VllmConfig, start_engine_loop: bool = True, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, - stat_loggers: Optional[list[type[StatLoggerBase]]] = None, + stat_loggers: Optional[list[StatLoggerFactory]] = None, disable_log_requests: bool = False, disable_log_stats: bool = False, ) -> "AsyncLLM": @@ -162,7 +163,7 @@ def from_engine_args( engine_args: AsyncEngineArgs, start_engine_loop: bool = True, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, - stat_loggers: Optional[list[type[StatLoggerBase]]] = None, + stat_loggers: Optional[list[StatLoggerFactory]] = None, ) -> "AsyncLLM": """Create an AsyncLLM from the EngineArgs.""" diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 74811c4d547d..ac2ee065f09f 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -10,7 +10,6 @@ from vllm.config import ParallelConfig, VllmConfig from vllm.distributed import stateless_destroy_torch_distributed_process_group from vllm.engine.arg_utils import EngineArgs -from vllm.engine.metrics_types import StatLoggerBase from vllm.inputs import PromptType from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -28,6 +27,7 @@ from vllm.v1.engine.parallel_sampling import ParentRequest from vllm.v1.engine.processor import Processor from vllm.v1.executor.abstract import Executor +from vllm.v1.metrics.loggers import StatLoggerFactory logger = init_logger(__name__) @@ -43,7 +43,7 @@ def __init__( executor_class: type[Executor], log_stats: bool, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, - stat_loggers: Optional[list[type[StatLoggerBase]]] = None, + stat_loggers: Optional[list[StatLoggerFactory]] = None, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, use_cached_outputs: bool = False, multiprocess_mode: bool = False, @@ -106,7 +106,7 @@ def from_vllm_config( cls, vllm_config: VllmConfig, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, - stat_loggers: Optional[list[type[StatLoggerBase]]] = None, + stat_loggers: Optional[list[StatLoggerFactory]] = None, disable_log_stats: bool = False, ) -> "LLMEngine": return cls(vllm_config=vllm_config, @@ -121,7 +121,7 @@ def from_engine_args( cls, engine_args: EngineArgs, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, - stat_loggers: Optional[list[type[StatLoggerBase]]] = None, + stat_loggers: Optional[list[StatLoggerFactory]] = None, enable_multiprocessing: bool = False, ) -> "LLMEngine": """Creates an LLM engine from the engine arguments.""" diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py index 471995d996d7..2587a6d6275e 100644 --- a/vllm/v1/metrics/loggers.py +++ b/vllm/v1/metrics/loggers.py @@ -3,7 +3,7 @@ import logging import time from abc import ABC, abstractmethod -from typing import Optional +from typing import Callable, Optional import numpy as np import prometheus_client @@ -19,6 +19,8 @@ _LOCAL_LOGGING_INTERVAL_SEC = 5.0 +StatLoggerFactory = Callable[[VllmConfig, int], "StatLoggerBase"] + class StatLoggerBase(ABC): """Interface for logging metrics. @@ -479,25 +481,27 @@ def setup_default_loggers( vllm_config: VllmConfig, log_stats: bool, engine_num: int, - custom_stat_loggers: Optional[list[type[StatLoggerBase]]] = None, + custom_stat_loggers: Optional[list[StatLoggerFactory]] = None, ) -> list[list[StatLoggerBase]]: """Setup logging and prometheus metrics.""" if not log_stats: return [] + def logging_factory(vllm_config: VllmConfig, engine_num: int): + return (LoggingStatLogger(vllm_config, engine_num) + if logger.isEnabledFor(logging.INFO) else None) + + factories: list[StatLoggerFactory] = [ + PrometheusStatLogger, logging_factory + ] + if custom_stat_loggers is not None: + factories = custom_stat_loggers + stat_loggers: list[list[StatLoggerBase]] = [] for i in range(engine_num): per_engine_stat_loggers: list[StatLoggerBase] = [] - if custom_stat_loggers is not None: - for logger_cls in custom_stat_loggers: - per_engine_stat_loggers.append( - logger_cls(vllm_config, engine_index=i)) - else: - per_engine_stat_loggers.append( - PrometheusStatLogger(vllm_config, engine_index=i)) - if logger.isEnabledFor(logging.INFO): - per_engine_stat_loggers.append( - LoggingStatLogger(vllm_config, engine_index=i)) + for logger_factory in factories: + per_engine_stat_loggers.append(logger_factory(vllm_config, i)) stat_loggers.append(per_engine_stat_loggers) return stat_loggers From f38f89315081693045a3e11422ef49196e2be3c0 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Wed, 23 Apr 2025 09:05:34 -0700 Subject: [PATCH 3/4] Update vllm/v1/metrics/loggers.py Signed-off-by: Zijing Liu --- vllm/v1/metrics/loggers.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py index 2587a6d6275e..7de02d43e165 100644 --- a/vllm/v1/metrics/loggers.py +++ b/vllm/v1/metrics/loggers.py @@ -487,15 +487,12 @@ def setup_default_loggers( if not log_stats: return [] - def logging_factory(vllm_config: VllmConfig, engine_num: int): - return (LoggingStatLogger(vllm_config, engine_num) - if logger.isEnabledFor(logging.INFO) else None) - - factories: list[StatLoggerFactory] = [ - PrometheusStatLogger, logging_factory - ] if custom_stat_loggers is not None: factories = custom_stat_loggers + else: + factories: list[StatLoggerFactory] = [PrometheusStatLogger] + if logger.isEnabledFor(logging.INFO): + factories.append(LoggingStatLogger) stat_loggers: list[list[StatLoggerBase]] = [] for i in range(engine_num): From 303e7ef5596aa71fbb787be3921b3ca334e8f968 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Thu, 24 Apr 2025 08:29:10 -0700 Subject: [PATCH 4/4] Fix typing Signed-off-by: Nick Hill Signed-off-by: Zijing Liu --- vllm/v1/metrics/loggers.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/v1/metrics/loggers.py b/vllm/v1/metrics/loggers.py index 7de02d43e165..22d1d9724c8c 100644 --- a/vllm/v1/metrics/loggers.py +++ b/vllm/v1/metrics/loggers.py @@ -487,10 +487,11 @@ def setup_default_loggers( if not log_stats: return [] + factories: list[StatLoggerFactory] if custom_stat_loggers is not None: factories = custom_stat_loggers else: - factories: list[StatLoggerFactory] = [PrometheusStatLogger] + factories = [PrometheusStatLogger] if logger.isEnabledFor(logging.INFO): factories.append(LoggingStatLogger)