diff --git a/truss/templates/shared/secrets_resolver.py b/truss/templates/shared/secrets_resolver.py index 8245acf98..875d73aa4 100644 --- a/truss/templates/shared/secrets_resolver.py +++ b/truss/templates/shared/secrets_resolver.py @@ -4,6 +4,10 @@ from typing import Dict +class SecretNotFound(Exception): + pass + + class SecretsResolver: SECRETS_MOUNT_DIR = "/secrets" SECRET_ENV_VAR_PREFIX = "TRUSS_SECRET_" @@ -34,7 +38,18 @@ def __init__(self, base_secrets: Dict[str, str]): self._base_secrets = base_secrets def __getitem__(self, key: str) -> str: - return SecretsResolver._resolve_secret(key, self._base_secrets[key]) + if key not in self._base_secrets: + # Note this is the case where the secrets are not specified in + # config.yaml + raise SecretNotFound(f"Secret '{key}' not specified in the config.") + + found_secret = SecretsResolver._resolve_secret(key, self._base_secrets[key]) + if not found_secret: + raise SecretNotFound( + f"Secret '{key} not found. Please check available secrets." + ) + + return found_secret def __iter__(self): raise NotImplementedError( diff --git a/truss/tests/test_model_inference.py b/truss/tests/test_model_inference.py index 90c974474..5b5c3e570 100644 --- a/truss/tests/test_model_inference.py +++ b/truss/tests/test_model_inference.py @@ -1,6 +1,8 @@ import concurrent +import inspect import logging import tempfile +import textwrap import time from concurrent.futures import ThreadPoolExecutor from pathlib import Path @@ -11,6 +13,7 @@ import requests from requests.exceptions import RequestException from truss.constants import PYTORCH +from truss.local.local_config_handler import LocalConfigHandler from truss.model_frameworks import SKLearn from truss.model_inference import ( infer_model_information, @@ -23,6 +26,19 @@ logger = logging.getLogger(__name__) +def _create_truss(truss_dir: Path, config_contents: str, model_contents: str): + truss_dir.mkdir(exist_ok=True) # Ensure the 'truss' directory exists + truss_model_dir = truss_dir / "model" + truss_model_dir.mkdir(parents=True, exist_ok=True) + + config_file = truss_dir / "config.yaml" + model_file = truss_model_dir / "model.py" + with open(config_file, "w", encoding="utf-8") as file: + file.write(config_contents) + with open(model_file, "w", encoding="utf-8") as file: + file.write(model_contents) + + class PropagatingThread(Thread): """ PropagatingThread allows us to run threads and keep track of exceptions @@ -342,6 +358,85 @@ def make_request(delay: int): assert second_request in not_done +@pytest.mark.integration +def test_secrets_truss(): + class Model: + def __init__(self, **kwargs): + self._secrets = kwargs["secrets"] + + def predict(self, request): + return self._secrets["secret"] + + config = """model_name: secrets-truss +cpu: "3" +memory: 14Gi +use_gpu: true +accelerator: A10G +secrets: + secret: null + """ + + config_with_no_secret = """model_name: secrets-truss +cpu: "3" +memory: 14Gi +use_gpu: true +accelerator: A10G + """ + + with ensure_kill_all(), tempfile.TemporaryDirectory(dir=".") as tmp_work_dir: + truss_dir = Path(tmp_work_dir, "truss") + + _create_truss(truss_dir, config, textwrap.dedent(inspect.getsource(Model))) + + tr = TrussHandle(truss_dir) + LocalConfigHandler.set_secret("secret", "secret_value") + _ = tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=True) + truss_server_addr = "http://localhost:8090" + full_url = f"{truss_server_addr}/v1/models/model:predict" + + response = requests.post(full_url, json={}) + assert response.json() == "secret_value" + + _create_truss(truss_dir, config, textwrap.dedent(inspect.getsource(Model))) + + with ensure_kill_all(), tempfile.TemporaryDirectory(dir=".") as tmp_work_dir: + # Case where the secret is not specified in the config + truss_dir = Path(tmp_work_dir, "truss") + + _create_truss( + truss_dir, config_with_no_secret, textwrap.dedent(inspect.getsource(Model)) + ) + tr = TrussHandle(truss_dir) + LocalConfigHandler.set_secret("secret", "secret_value") + _ = tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=True) + truss_server_addr = "http://localhost:8090" + full_url = f"{truss_server_addr}/v1/models/model:predict" + + response = requests.post(full_url, json={}) + + assert "error" in response.json() + assert "not specified in the config" in response.json()["error"]["traceback"] + + with ensure_kill_all(), tempfile.TemporaryDirectory(dir=".") as tmp_work_dir: + # Case where the secret is not specified in the config + truss_dir = Path(tmp_work_dir, "truss") + + _create_truss(truss_dir, config, textwrap.dedent(inspect.getsource(Model))) + tr = TrussHandle(truss_dir) + LocalConfigHandler.remove_secret("secret") + _ = tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=True) + truss_server_addr = "http://localhost:8090" + full_url = f"{truss_server_addr}/v1/models/model:predict" + + response = requests.post(full_url, json={}) + + assert "error" in response.json() + assert ( + "not found. Please check available secrets." + in response.json()["error"]["traceback"] + ) + + @pytest.mark.integration def test_slow_truss(): with ensure_kill_all():