Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 5 additions & 73 deletions llama_stack/distribution/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import inspect
import json
import os
import re
import signal
import sys
import traceback
Expand Down Expand Up @@ -41,7 +40,11 @@
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.resolver import InvalidProviderError
from llama_stack.distribution.stack import construct_stack
from llama_stack.distribution.stack import (
construct_stack,
replace_env_vars,
validate_env_pair,
)

from .endpoints import get_all_api_endpoints

Expand Down Expand Up @@ -271,77 +274,6 @@ async def endpoint(request: Request, **kwargs):
return endpoint


class EnvVarError(Exception):
def __init__(self, var_name: str, path: str = ""):
self.var_name = var_name
self.path = path
super().__init__(
f"Environment variable '{var_name}' not set or empty{f' at {path}' if path else ''}"
)


def replace_env_vars(config: Any, path: str = "") -> Any:
if isinstance(config, dict):
result = {}
for k, v in config.items():
try:
result[k] = replace_env_vars(v, f"{path}.{k}" if path else k)
except EnvVarError as e:
raise EnvVarError(e.var_name, e.path) from None
return result

elif isinstance(config, list):
result = []
for i, v in enumerate(config):
try:
result.append(replace_env_vars(v, f"{path}[{i}]"))
except EnvVarError as e:
raise EnvVarError(e.var_name, e.path) from None
return result

elif isinstance(config, str):
pattern = r"\${env\.([A-Z0-9_]+)(?::([^}]*))?}"

def get_env_var(match):
env_var = match.group(1)
default_val = match.group(2)

value = os.environ.get(env_var)
if not value:
if default_val is None:
raise EnvVarError(env_var, path)
else:
value = default_val

# expand "~" from the values
return os.path.expanduser(value)

try:
return re.sub(pattern, get_env_var, config)
except EnvVarError as e:
raise EnvVarError(e.var_name, e.path) from None

return config


def validate_env_pair(env_pair: str) -> tuple[str, str]:
"""Validate and split an environment variable key-value pair."""
try:
key, value = env_pair.split("=", 1)
key = key.strip()
if not key:
raise ValueError(f"Empty key in environment variable pair: {env_pair}")
if not all(c.isalnum() or c == "_" for c in key):
raise ValueError(
f"Key must contain only alphanumeric characters and underscores: {key}"
)
return key, value
except ValueError as e:
raise ValueError(
f"Invalid environment variable format '{env_pair}': {str(e)}. Expected format: KEY=value"
) from e


def main():
"""Start the LlamaStack server."""
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
Expand Down
90 changes: 90 additions & 0 deletions llama_stack/distribution/stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,13 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

import os
from pathlib import Path
from typing import Any, Dict

import pkg_resources
import yaml

from termcolor import colored

from llama_models.llama3.api.datatypes import * # noqa: F403
Expand Down Expand Up @@ -92,6 +97,77 @@ async def register_resources(run_config: StackRunConfig, impls: Dict[Api, Any]):
print("")


class EnvVarError(Exception):
def __init__(self, var_name: str, path: str = ""):
self.var_name = var_name
self.path = path
super().__init__(
f"Environment variable '{var_name}' not set or empty{f' at {path}' if path else ''}"
)


def replace_env_vars(config: Any, path: str = "") -> Any:
if isinstance(config, dict):
result = {}
for k, v in config.items():
try:
result[k] = replace_env_vars(v, f"{path}.{k}" if path else k)
except EnvVarError as e:
raise EnvVarError(e.var_name, e.path) from None
return result

elif isinstance(config, list):
result = []
for i, v in enumerate(config):
try:
result.append(replace_env_vars(v, f"{path}[{i}]"))
except EnvVarError as e:
raise EnvVarError(e.var_name, e.path) from None
return result

elif isinstance(config, str):
pattern = r"\${env\.([A-Z0-9_]+)(?::([^}]*))?}"

def get_env_var(match):
env_var = match.group(1)
default_val = match.group(2)

value = os.environ.get(env_var)
if not value:
if default_val is None:
raise EnvVarError(env_var, path)
else:
value = default_val

# expand "~" from the values
return os.path.expanduser(value)

try:
return re.sub(pattern, get_env_var, config)
except EnvVarError as e:
raise EnvVarError(e.var_name, e.path) from None

return config


def validate_env_pair(env_pair: str) -> tuple[str, str]:
"""Validate and split an environment variable key-value pair."""
try:
key, value = env_pair.split("=", 1)
key = key.strip()
if not key:
raise ValueError(f"Empty key in environment variable pair: {env_pair}")
if not all(c.isalnum() or c == "_" for c in key):
raise ValueError(
f"Key must contain only alphanumeric characters and underscores: {key}"
)
return key, value
except ValueError as e:
raise ValueError(
f"Invalid environment variable format '{env_pair}': {str(e)}. Expected format: KEY=value"
) from e


# Produces a stack of providers for the given run config. Not all APIs may be
# asked for in the run config.
async def construct_stack(
Expand All @@ -105,3 +181,17 @@ async def construct_stack(
)
await register_resources(run_config, impls)
return impls


def get_stack_run_config_from_template(template: str) -> StackRunConfig:
template_path = pkg_resources.resource_filename(
"llama_stack", f"templates/{template}/run.yaml"
)

if not Path(template_path).exists():
raise ValueError(f"Template '{template}' not found at {template_path}")

with open(template_path) as f:
run_config = yaml.safe_load(f)

return StackRunConfig(**replace_env_vars(run_config))