Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better secret error handling. #531

Merged
merged 4 commits into from
Aug 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion truss/templates/shared/secrets_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
from typing import Dict


class SecretNotFound(Exception):
pass


class SecretsResolver:
SECRETS_MOUNT_DIR = "/secrets"
SECRET_ENV_VAR_PREFIX = "TRUSS_SECRET_"
Expand Down Expand Up @@ -34,7 +38,18 @@ def __init__(self, base_secrets: Dict[str, str]):
self._base_secrets = base_secrets

def __getitem__(self, key: str) -> str:
return SecretsResolver._resolve_secret(key, self._base_secrets[key])
if key not in self._base_secrets:
# Note this is the case where the secrets are not specified in
# config.yaml
raise SecretNotFound(f"Secret '{key}' not specified in the config.")

found_secret = SecretsResolver._resolve_secret(key, self._base_secrets[key])
if not found_secret:
raise SecretNotFound(
f"Secret '{key} not found. Please check available secrets."
)

return found_secret

def __iter__(self):
raise NotImplementedError(
Expand Down
95 changes: 95 additions & 0 deletions truss/tests/test_model_inference.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import concurrent
import inspect
import logging
import tempfile
import textwrap
import time
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
Expand All @@ -11,6 +13,7 @@
import requests
from requests.exceptions import RequestException
from truss.constants import PYTORCH
from truss.local.local_config_handler import LocalConfigHandler
from truss.model_frameworks import SKLearn
from truss.model_inference import (
infer_model_information,
Expand All @@ -23,6 +26,19 @@
logger = logging.getLogger(__name__)


def _create_truss(truss_dir: Path, config_contents: str, model_contents: str):
truss_dir.mkdir(exist_ok=True) # Ensure the 'truss' directory exists
truss_model_dir = truss_dir / "model"
truss_model_dir.mkdir(parents=True, exist_ok=True)

config_file = truss_dir / "config.yaml"
model_file = truss_model_dir / "model.py"
with open(config_file, "w", encoding="utf-8") as file:
file.write(config_contents)
with open(model_file, "w", encoding="utf-8") as file:
file.write(model_contents)


class PropagatingThread(Thread):
"""
PropagatingThread allows us to run threads and keep track of exceptions
Expand Down Expand Up @@ -342,6 +358,85 @@ def make_request(delay: int):
assert second_request in not_done


@pytest.mark.integration
def test_secrets_truss():
class Model:
def __init__(self, **kwargs):
self._secrets = kwargs["secrets"]

def predict(self, request):
return self._secrets["secret"]

config = """model_name: secrets-truss
cpu: "3"
memory: 14Gi
use_gpu: true
accelerator: A10G
secrets:
secret: null
"""

config_with_no_secret = """model_name: secrets-truss
cpu: "3"
memory: 14Gi
use_gpu: true
accelerator: A10G
"""

with ensure_kill_all(), tempfile.TemporaryDirectory(dir=".") as tmp_work_dir:
truss_dir = Path(tmp_work_dir, "truss")

_create_truss(truss_dir, config, textwrap.dedent(inspect.getsource(Model)))

tr = TrussHandle(truss_dir)
LocalConfigHandler.set_secret("secret", "secret_value")
_ = tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=True)
truss_server_addr = "http://localhost:8090"
full_url = f"{truss_server_addr}/v1/models/model:predict"

response = requests.post(full_url, json={})
assert response.json() == "secret_value"

_create_truss(truss_dir, config, textwrap.dedent(inspect.getsource(Model)))

with ensure_kill_all(), tempfile.TemporaryDirectory(dir=".") as tmp_work_dir:
# Case where the secret is not specified in the config
truss_dir = Path(tmp_work_dir, "truss")

_create_truss(
truss_dir, config_with_no_secret, textwrap.dedent(inspect.getsource(Model))
)
tr = TrussHandle(truss_dir)
LocalConfigHandler.set_secret("secret", "secret_value")
_ = tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=True)
truss_server_addr = "http://localhost:8090"
full_url = f"{truss_server_addr}/v1/models/model:predict"

response = requests.post(full_url, json={})

assert "error" in response.json()
assert "not specified in the config" in response.json()["error"]["traceback"]

with ensure_kill_all(), tempfile.TemporaryDirectory(dir=".") as tmp_work_dir:
# Case where the secret is not specified in the config
truss_dir = Path(tmp_work_dir, "truss")

_create_truss(truss_dir, config, textwrap.dedent(inspect.getsource(Model)))
tr = TrussHandle(truss_dir)
LocalConfigHandler.remove_secret("secret")
_ = tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=True)
truss_server_addr = "http://localhost:8090"
full_url = f"{truss_server_addr}/v1/models/model:predict"

response = requests.post(full_url, json={})

assert "error" in response.json()
assert (
"not found. Please check available secrets."
in response.json()["error"]["traceback"]
)


@pytest.mark.integration
def test_slow_truss():
with ensure_kill_all():
Expand Down