Skip to content

Commit

Permalink
FEAT: incorporate vLLM
Browse files Browse the repository at this point in the history
  • Loading branch information
UranusSeven committed Sep 13, 2023
1 parent 14364aa commit bc3441e
Show file tree
Hide file tree
Showing 9 changed files with 437 additions and 59 deletions.
2 changes: 2 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ pytorch =
protobuf
einops
tiktoken
vllm =
vllm
embedding =
sentence-transformers
doc =
Expand Down
57 changes: 44 additions & 13 deletions xinference/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,17 @@
# limitations under the License.

import inspect
from typing import TYPE_CHECKING, Any, Generic, Iterator, List, Optional, TypeVar, Union
from typing import (
TYPE_CHECKING,
Any,
AsyncGenerator,
Generic,
Iterator,
List,
Optional,
TypeVar,
Union,
)

import xoscar as xo

Expand Down Expand Up @@ -62,9 +72,10 @@ def gen_uid(cls, model: "LLM"):
async def __pre_destroy__(self):
from ..model.embedding.core import EmbeddingModel
from ..model.llm.pytorch.core import PytorchModel as LLMPytorchModel
from ..model.llm.vllm.core import VLLMModel as LLMVLLMModel

if (
isinstance(self._model, LLMPytorchModel)
isinstance(self._model, (LLMPytorchModel, LLMVLLMModel))
and self._model.model_spec.model_format == "pytorch"
) or isinstance(self._model, EmbeddingModel):
try:
Expand All @@ -86,13 +97,13 @@ async def __pre_destroy__(self):
def __init__(self, model: "LLM"):
super().__init__()
self._model = model
self._generator: Optional[Iterator] = None
self._generator: Optional[Union[Iterator, AsyncGenerator]] = None

def load(self):
self._model.load()

async def _wrap_generator(self, ret: Any):
if inspect.isgenerator(ret):
if inspect.isgenerator(ret) or inspect.isasyncgen(ret):
self._generator = ret
return IteratorWrapper(
model_actor_addr=self.address, model_actor_uid=self.uid
Expand All @@ -101,20 +112,32 @@ async def _wrap_generator(self, ret: Any):
return ret

async def generate(self, prompt: str, *args, **kwargs):
if not hasattr(self._model, "generate"):
if not hasattr(self._model, "generate") and not hasattr(
self._model, "async_generate"
):
raise AttributeError(f"Model {self._model.model_spec} is not for generate.")

return self._wrap_generator(
getattr(self._model, "generate")(prompt, *args, **kwargs)
)
if hasattr(self._model, "generate"):
return self._wrap_generator(
getattr(self._model, "generate")(prompt, *args, **kwargs)
)
else:
return self._wrap_generator(
await getattr(self._model, "async_generate")(prompt, *args, **kwargs)
)

async def chat(self, prompt: str, *args, **kwargs):
if not hasattr(self._model, "chat"):
if not hasattr(self._model, "chat") and not hasattr(self._model, "async_chat"):
raise AttributeError(f"Model {self._model.model_spec} is not for chat.")

return self._wrap_generator(
getattr(self._model, "chat")(prompt, *args, **kwargs)
)
if hasattr(self._model, "chat"):
return self._wrap_generator(
getattr(self._model, "chat")(prompt, *args, **kwargs)
)
else:
return self._wrap_generator(
await getattr(self._model, "async_chat")(prompt, *args, **kwargs)
)

async def create_embedding(self, input: Union[str, List[str]], *args, **kwargs):
if not hasattr(self._model, "create_embedding"):
Expand All @@ -127,7 +150,15 @@ async def create_embedding(self, input: Union[str, List[str]], *args, **kwargs):
async def next(self) -> Union["ChatCompletionChunk", "CompletionChunk"]:
try:
assert self._generator is not None
return next(self._generator)
if inspect.isgenerator(self._generator):
ret = next(self._generator)
elif inspect.isasyncgen(self._generator):
ret = await anext(self._generator)
else:
raise TypeError(
f"Unexpected type {type(self._generator)}, expecting generator or async generator"
)
return ret
except StopIteration:
self._generator = None
raise Exception("StopIteration")
2 changes: 2 additions & 0 deletions xinference/model/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def _install():
from .pytorch.falcon import FalconPytorchChatModel, FalconPytorchModel
from .pytorch.llama_2 import LlamaPytorchChatModel, LlamaPytorchModel
from .pytorch.vicuna import VicunaPytorchChatModel
from .vllm.core import VLLMChatModel, VLLMModel

# register llm classes.
LLM_CLASSES.extend(
Expand All @@ -61,6 +62,7 @@ def _install():
CtransformersModel,
]
)
LLM_CLASSES.extend([VLLMModel, VLLMChatModel])
LLM_CLASSES.extend(
[
BaichuanPytorchChatModel,
Expand Down
4 changes: 2 additions & 2 deletions xinference/model/llm/ggml/llamacpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,8 +299,8 @@ def chat(
if stream:
it = self.generate(full_prompt, generate_config)
assert isinstance(it, Iterator)
return self._convert_chat_completion_chunks_to_chat(it)
return self._to_chat_completion_chunks(it)
else:
c = self.generate(full_prompt, generate_config)
assert not isinstance(c, Iterator)
return self._convert_text_completion_to_chat(c)
return self._to_chat_completion(c)
10 changes: 5 additions & 5 deletions xinference/model/llm/llm_family.json
Original file line number Diff line number Diff line change
Expand Up @@ -703,8 +703,8 @@
"8-bit",
"none"
],
"model_id": "meta-llama/Llama-2-7b",
"model_revision": "365ffa8f1a6c455d3e2028ae658236b4b85ba824"
"model_id": "meta-llama/Llama-2-7b-hf",
"model_revision": "6fdf2e60f86ff2481f2241aaee459f85b5b0bbb9"
},
{
"model_format": "pytorch",
Expand All @@ -714,8 +714,8 @@
"8-bit",
"none"
],
"model_id": "meta-llama/Llama-2-13b",
"model_revision": "bd5c881755fa1f0518506e60207229b13b0f67e1"
"model_id": "meta-llama/Llama-2-13b-hf",
"model_revision": "db6b8eb1feabb38985fdf785a89895959e944936"
},
{
"model_format": "pytorch",
Expand All @@ -726,7 +726,7 @@
"none"
],
"model_id": "meta-llama/Llama-2-70b",
"model_revision": "fce501427e806d830acbd5e0a697a7924dc49278"
"model_revision": "cc8aa03a000ff08b4d5c5b39673321a2a396c396"
}
]
},
Expand Down
4 changes: 2 additions & 2 deletions xinference/model/llm/pytorch/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,8 +478,8 @@ def chat(
if stream:
it = self.generate(full_prompt, generate_config)
assert isinstance(it, Iterator)
return self._convert_chat_completion_chunks_to_chat(it)
return self._to_chat_completion_chunks(it)
else:
c = self.generate(full_prompt, generate_config)
assert not isinstance(c, Iterator)
return self._convert_text_completion_to_chat(c)
return self._to_chat_completion(c)
100 changes: 63 additions & 37 deletions xinference/model/llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Iterator, List
from typing import AsyncGenerator, Iterator, List

from xinference.model.llm.llm_family import PromptStyleV1

Expand Down Expand Up @@ -177,59 +177,85 @@ def get_prompt(
else:
raise ValueError(f"Invalid prompt style: {prompt_style.style_name}")

@staticmethod
def _convert_chat_completion_chunks_to_chat(
@classmethod
def _to_chat_completion_chunk(cls, chunk: CompletionChunk) -> ChatCompletionChunk:
return {
"id": "chat" + chunk["id"],
"model": chunk["model"],
"created": chunk["created"],
"object": "chat.completion.chunk",
"choices": [
{
"index": i,
"delta": {
"content": choice["text"],
},
"finish_reason": choice["finish_reason"],
}
for i, choice in enumerate(chunk["choices"])
],
}

@classmethod
def _get_first_chat_completion_chunk(
cls, chunk: CompletionChunk
) -> ChatCompletionChunk:
return {
"id": "chat" + chunk["id"],
"model": chunk["model"],
"created": chunk["created"],
"object": "chat.completion.chunk",
"choices": [
{
"index": i,
"delta": {
"role": "assistant",
},
"finish_reason": None,
}
for i, choice in enumerate(chunk["choices"])
],
}

@classmethod
def _to_chat_completion_chunks(
cls,
chunks: Iterator[CompletionChunk],
) -> Iterator[ChatCompletionChunk]:
for i, chunk in enumerate(chunks):
if i == 0:
yield {
"id": "chat" + chunk["id"],
"model": chunk["model"],
"created": chunk["created"],
"object": "chat.completion.chunk",
"choices": [
{
"index": 0,
"delta": {
"role": "assistant",
},
"finish_reason": None,
}
],
}
yield {
"id": "chat" + chunk["id"],
"model": chunk["model"],
"created": chunk["created"],
"object": "chat.completion.chunk",
"choices": [
{
"index": 0,
"delta": {
"content": chunk["choices"][0]["text"],
},
"finish_reason": chunk["choices"][0]["finish_reason"],
}
],
}
yield cls._get_first_chat_completion_chunk(chunk)
yield cls._to_chat_completion_chunk(chunk)

@classmethod
async def _async_to_chat_completion_chunks(
cls,
chunks: AsyncGenerator[CompletionChunk, None],
) -> AsyncGenerator[ChatCompletionChunk, None]:
i = 0
async for chunk in chunks:
if i == 0:
yield cls._get_first_chat_completion_chunk(chunk)
yield cls._to_chat_completion_chunk(chunk)
i += 1

@staticmethod
def _convert_text_completion_to_chat(completion: Completion) -> ChatCompletion:
def _to_chat_completion(completion: Completion) -> ChatCompletion:
return {
"id": "chat" + completion["id"],
"object": "chat.completion",
"created": completion["created"],
"model": completion["model"],
"choices": [
{
"index": 0,
"index": i,
"message": {
"role": "assistant",
"content": completion["choices"][0]["text"],
"content": choice["text"],
},
"finish_reason": completion["choices"][0]["finish_reason"],
"finish_reason": choice["finish_reason"],
}
for i, choice in enumerate(completion["choices"])
],
"usage": completion["usage"],
}
Expand Down
13 changes: 13 additions & 0 deletions xinference/model/llm/vllm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2022-2023 XProbe Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Loading

0 comments on commit bc3441e

Please sign in to comment.