Skip to content

Commit

Permalink
Merge pull request #493 from basetenlabs/sshanker/async-cleanup
Browse files Browse the repository at this point in the history
Fix Streaming Issues & Cleanup Async
  • Loading branch information
squidarth authored Aug 2, 2023
2 parents 3dccd8f + 03d725b commit c93c9ac
Show file tree
Hide file tree
Showing 6 changed files with 185 additions and 19 deletions.
8 changes: 1 addition & 7 deletions truss/templates/server/common/truss_server.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import asyncio
import concurrent.futures
import json
import logging
import multiprocessing
Expand Down Expand Up @@ -119,7 +118,7 @@ async def predict(
self, model_name: str, request: Request, body_raw: bytes = Depends(parse_body)
) -> Response:
"""
This method is called by FastAPI, which introspects that it's not async, and schedules it on a thread
This method calls the user-provided predict method
"""
model: ModelWrapper = self._safe_lookup_model(model_name)

Expand Down Expand Up @@ -294,13 +293,8 @@ def start(self):
},
)

max_asyncio_workers = min(32, utils.cpu_count() + 4)
logging.info(f"Setting max asyncio worker threads as {max_asyncio_workers}")
# Call this so uvloop gets used
cfg.setup_event_loop()
asyncio.get_event_loop().set_default_executor(
concurrent.futures.ThreadPoolExecutor(max_workers=max_asyncio_workers)
)

async def serve():
serversocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
Expand Down
60 changes: 55 additions & 5 deletions truss/templates/server/model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import time
import traceback
from collections.abc import Generator
from contextlib import asynccontextmanager
from enum import Enum
from pathlib import Path
from threading import Lock, Thread
Expand All @@ -24,6 +25,44 @@
DEFAULT_PREDICT_CONCURRENCY = 1


class DeferredSemaphoreManager:
"""
Helper class for supported deferred semaphore release.
"""

def __init__(self, semaphore: Semaphore):
self.semaphore = semaphore
self.deferred = False

def defer(self):
"""
Track that this semaphore is to be deferred, and return
a release method that the context block can use to release
the semaphore.
"""
self.deferred = True

return self.semaphore.release


@asynccontextmanager
async def deferred_semaphore(semaphore: Semaphore):
"""
Context manager that allows deferring the release of a semaphore.
It yields a DeferredSemaphoreManager -- in your use of this context manager,
if you call DeferredSemaphoreManager.defer(), you will get back a function that releases
the semaphore that you must call.
"""
semaphore_manager = DeferredSemaphoreManager(semaphore)
await semaphore.acquire()

try:
yield semaphore_manager
finally:
if not semaphore_manager.deferred:
semaphore.release()


class ModelWrapper:
class Status(Enum):
NOT_READY = 0
Expand Down Expand Up @@ -206,10 +245,13 @@ async def postprocess(
async def write_response_to_queue(
self, queue: asyncio.Queue, generator: AsyncGenerator
):
async for chunk in generator:
await queue.put(ResponseChunk(chunk))

await queue.put(None)
try:
async for chunk in generator:
await queue.put(ResponseChunk(chunk))
except Exception as e:
self._logger.exception("Exception while reading stream response: " + str(e))
finally:
await queue.put(None)

async def __call__(
self, body: Any, headers: Optional[Dict[str, str]] = None
Expand All @@ -227,7 +269,7 @@ async def __call__(

payload = await self.preprocess(body, headers)

async with self._predict_semaphore:
async with deferred_semaphore(self._predict_semaphore) as semaphore_manager:
response = await self.predict(payload, headers)

processed_response = await self.postprocess(response)
Expand All @@ -251,8 +293,16 @@ async def __call__(
task = asyncio.create_task(
self.write_response_to_queue(response_queue, async_generator)
)

# We add the task to the ModelWrapper instance to ensure it does
# not get garbage collected after the predict method completes,
# and continues running.
self._background_tasks.add(task)

# Defer the release of the semaphore until the write_response_to_queue
# task.
semaphore_release_function = semaphore_manager.defer()
task.add_done_callback(lambda _: semaphore_release_function())
task.add_done_callback(self._background_tasks.discard)

async def _response_generator():
Expand Down
2 changes: 2 additions & 0 deletions truss/test_data/test_streaming_truss/model/model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import time
from typing import Any, Dict, List


Expand All @@ -15,6 +16,7 @@ def load(self):
def predict(self, model_input: Any) -> Dict[str, List]:
# Invoke model on model_input and calculate predictions here.
def inner():
time.sleep(2)
for i in range(5):
yield str(i)

Expand Down
38 changes: 38 additions & 0 deletions truss/test_data/test_streaming_truss_with_error/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
apply_library_patches: true
bundled_packages_dir: packages
data_dir: data
description: null
environment_variables: {}
examples_filename: examples.yaml
external_package_dirs: []
input_type: Any
live_reload: false
model_class_filename: model.py
model_class_name: Model
model_framework: custom
model_metadata: {}
model_module_dir: model
model_name: null
model_type: custom
python_version: py39
requirements: []
resources:
accelerator: null
cpu: 500m
memory: 512Mi
use_gpu: false
runtime:
predict_concurrency: 1
secrets: {}
spec_version: '2.0'
system_packages: []
train:
resources:
accelerator: null
cpu: 500m
memory: 512Mi
use_gpu: false
training_class_filename: train.py
training_class_name: Train
training_module_dir: train
variables: {}
23 changes: 23 additions & 0 deletions truss/test_data/test_streaming_truss_with_error/model/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from typing import Any, Dict, List


class Model:
def __init__(self, **kwargs) -> None:
self._data_dir = kwargs["data_dir"]
self._config = kwargs["config"]
self._secrets = kwargs["secrets"]
self._model = None

def load(self):
# Load model here and assign to self._model.
pass

def predict(self, model_input: Any) -> Dict[str, List]:
def inner():
for i in range(5):
# Raise error partway through if throw_error is set
if i == 3 and model_input.get("throw_error"):
raise Exception("error")
yield str(i)

return inner()
73 changes: 66 additions & 7 deletions truss/tests/test_model_inference.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import concurrent
import logging
import tempfile
import time
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from threading import Thread

Expand Down Expand Up @@ -244,6 +246,42 @@ def test_async_streaming():
assert predict_non_stream_response.json() == "01234"


@pytest.mark.integration
def test_streaming_with_error():
with ensure_kill_all():
truss_root = Path(__file__).parent.parent.parent.resolve() / "truss"

truss_dir = truss_root / "test_data" / "test_streaming_truss_with_error"

tr = TrussHandle(truss_dir)

_ = tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=True)
truss_server_addr = "http://localhost:8090"
predict_url = f"{truss_server_addr}/v1/models/model:predict"

predict_error_response = requests.post(
predict_url, json={"throw_error": True}, stream=True, timeout=2
)

# In error cases, the response will return whatever the stream returned,
# in this case, the first 3 items. We timeout after 2 seconds to ensure that
# stream finishes reading and releases the predict semaphore.
assert [
byte_string.decode()
for byte_string in predict_error_response.iter_content()
] == ["0", "1", "2"]

# Test that we are able to continue to make requests successfully
predict_non_error_response = requests.post(
predict_url, json={"throw_error": False}, stream=True, timeout=2
)

assert [
byte_string.decode()
for byte_string in predict_non_error_response.iter_content()
] == ["0", "1", "2", "3", "4"]


@pytest.mark.integration
def test_streaming_truss():
with ensure_kill_all():
Expand All @@ -254,17 +292,14 @@ def test_streaming_truss():
_ = tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=True)

truss_server_addr = "http://localhost:8090"
predict_url = f"{truss_server_addr}/v1/models/model:predict"

# A request for which response is not completely read
predict_response = requests.post(
f"{truss_server_addr}/v1/models/model:predict", json={}, stream=True
)
predict_response = requests.post(predict_url, json={}, stream=True)
# We just read the first part and leave it hanging here
next(predict_response.iter_content())

predict_response = requests.post(
f"{truss_server_addr}/v1/models/model:predict", json={}, stream=True
)
predict_response = requests.post(predict_url, json={}, stream=True)

assert predict_response.headers.get("transfer-encoding") == "chunked"
assert [
Expand All @@ -274,14 +309,38 @@ def test_streaming_truss():

# When accept is set to application/json, the response is not streamed.
predict_non_stream_response = requests.post(
f"{truss_server_addr}/v1/models/model:predict",
predict_url,
json={},
stream=True,
headers={"accept": "application/json"},
)
assert "transfer-encoding" not in predict_non_stream_response.headers
assert predict_non_stream_response.json() == "01234"

# Test that concurrency work correctly. The streaming Truss has a configured
# concurrency of 1, so only one request can be in flight at a time. Each request
# takes 2 seconds, so with a timeout of 3 seconds, we expect the first request to
# succeed and for the second to timeout.
#
# Note that with streamed requests, requests.post raises a ReadTimeout exception if
# `timeout` seconds has passed since receiving any data from the server.
def make_request(delay: int):
# For streamed responses, requests does not start receiving content from server until
# `iter_content` is called, so we must call this in order to get an actual timeout.
time.sleep(delay)
list(requests.post(predict_url, json={}, stream=True).iter_content())

with ThreadPoolExecutor() as e:
# We use concurrent.futures.wait instead of the timeout property
# on requests, since requests timeout property has a complex interaction
# with streaming.
first_request = e.submit(make_request, 0)
second_request = e.submit(make_request, 0.2)
futures = [first_request, second_request]
done, not_done = concurrent.futures.wait(futures, timeout=3)
assert first_request in done
assert second_request in not_done


@pytest.mark.integration
def test_slow_truss():
Expand Down

0 comments on commit c93c9ac

Please sign in to comment.