Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEAT: Support events #916

Merged
merged 4 commits into from
Jan 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions xinference/api/restful_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import os
import pprint
import sys
import time
import warnings
from datetime import timedelta
from typing import Any, List, Optional, Union
Expand Down Expand Up @@ -54,6 +55,7 @@
from xoscar.utils import get_next_port

from ..constants import XINFERENCE_DEFAULT_ENDPOINT_PORT
from ..core.event import Event, EventCollectorActor, EventType
from ..core.supervisor import SupervisorActor
from ..core.utils import json_dumps
from ..types import (
Expand Down Expand Up @@ -157,6 +159,7 @@ def __init__(
self._host = host
self._port = port
self._supervisor_ref = None
self._event_collector_ref = None
self._auth_config: AuthStartupConfig = self.init_auth_config(auth_config_file)
self._router = APIRouter()
self._app = FastAPI()
Expand Down Expand Up @@ -189,6 +192,29 @@ async def _get_supervisor_ref(self) -> xo.ActorRefType[SupervisorActor]:
)
return self._supervisor_ref

async def _get_event_collector_ref(self) -> xo.ActorRefType[EventCollectorActor]:
if self._event_collector_ref is None:
self._event_collector_ref = await xo.actor_ref(
address=self._supervisor_address, uid=EventCollectorActor.uid()
)
return self._event_collector_ref

async def _report_error_event(self, model_uid: str, content: str):
try:
event_collector_ref = await self._get_event_collector_ref()
await event_collector_ref.report_event(
model_uid,
Event(
event_type=EventType.ERROR,
event_ts=int(time.time()),
event_content=content,
),
)
except Exception:
logger.exception(
"Report error event failed, model: %s, content: %s", model_uid, content
)

async def login_for_access_token(self, form_data: LoginUserForm) -> JSONResponse:
user = authenticate_user(
self._auth_config.user_config, form_data.username, form_data.password
Expand Down Expand Up @@ -288,6 +314,14 @@ def serve(self, logging_conf: Optional[dict] = None):
if self.is_authenticated()
else None,
)
self._router.add_api_route(
"/v1/models/{model_uid}/events",
self.get_model_events,
methods=["GET"],
dependencies=[Security(verify_token, scopes=["models:read"])]
if self.is_authenticated()
else None,
)
self._router.add_api_route(
"/v1/models/instance",
self.launch_model_by_version,
Expand Down Expand Up @@ -794,10 +828,12 @@ async def create_completion(
model = await (await self._get_supervisor_ref()).get_model(model_uid)
except ValueError as ve:
logger.error(str(ve), exc_info=True)
await self._report_error_event(model_uid, str(ve))
raise HTTPException(status_code=400, detail=str(ve))

except Exception as e:
logger.error(e, exc_info=True)
await self._report_error_event(model_uid, str(e))
raise HTTPException(status_code=500, detail=str(e))

if body.stream:
Expand All @@ -813,6 +849,7 @@ async def stream_results():
yield item
except Exception as ex:
logger.exception("Completion stream got an error: %s", ex)
await self._report_error_event(model_uid, str(ex))
# https://github.com/openai/openai-python/blob/e0aafc6c1a45334ac889fe3e54957d309c3af93f/src/openai/_streaming.py#L107
yield dict(data=json.dumps({"error": str(ex)}))

Expand All @@ -823,6 +860,7 @@ async def stream_results():
return Response(data, media_type="application/json")
except Exception as e:
logger.error(e, exc_info=True)
await self._report_error_event(model_uid, str(e))
self.handle_request_limit_error(e)
raise HTTPException(status_code=500, detail=str(e))

Expand All @@ -833,20 +871,24 @@ async def create_embedding(self, request: CreateEmbeddingRequest) -> Response:
model = await (await self._get_supervisor_ref()).get_model(model_uid)
except ValueError as ve:
logger.error(str(ve), exc_info=True)
await self._report_error_event(model_uid, str(ve))
raise HTTPException(status_code=400, detail=str(ve))
except Exception as e:
logger.error(e, exc_info=True)
await self._report_error_event(model_uid, str(e))
raise HTTPException(status_code=500, detail=str(e))

try:
embedding = await model.create_embedding(request.input)
return Response(embedding, media_type="application/json")
except RuntimeError as re:
logger.error(re, exc_info=True)
await self._report_error_event(model_uid, str(re))
self.handle_request_limit_error(re)
raise HTTPException(status_code=400, detail=str(re))
except Exception as e:
logger.error(e, exc_info=True)
await self._report_error_event(model_uid, str(e))
raise HTTPException(status_code=500, detail=str(e))

async def rerank(self, request: RerankRequest) -> Response:
Expand All @@ -855,9 +897,11 @@ async def rerank(self, request: RerankRequest) -> Response:
model = await (await self._get_supervisor_ref()).get_model(model_uid)
except ValueError as ve:
logger.error(str(ve), exc_info=True)
await self._report_error_event(model_uid, str(ve))
raise HTTPException(status_code=400, detail=str(ve))
except Exception as e:
logger.error(e, exc_info=True)
await self._report_error_event(model_uid, str(e))
raise HTTPException(status_code=500, detail=str(e))

try:
Expand All @@ -871,10 +915,12 @@ async def rerank(self, request: RerankRequest) -> Response:
return Response(scores, media_type="application/json")
except RuntimeError as re:
logger.error(re, exc_info=True)
await self._report_error_event(model_uid, str(re))
self.handle_request_limit_error(re)
raise HTTPException(status_code=400, detail=str(re))
except Exception as e:
logger.error(e, exc_info=True)
await self._report_error_event(model_uid, str(e))
raise HTTPException(status_code=500, detail=str(e))

async def create_images(self, request: TextToImageRequest) -> Response:
Expand All @@ -883,9 +929,11 @@ async def create_images(self, request: TextToImageRequest) -> Response:
model = await (await self._get_supervisor_ref()).get_model(model_uid)
except ValueError as ve:
logger.error(str(ve), exc_info=True)
await self._report_error_event(model_uid, str(ve))
raise HTTPException(status_code=400, detail=str(ve))
except Exception as e:
logger.error(e, exc_info=True)
await self._report_error_event(model_uid, str(e))
raise HTTPException(status_code=500, detail=str(e))

try:
Expand All @@ -900,10 +948,12 @@ async def create_images(self, request: TextToImageRequest) -> Response:
return Response(content=image_list, media_type="application/json")
except RuntimeError as re:
logger.error(re, exc_info=True)
await self._report_error_event(model_uid, str(re))
self.handle_request_limit_error(re)
raise HTTPException(status_code=400, detail=str(re))
except Exception as e:
logger.error(e, exc_info=True)
await self._report_error_event(model_uid, str(e))
raise HTTPException(status_code=500, detail=str(e))

async def create_variations(
Expand All @@ -922,9 +972,11 @@ async def create_variations(
model_ref = await (await self._get_supervisor_ref()).get_model(model_uid)
except ValueError as ve:
logger.error(str(ve), exc_info=True)
await self._report_error_event(model_uid, str(ve))
raise HTTPException(status_code=400, detail=str(ve))
except Exception as e:
logger.error(e, exc_info=True)
await self._report_error_event(model_uid, str(e))
raise HTTPException(status_code=500, detail=str(e))

try:
Expand All @@ -942,9 +994,11 @@ async def create_variations(
return Response(content=image_list, media_type="application/json")
except RuntimeError as re:
logger.error(re, exc_info=True)
await self._report_error_event(model_uid, str(re))
raise HTTPException(status_code=400, detail=str(re))
except Exception as e:
logger.error(e, exc_info=True)
await self._report_error_event(model_uid, str(e))
raise HTTPException(status_code=500, detail=str(e))

async def create_chat_completion(
Expand Down Expand Up @@ -1012,18 +1066,22 @@ async def create_chat_completion(
model = await (await self._get_supervisor_ref()).get_model(model_uid)
except ValueError as ve:
logger.error(str(ve), exc_info=True)
await self._report_error_event(model_uid, str(ve))
raise HTTPException(status_code=400, detail=str(ve))
except Exception as e:
logger.error(e, exc_info=True)
await self._report_error_event(model_uid, str(e))
raise HTTPException(status_code=500, detail=str(e))

try:
desc = await (await self._get_supervisor_ref()).describe_model(model_uid)
except ValueError as ve:
logger.error(str(ve), exc_info=True)
await self._report_error_event(model_uid, str(ve))
raise HTTPException(status_code=400, detail=str(ve))
except Exception as e:
logger.error(e, exc_info=True)
await self._report_error_event(model_uid, str(e))
raise HTTPException(status_code=500, detail=str(e))

model_name = desc.get("model_name", "")
Expand Down Expand Up @@ -1068,11 +1126,13 @@ async def stream_results():
prompt, system_prompt, chat_history, kwargs
)
except RuntimeError as re:
await self._report_error_event(model_uid, str(re))
self.handle_request_limit_error(re)
async for item in iterator:
yield item
except Exception as ex:
logger.exception("Chat completion stream got an error: %s", ex)
await self._report_error_event(model_uid, str(ex))
# https://github.com/openai/openai-python/blob/e0aafc6c1a45334ac889fe3e54957d309c3af93f/src/openai/_streaming.py#L107
yield dict(data=json.dumps({"error": str(ex)}))

Expand All @@ -1086,6 +1146,7 @@ async def stream_results():
return Response(content=data, media_type="application/json")
except Exception as e:
logger.error(e, exc_info=True)
await self._report_error_event(model_uid, str(e))
self.handle_request_limit_error(e)
raise HTTPException(status_code=500, detail=str(e))

Expand Down Expand Up @@ -1150,6 +1211,18 @@ async def get_model_registrations(
logger.error(e, exc_info=True)
raise HTTPException(status_code=500, detail=str(e))

async def get_model_events(self, model_uid: str) -> JSONResponse:
try:
event_collector_ref = await self._get_event_collector_ref()
events = await event_collector_ref.get_model_events(model_uid)
return JSONResponse(content=events)
except ValueError as re:
logger.error(re, exc_info=True)
raise HTTPException(status_code=400, detail=str(re))
except Exception as e:
logger.error(e, exc_info=True)
raise HTTPException(status_code=500, detail=str(e))


def run(
supervisor_address: str,
Expand Down
56 changes: 56 additions & 0 deletions xinference/core/event.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright 2022-2024 XProbe Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import queue
from collections import defaultdict
from enum import Enum
from typing import Dict, List, TypedDict

import xoscar as xo

MAX_EVENT_COUNT_PER_MODEL = 100


class EventType(Enum):
INFO = 1
WARNING = 2
ERROR = 3


class Event(TypedDict):
event_type: EventType
event_ts: int
event_content: str


class EventCollectorActor(xo.StatelessActor):
def __init__(self):
super().__init__()
self._model_uid_to_events: Dict[str, queue.Queue] = defaultdict(
lambda: queue.Queue(maxsize=MAX_EVENT_COUNT_PER_MODEL)
)

@classmethod
def uid(cls) -> str:
return "event_collector"

def get_model_events(self, model_uid: str) -> List[Dict]:
event_queue = self._model_uid_to_events.get(model_uid)
if event_queue is None:
return []
else:
return [dict(e, event_type=e["event_type"].name) for e in event_queue.queue]

def report_event(self, model_uid: str, event: Event):
self._model_uid_to_events[model_uid].put(event)
8 changes: 8 additions & 0 deletions xinference/core/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,14 @@ async def __post_create__(self):
CacheTrackerActor, address=self.address, uid=CacheTrackerActor.uid()
)

from .event import EventCollectorActor

self._event_collector_ref: xo.ActorRefType[
EventCollectorActor
] = await xo.create_actor(
EventCollectorActor, address=self.address, uid=EventCollectorActor.uid()
)

from ..model.embedding import (
CustomEmbeddingModelSpec,
generate_embedding_description,
Expand Down
34 changes: 34 additions & 0 deletions xinference/core/tests/test_restful_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1069,6 +1069,40 @@ def test_launch_model_async(setup):
assert len(response.json()) == 0


def test_events(setup):
endpoint, _ = setup
url = f"{endpoint}/v1/models"

payload = {
"model_uid": "test_orca",
"model_name": "orca",
"quantization": "q4_0",
}

response = requests.post(url, json=payload)
response_data = response.json()
model_uid_res = response_data["model_uid"]
assert model_uid_res == "test_orca"

events_url = f"{endpoint}/v1/models/test_orca/events"
response = requests.get(events_url)
response_data = response.json()
# [{'event_type': 'INFO', 'event_ts': 1705896156, 'event_content': 'Launch model'}]
assert len(response_data) == 1
assert "Launch" in response_data[0]["event_content"]

# delete again
url = f"{endpoint}/v1/models/test_orca"
requests.delete(url)

response = requests.get(events_url)
response_data = response.json()
# [{'event_type': 'INFO', 'event_ts': 1705896215, 'event_content': 'Launch model'},
# {'event_type': 'INFO', 'event_ts': 1705896215, 'event_content': 'Terminate model'}]
assert len(response_data) == 2
assert "Terminate" in response_data[1]["event_content"]


def test_launch_model_by_version(setup):
from ...model.llm import get_llm_model_descriptions

Expand Down
Loading
Loading