Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 47 additions & 21 deletions llama_stack/distribution/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,18 +48,14 @@ class ApiInput(BaseModel):
provider: str


def build_image(build_config: BuildConfig, build_file_path: Path):
package_deps = Dependencies(
docker_image=build_config.distribution_spec.docker_image or "python:3.10-slim",
pip_packages=SERVER_DEPENDENCIES,
)

# extend package dependencies based on providers spec
def get_provider_dependencies(
config_providers: Dict[str, List[Provider]]
) -> tuple[list[str], list[str]]:
"""Get normal and special dependencies from provider configuration."""
all_providers = get_provider_registry()
for (
api_str,
provider_or_providers,
) in build_config.distribution_spec.providers.items():
deps = []

for api_str, provider_or_providers in config_providers.items():
providers_for_api = all_providers[Api(api_str)]

providers = (
Expand All @@ -69,25 +65,55 @@ def build_image(build_config: BuildConfig, build_file_path: Path):
)

for provider in providers:
if provider not in providers_for_api:
# Providers from BuildConfig and RunConfig are subtly different – not great
provider_type = (
provider if isinstance(provider, str) else provider.provider_type
)

if provider_type not in providers_for_api:
raise ValueError(
f"Provider `{provider}` is not available for API `{api_str}`"
)

provider_spec = providers_for_api[provider]
package_deps.pip_packages.extend(provider_spec.pip_packages)
provider_spec = providers_for_api[provider_type]
deps.extend(provider_spec.pip_packages)
if provider_spec.docker_image:
raise ValueError("A stack's dependencies cannot have a docker image")

normal_deps = []
special_deps = []
deps = []
for package in package_deps.pip_packages:
for package in deps:
if "--no-deps" in package or "--index-url" in package:
special_deps.append(package)
else:
deps.append(package)
deps = list(set(deps))
special_deps = list(set(special_deps))
normal_deps.append(package)

return list(set(normal_deps)), list(set(special_deps))


def print_pip_install_help(providers: Dict[str, List[Provider]]):
normal_deps, special_deps = get_provider_dependencies(providers)

print(
f"Please install needed dependencies using the following commands:\n\n\tpip install {' '.join(normal_deps)}"
)
for special_dep in special_deps:
print(f"\tpip install {special_dep}")
print()


def build_image(build_config: BuildConfig, build_file_path: Path):
package_deps = Dependencies(
docker_image=build_config.distribution_spec.docker_image or "python:3.10-slim",
pip_packages=SERVER_DEPENDENCIES,
)

# extend package dependencies based on providers spec
normal_deps, special_deps = get_provider_dependencies(
build_config.distribution_spec.providers
)
package_deps.pip_packages.extend(normal_deps)
package_deps.pip_packages.extend(special_deps)

if build_config.image_type == ImageType.docker.value:
script = pkg_resources.resource_filename(
Expand All @@ -99,7 +125,7 @@ def build_image(build_config: BuildConfig, build_file_path: Path):
package_deps.docker_image,
str(build_file_path),
str(BUILDS_BASE_DIR / ImageType.docker.value),
" ".join(deps),
" ".join(normal_deps),
]
else:
script = pkg_resources.resource_filename(
Expand All @@ -109,7 +135,7 @@ def build_image(build_config: BuildConfig, build_file_path: Path):
script,
build_config.name,
str(build_file_path),
" ".join(deps),
" ".join(normal_deps),
]

if special_deps:
Expand Down
13 changes: 11 additions & 2 deletions llama_stack/providers/tests/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import yaml

from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.distribution.build import print_pip_install_help
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.request_headers import set_request_provider_data
Expand All @@ -37,7 +38,11 @@ async def resolve_impls_for_test_v2(
sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
dist_kvstore = await kvstore_impl(SqliteKVStoreConfig(db_path=sqlite_file.name))
dist_registry = CachedDiskDistributionRegistry(dist_kvstore)
impls = await resolve_impls(run_config, get_provider_registry(), dist_registry)
try:
impls = await resolve_impls(run_config, get_provider_registry(), dist_registry)
except ModuleNotFoundError as e:
print_pip_install_help(providers)
raise e

if provider_data:
set_request_provider_data(
Expand Down Expand Up @@ -66,7 +71,11 @@ async def resolve_impls_for_test(api: Api, deps: List[Api] = None):
providers=chosen,
)
run_config = parse_and_maybe_upgrade_config(run_config)
impls = await resolve_impls(run_config, get_provider_registry())
try:
impls = await resolve_impls(run_config, get_provider_registry())
except ModuleNotFoundError as e:
print_pip_install_help(providers)
raise e

if "provider_data" in config_dict:
provider_id = chosen[api.value][0].provider_id
Expand Down