Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "llama_stack_client"
version = "0.0.39"
version = "0.0.40"
description = "The official Python library for the llama-stack-client API"
dynamic = ["readme"]
license = "Apache-2.0"
Expand Down
2 changes: 0 additions & 2 deletions src/llama_stack_client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from ._types import NOT_GIVEN, NoneType, NotGiven, Transport, ProxiesTypes
from ._utils import file_from_path
from ._client import (
ENVIRONMENTS,
Client,
Stream,
Timeout,
Expand Down Expand Up @@ -69,7 +68,6 @@
"AsyncStream",
"LlamaStackClient",
"AsyncLlamaStackClient",
"ENVIRONMENTS",
"file_from_path",
"BaseModel",
"DEFAULT_TIMEOUT",
Expand Down
14 changes: 12 additions & 2 deletions src/llama_stack_client/_base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,12 @@ def __init__(
self.url = url
self.params = params

@override
def __repr__(self) -> str:
if self.url:
return f"{self.__class__.__name__}(url={self.url})"
return f"{self.__class__.__name__}(params={self.params})"


class BasePage(GenericModel, Generic[_T]):
"""
Expand Down Expand Up @@ -412,7 +418,10 @@ def _build_headers(self, options: FinalRequestOptions, *, retries_taken: int = 0
if idempotency_header and options.method.lower() != "get" and idempotency_header not in headers:
headers[idempotency_header] = options.idempotency_key or self._idempotency_key()

headers.setdefault("x-stainless-retry-count", str(retries_taken))
# Don't set the retry count header if it was already set or removed by the caller. We check
# `custom_headers`, which can contain `Omit()`, instead of `headers` to account for the removal case.
if "x-stainless-retry-count" not in (header.lower() for header in custom_headers):
headers["x-stainless-retry-count"] = str(retries_taken)

return headers

Expand Down Expand Up @@ -686,7 +695,8 @@ def _calculate_retry_timeout(
if retry_after is not None and 0 < retry_after <= 60:
return retry_after

nb_retries = max_retries - remaining_retries
# Also cap retry count to 1000 to avoid any potential overflows with `pow`
nb_retries = min(max_retries - remaining_retries, 1000)

# Apply exponential backoff, but not more than the max.
sleep_seconds = min(INITIAL_RETRY_DELAY * pow(2.0, nb_retries), MAX_RETRY_DELAY)
Expand Down
202 changes: 84 additions & 118 deletions src/llama_stack_client/_client.py

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions src/llama_stack_client/_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,9 @@ def _parse(self, *, to: type[_T] | None = None) -> R | _T:
if cast_to == float:
return cast(R, float(response.text))

if cast_to == bool:
return cast(R, response.text.lower() == "true")

origin = get_origin(cast_to) or cast_to

if origin == APIResponse:
Expand Down
6 changes: 0 additions & 6 deletions src/llama_stack_client/lib/agents/event_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
from typing import List, Optional, Union

from llama_stack_client.types import ToolResponseMessage
from llama_stack_client.types.agents import AgentsTurnStreamChunk

from termcolor import cprint


Expand Down Expand Up @@ -66,10 +64,6 @@ async def log(self, event_generator):
)
continue

if not isinstance(chunk, AgentsTurnStreamChunk):
yield LogEvent(chunk, color="yellow")
continue

event = chunk.event
event_type = event.payload.event_type

Expand Down
5 changes: 5 additions & 0 deletions src/llama_stack_client/lib/cli/common/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
20 changes: 20 additions & 0 deletions src/llama_stack_client/lib/cli/common/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from tabulate import tabulate

def print_table_from_response(response, headers=[]):
if not headers:
headers = sorted(response[0].__dict__.keys())

rows = []
for spec in response:
rows.append(
[
spec.__dict__[headers[i]] for i in range(len(headers))
]
)

print(tabulate(rows, headers=headers, tablefmt="grid"))
2 changes: 2 additions & 0 deletions src/llama_stack_client/lib/cli/llama_stack_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import argparse

from .configure import ConfigureParser
from .providers import ProvidersParser
from .memory_banks import MemoryBanksParser

from .models import ModelsParser
Expand All @@ -31,6 +32,7 @@ def __init__(self):
MemoryBanksParser.create(subparsers)
ShieldsParser.create(subparsers)
ConfigureParser.create(subparsers)
ProvidersParser.create(subparsers)

def parse_args(self) -> argparse.Namespace:
return self.parser.parse_args()
Expand Down
24 changes: 9 additions & 15 deletions src/llama_stack_client/lib/cli/memory_banks/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from llama_stack_client.lib.cli.subcommand import Subcommand

from tabulate import tabulate
from llama_stack_client.lib.cli.common.utils import print_table_from_response


class MemoryBanksList(Subcommand):
Expand Down Expand Up @@ -41,21 +42,14 @@ def _run_memory_banks_list_cmd(self, args: argparse.Namespace):
)

headers = [
"Memory Bank Type",
"Provider Type",
"Provider Config",
"identifier",
"provider_id",
"type",
"embedding_model",
"chunk_size_in_tokens",
"overlap_size_in_tokens",
]

memory_banks_list_response = client.memory_banks.list()
rows = []

for bank_spec in memory_banks_list_response:
rows.append(
[
bank_spec.bank_type,
bank_spec.provider_config.provider_type,
json.dumps(bank_spec.provider_config.config, indent=4),
]
)

print(tabulate(rows, headers=headers, tablefmt="grid"))
if memory_banks_list_response:
print_table_from_response(memory_banks_list_response, headers)
16 changes: 4 additions & 12 deletions src/llama_stack_client/lib/cli/models/get.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,28 +46,20 @@ def _run_models_list_cmd(self, args: argparse.Namespace):
base_url=args.endpoint,
)

headers = [
"Model ID (model)",
"Model Metadata",
"Provider Type",
"Provider Config",
]

models_get_response = client.models.get(core_model_id=args.model_id)
models_get_response = client.models.retrieve(identifier=args.model_id)

if not models_get_response:
print(
f"Model {args.model_id} is not found at distribution endpoint {args.endpoint}. Please ensure endpoint is serving specified model. "
)
return

headers = sorted(models_get_response.__dict__.keys())

rows = []
rows.append(
[
models_get_response.llama_model["core_model_id"],
json.dumps(models_get_response.llama_model, indent=4),
models_get_response.provider_config.provider_type,
json.dumps(models_get_response.provider_config.config, indent=4),
models_get_response.__dict__[headers[i]] for i in range(len(headers))
]
)

Expand Down
29 changes: 8 additions & 21 deletions src/llama_stack_client/lib/cli/models/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,10 @@
import json
import argparse

from tabulate import tabulate

from llama_stack_client import LlamaStackClient
from llama_stack_client.lib.cli.configure import get_config
from llama_stack_client.lib.cli.subcommand import Subcommand
from llama_stack_client.lib.cli.common.utils import print_table_from_response


class ModelsList(Subcommand):
Expand Down Expand Up @@ -41,23 +40,11 @@ def _run_models_list_cmd(self, args: argparse.Namespace):
)

headers = [
"Model ID (model)",
"Model Metadata",
"Provider Type",
"Provider Config",
"identifier",
"llama_model",
"provider_id",
"metadata"
]

models_list_response = client.models.list()
rows = []

for model_spec in models_list_response:
rows.append(
[
model_spec.llama_model["core_model_id"],
json.dumps(model_spec.llama_model, indent=4),
model_spec.provider_config.provider_type,
json.dumps(model_spec.provider_config.config, indent=4),
]
)

print(tabulate(rows, headers=headers, tablefmt="grid"))
response = client.models.list()
if response:
print_table_from_response(response, headers)
65 changes: 65 additions & 0 deletions src/llama_stack_client/lib/cli/providers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

import argparse
import os

import yaml
from tabulate import tabulate

from llama_stack_client import LlamaStackClient
from llama_stack_client.lib.cli.subcommand import Subcommand
from llama_stack_client.lib.cli.configure import get_config


class ProvidersParser(Subcommand):
"""Configure Llama Stack Client CLI"""

def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"providers",
prog="llama-stack-client providers",
description="List available providers Llama Stack Client CLI",
formatter_class=argparse.RawTextHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._run_providers_cmd)

def _add_arguments(self):
self.endpoint = get_config().get("endpoint")
self.parser.add_argument(
"--endpoint",
type=str,
help="Llama Stack distribution endpoint",
default=self.endpoint,
)

def _run_providers_cmd(self, args: argparse.Namespace):
client = LlamaStackClient(
base_url=args.endpoint,
)

headers = [
"API",
"Provider ID",
"Provider Type",
]

providers_response = client.providers.list()
rows = []

for k, v in providers_response.items():
for provider_info in v:
rows.append(
[
k,
provider_info.provider_id,
provider_info.provider_type
]
)

print(tabulate(rows, headers=headers, tablefmt="grid"))
22 changes: 3 additions & 19 deletions src/llama_stack_client/lib/cli/shields/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,10 @@
import json
import argparse

from tabulate import tabulate

from llama_stack_client import LlamaStackClient
from llama_stack_client.lib.cli.configure import get_config
from llama_stack_client.lib.cli.subcommand import Subcommand
from llama_stack_client.lib.cli.common.utils import print_table_from_response


class ShieldsList(Subcommand):
Expand Down Expand Up @@ -44,22 +43,7 @@ def _run_shields_list_cmd(self, args: argparse.Namespace):
base_url=args.endpoint,
)

headers = [
"Shield Type (shield_type)",
"Provider Type",
"Provider Config",
]

shields_list_response = client.shields.list()
rows = []

for shield_spec in shields_list_response:
rows.append(
[
shield_spec.shield_type,
shield_spec.provider_config.provider_type,
json.dumps(shield_spec.provider_config.config, indent=4),
]
)

print(tabulate(rows, headers=headers, tablefmt="grid"))
if shields_list_response:
print_table_from_response(shields_list_response)
25 changes: 6 additions & 19 deletions src/llama_stack_client/lib/inference/event_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,6 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.


from llama_stack_client.types import (
ChatCompletionStreamChunk,
InferenceChatCompletionResponse,
)
from termcolor import cprint


Expand All @@ -30,17 +24,10 @@ def print(self, flush=True):
class EventLogger:
async def log(self, event_generator):
for chunk in event_generator:
if isinstance(chunk, ChatCompletionStreamChunk):
event = chunk.event
if event.event_type == "start":
yield LogEvent("Assistant> ", color="cyan", end="")
elif event.event_type == "progress":
yield LogEvent(event.delta, color="yellow", end="")
elif event.event_type == "complete":
yield LogEvent("")
elif isinstance(chunk, InferenceChatCompletionResponse):
yield LogEvent("Assistant> ", color="cyan", end="")
yield LogEvent(chunk.completion_message.content, color="yellow")
else:
event = chunk.event
if event.event_type == "start":
yield LogEvent("Assistant> ", color="cyan", end="")
yield LogEvent(chunk, color="yellow")
elif event.event_type == "progress":
yield LogEvent(event.delta, color="yellow", end="")
elif event.event_type == "complete":
yield LogEvent("")
Loading