diff --git a/.github/workflows/python.yaml b/.github/workflows/python.yaml index f0648629d0..cea6a8d6a5 100644 --- a/.github/workflows/python.yaml +++ b/.github/workflows/python.yaml @@ -93,6 +93,7 @@ jobs: pip install transformers_stream_generator pip install bitsandbytes pip install ctransformers + pip install sentence-transformers pip install -e ".[dev]" working-directory: . diff --git a/examples/LangChain_QA.ipynb b/examples/LangChain_QA.ipynb index 82cb06f118..84cd991781 100644 --- a/examples/LangChain_QA.ipynb +++ b/examples/LangChain_QA.ipynb @@ -404,4 +404,4 @@ }, "nbformat": 4, "nbformat_minor": 2 -} \ No newline at end of file +} diff --git a/setup.cfg b/setup.cfg index 4d9355a849..20f79bf937 100644 --- a/setup.cfg +++ b/setup.cfg @@ -73,6 +73,7 @@ all = protobuf einops tiktoken + sentence-transformers ggml = llama-cpp-python==0.1.77 ctransformers @@ -86,6 +87,8 @@ pytorch = protobuf einops tiktoken +embedding = + sentence-transformers doc = ipython>=6.5.0 sphinx>=3.0.0,<5.0.0 diff --git a/xinference/client.py b/xinference/client.py index 049f3c9cef..aece990ebc 100644 --- a/xinference/client.py +++ b/xinference/client.py @@ -49,7 +49,29 @@ def __init__(self, model_ref: xo.ActorRefType["ModelActor"], isolation: Isolatio self._isolation = isolation -class GenerateModelHandle(ModelHandle): +class EmbeddingModelHandle(ModelHandle): + def create_embedding(self, input: Union[str, List[str]]) -> "Embedding": + """ + Creates an embedding vector representing the input text. + + Parameters + ---------- + input: Union[str, List[str]] + Input text to embed, encoded as a string or array of tokens. + To embed multiple inputs in a single request, pass an array of strings or array of token arrays. + + Returns + ------- + Embedding + The resulted Embedding vector that can be easily consumed by machine learning models and algorithms. + + """ + + coro = self._model_ref.create_embedding(input) + return self._isolation.call(coro) + + +class GenerateModelHandle(EmbeddingModelHandle): def generate( self, prompt: str, @@ -81,26 +103,6 @@ def generate( coro = self._model_ref.generate(prompt, generate_config) return self._isolation.call(coro) - def create_embedding(self, input: Union[str, List[str]]) -> "Embedding": - """ - Creates an embedding vector representing the input text. - - Parameters - ---------- - input: Union[str, List[str]] - Input text to embed, encoded as a string or array of tokens. - To embed multiple inputs in a single request, pass an array of strings or array of token arrays. - - Returns - ------- - Embedding - The resulted Embedding vector that can be easily consumed by machine learning models and algorithms. - - """ - - coro = self._model_ref.create_embedding(input) - return self._isolation.call(coro) - class ChatModelHandle(GenerateModelHandle): def chat( @@ -147,7 +149,7 @@ def chat( return self._isolation.call(coro) -class ChatglmCppChatModelHandle(ModelHandle): +class ChatglmCppChatModelHandle(EmbeddingModelHandle): def chat( self, prompt: str, @@ -241,7 +243,41 @@ def __init__(self, model_uid: str, base_url: str): self._base_url = base_url -class RESTfulGenerateModelHandle(RESTfulModelHandle): +class RESTfulEmbeddingModelHandle(RESTfulModelHandle): + def create_embedding(self, input: Union[str, List[str]]) -> "Embedding": + """ + Create an Embedding from user input via RESTful APIs. + + Parameters + ---------- + input: Union[str, List[str]] + Input text to embed, encoded as a string or array of tokens. + To embed multiple inputs in a single request, pass an array of strings or array of token arrays. + + Returns + ------- + Embedding + The resulted Embedding vector that can be easily consumed by machine learning models and algorithms. + + Raises + ------ + RuntimeError + Report the failure of embeddings and provide the error message. + + """ + url = f"{self._base_url}/v1/embeddings" + request_body = {"model": self._model_uid, "input": input} + response = requests.post(url, json=request_body) + if response.status_code != 200: + raise RuntimeError( + f"Failed to create the embeddings, detail: {response.json()['detail']}" + ) + + response_data = response.json() + return response_data + + +class RESTfulGenerateModelHandle(RESTfulEmbeddingModelHandle): def generate( self, prompt: str, @@ -296,38 +332,6 @@ def generate( response_data = response.json() return response_data - def create_embedding(self, input: Union[str, List[str]]) -> "Embedding": - """ - Create an Embedding from user input via RESTful APIs. - - Parameters - ---------- - input: Union[str, List[str]] - Input text to embed, encoded as a string or array of tokens. - To embed multiple inputs in a single request, pass an array of strings or array of token arrays. - - Returns - ------- - Embedding - The resulted Embedding vector that can be easily consumed by machine learning models and algorithms. - - Raises - ------ - RuntimeError - Report the failure of embeddings and provide the error message. - - """ - url = f"{self._base_url}/v1/embeddings" - request_body = {"model": self._model_uid, "input": input} - response = requests.post(url, json=request_body) - if response.status_code != 200: - raise RuntimeError( - f"Failed to create the embeddings, detail: {response.json()['detail']}" - ) - - response_data = response.json() - return response_data - class RESTfulChatModelHandle(RESTfulGenerateModelHandle): def chat( @@ -407,7 +411,7 @@ def chat( return response_data -class RESTfulChatglmCppChatModelHandle(RESTfulModelHandle): +class RESTfulChatglmCppChatModelHandle(RESTfulEmbeddingModelHandle): def chat( self, prompt: str, @@ -556,6 +560,7 @@ def get_model_registration( def launch_model( self, model_name: str, + model_type: str = "LLM", model_size_in_billions: Optional[int] = None, model_format: Optional[str] = None, quantization: Optional[str] = None, @@ -568,6 +573,8 @@ def launch_model( ---------- model_name: str The name of model. + model_type: str + Type of model. model_size_in_billions: Optional[int] The size (in billions) of the model. model_format: Optional[str] @@ -589,6 +596,7 @@ def launch_model( coro = self._supervisor_ref.launch_builtin_model( model_uid=model_uid, model_name=model_name, + model_type=model_type, model_size_in_billions=model_size_in_billions, model_format=model_format, quantization=quantization, @@ -648,15 +656,19 @@ def get_model(self, model_uid: str) -> "ModelHandle": self._supervisor_ref.describe_model(model_uid) ) model_ref = self._isolation.call(self._supervisor_ref.get_model(model_uid)) - - if desc["model_format"] == "ggmlv3" and "chatglm" in desc["model_name"]: - return ChatglmCppChatModelHandle(model_ref, self._isolation) - elif "chat" in desc["model_ability"]: - return ChatModelHandle(model_ref, self._isolation) - elif "generate" in desc["model_ability"]: - return GenerateModelHandle(model_ref, self._isolation) + if desc["model_type"] == "LLM": + if desc["model_format"] == "ggmlv3" and "chatglm" in desc["model_name"]: + return ChatglmCppChatModelHandle(model_ref, self._isolation) + elif "chat" in desc["model_ability"]: + return ChatModelHandle(model_ref, self._isolation) + elif "generate" in desc["model_ability"]: + return GenerateModelHandle(model_ref, self._isolation) + else: + raise ValueError(f"Unrecognized model ability: {desc['model_ability']}") + elif desc["model_type"] == "embedding": + return EmbeddingModelHandle(model_ref, self._isolation) else: - raise ValueError(f"Unrecognized model ability: {desc['model_ability']}") + raise ValueError(f"Unknown model type:{desc['model_type']}") class RESTfulClient: @@ -693,6 +705,7 @@ def list_models(self) -> Dict[str, Dict[str, Any]]: def launch_model( self, model_name: str, + model_type: str = "LLM", model_size_in_billions: Optional[int] = None, model_format: Optional[str] = None, quantization: Optional[str] = None, @@ -705,6 +718,8 @@ def launch_model( ---------- model_name: str The name of model. + model_type: str + type of model. model_size_in_billions: Optional[int] The size (in billions) of the model. model_format: Optional[str] @@ -728,6 +743,7 @@ def launch_model( payload = { "model_uid": model_uid, "model_name": model_name, + "model_type": model_type, "model_size_in_billions": model_size_in_billions, "model_format": model_format, "quantization": quantization, diff --git a/xinference/core/api.py b/xinference/core/api.py deleted file mode 100644 index 69c52df507..0000000000 --- a/xinference/core/api.py +++ /dev/null @@ -1,139 +0,0 @@ -# Copyright 2022-2023 XProbe Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import asyncio -from typing import Any, Dict, Optional - -import xoscar as xo - -from ..isolation import Isolation -from . import ModelActor -from .supervisor import SupervisorActor - - -class AsyncSupervisorAPI: - def __init__(self, supervisor_address: str): - self._supervisor_address = supervisor_address - self._supervisor_ref = None - - async def _get_supervisor_ref(self) -> xo.ActorRefType["SupervisorActor"]: - if self._supervisor_ref is None: - self._supervisor_ref = await xo.actor_ref( - address=self._supervisor_address, uid=SupervisorActor.uid() - ) - return self._supervisor_ref - - async def launch_model( - self, - model_uid: str, - model_name: str, - model_size_in_billions: Optional[int] = None, - model_format: Optional[str] = None, - quantization: Optional[str] = None, - **kwargs, - ) -> str: - supervisor_ref = await self._get_supervisor_ref() - await supervisor_ref.launch_builtin_model( - model_uid=model_uid, - model_name=model_name, - model_size_in_billions=model_size_in_billions, - model_format=model_format, - quantization=quantization, - **kwargs, - ) - return model_uid - - async def terminate_model(self, model_uid: str): - supervisor_ref = await self._get_supervisor_ref() - await supervisor_ref.terminate_model(model_uid) - - async def list_models(self) -> Dict[str, Dict[str, Any]]: - supervisor_ref = await self._get_supervisor_ref() - return await supervisor_ref.list_models() - - async def get_model(self, model_uid: str) -> xo.ActorRefType["ModelActor"]: - supervisor_ref = await self._get_supervisor_ref() - return await supervisor_ref.get_model(model_uid) - - async def is_local_deployment(self) -> bool: - # TODO: temporary. - supervisor_ref = await self._get_supervisor_ref() - return await supervisor_ref.is_local_deployment() - - -class SyncSupervisorAPI: - def __init__(self, supervisor_address: str): - self._supervisor_address = supervisor_address - self._supervisor_ref = None - self._isolation = Isolation(asyncio.new_event_loop(), threaded=True) - self._isolation.start() - - async def _get_supervisor_ref(self) -> xo.ActorRefType["SupervisorActor"]: - if self._supervisor_ref is None: - self._supervisor_ref = await xo.actor_ref( - address=self._supervisor_address, uid=SupervisorActor.uid() - ) - return self._supervisor_ref - - def launch_model( - self, - model_uid: str, - model_name: str, - model_size_in_billions: Optional[int] = None, - model_format: Optional[str] = None, - quantization: Optional[str] = None, - **kwargs, - ) -> str: - async def _launch_model(): - supervisor_ref = await self._get_supervisor_ref() - await supervisor_ref.launch_builtin_model( - model_uid=model_uid, - model_name=model_name, - model_size_in_billions=model_size_in_billions, - model_format=model_format, - quantization=quantization, - **kwargs, - ) - return model_uid - - return self._isolation.call(_launch_model()) - - def terminate_model(self, model_uid: str): - async def _terminate_model(): - supervisor_ref = await self._get_supervisor_ref() - await supervisor_ref.terminate_model(model_uid) - - return self._isolation.call(_terminate_model()) - - def list_models(self) -> Dict[str, Dict[str, Any]]: - async def _list_models(): - supervisor_ref = await self._get_supervisor_ref() - return await supervisor_ref.list_models() - - return self._isolation.call(_list_models()) - - def get_model(self, model_uid: str) -> xo.ActorRefType["ModelActor"]: - async def _get_model(): - supervisor_ref = await self._get_supervisor_ref() - return await supervisor_ref.get_model(model_uid) - - return self._isolation.call(_get_model()) - - def is_local_deployment(self) -> bool: - # TODO: temporary. - async def _is_local_deployment(): - supervisor_ref = await self._get_supervisor_ref() - return await supervisor_ref.is_local_deployment() - - return self._isolation.call(_is_local_deployment()) diff --git a/xinference/core/model.py b/xinference/core/model.py index 64240f12cb..b01a0564a5 100644 --- a/xinference/core/model.py +++ b/xinference/core/model.py @@ -60,7 +60,13 @@ def gen_uid(cls, model: "LLM"): return f"{model.__class__}-model-actor" async def __pre_destroy__(self): - if self._model.model_spec.model_format == "pytorch": + from ..model.embedding.core import EmbeddingModel + from ..model.llm.pytorch.core import PytorchModel as LLMPytorchModel + + if ( + isinstance(self._model, LLMPytorchModel) + and self._model.model_spec.model_format == "pytorch" + ) or isinstance(self._model, EmbeddingModel): try: import gc diff --git a/xinference/core/restful_api.py b/xinference/core/restful_api.py index aa81456e45..007d7700f1 100644 --- a/xinference/core/restful_api.py +++ b/xinference/core/restful_api.py @@ -379,6 +379,7 @@ async def launch_model(self, request: Request) -> JSONResponse: model_size_in_billions = payload.get("model_size_in_billions") model_format = payload.get("model_format") quantization = payload.get("quantization") + model_type = payload.get("model_type") exclude_keys = { "model_uid", @@ -386,6 +387,7 @@ async def launch_model(self, request: Request) -> JSONResponse: "model_size_in_billions", "model_format", "quantization", + "model_type", } kwargs = { @@ -405,6 +407,7 @@ async def launch_model(self, request: Request) -> JSONResponse: model_size_in_billions=model_size_in_billions, model_format=model_format, quantization=quantization, + model_type=model_type, **kwargs, ) @@ -553,10 +556,8 @@ async def create_embedding(self, request: CreateEmbeddingRequest): logger.error(e, exc_info=True) raise HTTPException(status_code=500, detail=str(e)) - input = request.input - try: - embedding = await model.create_embedding(input) + embedding = await model.create_embedding(request.input) return embedding except RuntimeError as re: logger.error(re, exc_info=True) diff --git a/xinference/core/supervisor.py b/xinference/core/supervisor.py index 105280cd0a..d9b57d4257 100644 --- a/xinference/core/supervisor.py +++ b/xinference/core/supervisor.py @@ -144,6 +144,7 @@ async def launch_builtin_model( model_size_in_billions: Optional[int], model_format: Optional[str], quantization: Optional[str], + model_type: Optional[str], **kwargs, ) -> xo.ActorRefType["ModelActor"]: logger.debug( @@ -162,12 +163,15 @@ async def launch_builtin_model( raise ValueError(f"Model is already in the model list, uid: {model_uid}") worker_ref = await self._choose_worker() + # LLM as default for compatibility + model_type = model_type or "LLM" model_ref = yield worker_ref.launch_builtin_model( model_uid=model_uid, model_name=model_name, model_size_in_billions=model_size_in_billions, model_format=model_format, quantization=quantization, + model_type=model_type, **kwargs, ) # TODO: not protected. diff --git a/xinference/core/tests/test_api.py b/xinference/core/tests/test_api.py deleted file mode 100644 index 79ff1f7664..0000000000 --- a/xinference/core/tests/test_api.py +++ /dev/null @@ -1,65 +0,0 @@ -# Copyright 2022-2023 XProbe Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import TYPE_CHECKING - -import pytest -import xoscar as xo - -from xinference.core.api import AsyncSupervisorAPI, SyncSupervisorAPI - -if TYPE_CHECKING: - from xinference.core import ModelActor - - -@pytest.mark.asyncio -async def test_async_client(setup): - _, supervisor_address = setup - async_client = AsyncSupervisorAPI(supervisor_address) - assert len(await async_client.list_models()) == 0 - - model_uid = await async_client.launch_model( - model_uid="test_async_client", model_name="orca", quantization="q4_0" - ) - assert len(await async_client.list_models()) == 1 - - model_ref: xo.ActorRefType["ModelActor"] = await async_client.get_model( - model_uid=model_uid - ) - - completion = await model_ref.chat("write a poem.") - assert "content" in completion["choices"][0]["message"] - - await async_client.terminate_model(model_uid=model_uid) - assert len(await async_client.list_models()) == 0 - - -@pytest.mark.asyncio -async def test_sync_client(setup): - _, supervisor_address = setup - client = SyncSupervisorAPI(supervisor_address) - assert len(client.list_models()) == 0 - - model_uid = client.launch_model( - model_uid="test_sync_client", model_name="orca", quantization="q4_0" - ) - assert len(client.list_models()) == 1 - - model_ref: xo.ActorRefType["ModelActor"] = client.get_model(model_uid=model_uid) - - completion = await model_ref.chat("write a poem.") - assert "content" in completion["choices"][0]["message"] - - client.terminate_model(model_uid=model_uid) - assert len(client.list_models()) == 0 diff --git a/xinference/core/tests/test_restful_api.py b/xinference/core/tests/test_restful_api.py index 1a781b13f6..b921441d39 100644 --- a/xinference/core/tests/test_restful_api.py +++ b/xinference/core/tests/test_restful_api.py @@ -12,12 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytest import requests +from ...model.embedding import BUILTIN_EMBEDDING_MODELS -@pytest.mark.asyncio -async def test_restful_api(setup): + +def test_restful_api(setup): endpoint, _ = setup url = f"{endpoint}/v1/models" @@ -264,3 +264,70 @@ async def test_restful_api(setup): if model_reg["model_name"] == "custom_model": custom_model_reg = model_reg assert custom_model_reg is None + + +def test_restful_api_for_embedding(setup): + model_name = "gte-base" + model_spec = BUILTIN_EMBEDDING_MODELS[model_name] + + endpoint, _ = setup + url = f"{endpoint}/v1/models" + + # list + response = requests.get(url) + response_data = response.json() + assert len(response_data) == 0 + + # launch + payload = { + "model_uid": "test_embedding", + "model_name": model_name, + "model_type": "embedding", + } + + response = requests.post(url, json=payload) + response_data = response.json() + model_uid_res = response_data["model_uid"] + assert model_uid_res == "test_embedding" + + response = requests.get(url) + response_data = response.json() + assert len(response_data) == 1 + + # test embedding + url = f"{endpoint}/v1/embeddings" + payload = { + "model": "test_embedding", + "input": "The food was delicious and the waiter...", + } + response = requests.post(url, json=payload) + embedding_res = response.json() + + assert "embedding" in embedding_res["data"][0] + assert len(embedding_res["data"][0]["embedding"]) == model_spec.dimensions + + # test multiple + payload = { + "model": "test_embedding", + "input": [ + "The food was delicious and the waiter...", + "how to implement quick sort in python?", + "Beijing", + "sorting algorithms", + ], + } + response = requests.post(url, json=payload) + embedding_res = response.json() + + assert len(embedding_res["data"]) == 4 + for data in embedding_res["data"]: + assert len(data["embedding"]) == model_spec.dimensions + + # delete model + url = f"{endpoint}/v1/models/test_embedding" + response = requests.delete(url) + assert response.status_code == 200 + + response = requests.get(f"{endpoint}/v1/models") + response_data = response.json() + assert len(response_data) == 0 diff --git a/xinference/core/worker.py b/xinference/core/worker.py index 001c31aeef..cd759034f9 100644 --- a/xinference/core/worker.py +++ b/xinference/core/worker.py @@ -15,12 +15,12 @@ import asyncio import platform from logging import getLogger -from typing import Any, Dict, List, Optional, Set, Tuple +from typing import Any, Dict, List, Optional, Set import xoscar as xo from ..core import ModelActor -from ..model.llm import LLMFamilyV1, LLMSpecV1 +from ..model.core import ModelDescription, create_model_instance from .resource import gather_node_info from .utils import log_async, log_sync @@ -36,9 +36,7 @@ def __init__(self, supervisor_address: str, subpool_addresses: List[str]): self._supervisor_address = supervisor_address self._supervisor_ref = None self._model_uid_to_model: Dict[str, xo.ActorRefType["ModelActor"]] = {} - self._model_uid_to_model_spec: Dict[ - str, Tuple[LLMFamilyV1, LLMSpecV1, str] - ] = {} + self._model_uid_to_model_spec: Dict[str, ModelDescription] = {} self._subpool_address_to_model_uids: Dict[str, Set[str]] = dict( [(subpool_address, set()) for subpool_address in subpool_addresses] ) @@ -96,23 +94,6 @@ def _check_model_is_valid(self, model_name): if model_name in ["baichuan-base", "baichuan-chat"]: raise ValueError(f"{model_name} model can't run on Darwin system.") - @staticmethod - def _to_llm_description( - llm_family: LLMFamilyV1, llm_spec: LLMSpecV1, quantization: str - ) -> Dict[str, Any]: - return { - "model_type": "LLM", - "model_name": llm_family.model_name, - "model_lang": llm_family.model_lang, - "model_ability": llm_family.model_ability, - "model_description": llm_family.model_description, - "model_format": llm_spec.model_format, - "model_size_in_billions": llm_spec.model_size_in_billions, - "quantization": quantization, - "revision": llm_spec.model_revision, - "context_length": llm_family.context_length, - } - @log_sync(logger=logger) async def register_model(self, model_type: str, model: str, persist: bool): # TODO: centralized model registrations @@ -142,51 +123,32 @@ async def launch_builtin_model( model_size_in_billions: Optional[int], model_format: Optional[str], quantization: Optional[str], + model_type: str = "LLM", **kwargs, ) -> xo.ActorRefType["ModelActor"]: assert model_uid not in self._model_uid_to_model self._check_model_is_valid(model_name) - - from ..model.llm import match_llm, match_llm_cls - assert self._supervisor_ref is not None - match_result = match_llm( + is_local_deployment = await self._supervisor_ref.is_local_deployment() + + model, model_description = create_model_instance( + model_uid, + model_type, model_name, model_format, model_size_in_billions, quantization, - await self._supervisor_ref.is_local_deployment(), + is_local_deployment, + **kwargs, ) - if not match_result: - raise ValueError( - f"Model not found, name: {model_name}, format: {model_format}," - f" size: {model_size_in_billions}, quantization: {quantization}" - ) - llm_family, llm_spec, quantization = match_result - assert quantization is not None - from ..model.llm.llm_family import cache - - save_path = await asyncio.to_thread(cache, llm_family, llm_spec, quantization) - - llm_cls = match_llm_cls(llm_family, llm_spec) - logger.debug(f"Launching {model_uid} with {llm_cls.__name__}") - if not llm_cls: - raise ValueError( - f"Model not supported, name: {model_name}, format: {model_format}," - f" size: {model_size_in_billions}, quantization: {quantization}" - ) - - model = llm_cls( - model_uid, llm_family, llm_spec, quantization, save_path, kwargs - ) subpool_address = self._choose_subpool() model_ref = await xo.create_actor( ModelActor, address=subpool_address, uid=model_uid, model=model ) await model_ref.load() self._model_uid_to_model[model_uid] = model_ref - self._model_uid_to_model_spec[model_uid] = (llm_family, llm_spec, quantization) + self._model_uid_to_model_spec[model_uid] = model_description self._subpool_address_to_model_uids[subpool_address].add(model_uid) return model_ref @@ -208,7 +170,7 @@ async def terminate_model(self, model_uid: str): def list_models(self) -> Dict[str, Dict[str, Any]]: ret = {} for k, v in self._model_uid_to_model_spec.items(): - ret[k] = self._to_llm_description(v[0], v[1], v[2]) + ret[k] = v.to_dict() return ret @log_sync(logger=logger) @@ -223,8 +185,7 @@ def describe_model(self, model_uid: str) -> Dict[str, Any]: if model_uid not in self._model_uid_to_model: raise ValueError(f"Model not found in the model list, uid: {model_uid}") - llm_family, llm_spec, quantization = self._model_uid_to_model_spec[model_uid] - return self._to_llm_description(llm_family, llm_spec, quantization) + return self._model_uid_to_model_spec[model_uid].to_dict() async def report_status(self): status = await asyncio.to_thread(gather_node_info) diff --git a/xinference/deploy/cmdline.py b/xinference/deploy/cmdline.py index 2bb6aa8b43..f0ae048040 100644 --- a/xinference/deploy/cmdline.py +++ b/xinference/deploy/cmdline.py @@ -342,6 +342,13 @@ def list_model_registrations( required=True, help="Provide the name of the model to be launched.", ) +@click.option( + "--model-type", + "-t", + type=str, + default="LLM", + help="Specify type of model, LLM as default.", +) @click.option( "--size-in-billions", "-s", @@ -366,6 +373,7 @@ def list_model_registrations( def model_launch( endpoint: Optional[str], model_name: str, + model_type: str, size_in_billions: int, model_format: str, quantization: str, @@ -375,6 +383,7 @@ def model_launch( client = RESTfulClient(base_url=endpoint) model_uid = client.launch_model( model_name=model_name, + model_type=model_type, model_size_in_billions=size_in_billions, model_format=model_format, quantization=quantization, diff --git a/xinference/model/core.py b/xinference/model/core.py index 39bf55a7a0..db982de232 100644 --- a/xinference/model/core.py +++ b/xinference/model/core.py @@ -12,4 +12,42 @@ # See the License for the specific language governing permissions and # limitations under the License. -# TODO: define a data structure for LLMs, Speech recognition models, and etc. +from abc import ABC +from typing import Any, Optional, Tuple + + +class ModelDescription(ABC): + def to_dict(self): + """ + Return a dict to describe some information about model. + :return: + """ + + +def create_model_instance( + model_uid: str, + model_type: str, + model_name: str, + model_format: Optional[str] = None, + model_size_in_billions: Optional[int] = None, + quantization: Optional[str] = None, + is_local_deployment: bool = False, + **kwargs, +) -> Tuple[Any, ModelDescription]: + from .embedding.core import create_embedding_model_instance + from .llm.core import create_llm_model_instance + + if model_type == "LLM": + return create_llm_model_instance( + model_uid, + model_name, + model_format, + model_size_in_billions, + quantization, + is_local_deployment, + **kwargs, + ) + elif model_type == "embedding": + return create_embedding_model_instance(model_uid, model_name, **kwargs) + else: + raise ValueError(f"Unsupported model type: {model_type}.") diff --git a/xinference/model/embedding/__init__.py b/xinference/model/embedding/__init__.py new file mode 100644 index 0000000000..a636717c09 --- /dev/null +++ b/xinference/model/embedding/__init__.py @@ -0,0 +1,26 @@ +# Copyright 2022-2023 XProbe Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import codecs +import json +import os + +from .core import EmbeddingModelSpec + +_model_spec_json = os.path.join(os.path.dirname(__file__), "model_spec.json") +BUILTIN_EMBEDDING_MODELS = dict( + (spec["model_name"], EmbeddingModelSpec(**spec)) + for spec in json.load(codecs.open(_model_spec_json, "r", encoding="utf-8")) +) +del _model_spec_json diff --git a/xinference/model/embedding/core.py b/xinference/model/embedding/core.py new file mode 100644 index 0000000000..eb939b8cd4 --- /dev/null +++ b/xinference/model/embedding/core.py @@ -0,0 +1,284 @@ +# Copyright 2022-2023 XProbe Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +from typing import List, Optional, Tuple, Union, no_type_check + +import numpy as np +from pydantic import BaseModel + +from ...constants import XINFERENCE_CACHE_DIR +from ...types import Embedding, EmbeddingData, EmbeddingUsage +from ..core import ModelDescription + +MAX_ATTEMPTS = 3 + +logger = logging.getLogger(__name__) + + +class EmbeddingModelSpec(BaseModel): + model_name: str + dimensions: int + max_tokens: int + language: List[str] + model_id: str + model_revision: str + + +def cache(model_spec: EmbeddingModelSpec): + # TODO: cache from uri + import huggingface_hub + + cache_dir = os.path.realpath( + os.path.join(XINFERENCE_CACHE_DIR, model_spec.model_name) + ) + if not os.path.exists(cache_dir): + os.makedirs(cache_dir, exist_ok=True) + for current_attempt in range(1, MAX_ATTEMPTS + 1): + try: + huggingface_hub.snapshot_download( + model_spec.model_id, + revision=model_spec.model_revision, + local_dir=cache_dir, + local_dir_use_symlinks=True, + ) + break + except huggingface_hub.utils.LocalEntryNotFoundError: + remaining_attempts = MAX_ATTEMPTS - current_attempt + logger.warning( + f"Attempt {current_attempt} failed. Remaining attempts: {remaining_attempts}" + ) + else: + raise RuntimeError( + f"Failed to download model '{model_spec.model_name}' after {MAX_ATTEMPTS} attempts" + ) + return cache_dir + + +class EmbeddingModel: + def __init__(self, model_uid: str, model_path: str, device: Optional[str] = None): + self._model_uid = model_uid + self._model_path = model_path + self._device = device + self._model = None + + def load(self): + try: + from sentence_transformers import SentenceTransformer + except ImportError: + error_message = "Failed to import module 'SentenceTransformer'" + installation_guide = [ + "Please make sure 'sentence-transformers' is installed. ", + "You can install it by `pip install sentence-transformers`\n", + ] + + raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}") + self._model = SentenceTransformer(self._model_path, device=self._device) + + def create_embedding(self, sentences: Union[str, List[str]], **kwargs): + from sentence_transformers import SentenceTransformer + + normalize_embeddings = kwargs.pop("normalize_embeddings", True) + + # copied from sentence-transformers, and modify it to return tokens num + @no_type_check + def encode( + model: SentenceTransformer, + sentences: Union[str, List[str]], + batch_size: int = 32, + show_progress_bar: bool = None, + output_value: str = "sentence_embedding", + convert_to_numpy: bool = True, + convert_to_tensor: bool = False, + device: str = None, + normalize_embeddings: bool = False, + ): + """ + Computes sentence embeddings + + :param sentences: the sentences to embed + :param batch_size: the batch size used for the computation + :param show_progress_bar: Output a progress bar when encode sentences + :param output_value: Default sentence_embedding, to get sentence embeddings. Can be set to token_embeddings to get wordpiece token embeddings. Set to None, to get all output values + :param convert_to_numpy: If true, the output is a list of numpy vectors. Else, it is a list of pytorch tensors. + :param convert_to_tensor: If true, you get one large tensor as return. Overwrites any setting from convert_to_numpy + :param device: Which torch.device to use for the computation + :param normalize_embeddings: If set to true, returned vectors will have length 1. In that case, the faster dot-product (util.dot_score) instead of cosine similarity can be used. + + :return: + By default, a list of tensors is returned. If convert_to_tensor, a stacked tensor is returned. If convert_to_numpy, a numpy matrix is returned. + """ + import torch + from sentence_transformers.util import batch_to_device + from tqdm.autonotebook import trange + + model.eval() + if show_progress_bar is None: + show_progress_bar = ( + logger.getEffectiveLevel() == logging.INFO + or logger.getEffectiveLevel() == logging.DEBUG + ) + + if convert_to_tensor: + convert_to_numpy = False + + if output_value != "sentence_embedding": + convert_to_tensor = False + convert_to_numpy = False + + input_was_string = False + if isinstance(sentences, str) or not hasattr( + sentences, "__len__" + ): # Cast an individual sentence to a list with length 1 + sentences = [sentences] + input_was_string = True + + if device is None: + device = model._target_device + + model.to(device) + + all_embeddings = [] + all_token_nums = 0 + length_sorted_idx = np.argsort( + [-model._text_length(sen) for sen in sentences] + ) + sentences_sorted = [sentences[idx] for idx in length_sorted_idx] + + for start_index in trange( + 0, + len(sentences), + batch_size, + desc="Batches", + disable=not show_progress_bar, + ): + sentences_batch = sentences_sorted[ + start_index : start_index + batch_size + ] + features = model.tokenize(sentences_batch) + features = batch_to_device(features, device) + all_token_nums += sum([len(f) for f in features]) + + with torch.no_grad(): + out_features = model.forward(features) + + if output_value == "token_embeddings": + embeddings = [] + for token_emb, attention in zip( + out_features[output_value], out_features["attention_mask"] + ): + last_mask_id = len(attention) - 1 + while ( + last_mask_id > 0 and attention[last_mask_id].item() == 0 + ): + last_mask_id -= 1 + + embeddings.append(token_emb[0 : last_mask_id + 1]) + elif output_value is None: # Return all outputs + embeddings = [] + for sent_idx in range(len(out_features["sentence_embedding"])): + row = { + name: out_features[name][sent_idx] + for name in out_features + } + embeddings.append(row) + else: # Sentence embeddings + embeddings = out_features[output_value] + embeddings = embeddings.detach() + if normalize_embeddings: + embeddings = torch.nn.functional.normalize( + embeddings, p=2, dim=1 + ) + + # fixes for #522 and #487 to avoid oom problems on gpu with large datasets + if convert_to_numpy: + embeddings = embeddings.cpu() + + all_embeddings.extend(embeddings) + + all_embeddings = [ + all_embeddings[idx] for idx in np.argsort(length_sorted_idx) + ] + + if convert_to_tensor: + all_embeddings = torch.stack(all_embeddings) + elif convert_to_numpy: + all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings]) + + if input_was_string: + all_embeddings = all_embeddings[0] + + return all_embeddings, all_token_nums + + all_embeddings, all_token_nums = encode( + self._model, + sentences, + convert_to_numpy=False, + normalize_embeddings=normalize_embeddings, + **kwargs, + ) + if isinstance(sentences, str): + all_embeddings = [all_embeddings] + embedding_list = [] + for index, data in enumerate(all_embeddings): + embedding_list.append( + EmbeddingData(index=index, object="embedding", embedding=data.tolist()) + ) + usage = EmbeddingUsage( + prompt_tokens=all_token_nums, total_tokens=all_token_nums + ) + return Embedding( + object="list", + model=self._model_uid, + data=embedding_list, + usage=usage, + ) + + +class EmbeddingModelDescription(ModelDescription): + def __init__(self, model_spec: EmbeddingModelSpec): + self._model_spec = model_spec + + def to_dict(self): + return { + "model_type": "embedding", + "model_name": self._model_spec.model_name, + "dimensions": self._model_spec.dimensions, + "max_tokens": self._model_spec.max_tokens, + "language": self._model_spec.language, + "model_revision": self._model_spec.model_revision, + } + + +def match_embedding(model_name: str) -> EmbeddingModelSpec: + from . import BUILTIN_EMBEDDING_MODELS + + if model_name in BUILTIN_EMBEDDING_MODELS: + return BUILTIN_EMBEDDING_MODELS[model_name] + else: + raise ValueError( + f"Embedding model {model_name} not found, available" + f"model list: {BUILTIN_EMBEDDING_MODELS.keys()}" + ) + + +def create_embedding_model_instance( + model_uid: str, model_name: str, **kwargs +) -> Tuple[EmbeddingModel, EmbeddingModelDescription]: + model_spec = match_embedding(model_name) + model_path = cache(model_spec) + model = EmbeddingModel(model_uid, model_path, **kwargs) + model_description = EmbeddingModelDescription(model_spec) + return model, model_description diff --git a/xinference/model/embedding/model_spec.json b/xinference/model/embedding/model_spec.json new file mode 100644 index 0000000000..7d2e5f1e91 --- /dev/null +++ b/xinference/model/embedding/model_spec.json @@ -0,0 +1,82 @@ +[ + { + "model_name": "bge-large-en", + "dimensions": 1024, + "max_tokens": 512, + "language": ["en"], + "model_id": "BAAI/bge-large-en", + "model_revision": "d57a0d82f0d0884de76bbce093f201364d9b720e" + }, + { + "model_name": "bge-base-en", + "dimensions": 768, + "max_tokens": 512, + "language": ["en"], + "model_id": "BAAI/bge-base-en", + "model_revision": "90e113f4f9cd0c83220c873b94ca7bc37f85de97" + }, + { + "model_name": "gte-large", + "dimensions": 1024, + "max_tokens": 512, + "language": ["en"], + "model_id": "thenlper/gte-large", + "model_revision": "2b5163b62ed28492dc70eb19a882b71c81dbc7c8" + }, + { + "model_name": "gte-base", + "dimensions": 768, + "max_tokens": 512, + "language": ["en"], + "model_id": "thenlper/gte-base", + "model_revision": "792749e8178be77f479b26788a0e1adb4ec9c8a9" + }, + { + "model_name": "e5-large-v2", + "dimensions": 1024, + "max_tokens": 512, + "language": ["en"], + "model_id": "intfloat/e5-large-v2", + "model_revision": "b322e09026e4ea05f42beadf4d661fb4e101d311" + }, + { + "model_name": "bge-large-zh", + "dimensions": 1024, + "max_tokens": 512, + "language": ["zh"], + "model_id": "BAAI/bge-large-zh", + "model_revision": "1b543b301eb63dd32914b56d939db2a972df15d5" + }, + { + "model_name": "bge-large-zh-noinstruct", + "dimensions": 1024, + "max_tokens": 512, + "language": ["zh"], + "model_id": "BAAI/bge-large-zh-noinstruct", + "model_revision": "d971248454d6267756fab9caa431c2c2fc5f0f35" + }, + { + "model_name": "bge-base-zh", + "dimensions": 768, + "max_tokens": 512, + "language": ["zh"], + "model_id": "BAAI/bge-base-zh", + "model_revision": "faefe5952238b3d28bc35d1c8fe63eb269d8cee0" + }, + { + "model_name": "multilingual-e5-large", + "dimensions": 1024, + "max_tokens": 514, + "language": ["zh"], + "model_id": "intfloat/multilingual-e5-large", + "model_revision": "c505dce3578a12ec54e47bdc72bef5cd0eacb085" + }, + { + "model_name": "bge-small-zh", + "dimensions": 512, + "max_tokens": 512, + "language": ["zh"], + "model_id": "BAAI/bge-small-zh", + "model_revision": "52185a5f4aa5bb1fe80c0671b7303161880a2d79" + } +] diff --git a/xinference/model/embedding/tests/__init__.py b/xinference/model/embedding/tests/__init__.py new file mode 100644 index 0000000000..37f6558d95 --- /dev/null +++ b/xinference/model/embedding/tests/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2022-2023 XProbe Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/xinference/model/embedding/tests/test_embedding_models.py b/xinference/model/embedding/tests/test_embedding_models.py new file mode 100644 index 0000000000..532439e61c --- /dev/null +++ b/xinference/model/embedding/tests/test_embedding_models.py @@ -0,0 +1,49 @@ +# Copyright 2022-2023 XProbe Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ..core import EmbeddingModel, EmbeddingModelSpec, cache + +TEST_MODEL_SPEC = EmbeddingModelSpec( + model_name="gte-small", + dimensions=384, + max_tokens=512, + language=["en"], + model_id="thenlper/gte-small", + model_revision="d8e2604cadbeeda029847d19759d219e0ce2e6d8", +) + + +def test_model(): + model_path = cache(TEST_MODEL_SPEC) + model = EmbeddingModel("mock", model_path) + # input is a string + input_text = "what is the capital of China?" + model.load() + r = model.create_embedding(input_text) + assert len(r["data"]) == 1 + for d in r["data"]: + assert len(d["embedding"]) == 384 + + # input is a lit + input_texts = [ + "what is the capital of China?", + "how to implement quick sort in python?", + "Beijing", + "sorting algorithms", + ] + model.load() + r = model.create_embedding(input_texts) + assert len(r["data"]) == 4 + for d in r["data"]: + assert len(d["embedding"]) == 384 diff --git a/xinference/model/llm/core.py b/xinference/model/llm/core.py index 5e46843295..dba78f8e13 100644 --- a/xinference/model/llm/core.py +++ b/xinference/model/llm/core.py @@ -16,7 +16,9 @@ import logging import platform from abc import abstractmethod -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional, Tuple + +from ..core import ModelDescription if TYPE_CHECKING: from .llm_family import LLMFamilyV1, LLMSpecV1 @@ -60,3 +62,67 @@ def load(self): @classmethod def match(cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1") -> bool: raise NotImplementedError + + +class LLMModelDescription(ModelDescription): + def __init__( + self, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str + ): + self._llm_family = llm_family + self._llm_spec = llm_spec + self._quantization = quantization + + def to_dict(self): + return { + "model_type": "LLM", + "model_name": self._llm_family.model_name, + "model_lang": self._llm_family.model_lang, + "model_ability": self._llm_family.model_ability, + "model_description": self._llm_family.model_description, + "model_format": self._llm_spec.model_format, + "model_size_in_billions": self._llm_spec.model_size_in_billions, + "quantization": self._quantization, + "revision": self._llm_spec.model_revision, + "context_length": self._llm_family.context_length, + } + + +def create_llm_model_instance( + model_uid: str, + model_name: str, + model_format: Optional[str] = None, + model_size_in_billions: Optional[int] = None, + quantization: Optional[str] = None, + is_local_deployment: bool = False, + **kwargs, +) -> Tuple[LLM, LLMModelDescription]: + from . import match_llm, match_llm_cls + from .llm_family import cache + + match_result = match_llm( + model_name, + model_format, + model_size_in_billions, + quantization, + is_local_deployment, + ) + if not match_result: + raise ValueError( + f"Model not found, name: {model_name}, format: {model_format}," + f" size: {model_size_in_billions}, quantization: {quantization}" + ) + llm_family, llm_spec, quantization = match_result + + assert quantization is not None + save_path = cache(llm_family, llm_spec, quantization) + + llm_cls = match_llm_cls(llm_family, llm_spec) + if not llm_cls: + raise ValueError( + f"Model not supported, name: {model_name}, format: {model_format}," + f" size: {model_size_in_billions}, quantization: {quantization}" + ) + logger.debug(f"Launching {model_uid} with {llm_cls.__name__}") + + model = llm_cls(model_uid, llm_family, llm_spec, quantization, save_path, kwargs) + return model, LLMModelDescription(llm_family, llm_spec, quantization) diff --git a/xinference/tests/test_client.py b/xinference/tests/test_client.py index a1dc958d4e..3e18c604b5 100644 --- a/xinference/tests/test_client.py +++ b/xinference/tests/test_client.py @@ -14,7 +14,13 @@ import pytest -from ..client import ChatModelHandle, Client, RESTfulChatModelHandle, RESTfulClient +from ..client import ( + ChatModelHandle, + Client, + EmbeddingModelHandle, + RESTfulChatModelHandle, + RESTfulClient, +) def test_client(setup): @@ -51,6 +57,24 @@ def test_client(setup): assert len(client.list_models()) == 0 +def test_client_for_embedding(setup): + endpoint, _ = setup + client = Client(endpoint) + assert len(client.list_models()) == 0 + + model_uid = client.launch_model(model_name="gte-base", model_type="embedding") + assert len(client.list_models()) == 1 + + model = client.get_model(model_uid=model_uid) + assert isinstance(model, EmbeddingModelHandle) + + completion = model.create_embedding("write a poem.") + assert len(completion["data"][0]["embedding"]) == 768 + + client.terminate_model(model_uid=model_uid) + assert len(client.list_models()) == 0 + + def test_client_custom_model(setup): endpoint, _ = setup client = Client(endpoint)