From 8298842fd7e29d678e0f43927bb6afba951ed5a8 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Fri, 4 Oct 2024 18:01:37 +0800 Subject: [PATCH] [Misc] Move registry to its own file (#9064) Signed-off-by: Alvant --- docs/source/models/adding_model.rst | 2 +- tests/models/test_registry.py | 4 +- vllm/lora/models.py | 3 +- vllm/model_executor/model_loader/loader.py | 5 +- vllm/model_executor/models/__init__.py | 333 +-------------------- vllm/model_executor/models/jamba.py | 6 +- vllm/model_executor/models/registry.py | 320 ++++++++++++++++++++ vllm/worker/model_runner.py | 3 +- 8 files changed, 341 insertions(+), 335 deletions(-) create mode 100644 vllm/model_executor/models/registry.py diff --git a/docs/source/models/adding_model.rst b/docs/source/models/adding_model.rst index 5cffb58cafd96..1f220b723cacd 100644 --- a/docs/source/models/adding_model.rst +++ b/docs/source/models/adding_model.rst @@ -99,7 +99,7 @@ This method should load the weights from the HuggingFace's checkpoint file and a 5. Register your model ---------------------- -Finally, register your :code:`*ForCausalLM` class to the :code:`_MODELS` in `vllm/model_executor/models/__init__.py `_. +Finally, register your :code:`*ForCausalLM` class to the :code:`_MODELS` in `vllm/model_executor/models/registry.py `_. 6. Out-of-Tree Model Integration -------------------------------------------- diff --git a/tests/models/test_registry.py b/tests/models/test_registry.py index ee5c9e8ccb196..299aeacb9f337 100644 --- a/tests/models/test_registry.py +++ b/tests/models/test_registry.py @@ -3,13 +3,13 @@ import pytest import torch.cuda -from vllm.model_executor.models import _MODELS, ModelRegistry +from vllm.model_executor.models import ModelRegistry from vllm.platforms import current_platform from ..utils import fork_new_process_for_each_test -@pytest.mark.parametrize("model_arch", _MODELS) +@pytest.mark.parametrize("model_arch", ModelRegistry.get_supported_archs()) def test_registry_imports(model_arch): # Ensure all model classes can be imported successfully ModelRegistry.resolve_model_cls(model_arch) diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 1f80c716bc481..91e9f55e82433 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -24,8 +24,7 @@ from vllm.lora.punica import PunicaWrapper from vllm.lora.utils import (from_layer, from_layer_logits_processor, parse_fine_tuned_lora_name, replace_submodule) -from vllm.model_executor.models.interfaces import (SupportsLoRA, - supports_multimodal) +from vllm.model_executor.models import SupportsLoRA, supports_multimodal from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.utils import PPMissingLayer from vllm.utils import is_pin_memory_available diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 8fed5267a9eb5..8d4163ec88490 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -41,9 +41,8 @@ get_gguf_extra_tensor_names, get_quant_config, gguf_quant_weights_iterator, initialize_dummy_weights, np_cache_weights_iterator, pt_weights_iterator, safetensors_weights_iterator) -from vllm.model_executor.models.interfaces import (has_inner_state, - supports_lora, - supports_multimodal) +from vllm.model_executor.models import (has_inner_state, supports_lora, + supports_multimodal) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.utils import is_pin_memory_available diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 2f9cb2b760a82..51054a147a06f 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -1,325 +1,16 @@ -import importlib -import string -import subprocess -import sys -import uuid -from functools import lru_cache, partial -from typing import Callable, Dict, List, Optional, Tuple, Type, Union - -import torch.nn as nn - -from vllm.logger import init_logger -from vllm.utils import is_hip - -from .interfaces import supports_multimodal, supports_pp - -logger = init_logger(__name__) - -_GENERATION_MODELS = { - "AquilaModel": ("llama", "LlamaForCausalLM"), - "AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2 - "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"), - "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-7b - "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), # baichuan-13b - "BloomForCausalLM": ("bloom", "BloomForCausalLM"), - "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"), - "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"), - "CohereForCausalLM": ("commandr", "CohereForCausalLM"), - "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"), - "DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"), - "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"), - "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"), - "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"), - "FalconForCausalLM": ("falcon", "FalconForCausalLM"), - "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"), - "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"), - "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"), - "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"), - "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"), - "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"), - "GraniteForCausalLM": ("granite", "GraniteForCausalLM"), - "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"), - "InternLMForCausalLM": ("llama", "LlamaForCausalLM"), - "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"), - "JAISLMHeadModel": ("jais", "JAISLMHeadModel"), - "JambaForCausalLM": ("jamba", "JambaForCausalLM"), - "LlamaForCausalLM": ("llama", "LlamaForCausalLM"), - # For decapoda-research/llama-* - "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"), - "MistralForCausalLM": ("llama", "LlamaForCausalLM"), - "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"), - "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"), - # transformers's mpt class has lower case - "MptForCausalLM": ("mpt", "MPTForCausalLM"), - "MPTForCausalLM": ("mpt", "MPTForCausalLM"), - "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"), - "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"), - "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"), - "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"), - "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"), - "OPTForCausalLM": ("opt", "OPTForCausalLM"), - "OrionForCausalLM": ("orion", "OrionForCausalLM"), - "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"), - "PhiForCausalLM": ("phi", "PhiForCausalLM"), - "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"), - "Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"), - "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"), - "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), - "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"), - "Qwen2VLForConditionalGeneration": - ("qwen2_vl", "Qwen2VLForConditionalGeneration"), - "RWForCausalLM": ("falcon", "FalconForCausalLM"), - "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"), - "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"), - "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"), - "SolarForCausalLM": ("solar", "SolarForCausalLM"), - "XverseForCausalLM": ("xverse", "XverseForCausalLM"), - # NOTE: The below models are for speculative decoding only - "MedusaModel": ("medusa", "Medusa"), - "EAGLEModel": ("eagle", "EAGLE"), - "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"), -} - -_EMBEDDING_MODELS = { - "MistralModel": ("llama_embedding", "LlamaEmbeddingModel"), - "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"), -} - -_MULTIMODAL_MODELS = { - "Blip2ForConditionalGeneration": - ("blip2", "Blip2ForConditionalGeneration"), - "ChameleonForConditionalGeneration": - ("chameleon", "ChameleonForConditionalGeneration"), - "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"), - "InternVLChatModel": ("internvl", "InternVLChatModel"), - "LlavaForConditionalGeneration": ("llava", - "LlavaForConditionalGeneration"), - "LlavaNextForConditionalGeneration": ("llava_next", - "LlavaNextForConditionalGeneration"), - "LlavaNextVideoForConditionalGeneration": - ("llava_next_video", "LlavaNextVideoForConditionalGeneration"), - "LlavaOnevisionForConditionalGeneration": - ("llava_onevision", "LlavaOnevisionForConditionalGeneration"), - "MiniCPMV": ("minicpmv", "MiniCPMV"), - "PaliGemmaForConditionalGeneration": ("paligemma", - "PaliGemmaForConditionalGeneration"), - "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), - "PixtralForConditionalGeneration": ("pixtral", - "PixtralForConditionalGeneration"), - "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"), - "Qwen2VLForConditionalGeneration": ("qwen2_vl", - "Qwen2VLForConditionalGeneration"), - "UltravoxModel": ("ultravox", "UltravoxModel"), - "MllamaForConditionalGeneration": ("mllama", - "MllamaForConditionalGeneration"), -} -_CONDITIONAL_GENERATION_MODELS = { - "BartModel": ("bart", "BartForConditionalGeneration"), - "BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"), -} - -_MODELS = { - **_GENERATION_MODELS, - **_EMBEDDING_MODELS, - **_MULTIMODAL_MODELS, - **_CONDITIONAL_GENERATION_MODELS, -} - -# Architecture -> type. -# out of tree models -_OOT_MODELS: Dict[str, Type[nn.Module]] = {} - -# Models not supported by ROCm. -_ROCM_UNSUPPORTED_MODELS: List[str] = [] - -# Models partially supported by ROCm. -# Architecture -> Reason. -_ROCM_SWA_REASON = ("Sliding window attention (SWA) is not yet supported in " - "Triton flash attention. For half-precision SWA support, " - "please use CK flash attention by setting " - "`VLLM_USE_TRITON_FLASH_ATTN=0`") -_ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = { - "Qwen2ForCausalLM": - _ROCM_SWA_REASON, - "MistralForCausalLM": - _ROCM_SWA_REASON, - "MixtralForCausalLM": - _ROCM_SWA_REASON, - "PaliGemmaForConditionalGeneration": - ("ROCm flash attention does not yet " - "fully support 32-bit precision on PaliGemma"), - "Phi3VForCausalLM": - ("ROCm Triton flash attention may run into compilation errors due to " - "excessive use of shared memory. If this happens, disable Triton FA " - "by setting `VLLM_USE_TRITON_FLASH_ATTN=0`") -} - - -class ModelRegistry: - - @staticmethod - def _get_module_cls_name(model_arch: str) -> Tuple[str, str]: - module_relname, cls_name = _MODELS[model_arch] - return f"vllm.model_executor.models.{module_relname}", cls_name - - @staticmethod - @lru_cache(maxsize=128) - def _try_get_model_stateful(model_arch: str) -> Optional[Type[nn.Module]]: - if model_arch not in _MODELS: - return None - - module_name, cls_name = ModelRegistry._get_module_cls_name(model_arch) - module = importlib.import_module(module_name) - return getattr(module, cls_name, None) - - @staticmethod - def _try_get_model_stateless(model_arch: str) -> Optional[Type[nn.Module]]: - if model_arch in _OOT_MODELS: - return _OOT_MODELS[model_arch] - - if is_hip(): - if model_arch in _ROCM_UNSUPPORTED_MODELS: - raise ValueError( - f"Model architecture {model_arch} is not supported by " - "ROCm for now.") - if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS: - logger.warning( - "Model architecture %s is partially supported by ROCm: %s", - model_arch, _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch]) - - return None - - @staticmethod - def _try_load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]: - model = ModelRegistry._try_get_model_stateless(model_arch) - if model is not None: - return model - - return ModelRegistry._try_get_model_stateful(model_arch) - - @staticmethod - def resolve_model_cls( - architectures: Union[str, List[str]], ) -> Tuple[Type[nn.Module], str]: - if isinstance(architectures, str): - architectures = [architectures] - if not architectures: - logger.warning("No model architectures are specified") - - for arch in architectures: - model_cls = ModelRegistry._try_load_model_cls(arch) - if model_cls is not None: - return (model_cls, arch) - - raise ValueError( - f"Model architectures {architectures} are not supported for now. " - f"Supported architectures: {ModelRegistry.get_supported_archs()}") - - @staticmethod - def get_supported_archs() -> List[str]: - return list(_MODELS.keys()) + list(_OOT_MODELS.keys()) - - @staticmethod - def register_model(model_arch: str, model_cls: Type[nn.Module]): - if model_arch in _MODELS: - logger.warning( - "Model architecture %s is already registered, and will be " - "overwritten by the new model class %s.", model_arch, - model_cls.__name__) - - _OOT_MODELS[model_arch] = model_cls - - @staticmethod - @lru_cache(maxsize=128) - def _check_stateless( - func: Callable[[Type[nn.Module]], bool], - model_arch: str, - *, - default: Optional[bool] = None, - ) -> bool: - """ - Run a boolean function against a model and return the result. - - If the model is not found, returns the provided default value. - - If the model is not already imported, the function is run inside a - subprocess to avoid initializing CUDA for the main program. - """ - model = ModelRegistry._try_get_model_stateless(model_arch) - if model is not None: - return func(model) - - if model_arch not in _MODELS and default is not None: - return default - - module_name, cls_name = ModelRegistry._get_module_cls_name(model_arch) - - valid_name_characters = string.ascii_letters + string.digits + "._" - if any(s not in valid_name_characters for s in module_name): - raise ValueError(f"Unsafe module name detected for {model_arch}") - if any(s not in valid_name_characters for s in cls_name): - raise ValueError(f"Unsafe class name detected for {model_arch}") - if any(s not in valid_name_characters for s in func.__module__): - raise ValueError(f"Unsafe module name detected for {func}") - if any(s not in valid_name_characters for s in func.__name__): - raise ValueError(f"Unsafe class name detected for {func}") - - err_id = uuid.uuid4() - - stmts = ";".join([ - f"from {module_name} import {cls_name}", - f"from {func.__module__} import {func.__name__}", - f"assert {func.__name__}({cls_name}), '{err_id}'", - ]) - - result = subprocess.run([sys.executable, "-c", stmts], - capture_output=True) - - if result.returncode != 0: - err_lines = [line.decode() for line in result.stderr.splitlines()] - if err_lines and err_lines[-1] != f"AssertionError: {err_id}": - err_str = "\n".join(err_lines) - raise RuntimeError( - "An unexpected error occurred while importing the model in " - f"another process. Error log:\n{err_str}") - - return result.returncode == 0 - - @staticmethod - def is_embedding_model(architectures: Union[str, List[str]]) -> bool: - if isinstance(architectures, str): - architectures = [architectures] - if not architectures: - logger.warning("No model architectures are specified") - - return any(arch in _EMBEDDING_MODELS for arch in architectures) - - @staticmethod - def is_multimodal_model(architectures: Union[str, List[str]]) -> bool: - if isinstance(architectures, str): - architectures = [architectures] - if not architectures: - logger.warning("No model architectures are specified") - - is_mm = partial(ModelRegistry._check_stateless, - supports_multimodal, - default=False) - - return any(is_mm(arch) for arch in architectures) - - @staticmethod - def is_pp_supported_model(architectures: Union[str, List[str]]) -> bool: - if isinstance(architectures, str): - architectures = [architectures] - if not architectures: - logger.warning("No model architectures are specified") - - is_pp = partial(ModelRegistry._check_stateless, - supports_pp, - default=False) - - return any(is_pp(arch) for arch in architectures) - +from .interfaces import (HasInnerState, SupportsLoRA, SupportsMultiModal, + SupportsPP, has_inner_state, supports_lora, + supports_multimodal, supports_pp) +from .registry import ModelRegistry __all__ = [ "ModelRegistry", + "HasInnerState", + "has_inner_state", + "SupportsLoRA", + "supports_lora", + "SupportsMultiModal", + "supports_multimodal", + "SupportsPP", + "supports_pp", ] diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 330a2b6e3fd7f..06ec324b3e108 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -25,20 +25,18 @@ causal_conv1d_fn, causal_conv1d_update) from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( selective_scan_fn, selective_state_update) -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.interfaces import HasInnerState from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs from vllm.sequence import IntermediateTensors from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE, _get_graph_batch_size) -from .interfaces import SupportsLoRA +from .interfaces import HasInnerState, SupportsLoRA KVCache = Tuple[torch.Tensor, torch.Tensor] diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py new file mode 100644 index 0000000000000..aa5736e7cd517 --- /dev/null +++ b/vllm/model_executor/models/registry.py @@ -0,0 +1,320 @@ +import importlib +import string +import subprocess +import sys +import uuid +from functools import lru_cache, partial +from typing import Callable, Dict, List, Optional, Tuple, Type, Union + +import torch.nn as nn + +from vllm.logger import init_logger +from vllm.utils import is_hip + +from .interfaces import supports_multimodal, supports_pp + +logger = init_logger(__name__) + +_GENERATION_MODELS = { + "AquilaModel": ("llama", "LlamaForCausalLM"), + "AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2 + "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"), + "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-7b + "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), # baichuan-13b + "BloomForCausalLM": ("bloom", "BloomForCausalLM"), + "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"), + "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"), + "CohereForCausalLM": ("commandr", "CohereForCausalLM"), + "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"), + "DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"), + "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"), + "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"), + "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"), + "FalconForCausalLM": ("falcon", "FalconForCausalLM"), + "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"), + "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"), + "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"), + "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"), + "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"), + "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"), + "GraniteForCausalLM": ("granite", "GraniteForCausalLM"), + "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"), + "InternLMForCausalLM": ("llama", "LlamaForCausalLM"), + "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"), + "JAISLMHeadModel": ("jais", "JAISLMHeadModel"), + "JambaForCausalLM": ("jamba", "JambaForCausalLM"), + "LlamaForCausalLM": ("llama", "LlamaForCausalLM"), + # For decapoda-research/llama-* + "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"), + "MistralForCausalLM": ("llama", "LlamaForCausalLM"), + "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"), + "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"), + # transformers's mpt class has lower case + "MptForCausalLM": ("mpt", "MPTForCausalLM"), + "MPTForCausalLM": ("mpt", "MPTForCausalLM"), + "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"), + "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"), + "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"), + "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"), + "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"), + "OPTForCausalLM": ("opt", "OPTForCausalLM"), + "OrionForCausalLM": ("orion", "OrionForCausalLM"), + "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"), + "PhiForCausalLM": ("phi", "PhiForCausalLM"), + "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"), + "Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"), + "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"), + "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), + "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"), + "Qwen2VLForConditionalGeneration": + ("qwen2_vl", "Qwen2VLForConditionalGeneration"), + "RWForCausalLM": ("falcon", "FalconForCausalLM"), + "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"), + "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"), + "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"), + "SolarForCausalLM": ("solar", "SolarForCausalLM"), + "XverseForCausalLM": ("xverse", "XverseForCausalLM"), + # NOTE: The below models are for speculative decoding only + "MedusaModel": ("medusa", "Medusa"), + "EAGLEModel": ("eagle", "EAGLE"), + "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"), +} + +_EMBEDDING_MODELS = { + "MistralModel": ("llama_embedding", "LlamaEmbeddingModel"), + "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"), +} + +_MULTIMODAL_MODELS = { + "Blip2ForConditionalGeneration": + ("blip2", "Blip2ForConditionalGeneration"), + "ChameleonForConditionalGeneration": + ("chameleon", "ChameleonForConditionalGeneration"), + "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"), + "InternVLChatModel": ("internvl", "InternVLChatModel"), + "LlavaForConditionalGeneration": ("llava", + "LlavaForConditionalGeneration"), + "LlavaNextForConditionalGeneration": ("llava_next", + "LlavaNextForConditionalGeneration"), + "LlavaNextVideoForConditionalGeneration": + ("llava_next_video", "LlavaNextVideoForConditionalGeneration"), + "LlavaOnevisionForConditionalGeneration": + ("llava_onevision", "LlavaOnevisionForConditionalGeneration"), + "MiniCPMV": ("minicpmv", "MiniCPMV"), + "PaliGemmaForConditionalGeneration": ("paligemma", + "PaliGemmaForConditionalGeneration"), + "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), + "PixtralForConditionalGeneration": ("pixtral", + "PixtralForConditionalGeneration"), + "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"), + "Qwen2VLForConditionalGeneration": ("qwen2_vl", + "Qwen2VLForConditionalGeneration"), + "UltravoxModel": ("ultravox", "UltravoxModel"), + "MllamaForConditionalGeneration": ("mllama", + "MllamaForConditionalGeneration"), +} +_CONDITIONAL_GENERATION_MODELS = { + "BartModel": ("bart", "BartForConditionalGeneration"), + "BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"), +} + +_MODELS = { + **_GENERATION_MODELS, + **_EMBEDDING_MODELS, + **_MULTIMODAL_MODELS, + **_CONDITIONAL_GENERATION_MODELS, +} + +# Architecture -> type. +# out of tree models +_OOT_MODELS: Dict[str, Type[nn.Module]] = {} + +# Models not supported by ROCm. +_ROCM_UNSUPPORTED_MODELS: List[str] = [] + +# Models partially supported by ROCm. +# Architecture -> Reason. +_ROCM_SWA_REASON = ("Sliding window attention (SWA) is not yet supported in " + "Triton flash attention. For half-precision SWA support, " + "please use CK flash attention by setting " + "`VLLM_USE_TRITON_FLASH_ATTN=0`") +_ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = { + "Qwen2ForCausalLM": + _ROCM_SWA_REASON, + "MistralForCausalLM": + _ROCM_SWA_REASON, + "MixtralForCausalLM": + _ROCM_SWA_REASON, + "PaliGemmaForConditionalGeneration": + ("ROCm flash attention does not yet " + "fully support 32-bit precision on PaliGemma"), + "Phi3VForCausalLM": + ("ROCm Triton flash attention may run into compilation errors due to " + "excessive use of shared memory. If this happens, disable Triton FA " + "by setting `VLLM_USE_TRITON_FLASH_ATTN=0`") +} + + +class ModelRegistry: + + @staticmethod + def _get_module_cls_name(model_arch: str) -> Tuple[str, str]: + module_relname, cls_name = _MODELS[model_arch] + return f"vllm.model_executor.models.{module_relname}", cls_name + + @staticmethod + @lru_cache(maxsize=128) + def _try_get_model_stateful(model_arch: str) -> Optional[Type[nn.Module]]: + if model_arch not in _MODELS: + return None + + module_name, cls_name = ModelRegistry._get_module_cls_name(model_arch) + module = importlib.import_module(module_name) + return getattr(module, cls_name, None) + + @staticmethod + def _try_get_model_stateless(model_arch: str) -> Optional[Type[nn.Module]]: + if model_arch in _OOT_MODELS: + return _OOT_MODELS[model_arch] + + if is_hip(): + if model_arch in _ROCM_UNSUPPORTED_MODELS: + raise ValueError( + f"Model architecture {model_arch} is not supported by " + "ROCm for now.") + if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS: + logger.warning( + "Model architecture %s is partially supported by ROCm: %s", + model_arch, _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch]) + + return None + + @staticmethod + def _try_load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]: + model = ModelRegistry._try_get_model_stateless(model_arch) + if model is not None: + return model + + return ModelRegistry._try_get_model_stateful(model_arch) + + @staticmethod + def resolve_model_cls( + architectures: Union[str, List[str]], ) -> Tuple[Type[nn.Module], str]: + if isinstance(architectures, str): + architectures = [architectures] + if not architectures: + logger.warning("No model architectures are specified") + + for arch in architectures: + model_cls = ModelRegistry._try_load_model_cls(arch) + if model_cls is not None: + return (model_cls, arch) + + raise ValueError( + f"Model architectures {architectures} are not supported for now. " + f"Supported architectures: {ModelRegistry.get_supported_archs()}") + + @staticmethod + def get_supported_archs() -> List[str]: + return list(_MODELS.keys()) + list(_OOT_MODELS.keys()) + + @staticmethod + def register_model(model_arch: str, model_cls: Type[nn.Module]): + if model_arch in _MODELS: + logger.warning( + "Model architecture %s is already registered, and will be " + "overwritten by the new model class %s.", model_arch, + model_cls.__name__) + + _OOT_MODELS[model_arch] = model_cls + + @staticmethod + @lru_cache(maxsize=128) + def _check_stateless( + func: Callable[[Type[nn.Module]], bool], + model_arch: str, + *, + default: Optional[bool] = None, + ) -> bool: + """ + Run a boolean function against a model and return the result. + + If the model is not found, returns the provided default value. + + If the model is not already imported, the function is run inside a + subprocess to avoid initializing CUDA for the main program. + """ + model = ModelRegistry._try_get_model_stateless(model_arch) + if model is not None: + return func(model) + + if model_arch not in _MODELS and default is not None: + return default + + module_name, cls_name = ModelRegistry._get_module_cls_name(model_arch) + + valid_name_characters = string.ascii_letters + string.digits + "._" + if any(s not in valid_name_characters for s in module_name): + raise ValueError(f"Unsafe module name detected for {model_arch}") + if any(s not in valid_name_characters for s in cls_name): + raise ValueError(f"Unsafe class name detected for {model_arch}") + if any(s not in valid_name_characters for s in func.__module__): + raise ValueError(f"Unsafe module name detected for {func}") + if any(s not in valid_name_characters for s in func.__name__): + raise ValueError(f"Unsafe class name detected for {func}") + + err_id = uuid.uuid4() + + stmts = ";".join([ + f"from {module_name} import {cls_name}", + f"from {func.__module__} import {func.__name__}", + f"assert {func.__name__}({cls_name}), '{err_id}'", + ]) + + result = subprocess.run([sys.executable, "-c", stmts], + capture_output=True) + + if result.returncode != 0: + err_lines = [line.decode() for line in result.stderr.splitlines()] + if err_lines and err_lines[-1] != f"AssertionError: {err_id}": + err_str = "\n".join(err_lines) + raise RuntimeError( + "An unexpected error occurred while importing the model in " + f"another process. Error log:\n{err_str}") + + return result.returncode == 0 + + @staticmethod + def is_embedding_model(architectures: Union[str, List[str]]) -> bool: + if isinstance(architectures, str): + architectures = [architectures] + if not architectures: + logger.warning("No model architectures are specified") + + return any(arch in _EMBEDDING_MODELS for arch in architectures) + + @staticmethod + def is_multimodal_model(architectures: Union[str, List[str]]) -> bool: + if isinstance(architectures, str): + architectures = [architectures] + if not architectures: + logger.warning("No model architectures are specified") + + is_mm = partial(ModelRegistry._check_stateless, + supports_multimodal, + default=False) + + return any(is_mm(arch) for arch in architectures) + + @staticmethod + def is_pp_supported_model(architectures: Union[str, List[str]]) -> bool: + if isinstance(architectures, str): + architectures = [architectures] + if not architectures: + logger.warning("No model architectures are specified") + + is_pp = partial(ModelRegistry._check_stateless, + supports_pp, + default=False) + + return any(is_pp(arch) for arch in architectures) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 51f65cbfcf862..9784438841980 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -35,8 +35,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader.tensorizer import TensorizerConfig -from vllm.model_executor.models.interfaces import (supports_lora, - supports_multimodal) +from vllm.model_executor.models import supports_lora, supports_multimodal from vllm.model_executor.models.utils import set_cpu_offload_max_bytes from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, MultiModalInputs, MultiModalRegistry)