Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions vllm/entrypoints/chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,12 @@
# pydantic needs the TypedDict from typing_extensions
from typing_extensions import Required, TypedDict

from vllm import envs
from vllm.config import ModelConfig
from vllm.logger import init_logger
from vllm.model_executor.models import SupportsMultiModal
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict, MultiModalUUIDDict
from vllm.multimodal.utils import MediaConnector
from vllm.multimodal.utils import MEDIA_CONNECTOR_REGISTRY, MediaConnector
from vllm.transformers_utils.chat_templates import get_chat_template_fallback_path
from vllm.transformers_utils.processor import cached_get_processor
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
Expand Down Expand Up @@ -806,7 +807,9 @@ def __init__(self, tracker: MultiModalItemTracker) -> None:
self._tracker = tracker
multimodal_config = self._tracker.model_config.multimodal_config
media_io_kwargs = getattr(multimodal_config, "media_io_kwargs", None)
self._connector = MediaConnector(

self._connector: MediaConnector = MEDIA_CONNECTOR_REGISTRY.load(
envs.VLLM_MEDIA_CONNECTOR,
media_io_kwargs=media_io_kwargs,
allowed_local_media_path=tracker.allowed_local_media_path,
allowed_media_domains=tracker.allowed_media_domains,
Expand Down Expand Up @@ -891,7 +894,8 @@ def __init__(self, tracker: AsyncMultiModalItemTracker) -> None:
self._tracker = tracker
multimodal_config = self._tracker.model_config.multimodal_config
media_io_kwargs = getattr(multimodal_config, "media_io_kwargs", None)
self._connector = MediaConnector(
self._connector: MediaConnector = MEDIA_CONNECTOR_REGISTRY.load(
envs.VLLM_MEDIA_CONNECTOR,
media_io_kwargs=media_io_kwargs,
allowed_local_media_path=tracker.allowed_local_media_path,
allowed_media_domains=tracker.allowed_media_domains,
Expand Down
9 changes: 9 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
VLLM_MEDIA_LOADING_THREAD_COUNT: int = 8
VLLM_MAX_AUDIO_CLIP_FILESIZE_MB: int = 25
VLLM_VIDEO_LOADER_BACKEND: str = "opencv"
VLLM_MEDIA_CONNECTOR: str = "http"
VLLM_MM_INPUT_CACHE_GIB: int = 4
VLLM_TARGET_DEVICE: str = "cuda"
VLLM_MAIN_CUDA_VERSION: str = "12.8"
Expand Down Expand Up @@ -738,6 +739,14 @@ def get_vllm_port() -> int | None:
"VLLM_VIDEO_LOADER_BACKEND": lambda: os.getenv(
"VLLM_VIDEO_LOADER_BACKEND", "opencv"
),
# Media connector implementation.
# - "http": Default connector that supports fetching media via HTTP.
#
# Custom implementations can be registered
# via `@MEDIA_CONNECTOR_REGISTRY.register("my_custom_media_connector")` and
# imported at runtime.
# If a non-existing backend is used, an AssertionError will be thrown.
"VLLM_MEDIA_CONNECTOR": lambda: os.getenv("VLLM_MEDIA_CONNECTOR", "http"),
# [DEPRECATED] Cache size (in GiB per process) for multimodal input cache
# Default is 4 GiB per API process + 4 GiB per engine core process
"VLLM_MM_INPUT_CACHE_GIB": lambda: int(os.getenv("VLLM_MM_INPUT_CACHE_GIB", "4")),
Expand Down
4 changes: 4 additions & 0 deletions vllm/multimodal/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from vllm.connections import HTTPConnection, global_http_connection
from vllm.logger import init_logger
from vllm.utils.jsontree import json_map_leaves
from vllm.utils.registry import ExtensionManager

from .audio import AudioMediaIO
from .base import MediaIO
Expand All @@ -46,7 +47,10 @@

_M = TypeVar("_M")

MEDIA_CONNECTOR_REGISTRY = ExtensionManager()


@MEDIA_CONNECTOR_REGISTRY.register("http")
class MediaConnector:
def __init__(
self,
Expand Down
21 changes: 2 additions & 19 deletions vllm/multimodal/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from vllm import envs
from vllm.logger import init_logger
from vllm.utils.registry import ExtensionManager

from .base import MediaIO
from .image import ImageMediaIO
Expand Down Expand Up @@ -63,25 +64,7 @@ def load_bytes(
raise NotImplementedError


class VideoLoaderRegistry:
def __init__(self) -> None:
self.name2class: dict[str, type] = {}

def register(self, name: str):
def wrap(cls_to_register):
self.name2class[name] = cls_to_register
return cls_to_register

return wrap

@staticmethod
def load(cls_name: str) -> VideoLoader:
cls = VIDEO_LOADER_REGISTRY.name2class.get(cls_name)
assert cls is not None, f"VideoLoader class {cls_name} not found"
return cls()


VIDEO_LOADER_REGISTRY = VideoLoaderRegistry()
VIDEO_LOADER_REGISTRY = ExtensionManager()


@VIDEO_LOADER_REGISTRY.register("opencv")
Expand Down
49 changes: 49 additions & 0 deletions vllm/utils/registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any


class ExtensionManager:
"""
A registry for managing pluggable extension classes.

This class provides a simple mechanism to register and instantiate
extension classes by name. It is commonly used to implement plugin
systems where different implementations can be swapped at runtime.

Examples:
Basic usage with a registry instance:

>>> FOO_REGISTRY = ExtensionManager()
>>> @FOO_REGISTRY.register("my_foo_impl")
... class MyFooImpl(Foo):
... def __init__(self, value):
... self.value = value
>>> foo_impl = FOO_REGISTRY.load("my_foo_impl", value=123)

"""

def __init__(self) -> None:
"""
Initialize an empty extension registry.
"""
self.name2class: dict[str, type] = {}

def register(self, name: str):
"""
Decorator to register a class with the given name.
"""

def wrap(cls_to_register):
self.name2class[name] = cls_to_register
return cls_to_register

return wrap

def load(self, cls_name: str, *args, **kwargs) -> Any:
"""
Instantiate and return a registered extension class by name.
"""
cls = self.name2class.get(cls_name)
assert cls is not None, f"Extension class {cls_name} not found"
return cls(*args, **kwargs)