Skip to content

Commit

Permalink
Add setup_environment to truss (#1188)
Browse files Browse the repository at this point in the history
* Bump Version for CTX builder

* skip long test

* add aiofiles to requirements

* increase ctx builder version

* use async task instead

* increase ctx builder

* pass event_loop down

* allow event_loop to be optional

* only thread self._model.load

* revert test_server changes

* add integration test for setup_environment

* add truss for setup_environment integration test

* revert model wrapper test changes

* revert load changes and check model wrapper status instead + tests

* clean up resolver

* fix resolver test

* call setup env before load

* remove print statements

* unit tests + fixes

* sid CR

* fix polling structure

* CR

* bump rc version

* support for clearing out env

* update ctx builder and revert version of types-aiofiles

* test for Nones in configmap file

* better error handling

* update ctx builder version

* test fix + better error msg

* move sleep

* add cleanup to env integration tests

* add a todo for async setup nv + more logging
  • Loading branch information
spal1 authored Oct 22, 2024
1 parent 0556f2d commit 5906d60
Show file tree
Hide file tree
Showing 12 changed files with 618 additions and 233 deletions.
495 changes: 275 additions & 220 deletions poetry.lock

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "truss"
version = "0.9.44"
version = "0.9.45rc009"
description = "A seamless bridge from model development to model delivery"
license = "MIT"
readme = "README.md"
Expand All @@ -27,6 +27,7 @@ packages = [
"Baseten" = "https://baseten.co"

[tool.poetry.dependencies]
aiofiles = "^24.1.0"
blake3 = "^0.3.3"
boto3 = "^1.34.85"
fastapi = ">=0.109.1"
Expand Down Expand Up @@ -96,6 +97,7 @@ pytest = "7.2.0"
pytest-cov = "^3.0.0"
types-PyYAML = "^6.0.12.12"
types-setuptools = "^69.0.0.0"
types-aiofiles = "^24.1.0.20240626"

[tool.poetry.scripts]
truss = 'truss.cli:truss_cli'
Expand Down
2 changes: 1 addition & 1 deletion truss-chains/truss_chains/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def override_chainlet_to_service_metadata(
chainlet_to_service: Mapping[str, definitions.ServiceDescriptor],
):
# Override predict_urls in chainlet_to_service ServiceDescriptors if dynamic_chainlet_config exists
dynamic_chainlet_config_str = dynamic_config_resolver.get_dynamic_config_value(
dynamic_chainlet_config_str = dynamic_config_resolver.get_dynamic_config_value_sync(
definitions.DYNAMIC_CHAINLET_CONFIG_KEY
)
if dynamic_chainlet_config_str:
Expand Down
12 changes: 12 additions & 0 deletions truss/local/local_config_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,18 @@ def bptr_data_resolution_dir_path():
bptr_data_dir.mkdir(exist_ok=True, parents=True)
return bptr_data_dir

@staticmethod
def dynamic_config_path():
dynamic_config_dir = LocalConfigHandler.TRUSS_CONFIG_DIR / "b10_dynamic_config"
dynamic_config_dir.mkdir(exist_ok=True, parents=True)
return dynamic_config_dir

@staticmethod
def set_dynamic_config(key: str, value: str):
key_path = LocalConfigHandler.dynamic_config_path() / key
with key_path.open("w") as key_file:
key_file.write(value)

@staticmethod
def _signatures_dir_path():
return LocalConfigHandler.TRUSS_CONFIG_DIR / "signatures"
Expand Down
90 changes: 89 additions & 1 deletion truss/templates/server/model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import importlib
import importlib.util
import inspect
import json
import logging
import os
import pathlib
Expand Down Expand Up @@ -35,7 +36,7 @@
from common.schema import TrussSchema
from opentelemetry import trace
from pydantic import BaseModel
from shared import serialization
from shared import dynamic_config_resolver, serialization
from shared.lazy_data_resolver import LazyDataResolver
from shared.secrets_resolver import SecretsResolver

Expand All @@ -53,6 +54,7 @@
EXTENSION_CLASS_NAME = "Extension"
EXTENSION_FILE_NAME = "extension"
TRT_LLM_EXTENSION_NAME = "trt_llm"
POLL_FOR_ENVIRONMENT_UPDATES_TIMEOUT_SECS = 30


@asynccontextmanager
Expand Down Expand Up @@ -191,6 +193,7 @@ class ModelDescriptor:
predict: MethodDescriptor
postprocess: Optional[MethodDescriptor]
truss_schema: Optional[TrussSchema]
setup_environment: Optional[MethodDescriptor]

@cached_property
def skip_input_parsing(self) -> bool:
Expand Down Expand Up @@ -243,11 +246,19 @@ def from_model(cls, model) -> "ModelDescriptor":
else:
return_annotation = inspect.signature(model.predict).return_annotation

if hasattr(model, "setup_environment"):
setup_environment = MethodDescriptor.from_method(
model.setup_environment, "setup_environment"
)
else:
setup_environment = None

return cls(
preprocess=preprocess,
predict=predict,
postprocess=postprocess,
truss_schema=TrussSchema.from_signature(parameters, return_annotation),
setup_environment=setup_environment,
)


Expand All @@ -259,6 +270,8 @@ class ModelWrapper:
_logger: logging.Logger
_status: "ModelWrapper.Status"
_predict_semaphore: Semaphore
_poll_for_environment_updates_task: Optional[asyncio.Task]
_environment: Optional[dict]

class Status(Enum):
NOT_READY = 0
Expand All @@ -280,6 +293,8 @@ def __init__(self, config: Dict, tracer: sdk_trace.Tracer):
"predict_concurrency", DEFAULT_PREDICT_CONCURRENCY
)
)
self._poll_for_environment_updates_task = None
self._environment = None

@property
def _model(self) -> Any:
Expand Down Expand Up @@ -419,6 +434,9 @@ def _load_impl(self):

self._maybe_model_descriptor = ModelDescriptor.from_model(self._model)

if self._maybe_model_descriptor.setup_environment:
self._initialize_environment_before_load()

if hasattr(self._model, "load"):
retry(
self._model.load,
Expand All @@ -428,6 +446,76 @@ def _load_impl(self):
gap_seconds=1.0,
)

def setup_polling_for_environment_updates(self):
self._poll_for_environment_updates_task = asyncio.create_task(
self.poll_for_environment_updates()
)

def _initialize_environment_before_load(self):
environment_str = dynamic_config_resolver.get_dynamic_config_value_sync(
dynamic_config_resolver.ENVIRONMENT_DYNAMIC_CONFIG_KEY
)
if environment_str:
environment_json = json.loads(environment_str)
self._logger.info(
f"Executing model.setup_environment with environment: {environment_json}"
)
# TODO: Support calling an async setup_environment() here once we support async load()
self._model.setup_environment(environment_json)
self._environment = environment_json

async def setup_environment(self, environment: Optional[dict]):
descriptor = self.model_descriptor.setup_environment
if not descriptor:
return
self._logger.info(
f"Executing model.setup_environment with new environment: {environment}"
)
if descriptor.is_async:
return await self._model.setup_environment(environment)
else:
return await to_thread.run_sync(self._model.setup_environment, environment)

async def poll_for_environment_updates(self) -> None:
last_modified_time = None
environment_config_filename = (
dynamic_config_resolver.get_dynamic_config_file_path(
dynamic_config_resolver.ENVIRONMENT_DYNAMIC_CONFIG_KEY
)
)

while True:
# Give control back to the event loop while waiting for environment updates
await asyncio.sleep(POLL_FOR_ENVIRONMENT_UPDATES_TIMEOUT_SECS)

# Wait for load to finish before checking for environment updates
if not self.ready:
continue

# Skip polling if no setup_environment implementation provided
if not self.model_descriptor.setup_environment:
break

if environment_config_filename.exists():
try:
current_mtime = os.path.getmtime(environment_config_filename)
if not last_modified_time or last_modified_time != current_mtime:
environment_str = await dynamic_config_resolver.get_dynamic_config_value_async(
dynamic_config_resolver.ENVIRONMENT_DYNAMIC_CONFIG_KEY
)
if environment_str:
last_modified_time = current_mtime
environment_json = json.loads(environment_str)
# Avoid rerunning `setup_environment` with the same environment
if self._environment != environment_json:
await self.setup_environment(environment_json)
self._environment = environment_json
except Exception as e:
self._logger.exception(
"Exception while setting up environment: " + str(e),
exc_info=errors.filter_traceback(self._model_file_name),
)

async def preprocess(
self,
inputs: serialization.InputType,
Expand Down
1 change: 1 addition & 0 deletions truss/templates/server/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@ pyyaml==6.0.0
requests==2.31.0
uvicorn==0.24.0
uvloop==0.19.0
aiofiles==24.1.0
1 change: 1 addition & 0 deletions truss/templates/server/truss_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,7 @@ def on_startup(self):
if self._setup_json_logger:
setup_logging()
self._model.start_load_thread()
self._model.setup_polling_for_environment_updates()

def create_application(self):
app = FastAPI(
Expand Down
23 changes: 19 additions & 4 deletions truss/templates/shared/dynamic_config_resolver.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,28 @@
from pathlib import Path
from typing import Optional

import aiofiles

DYNAMIC_CONFIG_MOUNT_DIR = "/etc/b10_dynamic_config"
ENVIRONMENT_DYNAMIC_CONFIG_KEY = "environment"


def get_dynamic_config_value(key: str) -> Optional[str]:
def get_dynamic_config_value_sync(key: str) -> Optional[str]:
dynamic_config_path = Path(DYNAMIC_CONFIG_MOUNT_DIR) / key
if dynamic_config_path.exists() and dynamic_config_path.is_file():
if dynamic_config_path.exists():
with dynamic_config_path.open() as dynamic_config_file:
dynamic_config_value = dynamic_config_file.read()
return dynamic_config_value
return dynamic_config_file.read()
return None


def get_dynamic_config_file_path(key: str):
dynamic_config_path = Path(DYNAMIC_CONFIG_MOUNT_DIR) / key
return dynamic_config_path


async def get_dynamic_config_value_async(key: str) -> Optional[str]:
dynamic_config_path = get_dynamic_config_file_path(key)
if dynamic_config_path.exists():
async with aiofiles.open(dynamic_config_path, "r") as dynamic_config_file:
return await dynamic_config_file.read()
return None
2 changes: 1 addition & 1 deletion truss/test_data/model_load_failure_test/config.yaml
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
model_name: Test Loaf Failure
model_name: Test Load Failure
python_version: py39
102 changes: 97 additions & 5 deletions truss/tests/templates/core/server/test_dynamic_config_resolver.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import json

import aiofiles
import pytest
from truss.templates.shared.dynamic_config_resolver import get_dynamic_config_value
from truss.templates.shared import dynamic_config_resolver

from truss_chains import definitions

Expand All @@ -18,17 +19,108 @@
"",
],
)
def test_get_dynamic_config_value(config, tmp_path, dynamic_config_mount_dir):
def test_get_dynamic_chainlet_config_value_sync(
config, tmp_path, dynamic_config_mount_dir
):
with (tmp_path / definitions.DYNAMIC_CHAINLET_CONFIG_KEY).open("w") as f:
f.write(json.dumps(config))
chainlet_service_config = get_dynamic_config_value(
chainlet_service_config = dynamic_config_resolver.get_dynamic_config_value_sync(
definitions.DYNAMIC_CHAINLET_CONFIG_KEY
)
assert json.loads(chainlet_service_config) == config


def test_get_missing_config_value(dynamic_config_mount_dir):
chainlet_service_config = get_dynamic_config_value(
@pytest.mark.parametrize(
"config",
[
{
"environment_name": "production",
"foo": "bar",
},
{},
"",
None,
],
)
def test_get_dynamic_config_environment_value_sync(
config, tmp_path, dynamic_config_mount_dir
):
with (tmp_path / dynamic_config_resolver.ENVIRONMENT_DYNAMIC_CONFIG_KEY).open(
"w"
) as f:
f.write(json.dumps(config))
environment_str = dynamic_config_resolver.get_dynamic_config_value_sync(
dynamic_config_resolver.ENVIRONMENT_DYNAMIC_CONFIG_KEY
)
assert json.loads(environment_str) == config


def test_get_missing_config_value_sync(dynamic_config_mount_dir):
chainlet_service_config = dynamic_config_resolver.get_dynamic_config_value_sync(
definitions.DYNAMIC_CHAINLET_CONFIG_KEY
)
assert not chainlet_service_config


@pytest.mark.asyncio
@pytest.mark.parametrize(
"config",
[
{
"RandInt": {
"predict_url": "https://model-id.api.baseten.co/deployment/deployment-id/predict"
}
},
{},
"",
],
)
async def test_get_dynamic_chainlet_config_value_async(
config, tmp_path, dynamic_config_mount_dir
):
async with aiofiles.open(
tmp_path / definitions.DYNAMIC_CHAINLET_CONFIG_KEY, "w"
) as f:
await f.write(json.dumps(config))
chainlet_service_config = (
await dynamic_config_resolver.get_dynamic_config_value_async(
definitions.DYNAMIC_CHAINLET_CONFIG_KEY
)
)
assert json.loads(chainlet_service_config) == config


@pytest.mark.asyncio
@pytest.mark.parametrize(
"config",
[
{
"environment_name": "production",
"foo": "bar",
},
{},
"",
None,
],
)
async def test_get_dynamic_config_environment_value_async(
config, tmp_path, dynamic_config_mount_dir
):
async with aiofiles.open(
tmp_path / dynamic_config_resolver.ENVIRONMENT_DYNAMIC_CONFIG_KEY, "w"
) as f:
await f.write(json.dumps(config))
environment_str = await dynamic_config_resolver.get_dynamic_config_value_async(
dynamic_config_resolver.ENVIRONMENT_DYNAMIC_CONFIG_KEY
)
assert json.loads(environment_str) == config


@pytest.mark.asyncio
async def test_get_missing_config_value_async(dynamic_config_mount_dir):
chainlet_service_config = (
await dynamic_config_resolver.get_dynamic_config_value_async(
definitions.DYNAMIC_CHAINLET_CONFIG_KEY
)
)
assert not chainlet_service_config
Loading

0 comments on commit 5906d60

Please sign in to comment.