diff --git a/examples/agents/client.py b/examples/agents/client.py index 8883d3246..5ea7d75b0 100644 --- a/examples/agents/client.py +++ b/examples/agents/client.py @@ -90,7 +90,7 @@ async def run_main( ], session_id=session_id, ) - async for log in EventLogger().log(response): + for log in EventLogger().log(response): log.print() diff --git a/examples/agents/e2e_loop_with_custom_tools.py b/examples/agents/e2e_loop_with_custom_tools.py index fd914a5d1..9775cf3fe 100644 --- a/examples/agents/e2e_loop_with_custom_tools.py +++ b/examples/agents/e2e_loop_with_custom_tools.py @@ -28,7 +28,7 @@ async def run_main(host: str, port: int, disable_safety: bool = False): print(f"Available shields found: {available_shields}") available_models = [model.identifier for model in client.models.list()] - supported_models = [x for x in available_models if "3.2" in x] + supported_models = [x for x in available_models if "3.2" in x and "Vision" not in x] if not supported_models: raise ValueError( "No supported models found. Make sure to have a Llama 3.2 model." @@ -116,7 +116,7 @@ async def run_main(host: str, port: int, disable_safety: bool = False): session_id=session_id, ) - async for log in EventLogger().log(response): + for log in EventLogger().log(response): log.print() diff --git a/examples/agents/hello.py b/examples/agents/hello.py index b291e0046..3e63da674 100644 --- a/examples/agents/hello.py +++ b/examples/agents/hello.py @@ -3,8 +3,6 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. - -import asyncio import os import fire @@ -15,7 +13,7 @@ from llama_stack_client.types.agent_create_params import AgentConfig -async def run_main(host: str, port: int, disable_safety: bool = False): +def main(host: str, port: int): client = LlamaStackClient( base_url=f"http://{host}:{port}", ) @@ -72,13 +70,9 @@ async def run_main(host: str, port: int, disable_safety: bool = False): session_id=session_id, ) - async for log in EventLogger().log(response): + for log in EventLogger().log(response): log.print() -def main(host: str, port: int): - asyncio.run(run_main(host, port)) - - if __name__ == "__main__": fire.Fire(main) diff --git a/examples/agents/inflation.py b/examples/agents/inflation.py index b4ae702dd..573ce0d46 100644 --- a/examples/agents/inflation.py +++ b/examples/agents/inflation.py @@ -92,7 +92,7 @@ async def run_main(host: str, port: int, disable_safety: bool = False): session_id=session_id, ) - async for log in EventLogger().log(response): + for log in EventLogger().log(response): log.print() diff --git a/examples/agents/rag_as_attachments.py b/examples/agents/rag_as_attachments.py index 3907bd16f..079a5c152 100644 --- a/examples/agents/rag_as_attachments.py +++ b/examples/agents/rag_as_attachments.py @@ -112,7 +112,7 @@ async def run_main(host: str, port: int, disable_safety: bool = False): session_id=session_id, ) - async for log in EventLogger().log(response): + for log in EventLogger().log(response): log.print() diff --git a/examples/agents/rag_with_memory_bank.py b/examples/agents/rag_with_memory_bank.py index b79259b9c..76f82884e 100644 --- a/examples/agents/rag_with_memory_bank.py +++ b/examples/agents/rag_with_memory_bank.py @@ -112,7 +112,7 @@ async def run_main(host: str, port: int, disable_safety: bool = False): session_id=session_id, ) - async for log in EventLogger().log(response): + for log in EventLogger().log(response): log.print() diff --git a/examples/custom_tools/single_message.py b/examples/custom_tools/single_message.py index 7bb126eca..7bc70cde9 100644 --- a/examples/custom_tools/single_message.py +++ b/examples/custom_tools/single_message.py @@ -18,7 +18,7 @@ class SingleMessageCustomTool(CustomTool): allow for the tool be called by the model and the necessary plumbing. """ - async def run(self, messages: List[CompletionMessage]) -> List[ToolResponseMessage]: + def run(self, messages: List[CompletionMessage]) -> List[ToolResponseMessage]: assert len(messages) == 1, "Expected single message" message = messages[0] @@ -26,7 +26,7 @@ async def run(self, messages: List[CompletionMessage]) -> List[ToolResponseMessa tool_call = message.tool_calls[0] try: - response = await self.run_impl(**tool_call.arguments) + response = self.run_impl(**tool_call.arguments) response_str = json.dumps(response, ensure_ascii=False) except Exception as e: response_str = f"Error when running tool: {e}" @@ -40,5 +40,5 @@ async def run(self, messages: List[CompletionMessage]) -> List[ToolResponseMessa return [message] @abstractmethod - async def run_impl(self, *args, **kwargs): + def run_impl(self, *args, **kwargs): raise NotImplementedError() diff --git a/examples/custom_tools/ticker_data.py b/examples/custom_tools/ticker_data.py index a1bf5ada8..02e0ef755 100644 --- a/examples/custom_tools/ticker_data.py +++ b/examples/custom_tools/ticker_data.py @@ -41,7 +41,7 @@ def get_params_definition(self) -> Dict[str, ToolParamDefinitionParam]: ), } - async def run_impl(self, ticker_symbol: str, start: str, end: str): + def run_impl(self, ticker_symbol: str, start: str, end: str): data = yf.download(ticker_symbol, start=start, end=end) data["Year"] = data.index.year diff --git a/examples/custom_tools/web_search.py b/examples/custom_tools/web_search.py index 5ef427ccf..b3c783d20 100644 --- a/examples/custom_tools/web_search.py +++ b/examples/custom_tools/web_search.py @@ -4,8 +4,11 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import json from typing import Dict +import requests + from llama_stack_client.types.tool_param_definition_param import ( ToolParamDefinitionParam, ) @@ -17,7 +20,7 @@ class BraveSearch: def __init__(self, api_key: str) -> None: self.api_key = api_key - async def search(self, query: str) -> str: + def search(self, query: str) -> str: url = "https://api.search.brave.com/res/v1/web/search" headers = { "X-Subscription-Token": self.api_key, @@ -148,5 +151,5 @@ def get_params_definition(self) -> Dict[str, ToolParamDefinitionParam]: ) } - async def run_impl(self, query: str): - return await self.engine.search(query) + def run_impl(self, query: str): + return self.engine.search(query) diff --git a/tests/test_agents.py b/tests/test_agents.py index 87273610d..01d6b2dae 100644 --- a/tests/test_agents.py +++ b/tests/test_agents.py @@ -7,13 +7,12 @@ import os import pytest -import pytest_asyncio from dotenv import load_dotenv from llama_stack_client import LlamaStackClient from llama_stack_client.lib.agents.agent import Agent -from llama_stack_client.lib.agents.event_logger import EventLogger, LogEvent +from llama_stack_client.lib.agents.event_logger import EventLogger from llama_stack_client.types.agent_create_params import AgentConfig from .example_custom_tool import GetBoilingPointTool @@ -70,9 +69,7 @@ async def test_create_agent_turn(): session_id=session_id, ) - logs = [ - str(log) async for log in EventLogger().log(simple_hello) if log is not None - ] + logs = [str(log) for log in EventLogger().log(simple_hello) if log is not None] logs_str = "".join(logs) assert "shield_call>" in logs_str @@ -89,9 +86,7 @@ async def test_create_agent_turn(): session_id=session_id, ) - logs = [ - str(log) async for log in EventLogger().log(bomb_response) if log is not None - ] + logs = [str(log) for log in EventLogger().log(bomb_response) if log is not None] logs_str = "".join(logs) assert "I can't answer that. Can I help with something else?" in logs_str @@ -140,7 +135,7 @@ async def test_builtin_tool_brave_search(): session_id=session_id, ) - logs = [str(log) async for log in EventLogger().log(response) if log is not None] + logs = [str(log) for log in EventLogger().log(response) if log is not None] logs_str = "".join(logs) assert "tool_execution>" in logs_str @@ -195,7 +190,7 @@ async def test_builtin_tool_code_execution(): session_id=session_id, ) - logs = [str(log) async for log in EventLogger().log(response) if log is not None] + logs = [str(log) for log in EventLogger().log(response) if log is not None] logs_str = "".join(logs) assert "541" in logs_str @@ -203,7 +198,7 @@ async def test_builtin_tool_code_execution(): @pytest.mark.asyncio -async def test_builtin_tool_code_execution(): +async def test_custom_tool(): host = os.environ.get("LOCALHOST") port = os.environ.get("PORT") @@ -266,7 +261,7 @@ async def test_builtin_tool_code_execution(): session_id=session_id, ) - logs = [str(log) async for log in EventLogger().log(response) if log is not None] + logs = [str(log) for log in EventLogger().log(response) if log is not None] logs_str = "".join(logs) assert "-100" in logs_str assert "CustomTool" in logs_str