diff --git a/tests/transformers_utils/__init__.py b/tests/transformers_utils/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/transformers_utils/test_config_parser_registry.py b/tests/transformers_utils/test_config_parser_registry.py new file mode 100644 index 000000000000..13c654e05d2a --- /dev/null +++ b/tests/transformers_utils/test_config_parser_registry.py @@ -0,0 +1,37 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from pathlib import Path +from typing import Optional, Union + +import pytest +from transformers import PretrainedConfig + +from vllm.transformers_utils.config import (get_config_parser, + register_config_parser) +from vllm.transformers_utils.config_parser_base import ConfigParserBase + + +@register_config_parser("custom_config_parser") +class CustomConfigParser(ConfigParserBase): + + def parse(self, + model: Union[str, Path], + trust_remote_code: bool, + revision: Optional[str] = None, + code_revision: Optional[str] = None, + **kwargs) -> tuple[dict, PretrainedConfig]: + raise NotImplementedError + + +def test_register_config_parser(): + assert isinstance(get_config_parser("custom_config_parser"), + CustomConfigParser) + + +def test_invalid_config_parser(): + with pytest.raises(ValueError): + + @register_config_parser("invalid_config_parser") + class InvalidConfigParser: + pass diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index 063af69f41da..3ccd6fd2f342 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -420,7 +420,7 @@ class ModelConfig: `--media-io-kwargs '{"video": {"num_frames": 40} }'` """ use_async_output_proc: bool = True """Whether to use async output processor.""" - config_format: Union[str, ConfigFormat] = ConfigFormat.AUTO.value + config_format: Union[str, ConfigFormat] = "auto" """The format of the model config to load:\n - "auto" will try to load the config in hf format if available else it will try to load in mistral format.\n @@ -625,9 +625,6 @@ def __post_init__(self) -> None: raise ValueError( "Sleep mode is not supported on current platform.") - if isinstance(self.config_format, str): - self.config_format = ConfigFormat(self.config_format) - hf_config = get_config(self.hf_config_path or self.model, self.trust_remote_code, self.revision, diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index fdd25a2f9ce2..b373c5f6f48c 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -22,9 +22,9 @@ import vllm.envs as envs from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig, - ConfigFormat, ConfigType, ConvertOption, - DecodingConfig, DetailedTraceModules, Device, - DeviceConfig, DistributedExecutorBackend, EPLBConfig, + ConfigType, ConvertOption, DecodingConfig, + DetailedTraceModules, Device, DeviceConfig, + DistributedExecutorBackend, EPLBConfig, GuidedDecodingBackend, HfOverrides, KVEventsConfig, KVTransferConfig, LoadConfig, LogprobsMode, LoRAConfig, MambaDType, MMEncoderTPMode, ModelConfig, @@ -547,7 +547,6 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: help="Disable async output processing. This may result in " "lower performance.") model_group.add_argument("--config-format", - choices=[f.value for f in ConfigFormat], **model_kwargs["config_format"]) # This one is a special case because it can bool # or str. TODO: Handle this in get_kwargs diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 95e4ed1ccf07..d6ebcdf80525 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -1,13 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import enum import json import os import time from functools import cache, partial from pathlib import Path -from typing import Any, Callable, Optional, TypeVar, Union +from typing import Any, Callable, Literal, Optional, TypeVar, Union import huggingface_hub from huggingface_hub import get_safetensors_metadata, hf_hub_download @@ -27,6 +26,7 @@ from vllm import envs from vllm.logger import init_logger +from vllm.transformers_utils.config_parser_base import ConfigParserBase from vllm.transformers_utils.utils import check_gguf_file if envs.VLLM_USE_MODELSCOPE: @@ -100,10 +100,163 @@ def __getitem__(self, key): } -class ConfigFormat(str, enum.Enum): - AUTO = "auto" - HF = "hf" - MISTRAL = "mistral" +class HFConfigParser(ConfigParserBase): + + def parse(self, + model: Union[str, Path], + trust_remote_code: bool, + revision: Optional[str] = None, + code_revision: Optional[str] = None, + **kwargs) -> tuple[dict, PretrainedConfig]: + kwargs["local_files_only"] = huggingface_hub.constants.HF_HUB_OFFLINE + config_dict, _ = PretrainedConfig.get_config_dict( + model, + revision=revision, + code_revision=code_revision, + token=_get_hf_token(), + **kwargs, + ) + # Use custom model class if it's in our registry + model_type = config_dict.get("model_type") + if model_type is None: + model_type = "speculators" if config_dict.get( + "speculators_config") is not None else model_type + + if model_type in _CONFIG_REGISTRY: + config_class = _CONFIG_REGISTRY[model_type] + config = config_class.from_pretrained( + model, + revision=revision, + code_revision=code_revision, + token=_get_hf_token(), + **kwargs, + ) + else: + try: + kwargs = _maybe_update_auto_config_kwargs( + kwargs, model_type=model_type) + config = AutoConfig.from_pretrained( + model, + trust_remote_code=trust_remote_code, + revision=revision, + code_revision=code_revision, + token=_get_hf_token(), + **kwargs, + ) + except ValueError as e: + if (not trust_remote_code + and "requires you to execute the configuration file" + in str(e)): + err_msg = ( + "Failed to load the model config. If the model " + "is a custom model not yet available in the " + "HuggingFace transformers library, consider setting " + "`trust_remote_code=True` in LLM or using the " + "`--trust-remote-code` flag in the CLI.") + raise RuntimeError(err_msg) from e + else: + raise e + config = _maybe_remap_hf_config_attrs(config) + return config_dict, config + + +class MistralConfigParser(ConfigParserBase): + + def parse(self, + model: Union[str, Path], + trust_remote_code: bool, + revision: Optional[str] = None, + code_revision: Optional[str] = None, + **kwargs) -> tuple[dict, PretrainedConfig]: + # This function loads a params.json config which + # should be used when loading models in mistral format + config_dict = _download_mistral_config_file(model, revision) + if (max_position_embeddings := + config_dict.get("max_position_embeddings")) is None: + max_position_embeddings = _maybe_retrieve_max_pos_from_hf( + model, revision, **kwargs) + config_dict["max_position_embeddings"] = max_position_embeddings + + from vllm.transformers_utils.configs.mistral import adapt_config_dict + + config = adapt_config_dict(config_dict) + + # Mistral configs may define sliding_window as list[int]. Convert it + # to int and add the layer_types list[str] to make it HF compatible + if ((sliding_window := getattr(config, "sliding_window", None)) + and isinstance(sliding_window, list)): + pattern_repeats = config.num_hidden_layers // len(sliding_window) + layer_types = sliding_window * pattern_repeats + config.layer_types = [ + "full_attention" if layer_type is None else "sliding_attention" + for layer_type in layer_types + ] + config.sliding_window = next(filter(None, sliding_window), None) + + return config_dict, config + + +_CONFIG_FORMAT_TO_CONFIG_PARSER: dict[str, type[ConfigParserBase]] = { + "hf": HFConfigParser, + "mistral": MistralConfigParser, +} + +ConfigFormat = Literal[ + "auto", + "hf", + "mistral", +] + + +def get_config_parser(config_format: str) -> ConfigParserBase: + """Get the config parser for a given config format.""" + if config_format not in _CONFIG_FORMAT_TO_CONFIG_PARSER: + raise ValueError(f"Unknown config format `{config_format}`.") + return _CONFIG_FORMAT_TO_CONFIG_PARSER[config_format]() + + +def register_config_parser(config_format: str): + + """Register a customized vllm config parser. + When a config format is not supported by vllm, you can register a customized + config parser to support it. + Args: + config_format (str): The config parser format name. + Examples: + + >>> from vllm.transformers_utils.config import (get_config_parser, + register_config_parser) + >>> from vllm.transformers_utils.config_parser_base import ConfigParserBase + >>> + >>> @register_config_parser("custom_config_parser") + ... class CustomConfigParser(ConfigParserBase): + ... def parse(self, + ... model: Union[str, Path], + ... trust_remote_code: bool, + ... revision: Optional[str] = None, + ... code_revision: Optional[str] = None, + ... **kwargs) -> tuple[dict, PretrainedConfig]: + ... raise NotImplementedError + >>> + >>> type(get_config_parser("custom_config_parser")) + + """ # noqa: E501 + + def _wrapper(config_parser_cls): + if config_format in _CONFIG_FORMAT_TO_CONFIG_PARSER: + logger.warning( + "Config format `%s` is already registered, and will be " + "overwritten by the new parser class `%s`.", config_format, + config_parser_cls) + if not issubclass(config_parser_cls, ConfigParserBase): + raise ValueError("The config parser must be a subclass of " + "`ConfigParserBase`.") + _CONFIG_FORMAT_TO_CONFIG_PARSER[config_format] = config_parser_cls + logger.info("Registered config parser `%s` with config format `%s`", + config_parser_cls, config_format) + return config_parser_cls + + return _wrapper _R = TypeVar("_R") @@ -350,7 +503,7 @@ def get_config( trust_remote_code: bool, revision: Optional[str] = None, code_revision: Optional[str] = None, - config_format: ConfigFormat = ConfigFormat.AUTO, + config_format: Union[str, ConfigFormat] = "auto", hf_overrides_kw: Optional[dict[str, Any]] = None, hf_overrides_fn: Optional[Callable[[PretrainedConfig], PretrainedConfig]] = None, @@ -363,20 +516,22 @@ def get_config( kwargs["gguf_file"] = Path(model).name model = Path(model).parent - if config_format == ConfigFormat.AUTO: + if config_format == "auto": try: if is_gguf or file_or_path_exists( model, HF_CONFIG_NAME, revision=revision): - config_format = ConfigFormat.HF + config_format = "hf" elif file_or_path_exists(model, MISTRAL_CONFIG_NAME, revision=revision): - config_format = ConfigFormat.MISTRAL + config_format = "mistral" else: raise ValueError( "Could not detect config format for no config file found. " - "Ensure your model has either config.json (HF format) " - "or params.json (Mistral format).") + "With config_format 'auto', ensure your model has either" + "config.json (HF format) or params.json (Mistral format)." + "Otherwise please specify your_custom_config_format" + "in engine args for customized config parser") except Exception as e: error_message = ( @@ -395,92 +550,14 @@ def get_config( raise ValueError(error_message) from e - if config_format == ConfigFormat.HF: - kwargs["local_files_only"] = huggingface_hub.constants.HF_HUB_OFFLINE - config_dict, _ = PretrainedConfig.get_config_dict( - model, - revision=revision, - code_revision=code_revision, - token=_get_hf_token(), - **kwargs, - ) - # Use custom model class if it's in our registry - model_type = config_dict.get("model_type") - if model_type is None: - model_type = "speculators" if config_dict.get( - "speculators_config") is not None else model_type - - if model_type in _CONFIG_REGISTRY: - config_class = _CONFIG_REGISTRY[model_type] - config = config_class.from_pretrained( - model, - revision=revision, - code_revision=code_revision, - token=_get_hf_token(), - **kwargs, - ) - else: - try: - kwargs = _maybe_update_auto_config_kwargs( - kwargs, model_type=model_type) - config = AutoConfig.from_pretrained( - model, - trust_remote_code=trust_remote_code, - revision=revision, - code_revision=code_revision, - token=_get_hf_token(), - **kwargs, - ) - except ValueError as e: - if (not trust_remote_code - and "requires you to execute the configuration file" - in str(e)): - err_msg = ( - "Failed to load the model config. If the model " - "is a custom model not yet available in the " - "HuggingFace transformers library, consider setting " - "`trust_remote_code=True` in LLM or using the " - "`--trust-remote-code` flag in the CLI.") - raise RuntimeError(err_msg) from e - else: - raise e - config = _maybe_remap_hf_config_attrs(config) - - elif config_format == ConfigFormat.MISTRAL: - # This function loads a params.json config which - # should be used when loading models in mistral format - config_dict = _download_mistral_config_file(model, revision) - if (max_position_embeddings := - config_dict.get("max_position_embeddings")) is None: - max_position_embeddings = _maybe_retrieve_max_pos_from_hf( - model, revision, **kwargs) - config_dict["max_position_embeddings"] = max_position_embeddings - - from vllm.transformers_utils.configs.mistral import adapt_config_dict - - config = adapt_config_dict(config_dict) - - # Mistral configs may define sliding_window as list[int]. Convert it - # to int and add the layer_types list[str] to make it HF compatible - if ((sliding_window := getattr(config, "sliding_window", None)) - and isinstance(sliding_window, list)): - pattern_repeats = config.num_hidden_layers // len(sliding_window) - layer_types = sliding_window * pattern_repeats - config.layer_types = [ - "full_attention" if layer_type is None else "sliding_attention" - for layer_type in layer_types - ] - config.sliding_window = next(filter(None, sliding_window), None) - else: - supported_formats = [ - fmt.value for fmt in ConfigFormat if fmt != ConfigFormat.AUTO - ] - raise ValueError( - f"Unsupported config format: {config_format}. " - f"Supported formats are: {', '.join(supported_formats)}. " - f"Ensure your model uses one of these configuration formats " - f"or specify the correct format explicitly.") - + config_parser = get_config_parser(config_format) + config_dict, config = config_parser.parse( + model, + trust_remote_code=trust_remote_code, + revision=revision, + code_revision=code_revision, + **kwargs, + ) # Special architecture mapping check for GGUF models if is_gguf: if config.model_type not in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: @@ -914,7 +991,7 @@ def _maybe_retrieve_max_pos_from_hf(model, revision, **kwargs) -> int: hf_config = get_config(model=model, trust_remote_code=trust_remote_code_val, revision=revision, - config_format=ConfigFormat.HF) + config_format="hf") if hf_value := hf_config.get_text_config().max_position_embeddings: max_position_embeddings = hf_value except Exception as e: diff --git a/vllm/transformers_utils/config_parser_base.py b/vllm/transformers_utils/config_parser_base.py new file mode 100644 index 000000000000..c27177f74d4b --- /dev/null +++ b/vllm/transformers_utils/config_parser_base.py @@ -0,0 +1,20 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Optional, Union + +from transformers import PretrainedConfig + + +class ConfigParserBase(ABC): + + @abstractmethod + def parse(self, + model: Union[str, Path], + trust_remote_code: bool, + revision: Optional[str] = None, + code_revision: Optional[str] = None, + **kwargs) -> tuple[dict, PretrainedConfig]: + raise NotImplementedError