diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d0fa4e8f64cc..653f491d545d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -58,7 +58,7 @@ repos: entry: tools/mypy.sh 0 "local" language: python types: [python] - additional_dependencies: &mypy_deps [mypy==1.11.1, types-cachetools, types-setuptools, types-PyYAML, types-requests] + additional_dependencies: &mypy_deps [mypy==1.11.1, types-cachetools, types-setuptools, types-PyYAML, types-requests, pydantic] stages: [pre-commit] # Don't run in CI - id: mypy-3.9 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward name: Run mypy for Python 3.9 diff --git a/pyproject.toml b/pyproject.toml index 62a734d795d5..eb55a9ffc5bf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -110,6 +110,7 @@ ignore = [ ] [tool.mypy] +plugins = ['pydantic.mypy'] ignore_missing_imports = true check_untyped_defs = true follow_imports = "silent" diff --git a/tests/lora/test_quant_model.py b/tests/lora/test_quant_model.py index caf71976a260..7a76ffb740ef 100644 --- a/tests/lora/test_quant_model.py +++ b/tests/lora/test_quant_model.py @@ -24,16 +24,16 @@ class ModelWithQuantization: MODELS = [ ModelWithQuantization( model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", - quantization="GPTQ"), + quantization="gptq"), ] else: MODELS = [ ModelWithQuantization( model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", - quantization="AWQ"), + quantization="awq"), ModelWithQuantization( model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", - quantization="GPTQ"), + quantization="gptq"), ] @@ -100,7 +100,7 @@ def test_quant_model_lora(tinyllama_lora_files, model): "#ff8050", "#ff8080", ] - elif model.quantization == "AWQ": + elif model.quantization == "awq": expected_no_lora_output = [ "I'm sorry, I don't understand", "I'm sorry, I don't understand", @@ -109,7 +109,7 @@ def test_quant_model_lora(tinyllama_lora_files, model): "#f07700: A v", "#f00000: A v", ] - elif model.quantization == "GPTQ": + elif model.quantization == "gptq": expected_no_lora_output = [ "I'm sorry, I don't have", "I'm sorry, I don't have", @@ -122,7 +122,7 @@ def test_quant_model_lora(tinyllama_lora_files, model): def expect_match(output, expected_output): # HACK: GPTQ lora outputs are just incredibly unstable. # Assert that the outputs changed. - if (model.quantization == "GPTQ" + if (model.quantization == "gptq" and expected_output is expected_lora_output): assert output != expected_no_lora_output for i, o in enumerate(output): @@ -172,7 +172,7 @@ def test_quant_model_tp_equality(tinyllama_lora_files, num_gpus_available, model): if num_gpus_available < 2: pytest.skip(f"Not enough GPUs for tensor parallelism {2}") - if model.quantization == "GPTQ": + if model.quantization == "gptq": pytest.skip("GPTQ lora outputs are just incredibly unstable") llm_tp1 = vllm.LLM( model=model.model_path, diff --git a/tests/tracing/test_tracing.py b/tests/tracing/test_tracing.py index a781b8b563be..caa233ec3ff9 100644 --- a/tests/tracing/test_tracing.py +++ b/tests/tracing/test_tracing.py @@ -173,7 +173,7 @@ def test_traces_with_detailed_steps( llm = LLM( model=model, otlp_traces_endpoint=FAKE_TRACE_SERVER_ADDRESS, - collect_detailed_traces="all", + collect_detailed_traces=["all"], ) prompts = ["This is a short prompt"] outputs = llm.generate(prompts, sampling_params=sampling_params) diff --git a/vllm/config.py b/vllm/config.py index 4196684639ee..93b367ca81cc 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -11,8 +11,8 @@ import warnings from collections import Counter from contextlib import contextmanager -from dataclasses import (MISSING, Field, asdict, dataclass, field, fields, - is_dataclass, replace) +from dataclasses import (MISSING, Field, asdict, field, fields, is_dataclass, + replace) from functools import cached_property from importlib.util import find_spec from pathlib import Path @@ -21,9 +21,12 @@ import regex as re import torch +from pydantic import (ConfigDict, SkipValidation, TypeAdapter, field_validator, + model_validator) +from pydantic.dataclasses import dataclass from torch.distributed import ProcessGroup, ReduceOp from transformers import PretrainedConfig -from typing_extensions import deprecated +from typing_extensions import deprecated, runtime_checkable import vllm.envs as envs from vllm import version @@ -57,10 +60,15 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.model_loader import BaseModelLoader + from vllm.model_executor.model_loader.tensorizer import TensorizerConfig ConfigType = type[DataclassInstance] else: + PlacementGroup = Any + ExecutorBase = Any QuantizationConfig = Any + BaseModelLoader = Any + TensorizerConfig = Any ConfigType = type logger = init_logger(__name__) @@ -92,6 +100,7 @@ PretrainedConfig]] +@runtime_checkable class SupportsHash(Protocol): def compute_hash(self) -> str: @@ -223,7 +232,7 @@ def is_init_field(cls: ConfigType, name: str) -> bool: @config -@dataclass +@dataclass(config=ConfigDict(arbitrary_types_allowed=True)) class ModelConfig: """Configuration for the model.""" @@ -236,7 +245,7 @@ class ModelConfig: task, even if the same model can be used for multiple tasks. When the model only supports one task, "auto" can be used to select it; otherwise, you must specify explicitly which task to use.""" - tokenizer: str = None # type: ignore + tokenizer: SkipValidation[str] = None # type: ignore """Name or path of the Hugging Face tokenizer to use. If unspecified, model name or path will be used.""" tokenizer_mode: TokenizerMode = "auto" @@ -284,7 +293,7 @@ class ModelConfig: """The specific revision to use for the tokenizer on the Hugging Face Hub. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version.""" - max_model_len: int = None # type: ignore + max_model_len: SkipValidation[int] = None # type: ignore """Model context length (prompt and output). If unspecified, will be automatically derived from the model config. @@ -607,6 +616,22 @@ def __post_init__(self) -> None: self._verify_cuda_graph() self._verify_bnb_config() + @field_validator("quantization", mode="before") + @classmethod + def validate_quantization_before(cls, value: Any) -> Any: + if isinstance(value, str): + return value.lower() + return value + + @model_validator(mode="after") + def validate_model_config_after(self: "ModelConfig") -> "ModelConfig": + if not isinstance(self.tokenizer, str): + raise ValueError("tokenizer must be a string after __post_init__.") + if not isinstance(self.max_model_len, int): + raise ValueError( + "max_model_len must be an integer after __post_init__.") + return self + @property def registry(self): return ModelRegistry @@ -833,8 +858,7 @@ def _verify_quantization(self) -> None: "quark", "modelopt_fp4", "bitblas", "gptq_bitblas" ] if self.quantization is not None: - self.quantization = cast(QuantizationMethods, - self.quantization.lower()) + self.quantization = cast(QuantizationMethods, self.quantization) # Parse quantization method from the HF model config, if available. quant_cfg = self._parse_quant_hf_config() @@ -1397,7 +1421,7 @@ def matryoshka_dimensions(self): class CacheConfig: """Configuration for the KV cache.""" - block_size: BlockSize = None # type: ignore + block_size: SkipValidation[BlockSize] = None # type: ignore """Size of a contiguous cache block in number of tokens. This is ignored on neuron devices and set to `--max-model-len`. On CUDA devices, only block sizes up to 32 are supported. On HPU devices, block size defaults to 128. @@ -1619,7 +1643,8 @@ class LoadConfig: download_dir: Optional[str] = None """Directory to download and load the weights, default to the default cache directory of Hugging Face.""" - model_loader_extra_config: dict = field(default_factory=dict) + model_loader_extra_config: Union[dict, TensorizerConfig] = field( + default_factory=dict) """Extra config for model loader. This will be passed to the model loader corresponding to the chosen load_format.""" ignore_patterns: Optional[Union[list[str], str]] = None @@ -1929,19 +1954,19 @@ class SchedulerConfig: runner_type: RunnerType = "generate" """The runner type to launch for the model.""" - max_num_batched_tokens: int = None # type: ignore + max_num_batched_tokens: SkipValidation[int] = None # type: ignore """Maximum number of tokens to be processed in a single iteration. This config has no static default. If left unspecified by the user, it will be set in `EngineArgs.create_engine_config` based on the usage context.""" - max_num_seqs: int = None # type: ignore + max_num_seqs: SkipValidation[int] = None # type: ignore """Maximum number of sequences to be processed in a single iteration. This config has no static default. If left unspecified by the user, it will be set in `EngineArgs.create_engine_config` based on the usage context.""" - max_model_len: int = None # type: ignore + max_model_len: SkipValidation[int] = None # type: ignore """Maximum length of a sequence (including prompt and generated text). This is primarily set in `ModelConfig` and that value should be manually duplicated here.""" @@ -1980,7 +2005,7 @@ class SchedulerConfig: """Apply a delay (of delay factor multiplied by previous prompt latency) before scheduling next prompt.""" - enable_chunked_prefill: bool = None # type: ignore + enable_chunked_prefill: SkipValidation[bool] = None # type: ignore """If True, prefill requests can be chunked based on the remaining max_num_batched_tokens.""" @@ -2202,7 +2227,7 @@ def is_multi_step(self) -> bool: @config -@dataclass +@dataclass(config=ConfigDict(arbitrary_types_allowed=True)) class DeviceConfig: """Configuration for the device to use for vLLM execution.""" @@ -2260,8 +2285,8 @@ def __post_init__(self): self.device = torch.device(self.device_type) -SpeculativeMethod = Literal["ngram", "eagle", "medusa", "mlp_speculator", - "draft_model", "deepseek_mtp"] +SpeculativeMethod = Literal["ngram", "eagle", "eagle3", "medusa", + "mlp_speculator", "draft_model", "deepseek_mtp"] SpeculativeAcceptanceMethod = Literal["rejection_sampler", "typical_acceptance_sampler"] @@ -2272,8 +2297,7 @@ class SpeculativeConfig: """Configuration for speculative decoding.""" # General speculative decoding control - num_speculative_tokens: int = field(default=None, - init=True) # type: ignore + num_speculative_tokens: SkipValidation[int] = None # type: ignore """The number of speculative tokens, if provided. It will default to the number in the draft model config if present, otherwise, it is required.""" model: Optional[str] = None @@ -2349,26 +2373,23 @@ class SpeculativeConfig: """Specifies the tree structure for speculative token generation. """ # required configuration params passed from engine - target_model_config: ModelConfig = field(default=None, - init=True) # type: ignore + target_model_config: SkipValidation[ModelConfig] = None # type: ignore """The configuration of the target model.""" - target_parallel_config: ParallelConfig = field(default=None, - init=True) # type: ignore + target_parallel_config: SkipValidation[ + ParallelConfig] = None # type: ignore """The parallel configuration for the target model.""" - enable_chunked_prefill: bool = field(default=None, - init=True) # type: ignore + enable_chunked_prefill: SkipValidation[bool] = None # type: ignore """Whether vLLM is configured to use chunked prefill or not. Used for raising an error since it's not yet compatible with speculative decode.""" - disable_log_stats: bool = field(default=None, init=True) # type: ignore + disable_log_stats: SkipValidation[bool] = None # type: ignore """Whether to disable the periodic printing of stage times in speculative decoding.""" # params generated in the post-init stage - draft_model_config: ModelConfig = field(default=None, - init=True) # type: ignore + draft_model_config: SkipValidation[ModelConfig] = None # type: ignore """The configuration of the draft model initialized internal.""" - draft_parallel_config: ParallelConfig = field(default=None, - init=True) # type: ignore + draft_parallel_config: SkipValidation[ + ParallelConfig] = None # type: ignore """The parallel configuration for the draft model initialized internal.""" def compute_hash(self) -> str: @@ -2766,7 +2787,7 @@ def __repr__(self) -> str: @config -@dataclass +@dataclass(config=ConfigDict(arbitrary_types_allowed=True)) class LoRAConfig: """Configuration for LoRA.""" @@ -2863,7 +2884,7 @@ def verify_lora_support(self): @config -@dataclass +@dataclass(config=ConfigDict(arbitrary_types_allowed=True)) class PromptAdapterConfig: """Configuration for PromptAdapters.""" @@ -3888,17 +3909,11 @@ def __repr__(self) -> str: "pass_config", "traced_files", } - include = dict() - for k, v in asdict(self).items(): - if k in exclude: - continue - f = get_field(CompilationConfig, k) - if (d := f.default) is not MISSING and d == v: - continue - if (df := f.default_factory) is not MISSING and df() == v: - continue - include[k] = v - return json.dumps(include) + # The cast to string is necessary because Pydantic is mocked in docs + # builds and sphinx-argparse doesn't know the return type of decode() + return str( + TypeAdapter(CompilationConfig).dump_json( + self, exclude=exclude, exclude_unset=True).decode()) __str__ = __repr__ @@ -3907,7 +3922,7 @@ def from_cli(cls, cli_value: str) -> "CompilationConfig": """Parse the CLI value for the compilation config.""" if cli_value in ["0", "1", "2", "3"]: return cls(level=int(cli_value)) - return cls(**json.loads(cli_value)) + return TypeAdapter(CompilationConfig).validate_json(cli_value) def __post_init__(self) -> None: count_none = self.custom_ops.count("none") @@ -4033,7 +4048,7 @@ def set_splitting_ops_for_v1(self): @config -@dataclass +@dataclass(config=ConfigDict(arbitrary_types_allowed=True)) class VllmConfig: """Dataclass which contains all vllm-related configuration. This simplifies passing around the distinct configurations in the codebase. @@ -4290,9 +4305,6 @@ def __post_init__(self): "To workaround this limitation, vLLM will set 'ieee' input " "precision for chunked prefill triton kernels.") - if self.compilation_config is None: - self.compilation_config = CompilationConfig() - # async tp is built on top of sequence parallelism # and requires it to be enabled. if self.compilation_config.pass_config.enable_async_tp: diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 3b90880167dc..5515374bd045 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -14,6 +14,7 @@ import regex as re import torch +from pydantic import SkipValidation, TypeAdapter, ValidationError from typing_extensions import TypeIs, deprecated import vllm.envs as envs @@ -38,7 +39,7 @@ from vllm.transformers_utils.utils import check_gguf_file from vllm.usage.usage_lib import UsageContext from vllm.utils import (STR_DUAL_CHUNK_FLASH_ATTN_VAL, FlexibleArgumentParser, - GiB_bytes, is_in_doc_build, is_in_ray_actor) + GiB_bytes, is_in_ray_actor) # yapf: enable @@ -156,7 +157,8 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]: # Get the set of possible types for the field type_hints: set[TypeHint] = set() if get_origin(field.type) in {Union, Annotated}: - type_hints.update(get_args(field.type)) + predicate = lambda arg: not isinstance(arg, SkipValidation) + type_hints.update(filter(predicate, get_args(field.type))) else: type_hints.add(field.type) @@ -168,10 +170,7 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]: if field.default is not MISSING: default = field.default elif field.default_factory is not MISSING: - if is_dataclass(field.default_factory) and is_in_doc_build(): - default = {} - else: - default = field.default_factory() + default = field.default_factory() # Get the help text for the field name = field.name @@ -189,12 +188,16 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]: - `--json-arg '{"key1": "value1", "key2": {"key3": "value2"}}'`\n - `--json-arg.key1 value1 --json-arg.key2.key3 value2`\n\n""" if dataclass_cls is not None: - dataclass_init = lambda x, f=dataclass_cls: f(**json.loads(x)) - # Special case for configs with a from_cli method - if hasattr(dataclass_cls, "from_cli"): - from_cli = dataclass_cls.from_cli - dataclass_init = lambda x, f=from_cli: f(x) - kwargs[name]["type"] = dataclass_init + + def parse_dataclass(val: str, cls=dataclass_cls) -> Any: + try: + if hasattr(cls, "from_cli"): + return cls.from_cli(val) + return TypeAdapter(cls).validate_json(val) + except ValidationError as e: + raise argparse.ArgumentTypeError(repr(e)) from e + + kwargs[name]["type"] = parse_dataclass kwargs[name]["help"] += json_tip elif contains_type(type_hints, bool): # Creates --no- and -- flags @@ -225,12 +228,11 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]: kwargs[name]["type"] = human_readable_int elif contains_type(type_hints, float): kwargs[name]["type"] = float - elif contains_type(type_hints, - dict) and (contains_type(type_hints, str) or any( - is_not_builtin(th) for th in type_hints)): + elif (contains_type(type_hints, dict) + and (contains_type(type_hints, str) + or any(is_not_builtin(th) for th in type_hints))): kwargs[name]["type"] = union_dict_and_str elif contains_type(type_hints, dict): - # Dict arguments will always be optional kwargs[name]["type"] = parse_type(json.loads) kwargs[name]["help"] += json_tip elif (contains_type(type_hints, str) @@ -317,8 +319,7 @@ class EngineArgs: rope_scaling: dict[str, Any] = get_field(ModelConfig, "rope_scaling") rope_theta: Optional[float] = ModelConfig.rope_theta hf_token: Optional[Union[bool, str]] = ModelConfig.hf_token - hf_overrides: Optional[HfOverrides] = \ - get_field(ModelConfig, "hf_overrides") + hf_overrides: HfOverrides = get_field(ModelConfig, "hf_overrides") tokenizer_revision: Optional[str] = ModelConfig.tokenizer_revision quantization: Optional[QuantizationMethods] = ModelConfig.quantization enforce_eager: bool = ModelConfig.enforce_eager @@ -398,7 +399,8 @@ class EngineArgs: get_field(ModelConfig, "override_neuron_config") override_pooler_config: Optional[Union[dict, PoolerConfig]] = \ ModelConfig.override_pooler_config - compilation_config: Optional[CompilationConfig] = None + compilation_config: CompilationConfig = \ + get_field(VllmConfig, "compilation_config") worker_cls: str = ParallelConfig.worker_cls worker_extension_cls: str = ParallelConfig.worker_extension_cls @@ -413,7 +415,8 @@ class EngineArgs: calculate_kv_scales: bool = CacheConfig.calculate_kv_scales - additional_config: Optional[Dict[str, Any]] = None + additional_config: dict[str, Any] = \ + get_field(VllmConfig, "additional_config") enable_reasoning: Optional[bool] = None # DEPRECATED reasoning_parser: str = DecodingConfig.reasoning_backend diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index f818e1737975..a3e26c090caf 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -204,6 +204,9 @@ def __init__( if isinstance(worker_cls, type): kwargs["worker_cls"] = cloudpickle.dumps(worker_cls) + if hf_overrides is None: + hf_overrides = {} + if compilation_config is not None: if isinstance(compilation_config, int): compilation_config_instance = CompilationConfig( @@ -215,7 +218,7 @@ def __init__( else: compilation_config_instance = compilation_config else: - compilation_config_instance = None + compilation_config_instance = CompilationConfig() engine_args = EngineArgs( model=model, diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 393cf381b16b..a7f85e9eef39 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -175,11 +175,15 @@ class ChatCompletionNamedToolChoiceParam(OpenAIBaseModel): type: Literal["function"] = "function" +# extra="forbid" is a workaround to have kwargs as a field, +# see https://github.com/pydantic/pydantic/issues/3125 class LogitsProcessorConstructor(BaseModel): qualname: str args: Optional[list[Any]] = None kwargs: Optional[dict[str, Any]] = None + model_config = ConfigDict(extra="forbid") + LogitsProcessors = list[Union[str, LogitsProcessorConstructor]] @@ -234,7 +238,7 @@ class ChatCompletionRequest(OpenAIBaseModel): presence_penalty: Optional[float] = 0.0 response_format: Optional[AnyResponseFormat] = None seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max) - stop: Optional[Union[str, list[str]]] = Field(default_factory=list) + stop: Optional[Union[str, list[str]]] = [] stream: Optional[bool] = False stream_options: Optional[StreamOptions] = None temperature: Optional[float] = None @@ -258,7 +262,7 @@ class ChatCompletionRequest(OpenAIBaseModel): min_p: Optional[float] = None repetition_penalty: Optional[float] = None length_penalty: float = 1.0 - stop_token_ids: Optional[list[int]] = Field(default_factory=list) + stop_token_ids: Optional[list[int]] = [] include_stop_str_in_output: bool = False ignore_eos: bool = False min_tokens: int = 0 @@ -756,7 +760,7 @@ class CompletionRequest(OpenAIBaseModel): n: int = 1 presence_penalty: Optional[float] = 0.0 seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max) - stop: Optional[Union[str, list[str]]] = Field(default_factory=list) + stop: Optional[Union[str, list[str]]] = [] stream: Optional[bool] = False stream_options: Optional[StreamOptions] = None suffix: Optional[str] = None @@ -770,7 +774,7 @@ class CompletionRequest(OpenAIBaseModel): min_p: Optional[float] = None repetition_penalty: Optional[float] = None length_penalty: float = 1.0 - stop_token_ids: Optional[list[int]] = Field(default_factory=list) + stop_token_ids: Optional[list[int]] = [] include_stop_str_in_output: bool = False ignore_eos: bool = False min_tokens: int = 0 diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 93de9f3a5c05..725e9247c124 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -134,11 +134,9 @@ class RequestProcessingMixin(BaseModel): Mixin for request processing, handling prompt preparation and engine input. """ - request_prompts: Optional[Sequence[RequestPrompt]] = \ - Field(default_factory=list) + request_prompts: Optional[Sequence[RequestPrompt]] = [] engine_prompts: Optional[Union[list[EngineTokensPrompt], - list[EngineEmbedsPrompt]]] = Field( - default_factory=list) + list[EngineEmbedsPrompt]]] = [] model_config = ConfigDict(arbitrary_types_allowed=True) @@ -528,12 +526,14 @@ def _validate_input( if isinstance(request, (EmbeddingChatRequest, EmbeddingCompletionRequest, ScoreRequest, RerankRequest, ClassificationRequest)): - operation = { - ScoreRequest: "score", - ClassificationRequest: "classification" - }.get(type(request), "embedding generation") if token_num > self.max_model_len: + operations: dict[type[AnyRequest], str] = { + ScoreRequest: "score", + ClassificationRequest: "classification" + } + operation = operations.get(type(request), + "embedding generation") raise ValueError( f"This model's maximum context length is " f"{self.max_model_len} tokens. However, you requested " diff --git a/vllm/model_executor/guided_decoding/guided_fields.py b/vllm/model_executor/guided_decoding/guided_fields.py index 085f37a5d516..316860718b77 100644 --- a/vllm/model_executor/guided_decoding/guided_fields.py +++ b/vllm/model_executor/guided_decoding/guided_fields.py @@ -3,12 +3,10 @@ from dataclasses import dataclass from typing import Optional, TypedDict, Union -from pydantic import BaseModel - # These classes are deprecated, see SamplingParams class LLMGuidedOptions(TypedDict, total=False): - guided_json: Union[dict, BaseModel, str] + guided_json: Union[dict, str] guided_regex: str guided_choice: list[str] guided_grammar: str @@ -20,7 +18,7 @@ class LLMGuidedOptions(TypedDict, total=False): @dataclass class GuidedDecodingRequest: """One of the fields will be used to retrieve the logit processor.""" - guided_json: Optional[Union[dict, BaseModel, str]] = None + guided_json: Optional[Union[dict, str]] = None guided_regex: Optional[str] = None guided_choice: Optional[list[str]] = None guided_grammar: Optional[str] = None diff --git a/vllm/utils.py b/vllm/utils.py index 86873ff75817..f165c45b2bbf 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1879,14 +1879,6 @@ def get_cuda_view_from_cpu_tensor(cpu_tensor: torch.Tensor) -> torch.Tensor: return torch.ops._C.get_cuda_view_from_cpu_tensor(cpu_tensor) -def is_in_doc_build() -> bool: - try: - from sphinx.ext.autodoc.mock import _MockModule - return isinstance(zmq, _MockModule) - except ModuleNotFoundError: - return False - - def import_from_path(module_name: str, file_path: Union[str, os.PathLike]): """ Import a Python file according to its file path.