Skip to content

Commit

Permalink
update streaming
Browse files Browse the repository at this point in the history
  • Loading branch information
aniketmaurya committed Jan 20, 2025
1 parent deb7fe8 commit 6c44408
Showing 1 changed file with 18 additions and 17 deletions.
35 changes: 18 additions & 17 deletions src/litserve/loops/streaming_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,15 @@
import logging
import time
from queue import Empty, Queue
from typing import Dict, List, Optional
from typing import Dict, Optional

from fastapi import HTTPException

from litserve import LitAPI
from litserve.callbacks import CallbackRunner, EventTypes
from litserve.loops.base import DefaultLoop, _inject_context, collate_requests
from litserve.specs.base import LitSpec
from litserve.transport.base import MessageTransport
from litserve.utils import LitAPIStatus, PickleableHTTPException

logger = logging.getLogger(__name__)
Expand All @@ -33,7 +34,7 @@ def run_streaming_loop(
lit_api: LitAPI,
lit_spec: LitSpec,
request_queue: Queue,
response_queues: List[Queue],
transport: MessageTransport,
callback_runner: CallbackRunner,
):
while True:
Expand All @@ -52,7 +53,7 @@ def run_streaming_loop(
"You can adjust the timeout by providing the `timeout` argument to LitServe(..., timeout=30)."
)
self.put_response(
response_queues, response_queue_id, uid, HTTPException(504, "Request timed out"), LitAPIStatus.ERROR
transport, response_queue_id, uid, HTTPException(504, "Request timed out"), LitAPIStatus.ERROR
)
continue

Expand Down Expand Up @@ -82,15 +83,15 @@ def run_streaming_loop(
)
for y_enc in y_enc_gen:
y_enc = lit_api.format_encoded_response(y_enc)
self.put_response(response_queues, response_queue_id, uid, y_enc, LitAPIStatus.OK)
self.put_response(response_queues, response_queue_id, uid, "", LitAPIStatus.FINISH_STREAMING)
self.put_response(transport, response_queue_id, uid, y_enc, LitAPIStatus.OK)
self.put_response(transport, response_queue_id, uid, "", LitAPIStatus.FINISH_STREAMING)

callback_runner.trigger_event(EventTypes.AFTER_PREDICT, lit_api=lit_api)
callback_runner.trigger_event(EventTypes.AFTER_ENCODE_RESPONSE, lit_api=lit_api)

except HTTPException as e:
self.put_response(
response_queues,
transport,
response_queue_id,
uid,
PickleableHTTPException.from_exception(e),
Expand All @@ -102,7 +103,7 @@ def run_streaming_loop(
"Please check the error trace for more details.",
uid,
)
self.put_response(response_queues, response_queue_id, uid, e, LitAPIStatus.ERROR)
self.put_response(transport, response_queue_id, uid, e, LitAPIStatus.ERROR)

def __call__(
self,
Expand All @@ -111,14 +112,14 @@ def __call__(
device: str,
worker_id: int,
request_queue: Queue,
response_queues: List[Queue],
transport: MessageTransport,
max_batch_size: int,
batch_timeout: float,
stream: bool,
workers_setup_status: Dict[int, str],
callback_runner: CallbackRunner,
):
self.run_streaming_loop(lit_api, lit_spec, request_queue, response_queues, callback_runner)
self.run_streaming_loop(lit_api, lit_spec, request_queue, transport, callback_runner)


class BatchedStreamingLoop(DefaultLoop):
Expand All @@ -127,7 +128,7 @@ def run_batched_streaming_loop(
lit_api: LitAPI,
lit_spec: LitSpec,
request_queue: Queue,
response_queues: List[Queue],
transport: MessageTransport,
max_batch_size: int,
batch_timeout: float,
callback_runner: CallbackRunner,
Expand All @@ -146,7 +147,7 @@ def run_batched_streaming_loop(
"You can adjust the timeout by providing the `timeout` argument to LitServe(..., timeout=30)."
)
self.put_response(
response_queues, response_queue_id, uid, HTTPException(504, "Request timed out"), LitAPIStatus.ERROR
transport, response_queue_id, uid, HTTPException(504, "Request timed out"), LitAPIStatus.ERROR
)

if not batches:
Expand Down Expand Up @@ -186,15 +187,15 @@ def run_batched_streaming_loop(
for y_batch in y_enc_iter:
for response_queue_id, y_enc, uid in zip(response_queue_ids, y_batch, uids):
y_enc = lit_api.format_encoded_response(y_enc)
self.put_response(response_queues, response_queue_id, uid, y_enc, LitAPIStatus.OK)
self.put_response(transport, response_queue_id, uid, y_enc, LitAPIStatus.OK)

for response_queue_id, uid in zip(response_queue_ids, uids):
self.put_response(response_queues, response_queue_id, uid, "", LitAPIStatus.FINISH_STREAMING)
self.put_response(transport, response_queue_id, uid, "", LitAPIStatus.FINISH_STREAMING)

except HTTPException as e:
for response_queue_id, uid in zip(response_queue_ids, uids):
self.put_response(
response_queues,
transport,
response_queue_id,
uid,
PickleableHTTPException.from_exception(e),
Expand All @@ -207,7 +208,7 @@ def run_batched_streaming_loop(
"Please check the error trace for more details."
)
for response_queue_id, uid in zip(response_queue_ids, uids):
self.put_response(response_queues, response_queue_id, uid, e, LitAPIStatus.ERROR)
self.put_response(transport, response_queue_id, uid, e, LitAPIStatus.ERROR)

def __call__(
self,
Expand All @@ -216,7 +217,7 @@ def __call__(
device: str,
worker_id: int,
request_queue: Queue,
response_queues: List[Queue],
transport: MessageTransport,
max_batch_size: int,
batch_timeout: float,
stream: bool,
Expand All @@ -227,7 +228,7 @@ def __call__(
lit_api,
lit_spec,
request_queue,
response_queues,
transport,
max_batch_size,
batch_timeout,
callback_runner,
Expand Down

0 comments on commit 6c44408

Please sign in to comment.