-
Notifications
You must be signed in to change notification settings - Fork 158
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add logger * process queue * add test * fix test * add test * update docs * update docs * fix pickling * fix tests * add test * 100% coverage * clean API * add license info * apply feedback
- Loading branch information
1 parent
44e0fe9
commit 92b0dd5
Showing
6 changed files
with
333 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
# Copyright The Lightning AI team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import functools | ||
import multiprocessing as mp | ||
from abc import ABC, abstractmethod | ||
from typing import List, Optional, Union, TYPE_CHECKING | ||
|
||
from starlette.types import ASGIApp | ||
import logging | ||
|
||
module_logger = logging.getLogger(__name__) | ||
|
||
if TYPE_CHECKING: # pragma: no cover | ||
from litserve import LitServer | ||
|
||
|
||
class Logger(ABC): | ||
def __init__(self): | ||
self._config = {} | ||
|
||
def mount(self, path: str, app: ASGIApp) -> None: | ||
"""Mount an ASGI app endpoint to LitServer. Use this method when you want to add an additional endpoint to the | ||
server such as /metrics endpoint for prometheus metrics. | ||
Args: | ||
path (str): The path to mount the app to. | ||
app (ASGIApp): The ASGI app to mount. | ||
""" | ||
self._config.update({"mount": {"path": path, "app": app}}) | ||
|
||
@abstractmethod | ||
def process(self, key, value): | ||
"""Process a log entry from the log queue. | ||
This method should be implemented to define the specific logic for processing | ||
log entries. | ||
Args: | ||
key (str): The key associated with the log entry, typically indicating the type or category of the log. | ||
value (Any): The value associated with the log entry, containing the actual log data. | ||
Raises: | ||
NotImplementedError: This method must be overridden by subclasses. If not, calling this method will raise | ||
a NotImplementedError. | ||
Example: | ||
Here is an example of a Logger that logs monitoring metrics using Prometheus: | ||
from prometheus_client import Counter | ||
class PrometheusLogger(Logger): | ||
def __init__(self): | ||
super().__init__() | ||
self._metric_counter = Counter('log_entries', 'Count of log entries') | ||
def process(self, key, value): | ||
# Increment the Prometheus counter for each log entry | ||
self._metric_counter.inc() | ||
print(f"Logged {key}: {value}") | ||
""" | ||
raise NotImplementedError # pragma: no cover | ||
|
||
|
||
class _LoggerConnector: | ||
"""_LoggerConnector is responsible for connecting Logger instances with the LitServer and managing their lifecycle. | ||
This class handles the following tasks: | ||
- Manages a queue (multiprocessing.Queue) where log data is placed using the LitAPI.log method. | ||
- Initiates a separate process to consume the log queue and process the log data using the associated | ||
Logger instances. | ||
""" | ||
|
||
def __init__(self, lit_server: "LitServer", loggers: Optional[Union[List[Logger], Logger]] = None): | ||
self._loggers = [] | ||
self._lit_server = lit_server | ||
if loggers is None: | ||
return # No loggers to add | ||
if isinstance(loggers, list): | ||
for logger in loggers: | ||
if not isinstance(logger, Logger): | ||
raise ValueError("Logger must be an instance of litserve.Logger") | ||
self.add_logger(logger) | ||
elif isinstance(loggers, Logger): | ||
self.add_logger(loggers) | ||
else: | ||
raise ValueError("loggers must be a list or an instance of litserve.Logger") | ||
|
||
def _mount(self, path: str, app: ASGIApp) -> None: | ||
self._lit_server.app.mount(path, app) | ||
|
||
def add_logger(self, logger: Logger): | ||
self._loggers.append(logger) | ||
if "mount" in logger._config: | ||
self._mount(logger._config["mount"]["path"], logger._config["mount"]["app"]) | ||
|
||
@staticmethod | ||
def _process_logger_queue(loggers: List[Logger], queue): | ||
while True: | ||
key, value = queue.get() | ||
for logger in loggers: | ||
logger.process(key, value) | ||
|
||
@functools.cache # Run once per LitServer instance | ||
def run(self, lit_server: "LitServer"): | ||
queue = lit_server.logger_queue | ||
lit_server.lit_api.set_logger_queue(queue) | ||
|
||
# Disconnect the logger connector from the LitServer to avoid pickling issues | ||
self._lit_server = None | ||
|
||
if not self._loggers: | ||
return | ||
|
||
module_logger.debug(f"Starting logger process with {len(self._loggers)} loggers") | ||
ctx = mp.get_context("spawn") | ||
process = ctx.Process( | ||
target=_LoggerConnector._process_logger_queue, | ||
args=( | ||
self._loggers, | ||
queue, | ||
), | ||
) | ||
process.start() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
# Copyright The Lightning AI team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import contextlib | ||
import os | ||
import time | ||
|
||
import pytest | ||
from fastapi.testclient import TestClient | ||
|
||
from unittest.mock import MagicMock | ||
from litserve.loggers import Logger, _LoggerConnector | ||
|
||
import litserve as ls | ||
from litserve.utils import wrap_litserve_start | ||
|
||
|
||
class TestLogger(Logger): | ||
def process(self, key, value): | ||
self.processed_data = (key, value) | ||
|
||
|
||
@pytest.fixture | ||
def mock_lit_server(): | ||
mock_server = MagicMock() | ||
mock_server.log_queue.get = MagicMock(return_value=("test_key", "test_value")) | ||
return mock_server | ||
|
||
|
||
@pytest.fixture | ||
def test_logger(): | ||
return TestLogger() | ||
|
||
|
||
@pytest.fixture | ||
def logger_connector(mock_lit_server, test_logger): | ||
return _LoggerConnector(mock_lit_server, [test_logger]) | ||
|
||
|
||
def test_logger_mount(test_logger): | ||
mock_app = MagicMock() | ||
test_logger.mount("/test", mock_app) | ||
assert test_logger._config["mount"]["path"] == "/test" | ||
assert test_logger._config["mount"]["app"] == mock_app | ||
|
||
|
||
def test_connector_add_logger(logger_connector): | ||
new_logger = TestLogger() | ||
logger_connector.add_logger(new_logger) | ||
assert new_logger in logger_connector._loggers | ||
|
||
|
||
def test_connector_mount(mock_lit_server, test_logger, logger_connector): | ||
mock_app = MagicMock() | ||
test_logger.mount("/test", mock_app) | ||
logger_connector.add_logger(test_logger) | ||
mock_lit_server.app.mount.assert_called_with("/test", mock_app) | ||
|
||
|
||
def test_invalid_loggers(): | ||
_LoggerConnector(None, TestLogger()) | ||
with pytest.raises(ValueError, match="Logger must be an instance of litserve.Logger"): | ||
_ = _LoggerConnector(None, [MagicMock()]) | ||
|
||
with pytest.raises(ValueError, match="loggers must be a list or an instance of litserve.Logger"): | ||
_ = _LoggerConnector(None, MagicMock()) | ||
|
||
|
||
class LoggerAPI(ls.test_examples.SimpleLitAPI): | ||
def predict(self, input): | ||
result = super().predict(input) | ||
for i in range(1, 5): | ||
self.log("time", i * 0.1) | ||
return result | ||
|
||
|
||
def test_server_wo_logger(): | ||
api = LoggerAPI() | ||
server = ls.LitServer(api) | ||
|
||
with wrap_litserve_start(server) as server, TestClient(server.app) as client: | ||
response = client.post("/predict", json={"input": 4.0}) | ||
assert response.json() == {"output": 16.0} | ||
|
||
|
||
class FileLogger(ls.Logger): | ||
def process(self, key, value): | ||
with open("test_logger_temp.txt", "a+") as f: | ||
f.write(f"{key}: {value:.1f}\n") | ||
|
||
|
||
def test_logger_with_api(): | ||
api = LoggerAPI() | ||
server = ls.LitServer(api, loggers=[FileLogger()]) | ||
with contextlib.suppress(Exception): | ||
os.remove("test_logger_temp.txt") | ||
with wrap_litserve_start(server) as server, TestClient(server.app) as client: | ||
response = client.post("/predict", json={"input": 4.0}) | ||
assert response.json() == {"output": 16.0} | ||
# Wait for FileLogger to write to file | ||
time.sleep(0.1) | ||
with open("test_logger_temp.txt") as f: | ||
data = f.readlines() | ||
assert data == [ | ||
"time: 0.1\n", | ||
"time: 0.2\n", | ||
"time: 0.3\n", | ||
"time: 0.4\n", | ||
], f"Expected metric not found in logger file {data}" | ||
os.remove("test_logger_temp.txt") | ||
|
||
|
||
class PredictionTimeLogger(ls.Callback): | ||
def on_after_predict(self, lit_api): | ||
for i in range(1, 5): | ||
lit_api.log("time", i * 0.1) | ||
|
||
|
||
def test_logger_with_callback(): | ||
api = ls.test_examples.SimpleLitAPI() | ||
server = ls.LitServer(api, loggers=[FileLogger()], callbacks=[PredictionTimeLogger()]) | ||
with contextlib.suppress(Exception): | ||
os.remove("test_logger_temp.txt") | ||
with wrap_litserve_start(server) as server, TestClient(server.app) as client: | ||
response = client.post("/predict", json={"input": 4.0}) | ||
assert response.json() == {"output": 16.0} | ||
# Wait for FileLogger to write to file | ||
time.sleep(1) | ||
with open("test_logger_temp.txt") as f: | ||
data = f.readlines() | ||
assert data == [ | ||
"time: 0.1\n", | ||
"time: 0.2\n", | ||
"time: 0.3\n", | ||
"time: 0.4\n", | ||
], f"Expected metric not found in logger file {data}" | ||
os.remove("test_logger_temp.txt") |