Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions src/llama_stack_client/lib/cli/inference/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@
from typing import Optional

import click
import yaml
from rich.console import Console
from rich.table import Table

from ..common.utils import handle_client_errors

Expand Down
25 changes: 18 additions & 7 deletions src/llama_stack_client/lib/direct/direct.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@
from typing import Any, Type, cast, get_args, get_origin

import yaml
from llama_stack.distribution.build import print_pip_install_help
from llama_stack.distribution.datatypes import StackRunConfig
from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.resolver import resolve_impls
from llama_stack.distribution.server.endpoints import get_all_api_endpoints
from llama_stack.distribution.server.server import is_streaming_request
from llama_stack.distribution.stack import (construct_stack,
Expand All @@ -20,7 +19,9 @@

class LlamaStackDirectClient(LlamaStackClient):
def __init__(self, config: StackRunConfig, **kwargs):
raise TypeError("Use from_config() or from_template() instead of direct initialization")
raise TypeError(
"Use from_config() or from_template() instead of direct initialization"
)

@classmethod
async def from_config(cls, config: StackRunConfig, **kwargs):
Expand All @@ -32,8 +33,12 @@ async def from_config(cls, config: StackRunConfig, **kwargs):
async def from_template(cls, template_name: str, **kwargs):
config = get_stack_run_config_from_template(template_name)
console = Console()
console.print(f"[green]Using template[/green] [blue]{template_name}[/blue] with config:")
console.print(yaml.dump(config.model_dump(), indent=2, default_flow_style=False))
console.print(
f"[green]Using template[/green] [blue]{template_name}[/blue] with config:"
)
console.print(
yaml.dump(config.model_dump(), indent=2, default_flow_style=False)
)
instance = object.__new__(cls)
await instance._initialize(config, **kwargs)
return instance
Expand All @@ -46,7 +51,11 @@ async def _initialize(self, config: StackRunConfig, **kwargs) -> None:
await self.initialize()

async def initialize(self) -> None:
self.impls = await construct_stack(self.config)
try:
self.impls = await construct_stack(self.config)
except ModuleNotFoundError as e:
print_pip_install_help(self.config.providers)
raise e

def _convert_param(self, param_type: Any, value: Any) -> Any:
origin = get_origin(param_type)
Expand Down Expand Up @@ -85,7 +94,9 @@ async def _call_endpoint(self, path: str, method: str, body: dict = None) -> Any
for param_name, param in sig.parameters.items():
if param_name in body:
value = body.get(param_name)
converted_body[param_name] = self._convert_param(param.annotation, value)
converted_body[param_name] = self._convert_param(
param.annotation, value
)
body = converted_body

if is_streaming_request(endpoint.name, body):
Expand Down
Empty file.