Skip to content

Commit

Permalink
[Feat]: add Logger API (#284)
Browse files Browse the repository at this point in the history
* add logger

* process queue

* add test

* fix test

* add test

* update docs

* update docs

* fix pickling

* fix tests

* add test

* 100% coverage

* clean API

* add license info

* apply feedback
  • Loading branch information
aniketmaurya authored Sep 23, 2024
1 parent 44e0fe9 commit 92b0dd5
Show file tree
Hide file tree
Showing 6 changed files with 333 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/litserve/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,6 @@
from litserve import test_examples
from litserve.specs.openai import OpenAISpec
from litserve.callbacks import Callback
from litserve.loggers import Logger

__all__ = ["LitAPI", "LitServer", "Request", "Response", "test_examples", "OpenAISpec", "Callback"]
__all__ = ["LitAPI", "LitServer", "Request", "Response", "test_examples", "OpenAISpec", "Callback", "Logger"]
19 changes: 19 additions & 0 deletions src/litserve/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
# limitations under the License.
import inspect
import json
import warnings
from abc import ABC, abstractmethod
from typing import Optional
from queue import Queue

from pydantic import BaseModel

Expand All @@ -26,6 +28,7 @@ class LitAPI(ABC):
_default_unbatch: callable = None
_spec: LitSpec = None
_device: Optional[str] = None
_logger_queue: Optional[Queue] = None
request_timeout: Optional[float] = None

@abstractmethod
Expand Down Expand Up @@ -172,3 +175,19 @@ def encode_response(self, outputs):
yield encoded_output
"""
)

def set_logger_queue(self, queue: Queue):
"""Set the queue for logging events."""

self._logger_queue = queue

def log(self, key, value):
"""Log a key-value pair to the server."""
if self._logger_queue is None:
warnings.warn(
f"Logging event ('{key}', '{value}') attempted without a configured logger. "
"To track and visualize metrics, please initialize and attach a logger. "
"If this is intentional, you can safely ignore this message."
)
return
self._logger_queue.put((key, value))
137 changes: 137 additions & 0 deletions src/litserve/loggers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import multiprocessing as mp
from abc import ABC, abstractmethod
from typing import List, Optional, Union, TYPE_CHECKING

from starlette.types import ASGIApp
import logging

module_logger = logging.getLogger(__name__)

if TYPE_CHECKING: # pragma: no cover
from litserve import LitServer


class Logger(ABC):
def __init__(self):
self._config = {}

def mount(self, path: str, app: ASGIApp) -> None:
"""Mount an ASGI app endpoint to LitServer. Use this method when you want to add an additional endpoint to the
server such as /metrics endpoint for prometheus metrics.
Args:
path (str): The path to mount the app to.
app (ASGIApp): The ASGI app to mount.
"""
self._config.update({"mount": {"path": path, "app": app}})

@abstractmethod
def process(self, key, value):
"""Process a log entry from the log queue.
This method should be implemented to define the specific logic for processing
log entries.
Args:
key (str): The key associated with the log entry, typically indicating the type or category of the log.
value (Any): The value associated with the log entry, containing the actual log data.
Raises:
NotImplementedError: This method must be overridden by subclasses. If not, calling this method will raise
a NotImplementedError.
Example:
Here is an example of a Logger that logs monitoring metrics using Prometheus:
from prometheus_client import Counter
class PrometheusLogger(Logger):
def __init__(self):
super().__init__()
self._metric_counter = Counter('log_entries', 'Count of log entries')
def process(self, key, value):
# Increment the Prometheus counter for each log entry
self._metric_counter.inc()
print(f"Logged {key}: {value}")
"""
raise NotImplementedError # pragma: no cover


class _LoggerConnector:
"""_LoggerConnector is responsible for connecting Logger instances with the LitServer and managing their lifecycle.
This class handles the following tasks:
- Manages a queue (multiprocessing.Queue) where log data is placed using the LitAPI.log method.
- Initiates a separate process to consume the log queue and process the log data using the associated
Logger instances.
"""

def __init__(self, lit_server: "LitServer", loggers: Optional[Union[List[Logger], Logger]] = None):
self._loggers = []
self._lit_server = lit_server
if loggers is None:
return # No loggers to add
if isinstance(loggers, list):
for logger in loggers:
if not isinstance(logger, Logger):
raise ValueError("Logger must be an instance of litserve.Logger")
self.add_logger(logger)
elif isinstance(loggers, Logger):
self.add_logger(loggers)
else:
raise ValueError("loggers must be a list or an instance of litserve.Logger")

def _mount(self, path: str, app: ASGIApp) -> None:
self._lit_server.app.mount(path, app)

def add_logger(self, logger: Logger):
self._loggers.append(logger)
if "mount" in logger._config:
self._mount(logger._config["mount"]["path"], logger._config["mount"]["app"])

@staticmethod
def _process_logger_queue(loggers: List[Logger], queue):
while True:
key, value = queue.get()
for logger in loggers:
logger.process(key, value)

@functools.cache # Run once per LitServer instance
def run(self, lit_server: "LitServer"):
queue = lit_server.logger_queue
lit_server.lit_api.set_logger_queue(queue)

# Disconnect the logger connector from the LitServer to avoid pickling issues
self._lit_server = None

if not self._loggers:
return

module_logger.debug(f"Starting logger process with {len(self._loggers)} loggers")
ctx = mp.get_context("spawn")
process = ctx.Process(
target=_LoggerConnector._process_logger_queue,
args=(
self._loggers,
queue,
),
)
process.start()
8 changes: 8 additions & 0 deletions src/litserve/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from litserve import LitAPI
from litserve.callbacks.base import CallbackRunner, Callback, EventTypes
from litserve.connector import _Connector
from litserve.loggers import Logger, _LoggerConnector
from litserve.loops import inference_worker
from litserve.specs import OpenAISpec
from litserve.specs.base import LitSpec
Expand Down Expand Up @@ -116,6 +117,7 @@ def __init__(
max_payload_size=None,
callbacks: Optional[Union[List[Callback], Callback]] = None,
middlewares: Optional[list[Union[Callable, tuple[Callable, dict]]]] = None,
loggers: Optional[Union[Logger, List[Logger]]] = None,
):
if batch_timeout > timeout and timeout not in (False, -1):
raise ValueError("batch_timeout must be less than timeout")
Expand Down Expand Up @@ -165,6 +167,8 @@ def __init__(
if max_payload_size is not None:
middlewares.append((MaxSizeMiddleware, {"max_size": max_payload_size}))
self.middlewares = middlewares
self._logger_connector = _LoggerConnector(self, loggers)
self.logger_queue = None
self.lit_api = lit_api
self.lit_spec = spec
self.workers_per_device = workers_per_device
Expand Down Expand Up @@ -206,6 +210,10 @@ def launch_inference_worker(self, num_uvicorn_servers: int):
manager = mp.Manager()
self.workers_setup_status = manager.dict()
self.request_queue = manager.Queue()
if self._logger_connector._loggers:
self.logger_queue = manager.Queue()

self._logger_connector.run(self)

self.response_queues = [manager.Queue() for _ in range(num_uvicorn_servers)]

Expand Down
20 changes: 20 additions & 0 deletions tests/test_litapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,3 +240,23 @@ def test_device_property():
api = ls.test_examples.SimpleLitAPI()
api.device = "cpu"
assert api.device == "cpu"


class TestLogger(ls.Logger):
def process(self, key, value):
self.processed_data = (key, value)


def test_log():
api = ls.test_examples.SimpleLitAPI()
assert api._logger_queue is None, "Logger queue should be None"
assert api.log("time", 0.1) is None, "Log should return None"
with pytest.warns(UserWarning, match="attempted without a configured logger"):
api.log("time", 0.1)

api = ls.test_examples.SimpleLitAPI()
assert api._logger_queue is None, "Logger queue should be None"
server = ls.LitServer(api, loggers=TestLogger())
server.launch_inference_worker(1)
api.log("time", 0.1)
assert server.logger_queue.get() == ("time", 0.1)
147 changes: 147 additions & 0 deletions tests/test_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import os
import time

import pytest
from fastapi.testclient import TestClient

from unittest.mock import MagicMock
from litserve.loggers import Logger, _LoggerConnector

import litserve as ls
from litserve.utils import wrap_litserve_start


class TestLogger(Logger):
def process(self, key, value):
self.processed_data = (key, value)


@pytest.fixture
def mock_lit_server():
mock_server = MagicMock()
mock_server.log_queue.get = MagicMock(return_value=("test_key", "test_value"))
return mock_server


@pytest.fixture
def test_logger():
return TestLogger()


@pytest.fixture
def logger_connector(mock_lit_server, test_logger):
return _LoggerConnector(mock_lit_server, [test_logger])


def test_logger_mount(test_logger):
mock_app = MagicMock()
test_logger.mount("/test", mock_app)
assert test_logger._config["mount"]["path"] == "/test"
assert test_logger._config["mount"]["app"] == mock_app


def test_connector_add_logger(logger_connector):
new_logger = TestLogger()
logger_connector.add_logger(new_logger)
assert new_logger in logger_connector._loggers


def test_connector_mount(mock_lit_server, test_logger, logger_connector):
mock_app = MagicMock()
test_logger.mount("/test", mock_app)
logger_connector.add_logger(test_logger)
mock_lit_server.app.mount.assert_called_with("/test", mock_app)


def test_invalid_loggers():
_LoggerConnector(None, TestLogger())
with pytest.raises(ValueError, match="Logger must be an instance of litserve.Logger"):
_ = _LoggerConnector(None, [MagicMock()])

with pytest.raises(ValueError, match="loggers must be a list or an instance of litserve.Logger"):
_ = _LoggerConnector(None, MagicMock())


class LoggerAPI(ls.test_examples.SimpleLitAPI):
def predict(self, input):
result = super().predict(input)
for i in range(1, 5):
self.log("time", i * 0.1)
return result


def test_server_wo_logger():
api = LoggerAPI()
server = ls.LitServer(api)

with wrap_litserve_start(server) as server, TestClient(server.app) as client:
response = client.post("/predict", json={"input": 4.0})
assert response.json() == {"output": 16.0}


class FileLogger(ls.Logger):
def process(self, key, value):
with open("test_logger_temp.txt", "a+") as f:
f.write(f"{key}: {value:.1f}\n")


def test_logger_with_api():
api = LoggerAPI()
server = ls.LitServer(api, loggers=[FileLogger()])
with contextlib.suppress(Exception):
os.remove("test_logger_temp.txt")
with wrap_litserve_start(server) as server, TestClient(server.app) as client:
response = client.post("/predict", json={"input": 4.0})
assert response.json() == {"output": 16.0}
# Wait for FileLogger to write to file
time.sleep(0.1)
with open("test_logger_temp.txt") as f:
data = f.readlines()
assert data == [
"time: 0.1\n",
"time: 0.2\n",
"time: 0.3\n",
"time: 0.4\n",
], f"Expected metric not found in logger file {data}"
os.remove("test_logger_temp.txt")


class PredictionTimeLogger(ls.Callback):
def on_after_predict(self, lit_api):
for i in range(1, 5):
lit_api.log("time", i * 0.1)


def test_logger_with_callback():
api = ls.test_examples.SimpleLitAPI()
server = ls.LitServer(api, loggers=[FileLogger()], callbacks=[PredictionTimeLogger()])
with contextlib.suppress(Exception):
os.remove("test_logger_temp.txt")
with wrap_litserve_start(server) as server, TestClient(server.app) as client:
response = client.post("/predict", json={"input": 4.0})
assert response.json() == {"output": 16.0}
# Wait for FileLogger to write to file
time.sleep(1)
with open("test_logger_temp.txt") as f:
data = f.readlines()
assert data == [
"time: 0.1\n",
"time: 0.2\n",
"time: 0.3\n",
"time: 0.4\n",
], f"Expected metric not found in logger file {data}"
os.remove("test_logger_temp.txt")

0 comments on commit 92b0dd5

Please sign in to comment.