Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Custom model uses vLLM #861

Merged
merged 7 commits into from
Jan 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions xinference/api/restful_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,9 @@ def serve(self, logging_conf: Optional[dict] = None):
self._router.add_api_route(
"/v1/models/prompts", self._get_builtin_prompts, methods=["GET"]
)
self._router.add_api_route(
"/v1/models/families", self._get_builtin_families, methods=["GET"]
)
self._router.add_api_route(
"/v1/cluster/devices", self._get_devices_count, methods=["GET"]
)
Expand Down Expand Up @@ -312,6 +315,17 @@ async def _get_builtin_prompts(self) -> JSONResponse:
logger.error(e, exc_info=True)
raise HTTPException(status_code=500, detail=str(e))

async def _get_builtin_families(self) -> JSONResponse:
"""
For internal usage
"""
try:
data = await (await self._get_supervisor_ref()).get_builtin_families()
return JSONResponse(content=data)
except Exception as e:
logger.error(e, exc_info=True)
raise HTTPException(status_code=500, detail=str(e))

async def _get_devices_count(self) -> JSONResponse:
"""
For internal usage
Expand Down
4 changes: 4 additions & 0 deletions xinference/client/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ def test_client_custom_model(setup):
"embed",
"chat"
],
"model_family": "other",
"model_specs": [
{
"model_format": "pytorch",
Expand Down Expand Up @@ -402,6 +403,7 @@ def test_RESTful_client_custom_model(setup):
"embed",
"chat"
],
"model_family": "other",
"model_specs": [
{
"model_format": "pytorch",
Expand Down Expand Up @@ -458,6 +460,7 @@ def test_RESTful_client_custom_model(setup):
"embed",
"chat"
],
"model_family": "other",
"model_specs": [
{
"model_format": "pytorch",
Expand Down Expand Up @@ -486,6 +489,7 @@ def test_RESTful_client_custom_model(setup):
"embed",
"chat"
],
"model_family": "other",
"model_specs": [
{
"model_format": "pytorch",
Expand Down
12 changes: 12 additions & 0 deletions xinference/core/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,18 @@ async def get_builtin_prompts() -> Dict[str, Any]:
data[k] = v.dict()
return data

@staticmethod
async def get_builtin_families() -> Dict[str, List[str]]:
from ..model.llm.llm_family import (
BUILTIN_LLM_MODEL_CHAT_FAMILIES,
BUILTIN_LLM_MODEL_GENERATE_FAMILIES,
)

return {
"chat": list(BUILTIN_LLM_MODEL_CHAT_FAMILIES),
"generate": list(BUILTIN_LLM_MODEL_GENERATE_FAMILIES),
}

async def get_devices_count(self) -> int:
from ..utils import cuda_count

Expand Down
1 change: 1 addition & 0 deletions xinference/core/tests/test_restful_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ async def test_restful_api(setup):
"embed",
"chat"
],
"model_family": "other",
"model_specs": [
{
"model_format": "pytorch",
Expand Down
1 change: 1 addition & 0 deletions xinference/deploy/test/test_cmdline.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ def test_cmdline_of_custom_model(setup):
"embed",
"chat"
],
"model_family": "other",
"model_specs": [
{
"model_format": "pytorch",
Expand Down
15 changes: 14 additions & 1 deletion xinference/model/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,12 @@
from .core import LLM
from .llm_family import (
BUILTIN_LLM_FAMILIES,
BUILTIN_LLM_MODEL_CHAT_FAMILIES,
BUILTIN_LLM_MODEL_GENERATE_FAMILIES,
BUILTIN_LLM_PROMPT_STYLE,
BUILTIN_MODELSCOPE_LLM_FAMILIES,
LLM_CLASSES,
CustomLLMFamilyV1,
GgmlLLMSpecV1,
LLMFamilyV1,
LLMSpecV1,
Expand Down Expand Up @@ -94,6 +97,11 @@ def _install():
# note that the key is the model name,
# since there are multiple representations of the same prompt style name in json.
BUILTIN_LLM_PROMPT_STYLE[model_spec.model_name] = model_spec.prompt_style
# register model family
if "chat" in model_spec.model_ability:
BUILTIN_LLM_MODEL_CHAT_FAMILIES.add(model_spec.model_name)
else:
BUILTIN_LLM_MODEL_GENERATE_FAMILIES.add(model_spec.model_name)

modelscope_json_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "llm_family_modelscope.json"
Expand All @@ -110,6 +118,11 @@ def _install():
and model_spec.model_name not in BUILTIN_LLM_PROMPT_STYLE
):
BUILTIN_LLM_PROMPT_STYLE[model_spec.model_name] = model_spec.prompt_style
# register model family
if "chat" in model_spec.model_ability:
BUILTIN_LLM_MODEL_CHAT_FAMILIES.add(model_spec.model_name)
else:
BUILTIN_LLM_MODEL_GENERATE_FAMILIES.add(model_spec.model_name)

from ...constants import XINFERENCE_MODEL_DIR

Expand All @@ -119,5 +132,5 @@ def _install():
with codecs.open(
os.path.join(user_defined_llm_dir, f), encoding="utf-8"
) as fd:
user_defined_llm_family = LLMFamilyV1.parse_obj(json.load(fd))
user_defined_llm_family = CustomLLMFamilyV1.parse_obj(json.load(fd))
register_llm(user_defined_llm_family, persist=False)
40 changes: 38 additions & 2 deletions xinference/model/llm/llm_family.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import platform
import shutil
from threading import Lock
from typing import Any, Dict, List, Optional, Tuple, Type, Union
from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union

from pydantic import BaseModel, Field, Protocol, ValidationError, validator
from pydantic.error_wrappers import ErrorWrapper
Expand All @@ -41,6 +41,8 @@

DEFAULT_CONTEXT_LENGTH = 2048
BUILTIN_LLM_PROMPT_STYLE: Dict[str, "PromptStyleV1"] = {}
BUILTIN_LLM_MODEL_CHAT_FAMILIES: Set[str] = set()
BUILTIN_LLM_MODEL_GENERATE_FAMILIES: Set[str] = set()


class GgmlLLMSpecV1(BaseModel):
Expand Down Expand Up @@ -105,6 +107,8 @@ class LLMFamilyV1(BaseModel):
model_lang: List[str]
model_ability: List[Literal["embed", "generate", "chat"]]
model_description: Optional[str]
# reason for not required str here: legacy registration
model_family: Optional[str]
model_specs: List["LLMSpecV1"]
prompt_style: Optional["PromptStyleV1"]

Expand Down Expand Up @@ -134,7 +138,39 @@ def parse_raw(
)
except (ValueError, TypeError, UnicodeDecodeError) as e:
raise ValidationError([ErrorWrapper(e, loc=ROOT_KEY)], cls)
llm_spec = cls.parse_obj(obj)
llm_spec: CustomLLMFamilyV1 = cls.parse_obj(obj)

# check model_family
if llm_spec.model_family is None:
raise ValueError(
f"You must specify `model_family` when registering custom LLM models."
)
assert isinstance(llm_spec.model_family, str)
if (
llm_spec.model_family != "other"
and "chat" in llm_spec.model_ability
and llm_spec.model_family not in BUILTIN_LLM_MODEL_CHAT_FAMILIES
):
raise ValueError(
f"`model_family` for chat model must be `other` or one of the following values: \n"
f"{', '.join(list(BUILTIN_LLM_MODEL_CHAT_FAMILIES))}"
)
if (
llm_spec.model_family != "other"
and "chat" not in llm_spec.model_ability
and llm_spec.model_family not in BUILTIN_LLM_MODEL_GENERATE_FAMILIES
):
raise ValueError(
f"`model_family` for generate model must be `other` or one of the following values: \n"
f"{', '.join(list(BUILTIN_LLM_MODEL_GENERATE_FAMILIES))}"
)
# set prompt style when it is the builtin model family
if (
llm_spec.prompt_style is None
and llm_spec.model_family != "other"
and "chat" in llm_spec.model_ability
):
llm_spec.prompt_style = llm_spec.model_family

# handle prompt style when user choose existing style
if llm_spec.prompt_style is not None and isinstance(llm_spec.prompt_style, str):
Expand Down
33 changes: 31 additions & 2 deletions xinference/model/llm/tests/test_llm_family.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def test_serialize_llm_family_v1():
prompt_style=prompt_style,
)

expected = """{"version": 1, "context_length": 2048, "model_name": "TestModel", "model_lang": ["en"], "model_ability": ["embed", "generate"], "model_description": null, "model_specs": [{"model_format": "ggmlv3", "model_hub": "huggingface", "model_size_in_billions": 2, "quantizations": ["q4_0", "q4_1"], "model_id": "example/TestModel", "model_revision": "123", "model_file_name_template": "TestModel.{quantization}.ggmlv3.bin", "model_uri": null}, {"model_format": "pytorch", "model_hub": "huggingface", "model_size_in_billions": 3, "quantizations": ["int8", "int4", "none"], "model_id": "example/TestModel", "model_revision": "456", "model_uri": null}], "prompt_style": {"style_name": "ADD_COLON_SINGLE", "system_prompt": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.", "roles": ["user", "assistant"], "intra_message_sep": "\\n### ", "inter_message_sep": "\\n### ", "stop": null, "stop_token_ids": null}}"""
expected = """{"version": 1, "context_length": 2048, "model_name": "TestModel", "model_lang": ["en"], "model_ability": ["embed", "generate"], "model_description": null, "model_family": null, "model_specs": [{"model_format": "ggmlv3", "model_hub": "huggingface", "model_size_in_billions": 2, "quantizations": ["q4_0", "q4_1"], "model_id": "example/TestModel", "model_revision": "123", "model_file_name_template": "TestModel.{quantization}.ggmlv3.bin", "model_uri": null}, {"model_format": "pytorch", "model_hub": "huggingface", "model_size_in_billions": 3, "quantizations": ["int8", "int4", "none"], "model_id": "example/TestModel", "model_revision": "456", "model_uri": null}], "prompt_style": {"style_name": "ADD_COLON_SINGLE", "system_prompt": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.", "roles": ["user", "assistant"], "intra_message_sep": "\\n### ", "inter_message_sep": "\\n### ", "stop": null, "stop_token_ids": null}}"""
assert json.loads(llm_family.json()) == json.loads(expected)

llm_family_context_length = LLMFamilyV1(
Expand Down Expand Up @@ -974,19 +974,48 @@ def test_parse_prompt_style():
model_lang=["en"],
model_ability=["chat", "generate"],
model_specs=[hf_spec, ms_spec],
model_family="chatglm3",
prompt_style="chatglm3",
)
model_spec = CustomLLMFamilyV1.parse_raw(bytes(llm_family.json(), "utf8"))
assert model_spec.model_name == llm_family.model_name

# error
# error: missing model_family
llm_family = CustomLLMFamilyV1(
version=1,
model_type="LLM",
model_name="test_LLM",
model_lang=["en"],
model_ability=["chat", "generate"],
model_specs=[hf_spec, ms_spec],
prompt_style="chatglm3",
)
with pytest.raises(ValueError):
CustomLLMFamilyV1.parse_raw(bytes(llm_family.json(), "utf8"))

# wrong model_family
llm_family = CustomLLMFamilyV1(
version=1,
model_type="LLM",
model_name="test_LLM",
model_lang=["en"],
model_ability=["chat", "generate"],
model_family="xyzz",
model_specs=[hf_spec, ms_spec],
prompt_style="chatglm3",
)
with pytest.raises(ValueError):
CustomLLMFamilyV1.parse_raw(bytes(llm_family.json(), "utf8"))

# error: wrong prompt style
llm_family = CustomLLMFamilyV1(
version=1,
model_type="LLM",
model_name="test_LLM",
model_lang=["en"],
model_ability=["chat", "generate"],
model_specs=[hf_spec, ms_spec],
model_family="chatglm3",
prompt_style="test_xyz",
)
with pytest.raises(ValueError):
Expand Down
17 changes: 13 additions & 4 deletions xinference/model/llm/vllm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
CompletionUsage,
)
from .. import LLM, LLMFamilyV1, LLMSpecV1
from ..llm_family import CustomLLMFamilyV1
from ..utils import ChatModelMixin

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -197,8 +198,12 @@ def match(
# Currently, only 4-bit weight quantization is supported for GPTQ, but got 8 bits.
if "4" not in quantization:
return False
if llm_family.model_name not in VLLM_SUPPORTED_MODELS:
return False
if isinstance(llm_family, CustomLLMFamilyV1):
if llm_family.model_family not in VLLM_SUPPORTED_MODELS:
return False
else:
if llm_family.model_name not in VLLM_SUPPORTED_MODELS:
return False
if "generate" not in llm_family.model_ability:
return False
return VLLM_INSTALLED
Expand Down Expand Up @@ -329,8 +334,12 @@ def match(
# Currently, only 4-bit weight quantization is supported for GPTQ, but got 8 bits.
if "4" not in quantization:
return False
if llm_family.model_name not in VLLM_SUPPORTED_CHAT_MODELS:
return False
if isinstance(llm_family, CustomLLMFamilyV1):
if llm_family.model_family not in VLLM_SUPPORTED_CHAT_MODELS:
return False
else:
if llm_family.model_name not in VLLM_SUPPORTED_CHAT_MODELS:
return False
if "chat" not in llm_family.model_ability:
return False
return VLLM_INSTALLED
Expand Down
Loading
Loading