Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/agents/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ async def run_main(
],
session_id=session_id,
)
async for log in EventLogger().log(response):
for log in EventLogger().log(response):
log.print()


Expand Down
4 changes: 2 additions & 2 deletions examples/agents/e2e_loop_with_custom_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ async def run_main(host: str, port: int, disable_safety: bool = False):
print(f"Available shields found: {available_shields}")

available_models = [model.identifier for model in client.models.list()]
supported_models = [x for x in available_models if "3.2" in x]
supported_models = [x for x in available_models if "3.2" in x and "Vision" not in x]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

llama3.1 and llama3.2 multimodal models follow the same tool prompt format, and don't have ToolPromptFormat ToolPromptFormat.python_list

if not supported_models:
raise ValueError(
"No supported models found. Make sure to have a Llama 3.2 model."
Expand Down Expand Up @@ -116,7 +116,7 @@ async def run_main(host: str, port: int, disable_safety: bool = False):
session_id=session_id,
)

async for log in EventLogger().log(response):
for log in EventLogger().log(response):
log.print()


Expand Down
10 changes: 2 additions & 8 deletions examples/agents/hello.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

import asyncio
import os

import fire
Expand All @@ -15,7 +13,7 @@
from llama_stack_client.types.agent_create_params import AgentConfig


async def run_main(host: str, port: int, disable_safety: bool = False):
def main(host: str, port: int):
client = LlamaStackClient(
base_url=f"http://{host}:{port}",
)
Expand Down Expand Up @@ -72,13 +70,9 @@ async def run_main(host: str, port: int, disable_safety: bool = False):
session_id=session_id,
)

async for log in EventLogger().log(response):
for log in EventLogger().log(response):
log.print()


def main(host: str, port: int):
asyncio.run(run_main(host, port))


if __name__ == "__main__":
fire.Fire(main)
2 changes: 1 addition & 1 deletion examples/agents/inflation.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ async def run_main(host: str, port: int, disable_safety: bool = False):
session_id=session_id,
)

async for log in EventLogger().log(response):
for log in EventLogger().log(response):
log.print()


Expand Down
2 changes: 1 addition & 1 deletion examples/agents/rag_as_attachments.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ async def run_main(host: str, port: int, disable_safety: bool = False):
session_id=session_id,
)

async for log in EventLogger().log(response):
for log in EventLogger().log(response):
log.print()


Expand Down
2 changes: 1 addition & 1 deletion examples/agents/rag_with_memory_bank.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ async def run_main(host: str, port: int, disable_safety: bool = False):
session_id=session_id,
)

async for log in EventLogger().log(response):
for log in EventLogger().log(response):
log.print()


Expand Down
6 changes: 3 additions & 3 deletions examples/custom_tools/single_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@ class SingleMessageCustomTool(CustomTool):
allow for the tool be called by the model and the necessary plumbing.
"""

async def run(self, messages: List[CompletionMessage]) -> List[ToolResponseMessage]:
def run(self, messages: List[CompletionMessage]) -> List[ToolResponseMessage]:
assert len(messages) == 1, "Expected single message"

message = messages[0]

tool_call = message.tool_calls[0]

try:
response = await self.run_impl(**tool_call.arguments)
response = self.run_impl(**tool_call.arguments)
response_str = json.dumps(response, ensure_ascii=False)
except Exception as e:
response_str = f"Error when running tool: {e}"
Expand All @@ -40,5 +40,5 @@ async def run(self, messages: List[CompletionMessage]) -> List[ToolResponseMessa
return [message]

@abstractmethod
async def run_impl(self, *args, **kwargs):
def run_impl(self, *args, **kwargs):
raise NotImplementedError()
2 changes: 1 addition & 1 deletion examples/custom_tools/ticker_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def get_params_definition(self) -> Dict[str, ToolParamDefinitionParam]:
),
}

async def run_impl(self, ticker_symbol: str, start: str, end: str):
def run_impl(self, ticker_symbol: str, start: str, end: str):
data = yf.download(ticker_symbol, start=start, end=end)

data["Year"] = data.index.year
Expand Down
9 changes: 6 additions & 3 deletions examples/custom_tools/web_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

import json
from typing import Dict

import requests

from llama_stack_client.types.tool_param_definition_param import (
ToolParamDefinitionParam,
)
Expand All @@ -17,7 +20,7 @@ class BraveSearch:
def __init__(self, api_key: str) -> None:
self.api_key = api_key

async def search(self, query: str) -> str:
def search(self, query: str) -> str:
url = "https://api.search.brave.com/res/v1/web/search"
headers = {
"X-Subscription-Token": self.api_key,
Expand Down Expand Up @@ -148,5 +151,5 @@ def get_params_definition(self) -> Dict[str, ToolParamDefinitionParam]:
)
}

async def run_impl(self, query: str):
return await self.engine.search(query)
def run_impl(self, query: str):
return self.engine.search(query)
19 changes: 7 additions & 12 deletions tests/test_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,12 @@
import os

import pytest
import pytest_asyncio

from dotenv import load_dotenv

from llama_stack_client import LlamaStackClient
from llama_stack_client.lib.agents.agent import Agent
from llama_stack_client.lib.agents.event_logger import EventLogger, LogEvent
from llama_stack_client.lib.agents.event_logger import EventLogger
from llama_stack_client.types.agent_create_params import AgentConfig

from .example_custom_tool import GetBoilingPointTool
Expand Down Expand Up @@ -70,9 +69,7 @@ async def test_create_agent_turn():
session_id=session_id,
)

logs = [
str(log) async for log in EventLogger().log(simple_hello) if log is not None
]
logs = [str(log) for log in EventLogger().log(simple_hello) if log is not None]
logs_str = "".join(logs)

assert "shield_call>" in logs_str
Expand All @@ -89,9 +86,7 @@ async def test_create_agent_turn():
session_id=session_id,
)

logs = [
str(log) async for log in EventLogger().log(bomb_response) if log is not None
]
logs = [str(log) for log in EventLogger().log(bomb_response) if log is not None]
logs_str = "".join(logs)
assert "I can't answer that. Can I help with something else?" in logs_str

Expand Down Expand Up @@ -140,7 +135,7 @@ async def test_builtin_tool_brave_search():
session_id=session_id,
)

logs = [str(log) async for log in EventLogger().log(response) if log is not None]
logs = [str(log) for log in EventLogger().log(response) if log is not None]
logs_str = "".join(logs)

assert "tool_execution>" in logs_str
Expand Down Expand Up @@ -195,15 +190,15 @@ async def test_builtin_tool_code_execution():
session_id=session_id,
)

logs = [str(log) async for log in EventLogger().log(response) if log is not None]
logs = [str(log) for log in EventLogger().log(response) if log is not None]
logs_str = "".join(logs)

assert "541" in logs_str
assert "Tool:code_interpreter Response" in logs_str


@pytest.mark.asyncio
async def test_builtin_tool_code_execution():
async def test_custom_tool():
host = os.environ.get("LOCALHOST")
port = os.environ.get("PORT")

Expand Down Expand Up @@ -266,7 +261,7 @@ async def test_builtin_tool_code_execution():
session_id=session_id,
)

logs = [str(log) async for log in EventLogger().log(response) if log is not None]
logs = [str(log) for log in EventLogger().log(response) if log is not None]
logs_str = "".join(logs)
assert "-100" in logs_str
assert "CustomTool" in logs_str