Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 27 additions & 5 deletions src/llama_stack_client/lib/direct/direct.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
import inspect
import yaml
from typing import Any, cast, get_args, get_origin, Type

from rich.console import Console
from llama_stack.distribution.datatypes import StackRunConfig
from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.resolver import resolve_impls
from llama_stack.distribution.server.endpoints import get_all_api_endpoints
from llama_stack.distribution.server.server import is_streaming_request

from llama_stack.distribution.store.registry import create_dist_registry
from llama_stack.distribution.stack import (
get_stack_run_config_from_template,
)
from llama_stack.distribution.stack import construct_stack
from pydantic import BaseModel

from ..._base_client import ResponseT
Expand All @@ -18,15 +22,33 @@

class LlamaStackDirectClient(LlamaStackClient):
def __init__(self, config: StackRunConfig, **kwargs):
raise TypeError("Use from_config() or from_template() instead of direct initialization")

@classmethod
async def from_config(cls, config: StackRunConfig, **kwargs):
instance = object.__new__(cls)
await instance._initialize(config, **kwargs)
return instance

@classmethod
async def from_template(cls, template_name: str, **kwargs):
config = get_stack_run_config_from_template(template_name)
console = Console()
console.print(f"[green]Using template[/green] [blue]{template_name}[/blue] with config:")
console.print(yaml.dump(config.model_dump(), indent=2, default_flow_style=False))
instance = object.__new__(cls)
await instance._initialize(config, **kwargs)
return instance

async def _initialize(self, config: StackRunConfig, **kwargs) -> None:
super().__init__(**kwargs)
self.endpoints = get_all_api_endpoints()
self.config = config
self.dist_registry = None
self.impls = None
await self.initialize()

async def initialize(self) -> None:
self.dist_registry, _ = await create_dist_registry(self.config)
self.impls = await resolve_impls(self.config, get_provider_registry(), self.dist_registry)
self.impls = await construct_stack(self.config)

def _convert_param(self, param_type: Any, value: Any) -> Any:
origin = get_origin(param_type)
Expand Down
2 changes: 1 addition & 1 deletion src/llama_stack_client/lib/direct/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ async def main(config_path: str):

run_config = parse_and_maybe_upgrade_config(config_dict)

client = LlamaStackDirectClient(config=run_config)
client = await LlamaStackDirectClient.from_config(run_config)
await client.initialize()

response = await client.models.list()
Expand Down