Skip to content

Commit 7fbf8ab

Browse files
authored
Fix loading streaming Bedrock response with tool usage with empty argument (#6979)
1 parent 82df9dd commit 7fbf8ab

File tree

2 files changed

+116
-2
lines changed

2 files changed

+116
-2
lines changed

python/packages/autogen-ext/src/autogen_ext/models/anthropic/_anthropic_client.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -881,7 +881,8 @@ async def create_stream(
881881
tool_calls[current_tool_id] = {
882882
"id": chunk.content_block.id,
883883
"name": chunk.content_block.name,
884-
"input": "", # Will be populated from deltas
884+
"input": json.dumps(chunk.content_block.input),
885+
"partial_json": "", # May be populated from deltas
885886
}
886887

887888
elif chunk.type == "content_block_delta":
@@ -896,10 +897,15 @@ async def create_stream(
896897
elif hasattr(chunk.delta, "type") and chunk.delta.type == "input_json_delta":
897898
if current_tool_id is not None and hasattr(chunk.delta, "partial_json"):
898899
# Accumulate partial JSON for the current tool
899-
tool_calls[current_tool_id]["input"] += chunk.delta.partial_json
900+
tool_calls[current_tool_id]["partial_json"] += chunk.delta.partial_json
900901

901902
elif chunk.type == "content_block_stop":
902903
# End of a content block (could be text or tool)
904+
if current_tool_id is not None:
905+
# If there was partial JSON accumulated, use it as the input
906+
if len(tool_calls[current_tool_id]["partial_json"]) > 0:
907+
tool_calls[current_tool_id]["input"] = tool_calls[current_tool_id]["partial_json"]
908+
del tool_calls[current_tool_id]["partial_json"]
903909
current_tool_id = None
904910

905911
elif chunk.type == "message_delta":

python/packages/autogen-ext/tests/models/test_anthropic_model_client.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
import json
23
import logging
34
import os
45
from typing import List, Sequence
@@ -20,6 +21,7 @@
2021
from autogen_ext.models.anthropic import (
2122
AnthropicBedrockChatCompletionClient,
2223
AnthropicChatCompletionClient,
24+
BaseAnthropicChatCompletionClient,
2325
BedrockInfo,
2426
)
2527

@@ -34,6 +36,11 @@ def _add_numbers(a: int, b: int) -> int:
3436
return a + b
3537

3638

39+
def _ask_for_input() -> str:
40+
"""Function that asks for user input. Used to test empty input handling, such as in `pass_to_user` tool."""
41+
return "Further input from user"
42+
43+
3744
@pytest.mark.asyncio
3845
async def test_mock_tool_choice_specific_tool() -> None:
3946
"""Test tool_choice parameter with a specific tool using mocks."""
@@ -999,3 +1006,104 @@ async def test_anthropic_tool_choice_none_value_with_actual_api() -> None:
9991006

10001007
# Should get a text response, not tool calls
10011008
assert isinstance(result.content, str)
1009+
1010+
1011+
def get_client_or_skip(provider: str) -> BaseAnthropicChatCompletionClient:
1012+
if provider == "anthropic":
1013+
api_key = os.getenv("ANTHROPIC_API_KEY")
1014+
if not api_key:
1015+
pytest.skip("ANTHROPIC_API_KEY not found in environment variables")
1016+
1017+
return AnthropicChatCompletionClient(
1018+
model="claude-3-haiku-20240307",
1019+
api_key=api_key,
1020+
)
1021+
else:
1022+
access_key = os.getenv("AWS_ACCESS_KEY_ID")
1023+
secret_key = os.getenv("AWS_SECRET_ACCESS_KEY")
1024+
region = os.getenv("AWS_REGION")
1025+
if not access_key or not secret_key or not region:
1026+
pytest.skip("AWS credentials not found in environment variables")
1027+
1028+
model = os.getenv("ANTHROPIC_BEDROCK_MODEL", "us.anthropic.claude-3-haiku-20240307-v1:0")
1029+
return AnthropicBedrockChatCompletionClient(
1030+
model=model,
1031+
bedrock_info=BedrockInfo(
1032+
aws_access_key=access_key,
1033+
aws_secret_key=secret_key,
1034+
aws_region=region,
1035+
aws_session_token=os.getenv("AWS_SESSION_TOKEN", ""),
1036+
),
1037+
model_info=ModelInfo(
1038+
vision=False, function_calling=True, json_output=False, family="unknown", structured_output=True
1039+
),
1040+
)
1041+
1042+
1043+
@pytest.mark.asyncio
1044+
@pytest.mark.parametrize("provider", ["anthropic", "bedrock"])
1045+
async def test_streaming_tool_usage_with_no_arguments(provider: str) -> None:
1046+
"""
1047+
Test reading streaming tool usage response with no arguments.
1048+
In that case `input` in initial `tool_use` chunk is `{}` and subsequent `partial_json` chunks are empty.
1049+
"""
1050+
client = get_client_or_skip(provider)
1051+
1052+
# Define tools
1053+
ask_for_input_tool = FunctionTool(
1054+
_ask_for_input, description="Ask user for more input", name="ask_for_input", strict=True
1055+
)
1056+
1057+
chunks: List[str | CreateResult] = []
1058+
async for chunk in client.create_stream(
1059+
messages=[
1060+
SystemMessage(content="When user intent is unclear, ask for more input"),
1061+
UserMessage(content="Erm...", source="user"),
1062+
],
1063+
tools=[ask_for_input_tool],
1064+
tool_choice="required",
1065+
):
1066+
chunks.append(chunk)
1067+
1068+
assert len(chunks) > 0
1069+
assert isinstance(chunks[-1], CreateResult)
1070+
result: CreateResult = chunks[-1]
1071+
assert len(result.content) == 1
1072+
content = result.content[-1]
1073+
assert isinstance(content, FunctionCall)
1074+
assert content.name == "ask_for_input"
1075+
assert json.loads(content.arguments) is not None
1076+
1077+
1078+
@pytest.mark.parametrize("provider", ["anthropic", "bedrock"])
1079+
@pytest.mark.asyncio
1080+
async def test_streaming_tool_usage_with_arguments(provider: str) -> None:
1081+
"""
1082+
Test reading streaming tool usage response with arguments.
1083+
In that case `input` in initial `tool_use` chunk is `{}` but subsequent `partial_json` chunks make up the actual
1084+
complete input value.
1085+
"""
1086+
client = get_client_or_skip(provider)
1087+
1088+
# Define tools
1089+
add_numbers = FunctionTool(_add_numbers, description="Add two numbers together", name="add_numbers")
1090+
1091+
chunks: List[str | CreateResult] = []
1092+
async for chunk in client.create_stream(
1093+
messages=[
1094+
SystemMessage(content="Use the tools to evaluate calculations"),
1095+
UserMessage(content="2 + 2", source="user"),
1096+
],
1097+
tools=[add_numbers],
1098+
tool_choice="required",
1099+
):
1100+
chunks.append(chunk)
1101+
1102+
assert len(chunks) > 0
1103+
assert isinstance(chunks[-1], CreateResult)
1104+
result: CreateResult = chunks[-1]
1105+
assert len(result.content) == 1
1106+
content = result.content[-1]
1107+
assert isinstance(content, FunctionCall)
1108+
assert content.name == "add_numbers"
1109+
assert json.loads(content.arguments) is not None

0 commit comments

Comments
 (0)