From 82fd70e0e076ceedba4455b2604670d3aedd2424 Mon Sep 17 00:00:00 2001 From: David Xia Date: Tue, 13 May 2025 00:11:05 +0000 Subject: [PATCH] [Frontend] speed up import time of vllm.config by changing submodules to lazily import expensive modules like `vllm.model_executor.layers.quantization` or only importing them for type checkers when not used during runtime. contributes to #14924 Signed-off-by: David Xia --- vllm/config.py | 44 +++++++++++++++++++++++++++----------------- 1 file changed, 27 insertions(+), 17 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 4333dcd3b8af..5cf7fe84dd3e 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -27,19 +27,13 @@ from pydantic.dataclasses import dataclass from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE from torch.distributed import ProcessGroup, ReduceOp -from transformers import PretrainedConfig from typing_extensions import Self, deprecated, runtime_checkable import vllm.envs as envs from vllm import version from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass from vllm.logger import init_logger -from vllm.model_executor.layers.quantization import (QUANTIZATION_METHODS, - QuantizationMethods, - get_quantization_config) -from vllm.model_executor.models import ModelRegistry from vllm.platforms import current_platform -from vllm.tracing import is_otel_available, otel_import_error_traceback from vllm.transformers_utils.config import ( ConfigFormat, get_config, get_hf_image_processor_config, get_hf_text_config, get_pooling_config, @@ -48,32 +42,49 @@ try_get_tokenizer_config, uses_mrope) from vllm.transformers_utils.s3_utils import S3Model from vllm.transformers_utils.utils import is_s3, maybe_model_redirect +# yapf conflicts with isort for this block +# yapf: disable from vllm.utils import (DEFAULT_MAX_NUM_BATCHED_TOKENS, MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS, POOLING_MODEL_MAX_NUM_BATCHED_TOKENS, GiB_bytes, - LayerBlockType, common_broadcastable_dtype, + LayerBlockType, LazyLoader, common_broadcastable_dtype, cuda_device_count_stateless, get_cpu_memory, get_open_port, is_torch_equal_or_newer, random_uuid, resolve_obj_by_qualname) +# yapf: enable + if TYPE_CHECKING: from _typeshed import DataclassInstance from ray.util.placement_group import PlacementGroup + from transformers.configuration_utils import PretrainedConfig + import vllm.model_executor.layers.quantization as me_quant + import vllm.model_executor.models as me_models from vllm.executor.executor_base import ExecutorBase + from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.model_loader import BaseModelLoader from vllm.model_executor.model_loader.tensorizer import TensorizerConfig ConfigType = type[DataclassInstance] + HfOverrides = Union[dict, Callable[[type], type]] else: PlacementGroup = Any + PretrainedConfig = Any ExecutorBase = Any QuantizationConfig = Any + QuantizationMethods = Any BaseModelLoader = Any TensorizerConfig = Any ConfigType = type + HfOverrides = Union[dict[str, Any], Callable[[type], type]] + + me_quant = LazyLoader("model_executor", globals(), + "vllm.model_executor.layers.quantization") + me_models = LazyLoader("model_executor", globals(), + "vllm.model_executor.models") logger = init_logger(__name__) @@ -100,9 +111,6 @@ for task in tasks } -HfOverrides = Union[dict[str, Any], Callable[[PretrainedConfig], - PretrainedConfig]] - @runtime_checkable class SupportsHash(Protocol): @@ -648,7 +656,7 @@ def validate_model_config_after(self: "ModelConfig") -> "ModelConfig": @property def registry(self): - return ModelRegistry + return me_models.ModelRegistry @property def architectures(self) -> list[str]: @@ -859,14 +867,15 @@ def _parse_quant_hf_config(self): return quant_cfg def _verify_quantization(self) -> None: - supported_quantization = QUANTIZATION_METHODS + supported_quantization = me_quant.QUANTIZATION_METHODS optimized_quantization_methods = [ "fp8", "marlin", "modelopt", "gptq_marlin_24", "gptq_marlin", "awq_marlin", "fbgemm_fp8", "compressed-tensors", "experts_int8", "quark", "modelopt_fp4", "bitblas", "gptq_bitblas" ] if self.quantization is not None: - self.quantization = cast(QuantizationMethods, self.quantization) + self.quantization = cast(me_quant.QuantizationMethods, + self.quantization) # Parse quantization method from the HF model config, if available. quant_cfg = self._parse_quant_hf_config() @@ -900,14 +909,14 @@ def _verify_quantization(self) -> None: # Detect which checkpoint is it for name in quantization_methods: - method = get_quantization_config(name) + method = me_quant.get_quantization_config(name) quantization_override = method.override_quantization_method( quant_cfg, self.quantization) if quantization_override is not None: # Raise error if the override is not custom (custom would # be in QUANTIZATION_METHODS but not QuantizationMethods) # and hasn't been added to the overrides list. - if (name in get_args(QuantizationMethods) + if (name in get_args(me_quant.QuantizationMethods) and name not in overrides): raise ValueError( f"Quantization method {name} is an override but " @@ -1417,7 +1426,7 @@ def runner_type(self) -> RunnerType: @property def is_v1_compatible(self) -> bool: architectures = getattr(self.hf_config, "architectures", []) - return ModelRegistry.is_v1_compatible(architectures) + return me_models.ModelRegistry.is_v1_compatible(architectures) @property def is_matryoshka(self) -> bool: @@ -2376,7 +2385,7 @@ class SpeculativeConfig: according to the log probability settings in SamplingParams.""" # Draft model configuration - quantization: Optional[QuantizationMethods] = None + quantization: Optional[me_quant.QuantizationMethods] = None """Quantization method that was used to quantize the draft model weights. If `None`, we assume the model weights are not quantized. Note that it only takes effect when using the draft model-based speculative method.""" @@ -3624,6 +3633,7 @@ def __post_init__(self): and "," in self.collect_detailed_traces[0]): self._parse_collect_detailed_traces() + from vllm.tracing import is_otel_available, otel_import_error_traceback if not is_otel_available() and self.otlp_traces_endpoint is not None: raise ValueError( "OpenTelemetry is not available. Unable to configure "