From 59fb002cd50f4d7651a99b6dbbba40cd5750eaac Mon Sep 17 00:00:00 2001 From: codingl2k1 Date: Mon, 22 Jan 2024 12:08:26 +0800 Subject: [PATCH 1/3] Implement events --- xinference/api/restful_api.py | 29 ++++++++++++ xinference/core/event.py | 56 +++++++++++++++++++++++ xinference/core/supervisor.py | 8 ++++ xinference/core/tests/test_restful_api.py | 34 ++++++++++++++ xinference/core/worker.py | 34 ++++++++++++++ 5 files changed, 161 insertions(+) create mode 100644 xinference/core/event.py diff --git a/xinference/api/restful_api.py b/xinference/api/restful_api.py index accf17eb98..45bf91eec4 100644 --- a/xinference/api/restful_api.py +++ b/xinference/api/restful_api.py @@ -54,6 +54,7 @@ from xoscar.utils import get_next_port from ..constants import XINFERENCE_DEFAULT_ENDPOINT_PORT +from ..core.event import EventCollectorActor from ..core.supervisor import SupervisorActor from ..core.utils import json_dumps from ..types import ( @@ -157,6 +158,7 @@ def __init__( self._host = host self._port = port self._supervisor_ref = None + self._event_collector_ref = None self._auth_config: AuthStartupConfig = self.init_auth_config(auth_config_file) self._router = APIRouter() self._app = FastAPI() @@ -189,6 +191,13 @@ async def _get_supervisor_ref(self) -> xo.ActorRefType[SupervisorActor]: ) return self._supervisor_ref + async def _get_event_collector_ref(self) -> xo.ActorRefType[EventCollectorActor]: + if self._event_collector_ref is None: + self._event_collector_ref = await xo.actor_ref( + address=self._supervisor_address, uid=EventCollectorActor.uid() + ) + return self._event_collector_ref + async def login_for_access_token(self, form_data: LoginUserForm) -> JSONResponse: user = authenticate_user( self._auth_config.user_config, form_data.username, form_data.password @@ -280,6 +289,14 @@ def serve(self, logging_conf: Optional[dict] = None): if self.is_authenticated() else None, ) + self._router.add_api_route( + "/v1/models/{model_uid}/events", + self.get_model_events, + methods=["GET"], + dependencies=[Security(verify_token, scopes=["models:events"])] + if self.is_authenticated() + else None, + ) self._router.add_api_route( "/v1/models", self.launch_model, @@ -1096,6 +1113,18 @@ async def get_model_registrations( logger.error(e, exc_info=True) raise HTTPException(status_code=500, detail=str(e)) + async def get_model_events(self, model_uid: str) -> JSONResponse: + try: + event_collector_ref = await self._get_event_collector_ref() + events = await event_collector_ref.get_model_events(model_uid) + return JSONResponse(content=events) + except ValueError as re: + logger.error(re, exc_info=True) + raise HTTPException(status_code=400, detail=str(re)) + except Exception as e: + logger.error(e, exc_info=True) + raise HTTPException(status_code=500, detail=str(e)) + def run( supervisor_address: str, diff --git a/xinference/core/event.py b/xinference/core/event.py new file mode 100644 index 0000000000..7361c179b2 --- /dev/null +++ b/xinference/core/event.py @@ -0,0 +1,56 @@ +# Copyright 2022-2024 XProbe Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import queue +from collections import defaultdict +from enum import Enum +from typing import Dict, List, TypedDict + +import xoscar as xo + +MAX_EVENT_COUNT_PER_MODEL = 100 + + +class EventType(Enum): + INFO = 1 + WARNING = 2 + ERROR = 3 + + +class Event(TypedDict): + event_type: EventType + event_ts: int + event_content: str + + +class EventCollectorActor(xo.StatelessActor): + def __init__(self): + super().__init__() + self._model_uid_to_events: Dict[str, queue.Queue] = defaultdict( + lambda: queue.Queue(maxsize=MAX_EVENT_COUNT_PER_MODEL) + ) + + @classmethod + def uid(cls) -> str: + return "event_collector" + + def get_model_events(self, model_uid: str) -> List[Dict]: + event_queue = self._model_uid_to_events.get(model_uid) + if event_queue is None: + return [] + else: + return [dict(e, event_type=e["event_type"].name) for e in event_queue.queue] + + def report_event(self, model_uid: str, event: Event): + self._model_uid_to_events[model_uid].put(event) diff --git a/xinference/core/supervisor.py b/xinference/core/supervisor.py index 2fce1496dc..3f2271309e 100644 --- a/xinference/core/supervisor.py +++ b/xinference/core/supervisor.py @@ -97,6 +97,14 @@ async def __post_create__(self): StatusGuardActor, address=self.address, uid=StatusGuardActor.uid() ) + from .event import EventCollectorActor + + self._event_collector_ref: xo.ActorRefType[ + EventCollectorActor + ] = await xo.create_actor( + EventCollectorActor, address=self.address, uid=EventCollectorActor.uid() + ) + from ..model.embedding import ( CustomEmbeddingModelSpec, register_embedding, diff --git a/xinference/core/tests/test_restful_api.py b/xinference/core/tests/test_restful_api.py index d8db6eb1ce..33682d1562 100644 --- a/xinference/core/tests/test_restful_api.py +++ b/xinference/core/tests/test_restful_api.py @@ -1054,3 +1054,37 @@ def test_launch_model_async(setup): response = requests.get(status_url) assert len(response.json()) == 0 + + +def test_events(setup): + endpoint, _ = setup + url = f"{endpoint}/v1/models" + + payload = { + "model_uid": "test_orca", + "model_name": "orca", + "quantization": "q4_0", + } + + response = requests.post(url, json=payload) + response_data = response.json() + model_uid_res = response_data["model_uid"] + assert model_uid_res == "test_orca" + + events_url = f"{endpoint}/v1/models/test_orca/events" + response = requests.get(events_url) + response_data = response.json() + # [{'event_type': 'INFO', 'event_ts': 1705896156, 'event_content': 'Launch model'}] + assert len(response_data) == 1 + assert "Launch" in response_data[0]["event_content"] + + # delete again + url = f"{endpoint}/v1/models/test_orca" + requests.delete(url) + + response = requests.get(events_url) + response_data = response.json() + # [{'event_type': 'INFO', 'event_ts': 1705896215, 'event_content': 'Launch model'}, + # {'event_type': 'INFO', 'event_ts': 1705896215, 'event_content': 'Terminate model'}] + assert len(response_data) == 2 + assert "Terminate" in response_data[1]["event_content"] diff --git a/xinference/core/worker.py b/xinference/core/worker.py index 2a936fe450..f66ab8eec2 100644 --- a/xinference/core/worker.py +++ b/xinference/core/worker.py @@ -18,6 +18,7 @@ import queue import signal import threading +import time from collections import defaultdict from logging import getLogger from typing import Any, Dict, List, Optional, Set, Tuple, Union @@ -30,6 +31,7 @@ from ..core.status_guard import LaunchStatus from ..model.core import ModelDescription, create_model_instance from ..utils import cuda_count +from .event import Event, EventCollectorActor, EventType from .metrics import launch_metrics_export_server, record_metrics from .resource import gather_node_info from .utils import log_async, log_sync, parse_replica_model_uid, purge_dir @@ -125,6 +127,15 @@ async def recover_sub_pool(self, address): model_uid, recover_count - 1, ) + event_model_uid, _, __ = parse_replica_model_uid(model_uid) + await self._event_collector_ref.report_event( + event_model_uid, + Event( + event_type=EventType.WARNING, + event_ts=int(time.time()), + event_content="Recreate model", + ), + ) self._model_uid_to_recover_count[model_uid] = ( recover_count - 1 ) @@ -149,6 +160,11 @@ async def __post_create__(self): ] = await xo.actor_ref( address=self._supervisor_address, uid=StatusGuardActor.uid() ) + self._event_collector_ref: xo.ActorRefType[ + EventCollectorActor + ] = await xo.actor_ref( + address=self._supervisor_address, uid=EventCollectorActor.uid() + ) self._supervisor_ref: xo.ActorRefType["SupervisorActor"] = await xo.actor_ref( address=self._supervisor_address, uid=SupervisorActor.uid() ) @@ -427,6 +443,15 @@ async def launch_builtin_model( request_limits: Optional[int] = None, **kwargs, ): + event_model_uid, _, __ = parse_replica_model_uid(model_uid) + await self._event_collector_ref.report_event( + event_model_uid, + Event( + event_type=EventType.INFO, + event_ts=int(time.time()), + event_content="Launch model", + ), + ) launch_args = locals() launch_args.pop("self") launch_args.pop("kwargs") @@ -497,6 +522,15 @@ async def launch_builtin_model( @log_async(logger=logger) async def terminate_model(self, model_uid: str): + event_model_uid, _, __ = parse_replica_model_uid(model_uid) + await self._event_collector_ref.report_event( + event_model_uid, + Event( + event_type=EventType.INFO, + event_ts=int(time.time()), + event_content="Terminate model", + ), + ) origin_uid, _, _ = parse_replica_model_uid(model_uid) await self._status_guard_ref.update_instance_info( origin_uid, {"status": LaunchStatus.TERMINATING.name} From 7d77de2fab08a4b61a59b2934dc8ade22eaa4777 Mon Sep 17 00:00:00 2001 From: codingl2k1 Date: Mon, 22 Jan 2024 14:08:03 +0800 Subject: [PATCH 2/3] Add more events --- xinference/api/restful_api.py | 48 +++++++++++++++++++++++++++++++++-- 1 file changed, 46 insertions(+), 2 deletions(-) diff --git a/xinference/api/restful_api.py b/xinference/api/restful_api.py index 45bf91eec4..f219393543 100644 --- a/xinference/api/restful_api.py +++ b/xinference/api/restful_api.py @@ -20,6 +20,7 @@ import os import pprint import sys +import time import warnings from datetime import timedelta from typing import Any, List, Optional, Union @@ -54,7 +55,7 @@ from xoscar.utils import get_next_port from ..constants import XINFERENCE_DEFAULT_ENDPOINT_PORT -from ..core.event import EventCollectorActor +from ..core.event import EventCollectorActor, EventType, Event from ..core.supervisor import SupervisorActor from ..core.utils import json_dumps from ..types import ( @@ -198,6 +199,22 @@ async def _get_event_collector_ref(self) -> xo.ActorRefType[EventCollectorActor] ) return self._event_collector_ref + async def _report_error_event(self, model_uid: str, content: str): + try: + event_collector_ref = await self._get_event_collector_ref() + await event_collector_ref.report_event( + model_uid, + Event( + event_type=EventType.ERROR, + event_ts=int(time.time()), + event_content=content, + ), + ) + except Exception: + logger.exception( + "Report error event failed, model: %s, content: %s", model_uid, content + ) + async def login_for_access_token(self, form_data: LoginUserForm) -> JSONResponse: user = authenticate_user( self._auth_config.user_config, form_data.username, form_data.password @@ -293,7 +310,7 @@ def serve(self, logging_conf: Optional[dict] = None): "/v1/models/{model_uid}/events", self.get_model_events, methods=["GET"], - dependencies=[Security(verify_token, scopes=["models:events"])] + dependencies=[Security(verify_token, scopes=["models:read"])] if self.is_authenticated() else None, ) @@ -757,10 +774,12 @@ async def create_completion( model = await (await self._get_supervisor_ref()).get_model(model_uid) except ValueError as ve: logger.error(str(ve), exc_info=True) + await self._report_error_event(model_uid, str(ve)) raise HTTPException(status_code=400, detail=str(ve)) except Exception as e: logger.error(e, exc_info=True) + await self._report_error_event(model_uid, str(e)) raise HTTPException(status_code=500, detail=str(e)) if body.stream: @@ -776,6 +795,7 @@ async def stream_results(): yield item except Exception as ex: logger.exception("Completion stream got an error: %s", ex) + await self._report_error_event(model_uid, str(ex)) # https://github.com/openai/openai-python/blob/e0aafc6c1a45334ac889fe3e54957d309c3af93f/src/openai/_streaming.py#L107 yield dict(data=json.dumps({"error": str(ex)})) @@ -786,6 +806,7 @@ async def stream_results(): return Response(data, media_type="application/json") except Exception as e: logger.error(e, exc_info=True) + await self._report_error_event(model_uid, str(e)) self.handle_request_limit_error(e) raise HTTPException(status_code=500, detail=str(e)) @@ -796,9 +817,11 @@ async def create_embedding(self, request: CreateEmbeddingRequest) -> Response: model = await (await self._get_supervisor_ref()).get_model(model_uid) except ValueError as ve: logger.error(str(ve), exc_info=True) + await self._report_error_event(model_uid, str(ve)) raise HTTPException(status_code=400, detail=str(ve)) except Exception as e: logger.error(e, exc_info=True) + await self._report_error_event(model_uid, str(e)) raise HTTPException(status_code=500, detail=str(e)) try: @@ -806,10 +829,12 @@ async def create_embedding(self, request: CreateEmbeddingRequest) -> Response: return Response(embedding, media_type="application/json") except RuntimeError as re: logger.error(re, exc_info=True) + await self._report_error_event(model_uid, str(re)) self.handle_request_limit_error(re) raise HTTPException(status_code=400, detail=str(re)) except Exception as e: logger.error(e, exc_info=True) + await self._report_error_event(model_uid, str(e)) raise HTTPException(status_code=500, detail=str(e)) async def rerank(self, request: RerankRequest) -> Response: @@ -818,9 +843,11 @@ async def rerank(self, request: RerankRequest) -> Response: model = await (await self._get_supervisor_ref()).get_model(model_uid) except ValueError as ve: logger.error(str(ve), exc_info=True) + await self._report_error_event(model_uid, str(ve)) raise HTTPException(status_code=400, detail=str(ve)) except Exception as e: logger.error(e, exc_info=True) + await self._report_error_event(model_uid, str(e)) raise HTTPException(status_code=500, detail=str(e)) try: @@ -834,10 +861,12 @@ async def rerank(self, request: RerankRequest) -> Response: return Response(scores, media_type="application/json") except RuntimeError as re: logger.error(re, exc_info=True) + await self._report_error_event(model_uid, str(re)) self.handle_request_limit_error(re) raise HTTPException(status_code=400, detail=str(re)) except Exception as e: logger.error(e, exc_info=True) + await self._report_error_event(model_uid, str(e)) raise HTTPException(status_code=500, detail=str(e)) async def create_images(self, request: TextToImageRequest) -> Response: @@ -846,9 +875,11 @@ async def create_images(self, request: TextToImageRequest) -> Response: model = await (await self._get_supervisor_ref()).get_model(model_uid) except ValueError as ve: logger.error(str(ve), exc_info=True) + await self._report_error_event(model_uid, str(ve)) raise HTTPException(status_code=400, detail=str(ve)) except Exception as e: logger.error(e, exc_info=True) + await self._report_error_event(model_uid, str(e)) raise HTTPException(status_code=500, detail=str(e)) try: @@ -863,10 +894,12 @@ async def create_images(self, request: TextToImageRequest) -> Response: return Response(content=image_list, media_type="application/json") except RuntimeError as re: logger.error(re, exc_info=True) + await self._report_error_event(model_uid, str(re)) self.handle_request_limit_error(re) raise HTTPException(status_code=400, detail=str(re)) except Exception as e: logger.error(e, exc_info=True) + await self._report_error_event(model_uid, str(e)) raise HTTPException(status_code=500, detail=str(e)) async def create_variations( @@ -885,9 +918,11 @@ async def create_variations( model_ref = await (await self._get_supervisor_ref()).get_model(model_uid) except ValueError as ve: logger.error(str(ve), exc_info=True) + await self._report_error_event(model_uid, str(ve)) raise HTTPException(status_code=400, detail=str(ve)) except Exception as e: logger.error(e, exc_info=True) + await self._report_error_event(model_uid, str(e)) raise HTTPException(status_code=500, detail=str(e)) try: @@ -905,9 +940,11 @@ async def create_variations( return Response(content=image_list, media_type="application/json") except RuntimeError as re: logger.error(re, exc_info=True) + await self._report_error_event(model_uid, str(re)) raise HTTPException(status_code=400, detail=str(re)) except Exception as e: logger.error(e, exc_info=True) + await self._report_error_event(model_uid, str(e)) raise HTTPException(status_code=500, detail=str(e)) async def create_chat_completion( @@ -975,18 +1012,22 @@ async def create_chat_completion( model = await (await self._get_supervisor_ref()).get_model(model_uid) except ValueError as ve: logger.error(str(ve), exc_info=True) + await self._report_error_event(model_uid, str(ve)) raise HTTPException(status_code=400, detail=str(ve)) except Exception as e: logger.error(e, exc_info=True) + await self._report_error_event(model_uid, str(e)) raise HTTPException(status_code=500, detail=str(e)) try: desc = await (await self._get_supervisor_ref()).describe_model(model_uid) except ValueError as ve: logger.error(str(ve), exc_info=True) + await self._report_error_event(model_uid, str(ve)) raise HTTPException(status_code=400, detail=str(ve)) except Exception as e: logger.error(e, exc_info=True) + await self._report_error_event(model_uid, str(e)) raise HTTPException(status_code=500, detail=str(e)) model_name = desc.get("model_name", "") @@ -1031,11 +1072,13 @@ async def stream_results(): prompt, system_prompt, chat_history, kwargs ) except RuntimeError as re: + await self._report_error_event(model_uid, str(re)) self.handle_request_limit_error(re) async for item in iterator: yield item except Exception as ex: logger.exception("Chat completion stream got an error: %s", ex) + await self._report_error_event(model_uid, str(ex)) # https://github.com/openai/openai-python/blob/e0aafc6c1a45334ac889fe3e54957d309c3af93f/src/openai/_streaming.py#L107 yield dict(data=json.dumps({"error": str(ex)})) @@ -1049,6 +1092,7 @@ async def stream_results(): return Response(content=data, media_type="application/json") except Exception as e: logger.error(e, exc_info=True) + await self._report_error_event(model_uid, str(e)) self.handle_request_limit_error(e) raise HTTPException(status_code=500, detail=str(e)) From 18a15b7be44a1138d81bec2d366315adbcd8712d Mon Sep 17 00:00:00 2001 From: codingl2k1 Date: Mon, 22 Jan 2024 14:08:39 +0800 Subject: [PATCH 3/3] Fix lint --- xinference/api/restful_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xinference/api/restful_api.py b/xinference/api/restful_api.py index f219393543..6394cf542c 100644 --- a/xinference/api/restful_api.py +++ b/xinference/api/restful_api.py @@ -55,7 +55,7 @@ from xoscar.utils import get_next_port from ..constants import XINFERENCE_DEFAULT_ENDPOINT_PORT -from ..core.event import EventCollectorActor, EventType, Event +from ..core.event import Event, EventCollectorActor, EventType from ..core.supervisor import SupervisorActor from ..core.utils import json_dumps from ..types import (