Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions src/llama_stack_client/lib/agents/client_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,17 @@
import inspect
import json
from abc import abstractmethod
from typing import Any, Callable, Dict, get_args, get_origin, get_type_hints, List, TypeVar, Union
from typing import (
Any,
Callable,
Dict,
get_args,
get_origin,
get_type_hints,
List,
TypeVar,
Union,
)

from llama_stack_client.types import CompletionMessage, Message, ToolResponse
from llama_stack_client.types.tool_def_param import Parameter, ToolDefParam
Expand Down Expand Up @@ -72,7 +82,14 @@ def run(

metadata = {}
try:
response = self.run_impl(**tool_call.arguments)
if tool_call.arguments_json is not None:
params = json.loads(tool_call.arguments_json)
elif isinstance(tool_call.arguments, str):
params = json.loads(tool_call.arguments)
else:
params = tool_call.arguments

response = self.run_impl(**params)
if isinstance(response, dict) and "content" in response:
content = json.dumps(response["content"], ensure_ascii=False)
metadata = response.get("metadata", {})
Expand Down
16 changes: 13 additions & 3 deletions src/llama_stack_client/lib/agents/react/tool_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

import json
import uuid
from typing import List, Optional, Union

from pydantic import BaseModel, ValidationError

from llama_stack_client.types.shared.completion_message import CompletionMessage
from llama_stack_client.types.shared.tool_call import ToolCall

from pydantic import BaseModel, ValidationError

from ..tool_parser import ToolParser


Expand All @@ -31,6 +33,7 @@ class ReActOutput(BaseModel):


class ReActToolParser(ToolParser):
@override
def get_tool_calls(self, output_message: CompletionMessage) -> List[ToolCall]:
tool_calls = []
response_text = str(output_message.content)
Expand All @@ -49,6 +52,13 @@ def get_tool_calls(self, output_message: CompletionMessage) -> List[ToolCall]:
params = {param.name: param.value for param in tool_params}
if tool_name and tool_params:
call_id = str(uuid.uuid4())
tool_calls = [ToolCall(call_id=call_id, tool_name=tool_name, arguments=params)]
tool_calls = [
ToolCall(
call_id=call_id,
tool_name=tool_name,
arguments=params,
arguments_json=json.dumps(params),
)
]

return tool_calls
12 changes: 6 additions & 6 deletions src/llama_stack_client/resources/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def iterrows(
Uses cursor-based pagination.

Args:
limit: The number of rows to get per page.
limit: The number of rows to get.

start_index: Index into dataset for the first row to get. Get all rows if None.

Expand Down Expand Up @@ -185,8 +185,8 @@ def register(
"Hello, John Doe. How can I help you today?"}, {"role": "user", "content":
"What's my name?"}, ], "answer": "John Doe" }

source:
The data source of the dataset. Examples: - { "type": "uri", "uri":
source: The data source of the dataset. Ensure that the data source schema is compatible
with the purpose of the dataset. Examples: - { "type": "uri", "uri":
"https://mywebsite.com/mydata.jsonl" } - { "type": "uri", "uri":
"lsfs://mydata.jsonl" } - { "type": "uri", "uri":
"data:csv;base64,{base64_content}" } - { "type": "uri", "uri":
Expand Down Expand Up @@ -347,7 +347,7 @@ async def iterrows(
Uses cursor-based pagination.

Args:
limit: The number of rows to get per page.
limit: The number of rows to get.

start_index: Index into dataset for the first row to get. Get all rows if None.

Expand Down Expand Up @@ -410,8 +410,8 @@ async def register(
"Hello, John Doe. How can I help you today?"}, {"role": "user", "content":
"What's my name?"}, ], "answer": "John Doe" }

source:
The data source of the dataset. Examples: - { "type": "uri", "uri":
source: The data source of the dataset. Ensure that the data source schema is compatible
with the purpose of the dataset. Examples: - { "type": "uri", "uri":
"https://mywebsite.com/mydata.jsonl" } - { "type": "uri", "uri":
"lsfs://mydata.jsonl" } - { "type": "uri", "uri":
"data:csv;base64,{base64_content}" } - { "type": "uri", "uri":
Expand Down
2 changes: 1 addition & 1 deletion src/llama_stack_client/types/dataset_iterrows_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

class DatasetIterrowsParams(TypedDict, total=False):
limit: int
"""The number of rows to get per page."""
"""The number of rows to get."""

start_index: int
"""Index into dataset for the first row to get. Get all rows if None."""
2 changes: 1 addition & 1 deletion src/llama_stack_client/types/dataset_iterrows_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class DatasetIterrowsResponse(BaseModel):
data: List[Dict[str, Union[bool, float, str, List[object], object, None]]]
"""The rows in the current page."""

next_index: Optional[int] = None
next_start_index: Optional[int] = None
"""Index into dataset for the first row in the next page.

None if there are no more rows.
Expand Down
6 changes: 4 additions & 2 deletions src/llama_stack_client/types/dataset_register_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@ class DatasetRegisterParams(TypedDict, total=False):
source: Required[Source]
"""The data source of the dataset.

Examples: - { "type": "uri", "uri": "https://mywebsite.com/mydata.jsonl" } - {
"type": "uri", "uri": "lsfs://mydata.jsonl" } - { "type": "uri", "uri":
Ensure that the data source schema is compatible with the purpose of the
dataset. Examples: - { "type": "uri", "uri":
"https://mywebsite.com/mydata.jsonl" } - { "type": "uri", "uri":
"lsfs://mydata.jsonl" } - { "type": "uri", "uri":
"data:csv;base64,{base64_content}" } - { "type": "uri", "uri":
"huggingface://llamastack/simpleqa?split=train" } - { "type": "rows", "rows": [
{ "messages": [ {"role": "user", "content": "Hello, world!"}, {"role":
Expand Down
13 changes: 10 additions & 3 deletions src/llama_stack_client/types/shared/tool_call.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.

from typing import Dict, List, Union
from typing import Dict, List, Union, Optional
from typing_extensions import Literal

from ..._models import BaseModel
Expand All @@ -9,11 +9,18 @@


class ToolCall(BaseModel):
arguments: Dict[
arguments: Union[
str,
Union[str, float, bool, List[Union[str, float, bool, None]], Dict[str, Union[str, float, bool, None]], None],
Dict[
str,
Union[
str, float, bool, List[Union[str, float, bool, None]], Dict[str, Union[str, float, bool, None]], None
],
],
]

call_id: str

tool_name: Union[Literal["brave_search", "wolfram_alpha", "photogen", "code_interpreter"], str]

arguments_json: Optional[str] = None
16 changes: 13 additions & 3 deletions src/llama_stack_client/types/shared_params/tool_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,24 @@

class ToolCall(TypedDict, total=False):
arguments: Required[
Dict[
Union[
str,
Union[
str, float, bool, List[Union[str, float, bool, None]], Dict[str, Union[str, float, bool, None]], None
Dict[
str,
Union[
str,
float,
bool,
List[Union[str, float, bool, None]],
Dict[str, Union[str, float, bool, None]],
None,
],
],
]
]

call_id: Required[str]

tool_name: Required[Union[Literal["brave_search", "wolfram_alpha", "photogen", "code_interpreter"], str]]

arguments_json: str