Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Evict requests if the client has disconnected #208

Closed
Closed
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
cdff37b
chore: Add request_evicted_status to streaming loop to cancel requests
bhimrazy Aug 21, 2024
7564479
Merge branch 'main' into feat/evict-req-on-client-disconnect
bhimrazy Aug 21, 2024
99da82b
Merge branch 'main' into feat/evict-req-on-client-disconnect
bhimrazy Aug 21, 2024
15fe905
Merge branch 'main' into feat/evict-req-on-client-disconnect
bhimrazy Aug 21, 2024
e5565a8
fix failing test
bhimrazy Aug 21, 2024
36429a5
fixed: cannot access local variable 'uid'
bhimrazy Aug 21, 2024
9c08744
feat: adds test for `test_stream_client_disconnection`
bhimrazy Aug 22, 2024
f5522fa
ref: format imports using ruff
bhimrazy Aug 22, 2024
2f46532
fix lint warning for `@pytest.mark.asyncio`
bhimrazy Aug 22, 2024
7aacee6
adds a todo in the test for reminder
bhimrazy Aug 22, 2024
4327e49
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 22, 2024
c054af3
Merge branch 'main' into feat/evict-req-on-client-disconnect
bhimrazy Aug 22, 2024
6eeb90d
adds cleanup for the dict to prevent leakage
bhimrazy Aug 22, 2024
dc041d2
chore: fix typo in test_lit_server.py
bhimrazy Aug 22, 2024
18419f1
updates the sleep time
bhimrazy Aug 22, 2024
f6763e5
updated some time
bhimrazy Aug 22, 2024
6dc6454
updated prompt len
bhimrazy Aug 22, 2024
e7b3059
chore: Remove print statement in stream_predict method
bhimrazy Aug 22, 2024
b0be9ce
Merge branch 'main' into feat/evict-req-on-client-disconnect
bhimrazy Aug 23, 2024
a9d86ce
Merge branch 'feat/evict-req-on-client-disconnect' of github.com:bhim…
bhimrazy Aug 23, 2024
34453e9
chore: Add delayed prediction support in LitAPI subclasses
bhimrazy Aug 23, 2024
0069b98
updated stream test and added test for nonstream case
bhimrazy Aug 23, 2024
f3d6bd2
added logic to handle the client disconnection in predict
bhimrazy Aug 23, 2024
6029165
update sleep duration
bhimrazy Aug 23, 2024
6e95b30
Update sleep duration
bhimrazy Aug 23, 2024
f6f3e4c
update sleep time
bhimrazy Aug 23, 2024
9d47245
removed sleep
bhimrazy Aug 23, 2024
86ca3ce
check if `is_disconnected` exists
bhimrazy Aug 23, 2024
154cc6c
adds sleep
bhimrazy Aug 23, 2024
39986bf
chore: Update sleep duration
bhimrazy Aug 23, 2024
2c7633a
chore: Update sleep duration in LitServer
bhimrazy Aug 23, 2024
ccaeee9
tried another approach to check & handle disconnection
bhimrazy Aug 23, 2024
f0b19af
wrap in try catch
bhimrazy Aug 23, 2024
dcab100
Merge branch 'main' into feat/evict-req-on-client-disconnect
bhimrazy Aug 23, 2024
b810a66
Merge branch 'main' into feat/evict-req-on-client-disconnect
bhimrazy Aug 23, 2024
4edab2c
Merge branch 'main' into feat/evict-req-on-client-disconnect
bhimrazy Aug 23, 2024
3d9a9a7
Merge branch 'main' into feat/evict-req-on-client-disconnect
aniketmaurya Aug 24, 2024
5c0d7fc
Merge branch 'main' into feat/evict-req-on-client-disconnect
bhimrazy Aug 25, 2024
919b304
Merge branch 'main' into feat/evict-req-on-client-disconnect
bhimrazy Aug 26, 2024
6ffe51c
Merge branch 'main' into feat/evict-req-on-client-disconnect
bhimrazy Aug 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 44 additions & 20 deletions src/litserve/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,13 @@ def run_batched_loop(
response_queues[response_queue_id].put((uid, (err_pkl, LitAPIStatus.ERROR)))


def run_streaming_loop(lit_api: LitAPI, lit_spec: LitSpec, request_queue: Queue, response_queues: List[Queue]):
def run_streaming_loop(
lit_api: LitAPI,
lit_spec: LitSpec,
request_queue: Queue,
response_queues: List[Queue],
request_evicted_status: Dict[str, bool],
):
while True:
try:
response_queue_id, uid, timestamp, x_enc = request_queue.get(timeout=1.0)
Expand Down Expand Up @@ -256,6 +262,9 @@ def run_streaming_loop(lit_api: LitAPI, lit_spec: LitSpec, request_queue: Queue,
y_gen,
)
for y_enc in y_enc_gen:
if request_evicted_status.get(uid):
request_evicted_status.pop(uid)
break
y_enc = lit_api.format_encoded_response(y_enc)
response_queues[response_queue_id].put((uid, (y_enc, LitAPIStatus.OK)))
response_queues[response_queue_id].put((uid, ("", LitAPIStatus.FINISH_STREAMING)))
Expand Down Expand Up @@ -338,6 +347,7 @@ def inference_worker(
worker_id: int,
request_queue: Queue,
response_queues: List[Queue],
request_evicted_status: Dict[str, bool],
max_batch_size: int,
batch_timeout: float,
stream: bool,
Expand All @@ -357,7 +367,7 @@ def inference_worker(
if max_batch_size > 1:
run_batched_streaming_loop(lit_api, lit_spec, request_queue, response_queues, max_batch_size, batch_timeout)
else:
run_streaming_loop(lit_api, lit_spec, request_queue, response_queues)
run_streaming_loop(lit_api, lit_spec, request_queue, response_queues, request_evicted_status)
return

if max_batch_size > 1:
Expand Down Expand Up @@ -397,7 +407,7 @@ async def response_queue_to_buffer(
await asyncio.sleep(0.0001)
continue
q, event = buffer[uid]
q.append(payload)
q.append((uid, payload))
event.set()

else:
Expand Down Expand Up @@ -499,6 +509,7 @@ def launch_inference_worker(self, num_uvicorn_servers: int):
manager = mp.Manager()
self.workers_setup_status = manager.dict()
self.request_queue = manager.Queue()
self.request_evicted_status = manager.dict()

self.response_queues = []
for _ in range(num_uvicorn_servers):
Expand Down Expand Up @@ -532,6 +543,7 @@ def launch_inference_worker(self, num_uvicorn_servers: int):
worker_id,
self.request_queue,
self.response_queues,
self.request_evicted_status,
self.max_batch_size,
self.batch_timeout,
self.stream,
Expand Down Expand Up @@ -569,26 +581,37 @@ def device_identifiers(self, accelerator, device):
return [f"{accelerator}:{device}"]

async def data_streamer(self, q: deque, data_available: asyncio.Event, send_status: bool = False):
uid = None
while True:
await data_available.wait()
while len(q) > 0:
data, status = q.popleft()
if status == LitAPIStatus.FINISH_STREAMING:
return

if status == LitAPIStatus.ERROR:
logger.error(
"Error occurred while streaming outputs from the inference worker. "
"Please check the above traceback."
)
try:
await data_available.wait()
while len(q) > 0:
uid, (data, status) = q.popleft()
if status == LitAPIStatus.FINISH_STREAMING:
return

if status == LitAPIStatus.ERROR:
logger.error(
"Error occurred while streaming outputs from the inference worker. "
"Please check the above traceback."
)
if send_status:
yield data, status
return
if send_status:
yield data, status
return
if send_status:
yield data, status
else:
yield data
data_available.clear()
else:
yield data
data_available.clear()
except asyncio.CancelledError:
if uid is not None:
self.request_evicted_status[uid] = True
logger.error("Request evicted for the uid=%s", uid)
break
except Exception as e:
# Handle other exceptions that might occur
logger.error(f"Exception occurred during streaming: {e}")
break

def setup_server(self):
workers_ready = False
Expand Down Expand Up @@ -635,6 +658,7 @@ async def predict(request: self.request_type, background_tasks: BackgroundTasks)

async def stream_predict(request: self.request_type, background_tasks: BackgroundTasks) -> self.response_type:
response_queue_id = self.app.response_queue_id
print("response_queue_id=", response_queue_id)
uid = uuid.uuid4()
event = asyncio.Event()
q = deque()
Expand Down
53 changes: 37 additions & 16 deletions tests/test_lit_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,32 +13,33 @@
# limitations under the License.
import asyncio
import inspect
import logging
import pickle
import re
from asgi_lifespan import LifespanManager
from litserve import LitAPI
from fastapi import Request, Response, HTTPException
import time
import torch
import torch.nn as nn
from queue import Queue
from httpx import AsyncClient
from litserve.utils import wrap_litserve_start
from unittest.mock import MagicMock, patch

from unittest.mock import patch, MagicMock
import pytest
import torch
import torch.nn as nn
from asgi_lifespan import LifespanManager
from fastapi import HTTPException, Request, Response
from fastapi.testclient import TestClient
from httpx import AsyncClient

import litserve as ls
from litserve import LitAPI
from litserve.connector import _Connector
from litserve.server import (
LitAPIStatus,
LitServer,
inference_worker,
run_batched_streaming_loop,
run_single_loop,
run_streaming_loop,
LitAPIStatus,
run_batched_streaming_loop,
)
from litserve.server import LitServer
import litserve as ls
from fastapi.testclient import TestClient
from litserve.utils import wrap_litserve_start


def test_index(sync_testclient):
Expand Down Expand Up @@ -66,10 +67,10 @@ def test_device_identifiers(lifespan_mock, simple_litapi):
@patch("litserve.server.run_batched_loop")
@patch("litserve.server.run_single_loop")
def test_inference_worker(mock_single_loop, mock_batched_loop):
inference_worker(*[MagicMock()] * 6, max_batch_size=2, batch_timeout=0, stream=False)
inference_worker(*[MagicMock()] * 7, max_batch_size=2, batch_timeout=0, stream=False)
mock_batched_loop.assert_called_once()

inference_worker(*[MagicMock()] * 6, max_batch_size=1, batch_timeout=0, stream=False)
inference_worker(*[MagicMock()] * 7, max_batch_size=1, batch_timeout=0, stream=False)
mock_single_loop.assert_called_once()


Expand Down Expand Up @@ -120,6 +121,25 @@ async def test_stream(simple_stream_api):
), "Server returns input prompt and generated output which didn't match."


@pytest.mark.asyncio()
async def test_stream_client_disconnection(simple_stream_api, caplog):
bhimrazy marked this conversation as resolved.
Show resolved Hide resolved
server = LitServer(simple_stream_api, stream=True, timeout=10)

with wrap_litserve_start(server) as server, caplog.at_level(logging.DEBUG):
async with LifespanManager(server.app) as manager, AsyncClient(app=manager.app, base_url="http://test") as ac:
task = asyncio.create_task(ac.post("/predict", json={"prompt": "Hey, How are you doing?" * 20}, timeout=10))
await asyncio.sleep(1)

# Simulate client disconnection by canceling the request
task.cancel()

# Allow some time for the server to handle the cancellation
await asyncio.sleep(1)
assert "Request evicted for the uid=" in caplog.text, "Server should log client disconnection"

# TODO: also check if the task actually stopped in the server


@pytest.mark.asyncio()
async def test_batched_stream_server(simple_batched_stream_api):
server = LitServer(simple_batched_stream_api, stream=True, max_batch_size=4, batch_timeout=2, timeout=30)
Expand Down Expand Up @@ -175,11 +195,12 @@ def fake_encode(output):
fake_stream_api.format_encoded_response = MagicMock(side_effect=lambda x: x)

requests_queue = Queue()
request_evicted_status = {}
requests_queue.put((0, "UUID-1234", time.monotonic(), {"prompt": "Hello"}))
response_queues = [FakeStreamResponseQueue(num_streamed_outputs)]

with pytest.raises(StopIteration, match="exit loop"):
run_streaming_loop(fake_stream_api, fake_stream_api, requests_queue, response_queues)
run_streaming_loop(fake_stream_api, fake_stream_api, requests_queue, response_queues, request_evicted_status)

fake_stream_api.predict.assert_called_once_with("Hello")
fake_stream_api.encode_response.assert_called_once()
Expand Down
Loading