diff --git a/src/litserve/__init__.py b/src/litserve/__init__.py index 41f3efea..6b41589a 100644 --- a/src/litserve/__init__.py +++ b/src/litserve/__init__.py @@ -17,5 +17,6 @@ from litserve import test_examples from litserve.specs.openai import OpenAISpec from litserve.callbacks import Callback +from litserve.loggers import Logger -__all__ = ["LitAPI", "LitServer", "Request", "Response", "test_examples", "OpenAISpec", "Callback"] +__all__ = ["LitAPI", "LitServer", "Request", "Response", "test_examples", "OpenAISpec", "Callback", "Logger"] diff --git a/src/litserve/api.py b/src/litserve/api.py index 5f30c236..18c9bb7d 100644 --- a/src/litserve/api.py +++ b/src/litserve/api.py @@ -13,8 +13,10 @@ # limitations under the License. import inspect import json +import warnings from abc import ABC, abstractmethod from typing import Optional +from queue import Queue from pydantic import BaseModel @@ -26,6 +28,7 @@ class LitAPI(ABC): _default_unbatch: callable = None _spec: LitSpec = None _device: Optional[str] = None + _logger_queue: Optional[Queue] = None request_timeout: Optional[float] = None @abstractmethod @@ -172,3 +175,19 @@ def encode_response(self, outputs): yield encoded_output """ ) + + def set_logger_queue(self, queue: Queue): + """Set the queue for logging events.""" + + self._logger_queue = queue + + def log(self, key, value): + """Log a key-value pair to the server.""" + if self._logger_queue is None: + warnings.warn( + f"Logging event ('{key}', '{value}') attempted without a configured logger. " + "To track and visualize metrics, please initialize and attach a logger. " + "If this is intentional, you can safely ignore this message." + ) + return + self._logger_queue.put((key, value)) diff --git a/src/litserve/loggers.py b/src/litserve/loggers.py new file mode 100644 index 00000000..388ffdc3 --- /dev/null +++ b/src/litserve/loggers.py @@ -0,0 +1,137 @@ +# Copyright The Lightning AI team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import functools +import multiprocessing as mp +from abc import ABC, abstractmethod +from typing import List, Optional, Union, TYPE_CHECKING + +from starlette.types import ASGIApp +import logging + +module_logger = logging.getLogger(__name__) + +if TYPE_CHECKING: # pragma: no cover + from litserve import LitServer + + +class Logger(ABC): + def __init__(self): + self._config = {} + + def mount(self, path: str, app: ASGIApp) -> None: + """Mount an ASGI app endpoint to LitServer. Use this method when you want to add an additional endpoint to the + server such as /metrics endpoint for prometheus metrics. + + Args: + path (str): The path to mount the app to. + app (ASGIApp): The ASGI app to mount. + + """ + self._config.update({"mount": {"path": path, "app": app}}) + + @abstractmethod + def process(self, key, value): + """Process a log entry from the log queue. + + This method should be implemented to define the specific logic for processing + log entries. + + Args: + key (str): The key associated with the log entry, typically indicating the type or category of the log. + value (Any): The value associated with the log entry, containing the actual log data. + + Raises: + NotImplementedError: This method must be overridden by subclasses. If not, calling this method will raise + a NotImplementedError. + + Example: + Here is an example of a Logger that logs monitoring metrics using Prometheus: + + from prometheus_client import Counter + + class PrometheusLogger(Logger): + def __init__(self): + super().__init__() + self._metric_counter = Counter('log_entries', 'Count of log entries') + + def process(self, key, value): + # Increment the Prometheus counter for each log entry + self._metric_counter.inc() + print(f"Logged {key}: {value}") + + """ + raise NotImplementedError # pragma: no cover + + +class _LoggerConnector: + """_LoggerConnector is responsible for connecting Logger instances with the LitServer and managing their lifecycle. + + This class handles the following tasks: + - Manages a queue (multiprocessing.Queue) where log data is placed using the LitAPI.log method. + - Initiates a separate process to consume the log queue and process the log data using the associated + Logger instances. + + """ + + def __init__(self, lit_server: "LitServer", loggers: Optional[Union[List[Logger], Logger]] = None): + self._loggers = [] + self._lit_server = lit_server + if loggers is None: + return # No loggers to add + if isinstance(loggers, list): + for logger in loggers: + if not isinstance(logger, Logger): + raise ValueError("Logger must be an instance of litserve.Logger") + self.add_logger(logger) + elif isinstance(loggers, Logger): + self.add_logger(loggers) + else: + raise ValueError("loggers must be a list or an instance of litserve.Logger") + + def _mount(self, path: str, app: ASGIApp) -> None: + self._lit_server.app.mount(path, app) + + def add_logger(self, logger: Logger): + self._loggers.append(logger) + if "mount" in logger._config: + self._mount(logger._config["mount"]["path"], logger._config["mount"]["app"]) + + @staticmethod + def _process_logger_queue(loggers: List[Logger], queue): + while True: + key, value = queue.get() + for logger in loggers: + logger.process(key, value) + + @functools.cache # Run once per LitServer instance + def run(self, lit_server: "LitServer"): + queue = lit_server.logger_queue + lit_server.lit_api.set_logger_queue(queue) + + # Disconnect the logger connector from the LitServer to avoid pickling issues + self._lit_server = None + + if not self._loggers: + return + + module_logger.debug(f"Starting logger process with {len(self._loggers)} loggers") + ctx = mp.get_context("spawn") + process = ctx.Process( + target=_LoggerConnector._process_logger_queue, + args=( + self._loggers, + queue, + ), + ) + process.start() diff --git a/src/litserve/server.py b/src/litserve/server.py index 289b720a..70bda914 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -39,6 +39,7 @@ from litserve import LitAPI from litserve.callbacks.base import CallbackRunner, Callback, EventTypes from litserve.connector import _Connector +from litserve.loggers import Logger, _LoggerConnector from litserve.loops import inference_worker from litserve.specs import OpenAISpec from litserve.specs.base import LitSpec @@ -116,6 +117,7 @@ def __init__( max_payload_size=None, callbacks: Optional[Union[List[Callback], Callback]] = None, middlewares: Optional[list[Union[Callable, tuple[Callable, dict]]]] = None, + loggers: Optional[Union[Logger, List[Logger]]] = None, ): if batch_timeout > timeout and timeout not in (False, -1): raise ValueError("batch_timeout must be less than timeout") @@ -165,6 +167,8 @@ def __init__( if max_payload_size is not None: middlewares.append((MaxSizeMiddleware, {"max_size": max_payload_size})) self.middlewares = middlewares + self._logger_connector = _LoggerConnector(self, loggers) + self.logger_queue = None self.lit_api = lit_api self.lit_spec = spec self.workers_per_device = workers_per_device @@ -206,6 +210,10 @@ def launch_inference_worker(self, num_uvicorn_servers: int): manager = mp.Manager() self.workers_setup_status = manager.dict() self.request_queue = manager.Queue() + if self._logger_connector._loggers: + self.logger_queue = manager.Queue() + + self._logger_connector.run(self) self.response_queues = [manager.Queue() for _ in range(num_uvicorn_servers)] diff --git a/tests/test_litapi.py b/tests/test_litapi.py index e8a604ad..8aed575e 100644 --- a/tests/test_litapi.py +++ b/tests/test_litapi.py @@ -240,3 +240,23 @@ def test_device_property(): api = ls.test_examples.SimpleLitAPI() api.device = "cpu" assert api.device == "cpu" + + +class TestLogger(ls.Logger): + def process(self, key, value): + self.processed_data = (key, value) + + +def test_log(): + api = ls.test_examples.SimpleLitAPI() + assert api._logger_queue is None, "Logger queue should be None" + assert api.log("time", 0.1) is None, "Log should return None" + with pytest.warns(UserWarning, match="attempted without a configured logger"): + api.log("time", 0.1) + + api = ls.test_examples.SimpleLitAPI() + assert api._logger_queue is None, "Logger queue should be None" + server = ls.LitServer(api, loggers=TestLogger()) + server.launch_inference_worker(1) + api.log("time", 0.1) + assert server.logger_queue.get() == ("time", 0.1) diff --git a/tests/test_logger.py b/tests/test_logger.py new file mode 100644 index 00000000..7a03424a --- /dev/null +++ b/tests/test_logger.py @@ -0,0 +1,147 @@ +# Copyright The Lightning AI team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import contextlib +import os +import time + +import pytest +from fastapi.testclient import TestClient + +from unittest.mock import MagicMock +from litserve.loggers import Logger, _LoggerConnector + +import litserve as ls +from litserve.utils import wrap_litserve_start + + +class TestLogger(Logger): + def process(self, key, value): + self.processed_data = (key, value) + + +@pytest.fixture +def mock_lit_server(): + mock_server = MagicMock() + mock_server.log_queue.get = MagicMock(return_value=("test_key", "test_value")) + return mock_server + + +@pytest.fixture +def test_logger(): + return TestLogger() + + +@pytest.fixture +def logger_connector(mock_lit_server, test_logger): + return _LoggerConnector(mock_lit_server, [test_logger]) + + +def test_logger_mount(test_logger): + mock_app = MagicMock() + test_logger.mount("/test", mock_app) + assert test_logger._config["mount"]["path"] == "/test" + assert test_logger._config["mount"]["app"] == mock_app + + +def test_connector_add_logger(logger_connector): + new_logger = TestLogger() + logger_connector.add_logger(new_logger) + assert new_logger in logger_connector._loggers + + +def test_connector_mount(mock_lit_server, test_logger, logger_connector): + mock_app = MagicMock() + test_logger.mount("/test", mock_app) + logger_connector.add_logger(test_logger) + mock_lit_server.app.mount.assert_called_with("/test", mock_app) + + +def test_invalid_loggers(): + _LoggerConnector(None, TestLogger()) + with pytest.raises(ValueError, match="Logger must be an instance of litserve.Logger"): + _ = _LoggerConnector(None, [MagicMock()]) + + with pytest.raises(ValueError, match="loggers must be a list or an instance of litserve.Logger"): + _ = _LoggerConnector(None, MagicMock()) + + +class LoggerAPI(ls.test_examples.SimpleLitAPI): + def predict(self, input): + result = super().predict(input) + for i in range(1, 5): + self.log("time", i * 0.1) + return result + + +def test_server_wo_logger(): + api = LoggerAPI() + server = ls.LitServer(api) + + with wrap_litserve_start(server) as server, TestClient(server.app) as client: + response = client.post("/predict", json={"input": 4.0}) + assert response.json() == {"output": 16.0} + + +class FileLogger(ls.Logger): + def process(self, key, value): + with open("test_logger_temp.txt", "a+") as f: + f.write(f"{key}: {value:.1f}\n") + + +def test_logger_with_api(): + api = LoggerAPI() + server = ls.LitServer(api, loggers=[FileLogger()]) + with contextlib.suppress(Exception): + os.remove("test_logger_temp.txt") + with wrap_litserve_start(server) as server, TestClient(server.app) as client: + response = client.post("/predict", json={"input": 4.0}) + assert response.json() == {"output": 16.0} + # Wait for FileLogger to write to file + time.sleep(0.1) + with open("test_logger_temp.txt") as f: + data = f.readlines() + assert data == [ + "time: 0.1\n", + "time: 0.2\n", + "time: 0.3\n", + "time: 0.4\n", + ], f"Expected metric not found in logger file {data}" + os.remove("test_logger_temp.txt") + + +class PredictionTimeLogger(ls.Callback): + def on_after_predict(self, lit_api): + for i in range(1, 5): + lit_api.log("time", i * 0.1) + + +def test_logger_with_callback(): + api = ls.test_examples.SimpleLitAPI() + server = ls.LitServer(api, loggers=[FileLogger()], callbacks=[PredictionTimeLogger()]) + with contextlib.suppress(Exception): + os.remove("test_logger_temp.txt") + with wrap_litserve_start(server) as server, TestClient(server.app) as client: + response = client.post("/predict", json={"input": 4.0}) + assert response.json() == {"output": 16.0} + # Wait for FileLogger to write to file + time.sleep(1) + with open("test_logger_temp.txt") as f: + data = f.readlines() + assert data == [ + "time: 0.1\n", + "time: 0.2\n", + "time: 0.3\n", + "time: 0.4\n", + ], f"Expected metric not found in logger file {data}" + os.remove("test_logger_temp.txt")