diff --git a/src/llama_stack_client/_utils/_sync.py b/src/llama_stack_client/_utils/_sync.py index 8b3aaf2b..ad7ec71b 100644 --- a/src/llama_stack_client/_utils/_sync.py +++ b/src/llama_stack_client/_utils/_sync.py @@ -7,16 +7,20 @@ from typing import Any, TypeVar, Callable, Awaitable from typing_extensions import ParamSpec +import anyio +import sniffio +import anyio.to_thread + T_Retval = TypeVar("T_Retval") T_ParamSpec = ParamSpec("T_ParamSpec") if sys.version_info >= (3, 9): - to_thread = asyncio.to_thread + _asyncio_to_thread = asyncio.to_thread else: # backport of https://docs.python.org/3/library/asyncio-task.html#asyncio.to_thread # for Python 3.8 support - async def to_thread( + async def _asyncio_to_thread( func: Callable[T_ParamSpec, T_Retval], /, *args: T_ParamSpec.args, **kwargs: T_ParamSpec.kwargs ) -> Any: """Asynchronously run function *func* in a separate thread. @@ -34,6 +38,17 @@ async def to_thread( return await loop.run_in_executor(None, func_call) +async def to_thread( + func: Callable[T_ParamSpec, T_Retval], /, *args: T_ParamSpec.args, **kwargs: T_ParamSpec.kwargs +) -> T_Retval: + if sniffio.current_async_library() == "asyncio": + return await _asyncio_to_thread(func, *args, **kwargs) + + return await anyio.to_thread.run_sync( + functools.partial(func, *args, **kwargs), + ) + + # inspired by `asyncer`, https://github.com/tiangolo/asyncer def asyncify(function: Callable[T_ParamSpec, T_Retval]) -> Callable[T_ParamSpec, Awaitable[T_Retval]]: """ diff --git a/src/llama_stack_client/types/shared/agent_config.py b/src/llama_stack_client/types/shared/agent_config.py index 273a98db..29cd5de5 100644 --- a/src/llama_stack_client/types/shared/agent_config.py +++ b/src/llama_stack_client/types/shared/agent_config.py @@ -49,14 +49,14 @@ class ToolgroupUnionMember1(BaseModel): class AgentConfig(BaseModel): - enable_session_persistence: bool - instructions: str model: str client_tools: Optional[List[ToolDef]] = None + enable_session_persistence: Optional[bool] = None + input_shields: Optional[List[str]] = None max_infer_iters: Optional[int] = None diff --git a/src/llama_stack_client/types/shared_params/agent_config.py b/src/llama_stack_client/types/shared_params/agent_config.py index fe62bc24..cda476c4 100644 --- a/src/llama_stack_client/types/shared_params/agent_config.py +++ b/src/llama_stack_client/types/shared_params/agent_config.py @@ -50,14 +50,14 @@ class ToolgroupUnionMember1(TypedDict, total=False): class AgentConfig(TypedDict, total=False): - enable_session_persistence: Required[bool] - instructions: Required[str] model: Required[str] client_tools: Iterable[ToolDefParam] + enable_session_persistence: bool + input_shields: List[str] max_infer_iters: int diff --git a/tests/api_resources/test_agents.py b/tests/api_resources/test_agents.py index 54006114..03a15837 100644 --- a/tests/api_resources/test_agents.py +++ b/tests/api_resources/test_agents.py @@ -21,7 +21,6 @@ class TestAgents: def test_method_create(self, client: LlamaStackClient) -> None: agent = client.agents.create( agent_config={ - "enable_session_persistence": True, "instructions": "instructions", "model": "model", }, @@ -32,7 +31,6 @@ def test_method_create(self, client: LlamaStackClient) -> None: def test_method_create_with_all_params(self, client: LlamaStackClient) -> None: agent = client.agents.create( agent_config={ - "enable_session_persistence": True, "instructions": "instructions", "model": "model", "client_tools": [ @@ -51,6 +49,7 @@ def test_method_create_with_all_params(self, client: LlamaStackClient) -> None: ], } ], + "enable_session_persistence": True, "input_shields": ["string"], "max_infer_iters": 0, "output_shields": ["string"], @@ -79,7 +78,6 @@ def test_method_create_with_all_params(self, client: LlamaStackClient) -> None: def test_raw_response_create(self, client: LlamaStackClient) -> None: response = client.agents.with_raw_response.create( agent_config={ - "enable_session_persistence": True, "instructions": "instructions", "model": "model", }, @@ -94,7 +92,6 @@ def test_raw_response_create(self, client: LlamaStackClient) -> None: def test_streaming_response_create(self, client: LlamaStackClient) -> None: with client.agents.with_streaming_response.create( agent_config={ - "enable_session_persistence": True, "instructions": "instructions", "model": "model", }, @@ -153,7 +150,6 @@ class TestAsyncAgents: async def test_method_create(self, async_client: AsyncLlamaStackClient) -> None: agent = await async_client.agents.create( agent_config={ - "enable_session_persistence": True, "instructions": "instructions", "model": "model", }, @@ -164,7 +160,6 @@ async def test_method_create(self, async_client: AsyncLlamaStackClient) -> None: async def test_method_create_with_all_params(self, async_client: AsyncLlamaStackClient) -> None: agent = await async_client.agents.create( agent_config={ - "enable_session_persistence": True, "instructions": "instructions", "model": "model", "client_tools": [ @@ -183,6 +178,7 @@ async def test_method_create_with_all_params(self, async_client: AsyncLlamaStack ], } ], + "enable_session_persistence": True, "input_shields": ["string"], "max_infer_iters": 0, "output_shields": ["string"], @@ -211,7 +207,6 @@ async def test_method_create_with_all_params(self, async_client: AsyncLlamaStack async def test_raw_response_create(self, async_client: AsyncLlamaStackClient) -> None: response = await async_client.agents.with_raw_response.create( agent_config={ - "enable_session_persistence": True, "instructions": "instructions", "model": "model", }, @@ -226,7 +221,6 @@ async def test_raw_response_create(self, async_client: AsyncLlamaStackClient) -> async def test_streaming_response_create(self, async_client: AsyncLlamaStackClient) -> None: async with async_client.agents.with_streaming_response.create( agent_config={ - "enable_session_persistence": True, "instructions": "instructions", "model": "model", },