diff --git a/src/litserve/server.py b/src/litserve/server.py index 4210df22..798d5636 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -110,7 +110,13 @@ def collate_requests( return payloads, timed_out_uids -def run_single_loop(lit_api: LitAPI, lit_spec: LitSpec, request_queue: Queue, response_queues: List[Queue]): +def run_single_loop( + lit_api: LitAPI, + lit_spec: LitSpec, + request_queue: Queue, + response_queues: List[Queue], + request_evicted_status: Dict[str, bool], +): while True: try: response_queue_id, uid, timestamp, x_enc = request_queue.get(timeout=1.0) @@ -146,6 +152,8 @@ def run_single_loop(lit_api: LitAPI, lit_spec: LitSpec, request_queue: Queue, re lit_api.encode_response, y, ) + # TODO: Cancel the task if the client disconnects + response_queues[response_queue_id].put((uid, (y_enc, LitAPIStatus.OK))) except Exception as e: logger.exception( @@ -217,7 +225,13 @@ def run_batched_loop( response_queues[response_queue_id].put((uid, (err_pkl, LitAPIStatus.ERROR))) -def run_streaming_loop(lit_api: LitAPI, lit_spec: LitSpec, request_queue: Queue, response_queues: List[Queue]): +def run_streaming_loop( + lit_api: LitAPI, + lit_spec: LitSpec, + request_queue: Queue, + response_queues: List[Queue], + request_evicted_status: Dict[str, bool], +): while True: try: response_queue_id, uid, timestamp, x_enc = request_queue.get(timeout=1.0) @@ -256,6 +270,9 @@ def run_streaming_loop(lit_api: LitAPI, lit_spec: LitSpec, request_queue: Queue, y_gen, ) for y_enc in y_enc_gen: + if request_evicted_status.get(uid): + request_evicted_status.pop(uid) + break y_enc = lit_api.format_encoded_response(y_enc) response_queues[response_queue_id].put((uid, (y_enc, LitAPIStatus.OK))) response_queues[response_queue_id].put((uid, ("", LitAPIStatus.FINISH_STREAMING))) @@ -338,6 +355,7 @@ def inference_worker( worker_id: int, request_queue: Queue, response_queues: List[Queue], + request_evicted_status: Dict[str, bool], max_batch_size: int, batch_timeout: float, stream: bool, @@ -357,7 +375,7 @@ def inference_worker( if max_batch_size > 1: run_batched_streaming_loop(lit_api, lit_spec, request_queue, response_queues, max_batch_size, batch_timeout) else: - run_streaming_loop(lit_api, lit_spec, request_queue, response_queues) + run_streaming_loop(lit_api, lit_spec, request_queue, response_queues, request_evicted_status) return if max_batch_size > 1: @@ -368,6 +386,7 @@ def inference_worker( lit_spec, request_queue, response_queues, + request_evicted_status, ) @@ -397,7 +416,7 @@ async def response_queue_to_buffer( await asyncio.sleep(0.0001) continue q, event = buffer[uid] - q.append(payload) + q.append((uid, payload)) event.set() else: @@ -498,6 +517,7 @@ def launch_inference_worker(self, num_uvicorn_servers: int): manager = mp.Manager() self.workers_setup_status = manager.dict() self.request_queue = manager.Queue() + self.request_evicted_status = manager.dict() self.response_queues = [] for _ in range(num_uvicorn_servers): @@ -531,6 +551,7 @@ def launch_inference_worker(self, num_uvicorn_servers: int): worker_id, self.request_queue, self.response_queues, + self.request_evicted_status, self.max_batch_size, self.batch_timeout, self.stream, @@ -568,26 +589,37 @@ def device_identifiers(self, accelerator, device): return [f"{accelerator}:{device}"] async def data_streamer(self, q: deque, data_available: asyncio.Event, send_status: bool = False): + uid = None while True: - await data_available.wait() - while len(q) > 0: - data, status = q.popleft() - if status == LitAPIStatus.FINISH_STREAMING: - return - - if status == LitAPIStatus.ERROR: - logger.error( - "Error occurred while streaming outputs from the inference worker. " - "Please check the above traceback." - ) + try: + await data_available.wait() + while len(q) > 0: + uid, (data, status) = q.popleft() + if status == LitAPIStatus.FINISH_STREAMING: + return + + if status == LitAPIStatus.ERROR: + logger.error( + "Error occurred while streaming outputs from the inference worker. " + "Please check the above traceback." + ) + if send_status: + yield data, status + return if send_status: yield data, status - return - if send_status: - yield data, status - else: - yield data - data_available.clear() + else: + yield data + data_available.clear() + except asyncio.CancelledError: + if uid is not None: + self.request_evicted_status[uid] = True + logger.error("Request evicted for the uid=%s", uid) + break + except Exception as e: + # Handle other exceptions that might occur + logger.error(f"Exception occurred during streaming: {e}") + break def setup_server(self): workers_ready = False @@ -625,8 +657,37 @@ async def predict(request: self.request_type, background_tasks: BackgroundTasks) self.request_queue.put_nowait((response_queue_id, uid, time.monotonic(), payload)) - await event.wait() - response, status = self.response_buffer.pop(uid) + async def wait_for_response(): + await event.wait() + return self.response_buffer.pop(uid) + + async def check_disconnection(): + while True: + if hasattr(request, "is_disconnected") and await request.is_disconnected(): + return True + await asyncio.sleep(1) # Check every second + + response_task = asyncio.create_task(wait_for_response()) + disconnection_task = asyncio.create_task(check_disconnection()) + + try: + # Use asyncio.wait to handle both response and disconnection checks + done, pending = await asyncio.wait( + [response_task, disconnection_task], return_when=asyncio.FIRST_COMPLETED + ) + if response_task in done: + response, status = await response_task + disconnection_task.cancel() + else: + response_task.cancel() + logger.error(f"Client disconnected for the request uid={uid}") + self.request_evicted_status[uid] = True + raise HTTPException(status_code=499, detail="Client closed request") + except asyncio.CancelledError: + response_task.cancel() + disconnection_task.cancel() + logger.error(f"Client disconnected for the request uid={uid}") + raise HTTPException(status_code=499, detail="Client closed request") if status == LitAPIStatus.ERROR: load_and_raise(response) diff --git a/tests/conftest.py b/tests/conftest.py index f7d1c84d..d92cea54 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -37,6 +37,12 @@ def encode_response(self, output) -> Response: return {"output": output} +class SimpleDelayedLitAPI(SimpleLitAPI): + def predict(self, x): + time.sleep(0.5) + return self.model(x) + + class SimpleStreamAPI(LitAPI): def setup(self, device) -> None: self.sentence = "LitServe is streaming output" @@ -55,6 +61,14 @@ def encode_response(self, output: Generator) -> Generator: yield out.lower() +class SimpleDelayedStreamAPI(SimpleStreamAPI): + def encode_response(self, output: Generator) -> Generator: + delay = 0.2 + for out in output: + time.sleep(delay) + yield out.lower() + + class SimpleBatchedStreamAPI(LitAPI): def setup(self, device) -> None: self.sentence = "LitServe is streaming output" @@ -88,11 +102,21 @@ def simple_litapi(): return SimpleLitAPI() +@pytest.fixture() +def simple_delayed_litapi(): + return SimpleDelayedLitAPI() + + @pytest.fixture() def simple_stream_api(): return SimpleStreamAPI() +@pytest.fixture() +def simple_delayed_stream_api(): + return SimpleDelayedStreamAPI() + + @pytest.fixture() def simple_batched_stream_api(): return SimpleBatchedStreamAPI() diff --git a/tests/test_lit_server.py b/tests/test_lit_server.py index 1e5b4f45..0f97e67f 100644 --- a/tests/test_lit_server.py +++ b/tests/test_lit_server.py @@ -13,32 +13,33 @@ # limitations under the License. import asyncio import inspect +import logging import pickle import re -from asgi_lifespan import LifespanManager -from litserve import LitAPI -from fastapi import Request, Response, HTTPException import time -import torch -import torch.nn as nn from queue import Queue -from httpx import AsyncClient -from litserve.utils import wrap_litserve_start +from unittest.mock import MagicMock, patch -from unittest.mock import patch, MagicMock import pytest +import torch +import torch.nn as nn +from asgi_lifespan import LifespanManager +from fastapi import HTTPException, Request, Response +from fastapi.testclient import TestClient +from httpx import AsyncClient +import litserve as ls +from litserve import LitAPI from litserve.connector import _Connector from litserve.server import ( + LitAPIStatus, + LitServer, inference_worker, + run_batched_streaming_loop, run_single_loop, run_streaming_loop, - LitAPIStatus, - run_batched_streaming_loop, ) -from litserve.server import LitServer -import litserve as ls -from fastapi.testclient import TestClient +from litserve.utils import wrap_litserve_start def test_index(sync_testclient): @@ -66,10 +67,10 @@ def test_device_identifiers(lifespan_mock, simple_litapi): @patch("litserve.server.run_batched_loop") @patch("litserve.server.run_single_loop") def test_inference_worker(mock_single_loop, mock_batched_loop): - inference_worker(*[MagicMock()] * 6, max_batch_size=2, batch_timeout=0, stream=False) + inference_worker(*[MagicMock()] * 7, max_batch_size=2, batch_timeout=0, stream=False) mock_batched_loop.assert_called_once() - inference_worker(*[MagicMock()] * 6, max_batch_size=1, batch_timeout=0, stream=False) + inference_worker(*[MagicMock()] * 7, max_batch_size=1, batch_timeout=0, stream=False) mock_single_loop.assert_called_once() @@ -94,9 +95,9 @@ def test_single_loop(loop_args): lit_api_mock, requests_queue = loop_args lit_api_mock.unbatch.side_effect = None response_queues = [FakeResponseQueue()] - + request_evicted_status = {} with pytest.raises(StopIteration, match="exit loop"): - run_single_loop(lit_api_mock, None, requests_queue, response_queues) + run_single_loop(lit_api_mock, None, requests_queue, response_queues, request_evicted_status) @pytest.mark.asyncio() @@ -120,6 +121,44 @@ async def test_stream(simple_stream_api): ), "Server returns input prompt and generated output which didn't match." +@pytest.mark.asyncio() +async def test_client_disconnection(simple_delayed_litapi, caplog): + server = LitServer(simple_delayed_litapi, timeout=10) + + with wrap_litserve_start(server) as server, caplog.at_level(logging.DEBUG): + async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: + task = asyncio.create_task(ac.post("/predict", json={"input": 1.0}, timeout=10)) + await asyncio.sleep(0.2) + task.cancel() + await asyncio.sleep(1) + assert "Client disconnected for the request uid" in caplog.text + # TODO: also check if the task actually stopped in the server + + caplog.clear() + task = asyncio.create_task(ac.post("/predict", json={"input": 1.0}, timeout=10)) + await task + assert "Client disconnected for the request uid" not in caplog.text + + +@pytest.mark.asyncio() +async def test_stream_client_disconnection(simple_delayed_stream_api, caplog): + server = LitServer(simple_delayed_stream_api, stream=True, timeout=10) + + with wrap_litserve_start(server) as server, caplog.at_level(logging.DEBUG): + async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac: + task = asyncio.create_task(ac.post("/predict", json={"prompt": "Hey, How are you doing?" * 5}, timeout=10)) + await asyncio.sleep(2) + task.cancel() # simulate client disconnection + await asyncio.sleep(1) # wait for the task to stop + assert "Request evicted for the uid=" in caplog.text + # TODO: also check if the task actually stopped in the server + + caplog.clear() + task = asyncio.create_task(ac.post("/predict", json={"prompt": "Hey, How are you doing?"}, timeout=10)) + await task + assert "Request evicted for the uid=" not in caplog.text + + @pytest.mark.asyncio() async def test_batched_stream_server(simple_batched_stream_api): server = LitServer(simple_batched_stream_api, stream=True, max_batch_size=4, batch_timeout=2, timeout=30) @@ -175,11 +214,12 @@ def fake_encode(output): fake_stream_api.format_encoded_response = MagicMock(side_effect=lambda x: x) requests_queue = Queue() + request_evicted_status = {} requests_queue.put((0, "UUID-1234", time.monotonic(), {"prompt": "Hello"})) response_queues = [FakeStreamResponseQueue(num_streamed_outputs)] with pytest.raises(StopIteration, match="exit loop"): - run_streaming_loop(fake_stream_api, fake_stream_api, requests_queue, response_queues) + run_streaming_loop(fake_stream_api, fake_stream_api, requests_queue, response_queues, request_evicted_status) fake_stream_api.predict.assert_called_once_with("Hello") fake_stream_api.encode_response.assert_called_once()