diff --git a/xinference/model/llm/llm_family.json b/xinference/model/llm/llm_family.json index 5b8a9003dd..0b4621bd45 100644 --- a/xinference/model/llm/llm_family.json +++ b/xinference/model/llm/llm_family.json @@ -1116,6 +1116,17 @@ "model_file_name_template": "qwen14b-ggml-{quantization}.bin", "model_revision": "11efca556af372b6f3c730322a4962e9900a2990" }, + { + "model_format": "pytorch", + "model_size_in_billions": "1_8", + "quantizations": [ + "4-bit", + "8-bit", + "none" + ], + "model_id": "Qwen/Qwen-1_8B-Chat", + "model_revision": "c3db8007171847931da7efa4b2ed4309afcce021" + }, { "model_format": "pytorch", "model_size_in_billions": 7, diff --git a/xinference/model/llm/llm_family.py b/xinference/model/llm/llm_family.py index 4ec72f4a74..9fcff01180 100644 --- a/xinference/model/llm/llm_family.py +++ b/xinference/model/llm/llm_family.py @@ -19,7 +19,7 @@ from threading import Lock from typing import Any, Dict, List, Optional, Tuple, Type, Union -from pydantic import BaseModel, Field, Protocol, ValidationError +from pydantic import BaseModel, Field, Protocol, ValidationError, validator from pydantic.error_wrappers import ErrorWrapper from pydantic.parse import load_str_bytes from pydantic.types import StrBytes @@ -45,7 +45,8 @@ class GgmlLLMSpecV1(BaseModel): model_format: Literal["ggmlv3", "ggufv2"] - model_size_in_billions: int + # Must in order that `str` first, then `int` + model_size_in_billions: Union[str, int] quantizations: List[str] model_id: str model_file_name_template: str @@ -53,16 +54,39 @@ class GgmlLLMSpecV1(BaseModel): model_uri: Optional[str] model_revision: Optional[str] + @validator("model_size_in_billions", pre=False) + def validate_model_size_with_radix(cls, v: object) -> object: + if isinstance(v, str): + if ( + "_" in v + ): # for example, "1_8" just returns "1_8", otherwise int("1_8") returns 18 + return v + else: + return int(v) + return v + class PytorchLLMSpecV1(BaseModel): model_format: Literal["pytorch", "gptq"] - model_size_in_billions: int + # Must in order that `str` first, then `int` + model_size_in_billions: Union[str, int] quantizations: List[str] model_id: str model_hub: str = "huggingface" model_uri: Optional[str] model_revision: Optional[str] + @validator("model_size_in_billions", pre=False) + def validate_model_size_with_radix(cls, v: object) -> object: + if isinstance(v, str): + if ( + "_" in v + ): # for example, "1_8" just returns "1_8", otherwise int("1_8") returns 18 + return v + else: + return int(v) + return v + class PromptStyleV1(BaseModel): style_name: str @@ -152,7 +176,7 @@ def download_from_self_hosted_storage() -> bool: def get_legacy_cache_path( model_name: str, model_format: str, - model_size_in_billions: Optional[int] = None, + model_size_in_billions: Optional[Union[str, int]] = None, quantization: Optional[str] = None, ) -> str: full_name = f"{model_name}-{model_format}-{model_size_in_billions}b-{quantization}" diff --git a/xinference/model/llm/llm_family_modelscope.json b/xinference/model/llm/llm_family_modelscope.json index fd0d3f59ac..9aaab67039 100644 --- a/xinference/model/llm/llm_family_modelscope.json +++ b/xinference/model/llm/llm_family_modelscope.json @@ -1366,6 +1366,18 @@ "model_file_name_template": "qwen14b-ggml-{quantization}.bin", "model_revision": "v0.0.2" }, + { + "model_format": "pytorch", + "model_size_in_billions": "1_8", + "quantizations": [ + "4-bit", + "8-bit", + "none" + ], + "model_hub": "modelscope", + "model_id": "qwen/Qwen-1_8B-Chat", + "model_revision": "v1.0.0" + }, { "model_format": "pytorch", "model_size_in_billions": 7, diff --git a/xinference/web/ui/src/scenes/launch_model/modelCard.js b/xinference/web/ui/src/scenes/launch_model/modelCard.js index bd2bfe7d4a..7a86277635 100644 --- a/xinference/web/ui/src/scenes/launch_model/modelCard.js +++ b/xinference/web/ui/src/scenes/launch_model/modelCard.js @@ -91,7 +91,8 @@ const ModelCard = ({ url, modelData, gpuAvailable, is_custom = false }) => { .filter( (spec) => spec.model_format === modelFormat && - spec.model_size_in_billions === parseFloat(modelSize) + spec.model_size_in_billions === + (modelSize.includes('_') ? modelSize : parseFloat(modelSize)) ) .flatMap((spec) => spec.quantizations) ),