From 1ca65086927a6cbda884ceefb7411c284a8b943a Mon Sep 17 00:00:00 2001 From: Hardik Shah Date: Tue, 18 Mar 2025 14:11:54 -0700 Subject: [PATCH] Sync updates from stainless branch: hardikjshah/dev --- .../lib/agents/client_tool.py | 21 +++++++++++++++++-- .../lib/agents/react/tool_parser.py | 16 +++++++++++--- src/llama_stack_client/resources/datasets.py | 12 +++++------ .../types/dataset_iterrows_params.py | 2 +- .../types/dataset_iterrows_response.py | 2 +- .../types/dataset_register_params.py | 6 ++++-- .../types/shared/tool_call.py | 13 +++++++++--- .../types/shared_params/tool_call.py | 16 +++++++++++--- 8 files changed, 67 insertions(+), 21 deletions(-) diff --git a/src/llama_stack_client/lib/agents/client_tool.py b/src/llama_stack_client/lib/agents/client_tool.py index 0d15dade..c199b211 100644 --- a/src/llama_stack_client/lib/agents/client_tool.py +++ b/src/llama_stack_client/lib/agents/client_tool.py @@ -7,7 +7,17 @@ import inspect import json from abc import abstractmethod -from typing import Any, Callable, Dict, get_args, get_origin, get_type_hints, List, TypeVar, Union +from typing import ( + Any, + Callable, + Dict, + get_args, + get_origin, + get_type_hints, + List, + TypeVar, + Union, +) from llama_stack_client.types import CompletionMessage, Message, ToolResponse from llama_stack_client.types.tool_def_param import Parameter, ToolDefParam @@ -72,7 +82,14 @@ def run( metadata = {} try: - response = self.run_impl(**tool_call.arguments) + if tool_call.arguments_json is not None: + params = json.loads(tool_call.arguments_json) + elif isinstance(tool_call.arguments, str): + params = json.loads(tool_call.arguments) + else: + params = tool_call.arguments + + response = self.run_impl(**params) if isinstance(response, dict) and "content" in response: content = json.dumps(response["content"], ensure_ascii=False) metadata = response.get("metadata", {}) diff --git a/src/llama_stack_client/lib/agents/react/tool_parser.py b/src/llama_stack_client/lib/agents/react/tool_parser.py index 0dfcfe48..1c418ffa 100644 --- a/src/llama_stack_client/lib/agents/react/tool_parser.py +++ b/src/llama_stack_client/lib/agents/react/tool_parser.py @@ -4,13 +4,15 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import json import uuid from typing import List, Optional, Union -from pydantic import BaseModel, ValidationError - from llama_stack_client.types.shared.completion_message import CompletionMessage from llama_stack_client.types.shared.tool_call import ToolCall + +from pydantic import BaseModel, ValidationError + from ..tool_parser import ToolParser @@ -31,6 +33,7 @@ class ReActOutput(BaseModel): class ReActToolParser(ToolParser): + @override def get_tool_calls(self, output_message: CompletionMessage) -> List[ToolCall]: tool_calls = [] response_text = str(output_message.content) @@ -49,6 +52,13 @@ def get_tool_calls(self, output_message: CompletionMessage) -> List[ToolCall]: params = {param.name: param.value for param in tool_params} if tool_name and tool_params: call_id = str(uuid.uuid4()) - tool_calls = [ToolCall(call_id=call_id, tool_name=tool_name, arguments=params)] + tool_calls = [ + ToolCall( + call_id=call_id, + tool_name=tool_name, + arguments=params, + arguments_json=json.dumps(params), + ) + ] return tool_calls diff --git a/src/llama_stack_client/resources/datasets.py b/src/llama_stack_client/resources/datasets.py index c4c6dc94..1df5f9c1 100644 --- a/src/llama_stack_client/resources/datasets.py +++ b/src/llama_stack_client/resources/datasets.py @@ -122,7 +122,7 @@ def iterrows( Uses cursor-based pagination. Args: - limit: The number of rows to get per page. + limit: The number of rows to get. start_index: Index into dataset for the first row to get. Get all rows if None. @@ -185,8 +185,8 @@ def register( "Hello, John Doe. How can I help you today?"}, {"role": "user", "content": "What's my name?"}, ], "answer": "John Doe" } - source: - The data source of the dataset. Examples: - { "type": "uri", "uri": + source: The data source of the dataset. Ensure that the data source schema is compatible + with the purpose of the dataset. Examples: - { "type": "uri", "uri": "https://mywebsite.com/mydata.jsonl" } - { "type": "uri", "uri": "lsfs://mydata.jsonl" } - { "type": "uri", "uri": "data:csv;base64,{base64_content}" } - { "type": "uri", "uri": @@ -347,7 +347,7 @@ async def iterrows( Uses cursor-based pagination. Args: - limit: The number of rows to get per page. + limit: The number of rows to get. start_index: Index into dataset for the first row to get. Get all rows if None. @@ -410,8 +410,8 @@ async def register( "Hello, John Doe. How can I help you today?"}, {"role": "user", "content": "What's my name?"}, ], "answer": "John Doe" } - source: - The data source of the dataset. Examples: - { "type": "uri", "uri": + source: The data source of the dataset. Ensure that the data source schema is compatible + with the purpose of the dataset. Examples: - { "type": "uri", "uri": "https://mywebsite.com/mydata.jsonl" } - { "type": "uri", "uri": "lsfs://mydata.jsonl" } - { "type": "uri", "uri": "data:csv;base64,{base64_content}" } - { "type": "uri", "uri": diff --git a/src/llama_stack_client/types/dataset_iterrows_params.py b/src/llama_stack_client/types/dataset_iterrows_params.py index 5c38d7c1..99065312 100644 --- a/src/llama_stack_client/types/dataset_iterrows_params.py +++ b/src/llama_stack_client/types/dataset_iterrows_params.py @@ -9,7 +9,7 @@ class DatasetIterrowsParams(TypedDict, total=False): limit: int - """The number of rows to get per page.""" + """The number of rows to get.""" start_index: int """Index into dataset for the first row to get. Get all rows if None.""" diff --git a/src/llama_stack_client/types/dataset_iterrows_response.py b/src/llama_stack_client/types/dataset_iterrows_response.py index f82233b5..48593bb2 100644 --- a/src/llama_stack_client/types/dataset_iterrows_response.py +++ b/src/llama_stack_client/types/dataset_iterrows_response.py @@ -11,7 +11,7 @@ class DatasetIterrowsResponse(BaseModel): data: List[Dict[str, Union[bool, float, str, List[object], object, None]]] """The rows in the current page.""" - next_index: Optional[int] = None + next_start_index: Optional[int] = None """Index into dataset for the first row in the next page. None if there are no more rows. diff --git a/src/llama_stack_client/types/dataset_register_params.py b/src/llama_stack_client/types/dataset_register_params.py index d2ff9d3a..824dd0a9 100644 --- a/src/llama_stack_client/types/dataset_register_params.py +++ b/src/llama_stack_client/types/dataset_register_params.py @@ -27,8 +27,10 @@ class DatasetRegisterParams(TypedDict, total=False): source: Required[Source] """The data source of the dataset. - Examples: - { "type": "uri", "uri": "https://mywebsite.com/mydata.jsonl" } - { - "type": "uri", "uri": "lsfs://mydata.jsonl" } - { "type": "uri", "uri": + Ensure that the data source schema is compatible with the purpose of the + dataset. Examples: - { "type": "uri", "uri": + "https://mywebsite.com/mydata.jsonl" } - { "type": "uri", "uri": + "lsfs://mydata.jsonl" } - { "type": "uri", "uri": "data:csv;base64,{base64_content}" } - { "type": "uri", "uri": "huggingface://llamastack/simpleqa?split=train" } - { "type": "rows", "rows": [ { "messages": [ {"role": "user", "content": "Hello, world!"}, {"role": diff --git a/src/llama_stack_client/types/shared/tool_call.py b/src/llama_stack_client/types/shared/tool_call.py index f1e83ee9..b9301d75 100644 --- a/src/llama_stack_client/types/shared/tool_call.py +++ b/src/llama_stack_client/types/shared/tool_call.py @@ -1,6 +1,6 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. -from typing import Dict, List, Union +from typing import Dict, List, Union, Optional from typing_extensions import Literal from ..._models import BaseModel @@ -9,11 +9,18 @@ class ToolCall(BaseModel): - arguments: Dict[ + arguments: Union[ str, - Union[str, float, bool, List[Union[str, float, bool, None]], Dict[str, Union[str, float, bool, None]], None], + Dict[ + str, + Union[ + str, float, bool, List[Union[str, float, bool, None]], Dict[str, Union[str, float, bool, None]], None + ], + ], ] call_id: str tool_name: Union[Literal["brave_search", "wolfram_alpha", "photogen", "code_interpreter"], str] + + arguments_json: Optional[str] = None diff --git a/src/llama_stack_client/types/shared_params/tool_call.py b/src/llama_stack_client/types/shared_params/tool_call.py index 2a50d041..801716e9 100644 --- a/src/llama_stack_client/types/shared_params/tool_call.py +++ b/src/llama_stack_client/types/shared_params/tool_call.py @@ -10,10 +10,18 @@ class ToolCall(TypedDict, total=False): arguments: Required[ - Dict[ + Union[ str, - Union[ - str, float, bool, List[Union[str, float, bool, None]], Dict[str, Union[str, float, bool, None]], None + Dict[ + str, + Union[ + str, + float, + bool, + List[Union[str, float, bool, None]], + Dict[str, Union[str, float, bool, None]], + None, + ], ], ] ] @@ -21,3 +29,5 @@ class ToolCall(TypedDict, total=False): call_id: Required[str] tool_name: Required[Union[Literal["brave_search", "wolfram_alpha", "photogen", "code_interpreter"], str]] + + arguments_json: str