Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

For live reload, retry if model is not ready #213

Merged
merged 1 commit into from
Feb 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion truss/templates/control/control/endpoints.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import requests
from flask import Blueprint, Response, current_app, jsonify, request
from helpers.errors import ModelNotReady
from requests.exceptions import ConnectionError
from tenacity import Retrying, retry_if_exception_type, stop_after_attempt, wait_fixed

Expand All @@ -23,7 +24,10 @@ def proxy(path):

# Wait a bit for inference server to start
for attempt in Retrying(
retry=retry_if_exception_type(ConnectionError),
retry=(
retry_if_exception_type(ConnectionError)
| retry_if_exception_type(ModelNotReady)
),
stop=stop_after_attempt(INFERENCE_SERVER_START_WAIT_SECS),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we increase this timeout if we're also waiting from the model to be ready

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I think the tradeoff is that for some error cases it may increase the time to return an error. Let's start with this and then we can tune. This will not completely eliminate the issue, models can take any amount of time to load. In those cases it's fine to return the 503 error. Idea is that we retry a bit for the common simple cases.

wait=wait_fixed(1),
):
Expand All @@ -36,6 +40,8 @@ def proxy(path):
cookies=request.cookies,
headers=request.headers,
)
if _is_model_not_ready(resp):
raise ModelNotReady("Model has started running, but not ready yet.")
except ConnectionError as exp:
# This check is a bit expensive so we don't do it before every request, we
# do it only if request fails with connection error. If the inference server
Expand Down Expand Up @@ -89,3 +95,11 @@ def has_partially_applied_patch():
def stop_inference_server():
current_app.config["inference_server_controller"].stop()
return {"msg": "Inference server stopped successfully"}


def _is_model_not_ready(resp) -> bool:
return (
resp.status_code == 503
and resp.content is not None
and "model is not ready" in resp.content.decode("utf-8")
)
6 changes: 6 additions & 0 deletions truss/templates/control/control/helpers/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,9 @@ class InadmissiblePatch(PatchApplicatonError):
"""Patch does not apply to current state of Truss."""

pass


class ModelNotReady(Error):
"""Model has started running, but not ready yet."""

pass
47 changes: 47 additions & 0 deletions truss/test_data/truss_container_fs/app/common/logging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import logging
import sys

from pythonjsonlogger import jsonlogger

LEVEL = logging.INFO

JSON_LOG_HANDLER = logging.StreamHandler(stream=sys.stderr)
JSON_LOG_HANDLER.set_name("json_logger_handler")
JSON_LOG_HANDLER.setLevel(LEVEL)
JSON_LOG_HANDLER.setFormatter(
jsonlogger.JsonFormatter("%(asctime)s %(levelname)s %(message)s")
)


class HealthCheckFilter(logging.Filter):
def filter(self, record: logging.LogRecord) -> bool:
# for any health check endpoints, lets skip logging
return (
record.getMessage().find("GET / ") == -1
and record.getMessage().find("GET /v1/models/model ") == -1
)


def setup_logging():
loggers = [logging.getLogger()] + [
logging.getLogger(name) for name in logging.root.manager.loggerDict
]

for logger in loggers:
logger.setLevel(LEVEL)
logger.propagate = False

setup = False

# let's not thrash the handlers unnecessarily
for handler in logger.handlers:
if handler.name == JSON_LOG_HANDLER.name:
setup = True

if not setup:
logger.handlers.clear()
logger.addHandler(JSON_LOG_HANDLER)

# some special handling for request logging
if logger.name == "uvicorn.access":
logger.addFilter(HealthCheckFilter())
14 changes: 10 additions & 4 deletions truss/test_data/truss_container_fs/app/common/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,6 @@
import json
import uuid

import msgpack
import msgpack_numpy as mp_np
import numpy as np


# mostly cribbed from django.core.serializer.DjangoJSONEncoder
def truss_msgpack_encoder(obj, chain=None):
Expand Down Expand Up @@ -61,6 +57,8 @@ def truss_msgpack_decoder(obj, chain=None):

# this json object is JSONType + np.array + datetime
def is_truss_serializable(obj):
import numpy as np

# basic JSON types
if isinstance(obj, (str, int, float, bool, type(None), dict, list)):
return True
Expand All @@ -75,19 +73,27 @@ def is_truss_serializable(obj):


def truss_msgpack_serialize(obj):
import msgpack
import msgpack_numpy as mp_np

return msgpack.packb(
obj, default=lambda x: truss_msgpack_encoder(x, chain=mp_np.encode)
)


def truss_msgpack_deserialize(obj):
import msgpack
import msgpack_numpy as mp_np

return msgpack.unpackb(
obj, object_hook=lambda x: truss_msgpack_decoder(x, chain=mp_np.decode)
)


class DeepNumpyEncoder(json.JSONEncoder):
def default(self, obj):
import numpy as np

if isinstance(obj, np.integer):
return int(obj)
elif isinstance(obj, np.floating):
Expand Down
Loading