Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fail fast when LitAPI.setup has error #356

Merged
merged 7 commits into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions src/litserve/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from litserve import LitAPI
from litserve.callbacks import CallbackRunner, EventTypes
from litserve.specs.base import LitSpec
from litserve.utils import LitAPIStatus, PickleableHTTPException
from litserve.utils import LitAPIStatus, PickleableHTTPException, WorkerSetupStatus

mp.allow_connection_pickling()

Expand Down Expand Up @@ -399,18 +399,23 @@ def inference_worker(
max_batch_size: int,
batch_timeout: float,
stream: bool,
workers_setup_status: Dict[str, bool],
workers_setup_status: Dict[int, str],
aniketmaurya marked this conversation as resolved.
Show resolved Hide resolved
callback_runner: CallbackRunner,
):
callback_runner.trigger_event(EventTypes.BEFORE_SETUP, lit_api=lit_api)
lit_api.setup(device)
try:
lit_api.setup(device)
except Exception:
logger.exception(f"Error setting up worker {worker_id}.")
workers_setup_status[worker_id] = WorkerSetupStatus.ERROR
return
lit_api.device = device
callback_runner.trigger_event(EventTypes.AFTER_SETUP, lit_api=lit_api)

print(f"Setup complete for worker {worker_id}.")
logger.info(f"Setup complete for worker {worker_id}.")

if workers_setup_status:
workers_setup_status[worker_id] = True
workers_setup_status[worker_id] = WorkerSetupStatus.READY

if lit_spec:
logging.info(f"LitServe will use {lit_spec.__class__.__name__} spec")
Expand Down
22 changes: 11 additions & 11 deletions src/litserve/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,10 @@
from litserve.python_client import client_template
from litserve.specs import OpenAISpec
from litserve.specs.base import LitSpec
from litserve.utils import LitAPIStatus, call_after_stream
from litserve.utils import LitAPIStatus, WorkerSetupStatus, call_after_stream

mp.allow_connection_pickling()

try:
import uvloop

asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())

except ImportError:
print("uvloop is not installed. Falling back to the default asyncio event loop.")

logger = logging.getLogger(__name__)

# if defined, it will require clients to auth with X-API-Key in the header
Expand Down Expand Up @@ -233,7 +225,7 @@ def launch_inference_worker(self, num_uvicorn_servers: int):
if len(device) == 1:
device = device[0]

self.workers_setup_status[worker_id] = False
self.workers_setup_status[worker_id] = WorkerSetupStatus.STARTING

ctx = mp.get_context("spawn")
process = ctx.Process(
Expand Down Expand Up @@ -328,7 +320,7 @@ async def index(request: Request) -> Response:
async def health(request: Request) -> Response:
nonlocal workers_ready
if not workers_ready:
workers_ready = all(self.workers_setup_status.values())
workers_ready = all(v == WorkerSetupStatus.READY for v in self.workers_setup_status.values())

if workers_ready:
return Response(content="ok", status_code=200)
Expand Down Expand Up @@ -436,6 +428,13 @@ def generate_client_file(port: Union[str, int] = 8000):
except Exception as e:
logger.exception(f"Error copying file: {e}")

def verify_worker_status(self):
while not any(v == WorkerSetupStatus.READY for v in self.workers_setup_status.values()):
if any(v == WorkerSetupStatus.ERROR for v in self.workers_setup_status.values()):
raise RuntimeError("One or more workers failed to start. Shutting down LitServe")
time.sleep(0.05)
logger.debug("One or more workers are ready to serve requests")

def run(
self,
host: str = "0.0.0.0",
Expand Down Expand Up @@ -481,6 +480,7 @@ def run(

manager, litserve_workers = self.launch_inference_worker(num_api_servers)

self.verify_worker_status()
try:
servers = self._start_server(port, num_api_servers, log_level, sockets, api_server_worker_type, **kwargs)
print(f"Swagger UI is available at http://0.0.0.0:{port}/docs")
Expand Down
9 changes: 9 additions & 0 deletions src/litserve/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import dataclasses
import logging
import pickle
from contextlib import contextmanager
Expand Down Expand Up @@ -78,3 +79,11 @@ async def call_after_stream(streamer: AsyncIterator, callback, *args, **kwargs):
logger.exception(f"Error in streamer: {e}")
finally:
callback(*args, **kwargs)


@dataclasses.dataclass
class WorkerSetupStatus:
aniketmaurya marked this conversation as resolved.
Show resolved Hide resolved
STARTING: str = "starting"
READY: str = "ready"
ERROR: str = "error"
FINISHED: str = "finished"
15 changes: 15 additions & 0 deletions tests/test_lit_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ def test_start_server(mock_uvicon):
def test_server_run_with_api_server_worker_type(mock_uvicorn):
api = ls.test_examples.SimpleLitAPI()
server = ls.LitServer(api, devices=1)
server.verify_worker_status = MagicMock()
with pytest.raises(ValueError, match=r"Must be 'process' or 'thread'"):
server.run(api_server_worker_type="invalid")

Expand Down Expand Up @@ -247,6 +248,7 @@ def test_server_run_with_api_server_worker_type(mock_uvicorn):
def test_server_run_windows(mock_uvicorn):
api = ls.test_examples.SimpleLitAPI()
server = ls.LitServer(api)
server.verify_worker_status = MagicMock()
server.launch_inference_worker = MagicMock(return_value=[MagicMock(), [MagicMock()]])
server._start_server = MagicMock()

Expand All @@ -258,6 +260,7 @@ def test_server_run_windows(mock_uvicorn):

def test_server_terminate():
server = LitServer(SimpleLitAPI())
server.verify_worker_status = MagicMock()
mock_manager = MagicMock()

with patch("litserve.server.LitServer._start_server", side_effect=Exception("mocked error")) as mock_start, patch(
Expand Down Expand Up @@ -392,3 +395,15 @@ def test_generate_client_file(tmp_path, monkeypatch):
LitServer.generate_client_file(8000)
with open(tmp_path / "client.py") as fr:
assert expected in fr.read(), "Shouldn't replace existing client.py"


class FailFastAPI(ls.test_examples.SimpleLitAPI):
def setup(self, device):
raise ValueError("setup failed")


def test_workers_setup_status():
api = FailFastAPI()
server = LitServer(api, devices=1)
with pytest.raises(RuntimeError, match="One or more workers failed to start. Shutting down LitServe"):
server.run()
Loading