Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1020,6 +1020,11 @@ steps:
- pytest -v -s plugins_tests/test_io_processor_plugins.py
- pip uninstall prithvi_io_processor_plugin -y
# end io_processor plugins test
# begin stat_logger plugins test
- pip install -e ./plugins/vllm_add_dummy_stat_logger
- pytest -v -s plugins_tests/test_stats_logger_plugins.py
- pip uninstall dummy_stat_logger -y
# end stat_logger plugins test
# other tests continue here:
- pytest -v -s plugins_tests/test_scheduler_plugins.py
- pip install -e ./plugins/vllm_add_dummy_model
Expand Down
4 changes: 3 additions & 1 deletion docs/design/plugin_system.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ Every plugin has three parts:

1. **Plugin group**: The name of the entry point group. vLLM uses the entry point group `vllm.general_plugins` to register general plugins. This is the key of `entry_points` in the `setup.py` file. Always use `vllm.general_plugins` for vLLM's general plugins.
2. **Plugin name**: The name of the plugin. This is the value in the dictionary of the `entry_points` dictionary. In the example above, the plugin name is `register_dummy_model`. Plugins can be filtered by their names using the `VLLM_PLUGINS` environment variable. To load only a specific plugin, set `VLLM_PLUGINS` to the plugin name.
3. **Plugin value**: The fully qualified name of the function to register in the plugin system. In the example above, the plugin value is `vllm_add_dummy_model:register`, which refers to a function named `register` in the `vllm_add_dummy_model` module.
3. **Plugin value**: The fully qualified name of the function or module to register in the plugin system. In the example above, the plugin value is `vllm_add_dummy_model:register`, which refers to a function named `register` in the `vllm_add_dummy_model` module.

## Types of supported plugins

Expand All @@ -51,6 +51,8 @@ Every plugin has three parts:

- **IO Processor plugins** (with group name `vllm.io_processor_plugins`): The primary use case for these plugins is to register custom pre/post processing of the model prompt and model output for pooling models. The plugin function returns the IOProcessor's class fully qualified name.

- **Stat logger plugins** (with group name `vllm.stat_logger_plugins`): The primary use case for these plugins is to register custom, out-of-the-tree loggers into vLLM. The entry point should be a class that subclasses StatLoggerBase.

## Guidelines for Writing Plugins

- **Being re-entrant**: The function specified in the entry point should be re-entrant, meaning it can be called multiple times without causing issues. This is necessary because the function might be called multiple times in some processes.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from vllm.v1.metrics.loggers import StatLoggerBase


class DummyStatLogger(StatLoggerBase):
"""
A dummy stat logger for testing purposes.
Implements the minimal interface expected by StatLoggerManager.
"""

def __init__(self, vllm_config, engine_idx=0):
self.vllm_config = vllm_config
self.engine_idx = engine_idx
self.recorded = []
self.logged = False
self.engine_initialized = False

def record(self, scheduler_stats, iteration_stats, mm_cache_stats, engine_idx):
self.recorded.append(
(scheduler_stats, iteration_stats, mm_cache_stats, engine_idx)
)

def log(self):
self.logged = True

def log_engine_initialized(self):
self.engine_initialized = True
15 changes: 15 additions & 0 deletions tests/plugins/vllm_add_dummy_stat_logger/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from setuptools import setup

setup(
name="dummy_stat_logger",
version="0.1",
packages=["dummy_stat_logger"],
entry_points={
"vllm.stat_logger_plugins": [
"dummy_stat_logger = dummy_stat_logger.dummy_stat_logger:DummyStatLogger" # noqa
]
},
)
76 changes: 76 additions & 0 deletions tests/plugins_tests/test_stats_logger_plugins.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import pytest
from dummy_stat_logger.dummy_stat_logger import DummyStatLogger

from vllm.config import VllmConfig
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.v1.engine.async_llm import AsyncLLM
from vllm.v1.metrics.loggers import load_stat_logger_plugin_factories


def test_stat_logger_plugin_is_discovered(monkeypatch: pytest.MonkeyPatch):
with monkeypatch.context() as m:
m.setenv("VLLM_PLUGINS", "dummy_stat_logger")

factories = load_stat_logger_plugin_factories()
assert len(factories) == 1, f"Expected 1 factory, got {len(factories)}"
assert factories[0] is DummyStatLogger, (
f"Expected DummyStatLogger class, got {factories[0]}"
)

# instantiate and confirm the right type
vllm_config = VllmConfig()
instance = factories[0](vllm_config)
assert isinstance(instance, DummyStatLogger)


def test_no_plugins_loaded_if_env_empty(monkeypatch: pytest.MonkeyPatch):
with monkeypatch.context() as m:
m.setenv("VLLM_PLUGINS", "")

factories = load_stat_logger_plugin_factories()
assert factories == []


def test_invalid_stat_logger_plugin_raises(monkeypatch: pytest.MonkeyPatch):
def fake_plugin_loader(group: str):
assert group == "vllm.stat_logger_plugins"
return {"bad": object()}

with monkeypatch.context() as m:
m.setattr(
"vllm.v1.metrics.loggers.load_plugins_by_group",
fake_plugin_loader,
)
with pytest.raises(
TypeError,
match="Stat logger plugin 'bad' must be a subclass of StatLoggerBase",
):
load_stat_logger_plugin_factories()


@pytest.mark.asyncio
async def test_stat_logger_plugin_integration_with_engine(
monkeypatch: pytest.MonkeyPatch,
):
with monkeypatch.context() as m:
m.setenv("VLLM_PLUGINS", "dummy_stat_logger")

engine_args = AsyncEngineArgs(
model="facebook/opt-125m",
enforce_eager=True, # reduce test time
disable_log_stats=True, # disable default loggers
)

engine = AsyncLLM.from_engine_args(engine_args=engine_args)

assert len(engine.logger_manager.stat_loggers) == 2
assert len(engine.logger_manager.stat_loggers[0].per_engine_stat_loggers) == 1
assert isinstance(
engine.logger_manager.stat_loggers[0].per_engine_stat_loggers[0],
DummyStatLogger,
)

engine.shutdown()
26 changes: 3 additions & 23 deletions tests/v1/metrics/test_engine_logger_apis.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,33 +4,13 @@

import pytest

from tests.plugins.vllm_add_dummy_stat_logger.dummy_stat_logger.dummy_stat_logger import ( # noqa E501
DummyStatLogger,
)
from vllm.v1.engine.async_llm import AsyncEngineArgs, AsyncLLM
from vllm.v1.metrics.ray_wrappers import RayPrometheusStatLogger


class DummyStatLogger:
"""
A dummy stat logger for testing purposes.
Implements the minimal interface expected by StatLoggerManager.
"""

def __init__(self, vllm_config, engine_idx):
self.vllm_config = vllm_config
self.engine_idx = engine_idx
self.recorded = []
self.logged = False
self.engine_initialized = False

def record(self, scheduler_stats, iteration_stats, engine_idx):
self.recorded.append((scheduler_stats, iteration_stats, engine_idx))

def log(self):
self.logged = True

def log_engine_initialized(self):
self.engine_initialized = True


@pytest.fixture
def log_stats_enabled_engine_args():
"""
Expand Down
21 changes: 15 additions & 6 deletions vllm/v1/engine/async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,11 @@
from vllm.v1.engine.parallel_sampling import ParentRequest
from vllm.v1.engine.processor import Processor
from vllm.v1.executor.abstract import Executor
from vllm.v1.metrics.loggers import StatLoggerFactory, StatLoggerManager
from vllm.v1.metrics.loggers import (
StatLoggerFactory,
StatLoggerManager,
load_stat_logger_plugin_factories,
)
from vllm.v1.metrics.prometheus import shutdown_prometheus
from vllm.v1.metrics.stats import IterationStats

Expand Down Expand Up @@ -100,11 +104,16 @@ def __init__(
self.observability_config = vllm_config.observability_config
self.log_requests = log_requests

self.log_stats = log_stats or (stat_loggers is not None)
if not log_stats and stat_loggers is not None:
custom_stat_loggers = list(stat_loggers or [])
custom_stat_loggers.extend(load_stat_logger_plugin_factories())

has_custom_loggers = bool(custom_stat_loggers)
self.log_stats = log_stats or has_custom_loggers
if not log_stats and has_custom_loggers:
logger.info(
"AsyncLLM created with log_stats=False and non-empty custom "
"logger list; enabling logging without default stat loggers"
"AsyncLLM created with log_stats=False, "
"but custom stat loggers were found; "
"enabling logging without default stat loggers."
)

if self.model_config.skip_tokenizer_init:
Expand Down Expand Up @@ -144,7 +153,7 @@ def __init__(
self.logger_manager = StatLoggerManager(
vllm_config=vllm_config,
engine_idxs=self.engine_core.engine_ranks_managed,
custom_stat_loggers=stat_loggers,
custom_stat_loggers=custom_stat_loggers,
enable_default_loggers=log_stats,
client_count=client_count,
aggregate_engine_logging=aggregate_engine_logging,
Expand Down
18 changes: 18 additions & 0 deletions vllm/v1/metrics/loggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from vllm.config import SupportsMetricsInfo, VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import KVConnectorLogging
from vllm.logger import init_logger
from vllm.plugins import load_plugins_by_group
from vllm.v1.engine import FinishReason
from vllm.v1.metrics.prometheus import unregister_vllm_metrics
from vllm.v1.metrics.stats import (
Expand Down Expand Up @@ -56,6 +57,23 @@ def log(self): # noqa
pass


def load_stat_logger_plugin_factories() -> list[StatLoggerFactory]:
factories: list[StatLoggerFactory] = []

for name, plugin_class in load_plugins_by_group("vllm.stat_logger_plugins").items():
if not isinstance(plugin_class, type) or not issubclass(
plugin_class, StatLoggerBase
):
raise TypeError(
f"Stat logger plugin {name!r} must be a subclass of "
f"StatLoggerBase (got {plugin_class!r})."
)

factories.append(plugin_class)

return factories


class AggregateStatLoggerBase(StatLoggerBase):
"""Abstract base class for loggers that
aggregate across multiple DP engines."""
Expand Down