Skip to content

Commit

Permalink
Make llama stack build not create a new conda by default (#788)
Browse files Browse the repository at this point in the history
## What does this PR do?

So far `llama stack build` has always created a separate conda
environment for packaging the dependencies of a distribution. The main
reason to do so is isolation -- distributions are composed of providers
which can have a variety of potentially conflicting dependencies. That
said, this has created significant annoyance for new users since it is
not at all transparent. The fact that `llama stack run` is actually
running the code in some other conda is very surprising.

This PR tries to make things better. 

- Both `llama stack build` and `llama stack run` now accept an
`--image-name` argument which represents the (conda, docker, virtualenv)
image you want to operate upon.
- For the default (conda) mode, the script checks if a current conda
environment exists. If one exists, it uses it.
- If `--image-name` is provided, that option is used. In this case, an
environment is created if needed.
- There is no automatic `llamastack-` prefixing of the environment names
done anymore.


## Test Plan

Start in a conda environment, run `llama stack build --template
fireworks`; verify that it successfully built into the current
environment and stored the build file at
`$CONDA_PREFIX/llamastack-build.yaml`. Run `llama stack run fireworks`
which started correctly in the current environment.

Ran the same build command outside of conda. It failed asking for
`--image-name`. Ran it with `llama stack build --template fireworks
--image-name foo`. This successfully created a conda environment called
`foo` and installed deps. Ran `llama stack run fireworks` outside conda
which failed. Activated a different conda, ran again, it failed saying
it did not find the `llamastack-build.yaml` file. Then used
`--image-name foo` option and it ran successfully.
  • Loading branch information
ashwinb authored Jan 16, 2025
1 parent 59eeaf7 commit cee3816
Show file tree
Hide file tree
Showing 8 changed files with 398 additions and 336 deletions.
307 changes: 307 additions & 0 deletions llama_stack/cli/stack/_build.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,307 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

import argparse
import importlib.resources
import json
import os
import shutil
import textwrap
from functools import lru_cache
from pathlib import Path
from typing import Dict, Optional

import yaml
from prompt_toolkit import prompt
from prompt_toolkit.completion import WordCompleter
from prompt_toolkit.validation import Validator
from termcolor import cprint

from llama_stack.cli.table import print_table

from llama_stack.distribution.build import build_image, ImageType
from llama_stack.distribution.datatypes import (
BuildConfig,
DistributionSpec,
Provider,
StackRunConfig,
)
from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.resolver import InvalidProviderError
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.providers.datatypes import Api


TEMPLATES_PATH = Path(__file__).parent.parent.parent / "templates"


@lru_cache()
def available_templates_specs() -> Dict[str, BuildConfig]:
import yaml

template_specs = {}
for p in TEMPLATES_PATH.rglob("*build.yaml"):
template_name = p.parent.name
with open(p, "r") as f:
build_config = BuildConfig(**yaml.safe_load(f))
template_specs[template_name] = build_config
return template_specs


def run_stack_build_command(
parser: argparse.ArgumentParser, args: argparse.Namespace
) -> None:
if args.list_templates:
return _run_template_list_cmd()

current_conda_env = os.environ.get("CONDA_DEFAULT_ENV")
image_name = args.image_name or current_conda_env

if args.template:
available_templates = available_templates_specs()
if args.template not in available_templates:
cprint(
f"Could not find template {args.template}. Please run `llama stack build --list-templates` to check out the available templates",
color="red",
)
return
build_config = available_templates[args.template]
if args.image_type:
build_config.image_type = args.image_type
else:
cprint(
f"Please specify a image-type (docker | conda | venv) for {args.template}",
color="red",
)
return
_run_stack_build_command_from_build_config(
build_config,
image_name=image_name,
template_name=args.template,
)
return

if not args.config and not args.template:
name = prompt(
"> Enter a name for your Llama Stack (e.g. my-local-stack): ",
validator=Validator.from_callable(
lambda x: len(x) > 0,
error_message="Name cannot be empty, please enter a name",
),
)

image_type = prompt(
"> Enter the image type you want your Llama Stack to be built as (docker or conda or venv): ",
validator=Validator.from_callable(
lambda x: x in ["docker", "conda", "venv"],
error_message="Invalid image type, please enter conda or docker or venv",
),
default="conda",
)

if image_type == "conda":
if not image_name:
cprint(
f"No current conda environment detected or specified, will create a new conda environment with the name `llamastack-{name}`",
color="yellow",
)
image_name = f"llamastack-{name}"
else:
cprint(
f"Using conda environment {image_name}",
color="green",
)

cprint(
textwrap.dedent(
"""
Llama Stack is composed of several APIs working together. Let's select
the provider types (implementations) you want to use for these APIs.
""",
),
color="green",
)

print("Tip: use <TAB> to see options for the providers.\n")

providers = dict()
for api, providers_for_api in get_provider_registry().items():
available_providers = [
x
for x in providers_for_api.keys()
if x not in ("remote", "remote::sample")
]
api_provider = prompt(
"> Enter provider for API {}: ".format(api.value),
completer=WordCompleter(available_providers),
complete_while_typing=True,
validator=Validator.from_callable(
lambda x: x in available_providers,
error_message="Invalid provider, use <TAB> to see options",
),
)

providers[api.value] = api_provider

description = prompt(
"\n > (Optional) Enter a short description for your Llama Stack: ",
default="",
)

distribution_spec = DistributionSpec(
providers=providers,
description=description,
)

build_config = BuildConfig(
image_type=image_type, distribution_spec=distribution_spec
)
else:
with open(args.config, "r") as f:
try:
build_config = BuildConfig(**yaml.safe_load(f))
except Exception as e:
cprint(
f"Could not parse config file {args.config}: {e}",
color="red",
)
return

_run_stack_build_command_from_build_config(build_config, image_name=image_name)


def _generate_run_config(
build_config: BuildConfig, build_dir: Path, image_name: str
) -> None:
"""
Generate a run.yaml template file for user to edit from a build.yaml file
"""
apis = list(build_config.distribution_spec.providers.keys())
run_config = StackRunConfig(
docker_image=(
image_name if build_config.image_type == ImageType.docker.value else None
),
image_name=image_name,
apis=apis,
providers={},
)
# build providers dict
provider_registry = get_provider_registry()
for api in apis:
run_config.providers[api] = []
provider_types = build_config.distribution_spec.providers[api]
if isinstance(provider_types, str):
provider_types = [provider_types]

for i, provider_type in enumerate(provider_types):
pid = provider_type.split("::")[-1]

p = provider_registry[Api(api)][provider_type]
if p.deprecation_error:
raise InvalidProviderError(p.deprecation_error)

config_type = instantiate_class_type(
provider_registry[Api(api)][provider_type].config_class
)
if hasattr(config_type, "sample_run_config"):
config = config_type.sample_run_config(
__distro_dir__=f"distributions/{image_name}"
)
else:
config = {}

p_spec = Provider(
provider_id=f"{pid}-{i}" if len(provider_types) > 1 else pid,
provider_type=provider_type,
config=config,
)
run_config.providers[api].append(p_spec)

run_config_file = build_dir / f"{image_name}-run.yaml"

with open(run_config_file, "w") as f:
to_write = json.loads(run_config.model_dump_json())
f.write(yaml.dump(to_write, sort_keys=False))

cprint(
f"You can now edit {run_config_file} and run `llama stack run {image_name}`",
color="green",
)


def _run_stack_build_command_from_build_config(
build_config: BuildConfig,
image_name: Optional[str] = None,
template_name: Optional[str] = None,
) -> None:
if build_config.image_type == ImageType.docker.value:
if template_name:
image_name = f"distribution-{template_name}"
else:
if not image_name:
raise ValueError(
"Please specify an image name when building a docker image without a template"
)
elif build_config.image_type == ImageType.conda.value:
if not image_name:
raise ValueError("Please specify an image name when building a conda image")

if template_name:
build_dir = DISTRIBS_BASE_DIR / template_name
build_file_path = build_dir / f"{template_name}-build.yaml"
else:
build_dir = DISTRIBS_BASE_DIR / image_name
build_file_path = build_dir / f"{image_name}-build.yaml"

os.makedirs(build_dir, exist_ok=True)
with open(build_file_path, "w") as f:
to_write = json.loads(build_config.model_dump_json())
f.write(yaml.dump(to_write, sort_keys=False))

return_code = build_image(
build_config, build_file_path, image_name, template_name=template_name
)
if return_code != 0:
return

if template_name:
# copy run.yaml from template to build_dir instead of generating it again
template_path = (
importlib.resources.files("llama_stack")
/ f"templates/{template_name}/run.yaml"
)
with importlib.resources.as_file(template_path) as path:
run_config_file = build_dir / f"{template_name}-run.yaml"
shutil.copy(path, run_config_file)
# Find all ${env.VARIABLE} patterns
cprint("Build Successful!", color="green")
else:
_generate_run_config(build_config, build_dir, image_name)


def _run_template_list_cmd() -> None:
# eventually, this should query a registry at llama.meta.com/llamastack/distributions
headers = [
"Template Name",
# "Providers",
"Description",
]

rows = []
for template_name, spec in available_templates_specs().items():
rows.append(
[
template_name,
# json.dumps(spec.distribution_spec.providers, indent=2),
spec.distribution_spec.description,
]
)
print_table(
rows,
headers,
separate_rows=True,
)
Loading

0 comments on commit cee3816

Please sign in to comment.