Skip to content

Commit

Permalink
ENH: Offload all response serialization to ModelActor (#837)
Browse files Browse the repository at this point in the history
  • Loading branch information
codingl2k1 authored Jan 3, 2024
1 parent 11f80bb commit 3a38ab5
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 25 deletions.
16 changes: 8 additions & 8 deletions xinference/api/restful_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,7 @@ async def stream_results():
except RuntimeError as re:
self.handle_request_limit_error(re)
async for item in iterator:
yield dict(data=json.dumps(item))
yield item
except Exception as ex:
if iterator is not None:
await iterator.destroy()
Expand All @@ -577,7 +577,7 @@ async def stream_results():
else:
try:
data = await model.generate(body.prompt, kwargs)
return JSONResponse(content=data)
return Response(data, media_type="application/json")
except Exception as e:
logger.error(e, exc_info=True)
self.handle_request_limit_error(e)
Expand Down Expand Up @@ -634,7 +634,7 @@ async def rerank(self, request: RerankRequest) -> Response:
logger.error(e, exc_info=True)
raise HTTPException(status_code=500, detail=str(e))

async def create_images(self, request: TextToImageRequest) -> JSONResponse:
async def create_images(self, request: TextToImageRequest) -> Response:
model_uid = request.model
try:
model = await (await self._get_supervisor_ref()).get_model(model_uid)
Expand All @@ -655,7 +655,7 @@ async def create_images(self, request: TextToImageRequest) -> JSONResponse:
response_format=request.response_format,
**kwargs,
)
return JSONResponse(content=image_list)
return Response(content=image_list, media_type="application/json")
except RuntimeError as re:
logger.error(re, exc_info=True)
self.handle_request_limit_error(re)
Expand All @@ -674,7 +674,7 @@ async def create_variations(
response_format: Optional[str] = Form("url"),
size: Optional[str] = Form("1024*1024"),
kwargs: Optional[str] = Form(None),
) -> JSONResponse:
) -> Response:
model_uid = model
try:
model_ref = await (await self._get_supervisor_ref()).get_model(model_uid)
Expand All @@ -697,7 +697,7 @@ async def create_variations(
response_format=response_format,
**kwargs,
)
return JSONResponse(content=image_list)
return Response(content=image_list, media_type="application/json")
except RuntimeError as re:
logger.error(re, exc_info=True)
raise HTTPException(status_code=400, detail=str(re))
Expand Down Expand Up @@ -828,7 +828,7 @@ async def stream_results():
except RuntimeError as re:
self.handle_request_limit_error(re)
async for item in iterator:
yield dict(data=json.dumps(item))
yield item
except Exception as ex:
if iterator is not None:
await iterator.destroy()
Expand All @@ -843,7 +843,7 @@ async def stream_results():
data = await model.chat(prompt, chat_history, kwargs)
else:
data = await model.chat(prompt, system_prompt, chat_history, kwargs)
return JSONResponse(content=data)
return Response(content=data, media_type="application/json")
except Exception as e:
logger.error(e, exc_info=True)
self.handle_request_limit_error(e)
Expand Down
86 changes: 78 additions & 8 deletions xinference/client/oscar/actor_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@
# limitations under the License.

import asyncio
import re
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Union

import orjson
import xoscar as xo

from ...core.model import ModelActor
from ...core.model import IteratorWrapper, ModelActor
from ...core.supervisor import SupervisorActor
from ...isolation import Isolation
from ..restful.restful_client import Client
Expand All @@ -38,6 +40,52 @@
)


class SSEEvent(object):
# https://github.com/btubbs/sseclient/blob/master/sseclient.py
sse_line_pattern = re.compile("(?P<name>[^:]*):?( ?(?P<value>.*))?")

def __init__(self, data="", event="message", id=None, retry=None):
self.data = data
self.event = event
self.id = id
self.retry = retry

@classmethod
def parse(cls, raw):
"""
Given a possibly-multiline string representing an SSE message, parse it
and return a Event object.
"""
msg = cls()
for line in raw.splitlines():
m = cls.sse_line_pattern.match(line)
if m is None:
# Malformed line. Discard but warn.
continue

name = m.group("name")
if name == "":
# line began with a ":", so is a comment. Ignore
continue
value = m.group("value")

if name == "data":
# If we already have some data, then join to it with a newline.
# Else this is it.
if msg.data:
msg.data = "%s\n%s" % (msg.data, value)
else:
msg.data = value
elif name == "event":
msg.event = value
elif name == "id":
msg.id = value
elif name == "retry":
msg.retry = int(value)

return msg


class ModelHandle:
"""
A sync model interface (for rpc client) which provides type hints that makes it much easier to use xinference
Expand All @@ -49,6 +97,19 @@ def __init__(self, model_ref: xo.ActorRefType["ModelActor"], isolation: Isolatio
self._isolation = isolation


class ClientIteratorWrapper(IteratorWrapper):
async def __anext__(self):
r = await super().__anext__()
text = r.decode("utf-8")
return orjson.loads(SSEEvent.parse(text).data)

@classmethod
def wrap(cls, iterator_wrapper):
c = cls.__new__(cls)
c.__dict__.update(iterator_wrapper.__dict__)
return c


class EmbeddingModelHandle(ModelHandle):
def create_embedding(self, input: Union[str, List[str]]) -> bytes:
"""
Expand All @@ -68,7 +129,7 @@ def create_embedding(self, input: Union[str, List[str]]) -> bytes:
"""

coro = self._model_ref.create_embedding(input)
return self._isolation.call(coro)
return orjson.loads(self._isolation.call(coro))


class RerankModelHandle(ModelHandle):
Expand Down Expand Up @@ -104,7 +165,7 @@ def rerank(
coro = self._model_ref.rerank(
documents, query, top_n, max_chunks_per_doc, return_documents
)
results = self._isolation.call(coro)
results = orjson.loads(self._isolation.call(coro))
for r in results["results"]:
r["document"] = documents[r["index"]]
return results
Expand Down Expand Up @@ -140,7 +201,10 @@ def generate(
"""

coro = self._model_ref.generate(prompt, generate_config)
return self._isolation.call(coro)
r = self._isolation.call(coro)
if isinstance(r, bytes):
return orjson.loads(r)
return ClientIteratorWrapper.wrap(r)


class ChatModelHandle(GenerateModelHandle):
Expand Down Expand Up @@ -185,7 +249,10 @@ def chat(
coro = self._model_ref.chat(
prompt, system_prompt, chat_history, generate_config
)
return self._isolation.call(coro)
r = self._isolation.call(coro)
if isinstance(r, bytes):
return orjson.loads(r)
return ClientIteratorWrapper.wrap(r)


class ChatglmCppChatModelHandle(EmbeddingModelHandle):
Expand Down Expand Up @@ -217,7 +284,10 @@ def chat(
"""

coro = self._model_ref.chat(prompt, chat_history, generate_config)
return self._isolation.call(coro)
r = self._isolation.call(coro)
if isinstance(r, bytes):
return orjson.loads(r)
return ClientIteratorWrapper.wrap(r)


class ImageModelHandle(ModelHandle):
Expand Down Expand Up @@ -249,7 +319,7 @@ def text_to_image(
"""

coro = self._model_ref.text_to_image(prompt, n, size, response_format, **kwargs)
return self._isolation.call(coro)
return orjson.loads(self._isolation.call(coro))

def image_to_image(
self,
Expand Down Expand Up @@ -294,7 +364,7 @@ def image_to_image(
coro = self._model_ref.image_to_image(
image, prompt, negative_prompt, n, size, response_format, **kwargs
)
return self._isolation.call(coro)
return orjson.loads(self._isolation.call(coro))


class ActorClient:
Expand Down
4 changes: 1 addition & 3 deletions xinference/client/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ async def test_client(setup):
completion = model.chat("write a poem.", generate_config={"stream": True})
async for chunk in completion:
assert chunk
assert isinstance(chunk, dict)

client.terminate_model(model_uid=model_uid)
assert len(client.list_models()) == 0
Expand All @@ -69,7 +70,6 @@ async def test_client(setup):
model = client.get_model(model_uid=model_uid)

embedding_res = model.create_embedding("The food was delicious and the waiter...")
embedding_res = json.loads(embedding_res)
assert "embedding" in embedding_res["data"][0]

client.terminate_model(model_uid=model_uid)
Expand Down Expand Up @@ -126,7 +126,6 @@ def test_client_for_embedding(setup):
assert isinstance(model, EmbeddingModelHandle)

completion = model.create_embedding("write a poem.")
completion = json.loads(completion)
assert len(completion["data"][0]["embedding"]) == 512

client.terminate_model(model_uid=model_uid)
Expand Down Expand Up @@ -156,7 +155,6 @@ def test_replica_model(setup):
replica_uids.add(model._model_ref.uid)

embedding_res = model.create_embedding("The food was delicious and the waiter...")
embedding_res = json.loads(embedding_res)
assert "embedding" in embedding_res["data"][0]

client2 = RESTfulClient(endpoint)
Expand Down
20 changes: 14 additions & 6 deletions xinference/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import asyncio
import inspect
import json
import os
import uuid
from typing import (
Expand All @@ -30,6 +31,7 @@
Union,
)

import sse_starlette.sse
import xoscar as xo

if TYPE_CHECKING:
Expand Down Expand Up @@ -186,7 +188,7 @@ def model_uid(self):
)
)

async def _wrap_generator(self, ret: Any):
def _wrap_generator(self, ret: Any):
if inspect.isgenerator(ret) or inspect.isasyncgen(ret):
if self._lock is not None and self._generators:
raise Exception("Parallel generation is not supported by ggml.")
Expand All @@ -199,7 +201,7 @@ async def _wrap_generator(self, ret: Any):
model_actor_uid=self.uid,
)
else:
return ret
return json_dumps(ret)

async def _call_wrapper(self, _wrapper: Callable):
try:
Expand Down Expand Up @@ -335,9 +337,10 @@ async def text_to_image(
)

def _wrapper():
return getattr(self._model, "text_to_image")(
r = getattr(self._model, "text_to_image")(
prompt, n, size, response_format, *args, **kwargs
)
return json_dumps(r)

return await self._call_wrapper(_wrapper)

Expand All @@ -358,7 +361,7 @@ async def image_to_image(
)

def _wrapper():
return getattr(self._model, "image_to_image")(
r = getattr(self._model, "image_to_image")(
image,
prompt,
negative_prompt,
Expand All @@ -368,6 +371,7 @@ def _wrapper():
*args,
**kwargs,
)
return json_dumps(r)

return await self._call_wrapper(_wrapper)

Expand All @@ -381,14 +385,18 @@ async def next(

def _wrapper():
try:
return next(gen)
v = dict(data=json.dumps(next(gen)))
return sse_starlette.sse.ensure_bytes(v, None)
except StopIteration:
return stop

async def _async_wrapper():
try:
# anext is only available for Python >= 3.10
return await gen.__anext__() # noqa: F821
v = await gen.__anext__()
v = await asyncio.to_thread(json.dumps, v)
v = dict(data=v) # noqa: F821
return await asyncio.to_thread(sse_starlette.sse.ensure_bytes, v, None)
except StopAsyncIteration:
return stop

Expand Down
2 changes: 2 additions & 0 deletions xinference/model/llm/pytorch/tests/test_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import json
import os
import threading
import time
Expand Down Expand Up @@ -146,4 +147,5 @@ async def test_concurrent_pytorch_model(setup):
)
coros.append(co)
r = await asyncio.gather(*coros)
r = [json.loads(i) for i in r]
assert not any(r)

0 comments on commit 3a38ab5

Please sign in to comment.