Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Offload all response serialization to ModelActor #837

Merged
merged 7 commits into from
Jan 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions xinference/api/restful_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,7 @@ async def stream_results():
except RuntimeError as re:
self.handle_request_limit_error(re)
async for item in iterator:
yield dict(data=json.dumps(item))
yield item
except Exception as ex:
if iterator is not None:
await iterator.destroy()
Expand All @@ -577,7 +577,7 @@ async def stream_results():
else:
try:
data = await model.generate(body.prompt, kwargs)
return JSONResponse(content=data)
return Response(data, media_type="application/json")
except Exception as e:
logger.error(e, exc_info=True)
self.handle_request_limit_error(e)
Expand Down Expand Up @@ -634,7 +634,7 @@ async def rerank(self, request: RerankRequest) -> Response:
logger.error(e, exc_info=True)
raise HTTPException(status_code=500, detail=str(e))

async def create_images(self, request: TextToImageRequest) -> JSONResponse:
async def create_images(self, request: TextToImageRequest) -> Response:
model_uid = request.model
try:
model = await (await self._get_supervisor_ref()).get_model(model_uid)
Expand All @@ -655,7 +655,7 @@ async def create_images(self, request: TextToImageRequest) -> JSONResponse:
response_format=request.response_format,
**kwargs,
)
return JSONResponse(content=image_list)
return Response(content=image_list, media_type="application/json")
except RuntimeError as re:
logger.error(re, exc_info=True)
self.handle_request_limit_error(re)
Expand All @@ -674,7 +674,7 @@ async def create_variations(
response_format: Optional[str] = Form("url"),
size: Optional[str] = Form("1024*1024"),
kwargs: Optional[str] = Form(None),
) -> JSONResponse:
) -> Response:
model_uid = model
try:
model_ref = await (await self._get_supervisor_ref()).get_model(model_uid)
Expand All @@ -697,7 +697,7 @@ async def create_variations(
response_format=response_format,
**kwargs,
)
return JSONResponse(content=image_list)
return Response(content=image_list, media_type="application/json")
except RuntimeError as re:
logger.error(re, exc_info=True)
raise HTTPException(status_code=400, detail=str(re))
Expand Down Expand Up @@ -828,7 +828,7 @@ async def stream_results():
except RuntimeError as re:
self.handle_request_limit_error(re)
async for item in iterator:
yield dict(data=json.dumps(item))
yield item
except Exception as ex:
if iterator is not None:
await iterator.destroy()
Expand All @@ -843,7 +843,7 @@ async def stream_results():
data = await model.chat(prompt, chat_history, kwargs)
else:
data = await model.chat(prompt, system_prompt, chat_history, kwargs)
return JSONResponse(content=data)
return Response(content=data, media_type="application/json")
except Exception as e:
logger.error(e, exc_info=True)
self.handle_request_limit_error(e)
Expand Down
86 changes: 78 additions & 8 deletions xinference/client/oscar/actor_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@
# limitations under the License.

import asyncio
import re
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Union

import orjson
import xoscar as xo

from ...core.model import ModelActor
from ...core.model import IteratorWrapper, ModelActor
from ...core.supervisor import SupervisorActor
from ...isolation import Isolation
from ..restful.restful_client import Client
Expand All @@ -38,6 +40,52 @@
)


class SSEEvent(object):
# https://github.com/btubbs/sseclient/blob/master/sseclient.py
sse_line_pattern = re.compile("(?P<name>[^:]*):?( ?(?P<value>.*))?")

def __init__(self, data="", event="message", id=None, retry=None):
self.data = data
self.event = event
self.id = id
self.retry = retry

@classmethod
def parse(cls, raw):
"""
Given a possibly-multiline string representing an SSE message, parse it
and return a Event object.
"""
msg = cls()
for line in raw.splitlines():
m = cls.sse_line_pattern.match(line)
if m is None:
# Malformed line. Discard but warn.
continue

name = m.group("name")
if name == "":
# line began with a ":", so is a comment. Ignore
continue
value = m.group("value")

if name == "data":
# If we already have some data, then join to it with a newline.
# Else this is it.
if msg.data:
msg.data = "%s\n%s" % (msg.data, value)
else:
msg.data = value
elif name == "event":
msg.event = value
elif name == "id":
msg.id = value
elif name == "retry":
msg.retry = int(value)

return msg


class ModelHandle:
"""
A sync model interface (for rpc client) which provides type hints that makes it much easier to use xinference
Expand All @@ -49,6 +97,19 @@ def __init__(self, model_ref: xo.ActorRefType["ModelActor"], isolation: Isolatio
self._isolation = isolation


class ClientIteratorWrapper(IteratorWrapper):
async def __anext__(self):
r = await super().__anext__()
text = r.decode("utf-8")
return orjson.loads(SSEEvent.parse(text).data)

@classmethod
def wrap(cls, iterator_wrapper):
c = cls.__new__(cls)
c.__dict__.update(iterator_wrapper.__dict__)
return c


class EmbeddingModelHandle(ModelHandle):
def create_embedding(self, input: Union[str, List[str]]) -> bytes:
"""
Expand All @@ -68,7 +129,7 @@ def create_embedding(self, input: Union[str, List[str]]) -> bytes:
"""

coro = self._model_ref.create_embedding(input)
return self._isolation.call(coro)
return orjson.loads(self._isolation.call(coro))


class RerankModelHandle(ModelHandle):
Expand Down Expand Up @@ -104,7 +165,7 @@ def rerank(
coro = self._model_ref.rerank(
documents, query, top_n, max_chunks_per_doc, return_documents
)
results = self._isolation.call(coro)
results = orjson.loads(self._isolation.call(coro))
for r in results["results"]:
r["document"] = documents[r["index"]]
return results
Expand Down Expand Up @@ -140,7 +201,10 @@ def generate(
"""

coro = self._model_ref.generate(prompt, generate_config)
return self._isolation.call(coro)
r = self._isolation.call(coro)
if isinstance(r, bytes):
return orjson.loads(r)
return ClientIteratorWrapper.wrap(r)


class ChatModelHandle(GenerateModelHandle):
Expand Down Expand Up @@ -185,7 +249,10 @@ def chat(
coro = self._model_ref.chat(
prompt, system_prompt, chat_history, generate_config
)
return self._isolation.call(coro)
r = self._isolation.call(coro)
if isinstance(r, bytes):
return orjson.loads(r)
return ClientIteratorWrapper.wrap(r)


class ChatglmCppChatModelHandle(EmbeddingModelHandle):
Expand Down Expand Up @@ -217,7 +284,10 @@ def chat(
"""

coro = self._model_ref.chat(prompt, chat_history, generate_config)
return self._isolation.call(coro)
r = self._isolation.call(coro)
if isinstance(r, bytes):
return orjson.loads(r)
return ClientIteratorWrapper.wrap(r)


class ImageModelHandle(ModelHandle):
Expand Down Expand Up @@ -249,7 +319,7 @@ def text_to_image(
"""

coro = self._model_ref.text_to_image(prompt, n, size, response_format, **kwargs)
return self._isolation.call(coro)
return orjson.loads(self._isolation.call(coro))

def image_to_image(
self,
Expand Down Expand Up @@ -294,7 +364,7 @@ def image_to_image(
coro = self._model_ref.image_to_image(
image, prompt, negative_prompt, n, size, response_format, **kwargs
)
return self._isolation.call(coro)
return orjson.loads(self._isolation.call(coro))


class ActorClient:
Expand Down
4 changes: 1 addition & 3 deletions xinference/client/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ async def test_client(setup):
completion = model.chat("write a poem.", generate_config={"stream": True})
async for chunk in completion:
assert chunk
assert isinstance(chunk, dict)

client.terminate_model(model_uid=model_uid)
assert len(client.list_models()) == 0
Expand All @@ -69,7 +70,6 @@ async def test_client(setup):
model = client.get_model(model_uid=model_uid)

embedding_res = model.create_embedding("The food was delicious and the waiter...")
embedding_res = json.loads(embedding_res)
assert "embedding" in embedding_res["data"][0]

client.terminate_model(model_uid=model_uid)
Expand Down Expand Up @@ -126,7 +126,6 @@ def test_client_for_embedding(setup):
assert isinstance(model, EmbeddingModelHandle)

completion = model.create_embedding("write a poem.")
completion = json.loads(completion)
assert len(completion["data"][0]["embedding"]) == 512

client.terminate_model(model_uid=model_uid)
Expand Down Expand Up @@ -156,7 +155,6 @@ def test_replica_model(setup):
replica_uids.add(model._model_ref.uid)

embedding_res = model.create_embedding("The food was delicious and the waiter...")
embedding_res = json.loads(embedding_res)
assert "embedding" in embedding_res["data"][0]

client2 = RESTfulClient(endpoint)
Expand Down
20 changes: 14 additions & 6 deletions xinference/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import asyncio
import inspect
import json
import os
import uuid
from typing import (
Expand All @@ -30,6 +31,7 @@
Union,
)

import sse_starlette.sse
import xoscar as xo

if TYPE_CHECKING:
Expand Down Expand Up @@ -186,7 +188,7 @@ def model_uid(self):
)
)

async def _wrap_generator(self, ret: Any):
def _wrap_generator(self, ret: Any):
if inspect.isgenerator(ret) or inspect.isasyncgen(ret):
if self._lock is not None and self._generators:
raise Exception("Parallel generation is not supported by ggml.")
Expand All @@ -199,7 +201,7 @@ async def _wrap_generator(self, ret: Any):
model_actor_uid=self.uid,
)
else:
return ret
return json_dumps(ret)

async def _call_wrapper(self, _wrapper: Callable):
try:
Expand Down Expand Up @@ -335,9 +337,10 @@ async def text_to_image(
)

def _wrapper():
return getattr(self._model, "text_to_image")(
r = getattr(self._model, "text_to_image")(
prompt, n, size, response_format, *args, **kwargs
)
return json_dumps(r)

return await self._call_wrapper(_wrapper)

Expand All @@ -358,7 +361,7 @@ async def image_to_image(
)

def _wrapper():
return getattr(self._model, "image_to_image")(
r = getattr(self._model, "image_to_image")(
image,
prompt,
negative_prompt,
Expand All @@ -368,6 +371,7 @@ def _wrapper():
*args,
**kwargs,
)
return json_dumps(r)

return await self._call_wrapper(_wrapper)

Expand All @@ -381,14 +385,18 @@ async def next(

def _wrapper():
try:
return next(gen)
v = dict(data=json.dumps(next(gen)))
return sse_starlette.sse.ensure_bytes(v, None)
except StopIteration:
return stop

async def _async_wrapper():
try:
# anext is only available for Python >= 3.10
return await gen.__anext__() # noqa: F821
v = await gen.__anext__()
v = await asyncio.to_thread(json.dumps, v)
v = dict(data=v) # noqa: F821
return await asyncio.to_thread(sse_starlette.sse.ensure_bytes, v, None)
except StopAsyncIteration:
return stop

Expand Down
2 changes: 2 additions & 0 deletions xinference/model/llm/pytorch/tests/test_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import json
import os
import threading
import time
Expand Down Expand Up @@ -146,4 +147,5 @@ async def test_concurrent_pytorch_model(setup):
)
coros.append(co)
r = await asyncio.gather(*coros)
r = [json.loads(i) for i in r]
assert not any(r)
Loading