11import asyncio
2+ import json
23import logging
34import os
45from typing import List , Sequence
2021from autogen_ext .models .anthropic import (
2122 AnthropicBedrockChatCompletionClient ,
2223 AnthropicChatCompletionClient ,
24+ BaseAnthropicChatCompletionClient ,
2325 BedrockInfo ,
2426)
2527
@@ -34,6 +36,11 @@ def _add_numbers(a: int, b: int) -> int:
3436 return a + b
3537
3638
39+ def _ask_for_input () -> str :
40+ """Function that asks for user input. Used to test empty input handling, such as in `pass_to_user` tool."""
41+ return "Further input from user"
42+
43+
3744@pytest .mark .asyncio
3845async def test_mock_tool_choice_specific_tool () -> None :
3946 """Test tool_choice parameter with a specific tool using mocks."""
@@ -999,3 +1006,104 @@ async def test_anthropic_tool_choice_none_value_with_actual_api() -> None:
9991006
10001007 # Should get a text response, not tool calls
10011008 assert isinstance (result .content , str )
1009+
1010+
1011+ def get_client_or_skip (provider : str ) -> BaseAnthropicChatCompletionClient :
1012+ if provider == "anthropic" :
1013+ api_key = os .getenv ("ANTHROPIC_API_KEY" )
1014+ if not api_key :
1015+ pytest .skip ("ANTHROPIC_API_KEY not found in environment variables" )
1016+
1017+ return AnthropicChatCompletionClient (
1018+ model = "claude-3-haiku-20240307" ,
1019+ api_key = api_key ,
1020+ )
1021+ else :
1022+ access_key = os .getenv ("AWS_ACCESS_KEY_ID" )
1023+ secret_key = os .getenv ("AWS_SECRET_ACCESS_KEY" )
1024+ region = os .getenv ("AWS_REGION" )
1025+ if not access_key or not secret_key or not region :
1026+ pytest .skip ("AWS credentials not found in environment variables" )
1027+
1028+ model = os .getenv ("ANTHROPIC_BEDROCK_MODEL" , "us.anthropic.claude-3-haiku-20240307-v1:0" )
1029+ return AnthropicBedrockChatCompletionClient (
1030+ model = model ,
1031+ bedrock_info = BedrockInfo (
1032+ aws_access_key = access_key ,
1033+ aws_secret_key = secret_key ,
1034+ aws_region = region ,
1035+ aws_session_token = os .getenv ("AWS_SESSION_TOKEN" , "" ),
1036+ ),
1037+ model_info = ModelInfo (
1038+ vision = False , function_calling = True , json_output = False , family = "unknown" , structured_output = True
1039+ ),
1040+ )
1041+
1042+
1043+ @pytest .mark .asyncio
1044+ @pytest .mark .parametrize ("provider" , ["anthropic" , "bedrock" ])
1045+ async def test_streaming_tool_usage_with_no_arguments (provider : str ) -> None :
1046+ """
1047+ Test reading streaming tool usage response with no arguments.
1048+ In that case `input` in initial `tool_use` chunk is `{}` and subsequent `partial_json` chunks are empty.
1049+ """
1050+ client = get_client_or_skip (provider )
1051+
1052+ # Define tools
1053+ ask_for_input_tool = FunctionTool (
1054+ _ask_for_input , description = "Ask user for more input" , name = "ask_for_input" , strict = True
1055+ )
1056+
1057+ chunks : List [str | CreateResult ] = []
1058+ async for chunk in client .create_stream (
1059+ messages = [
1060+ SystemMessage (content = "When user intent is unclear, ask for more input" ),
1061+ UserMessage (content = "Erm..." , source = "user" ),
1062+ ],
1063+ tools = [ask_for_input_tool ],
1064+ tool_choice = "required" ,
1065+ ):
1066+ chunks .append (chunk )
1067+
1068+ assert len (chunks ) > 0
1069+ assert isinstance (chunks [- 1 ], CreateResult )
1070+ result : CreateResult = chunks [- 1 ]
1071+ assert len (result .content ) == 1
1072+ content = result .content [- 1 ]
1073+ assert isinstance (content , FunctionCall )
1074+ assert content .name == "ask_for_input"
1075+ assert json .loads (content .arguments ) is not None
1076+
1077+
1078+ @pytest .mark .parametrize ("provider" , ["anthropic" , "bedrock" ])
1079+ @pytest .mark .asyncio
1080+ async def test_streaming_tool_usage_with_arguments (provider : str ) -> None :
1081+ """
1082+ Test reading streaming tool usage response with arguments.
1083+ In that case `input` in initial `tool_use` chunk is `{}` but subsequent `partial_json` chunks make up the actual
1084+ complete input value.
1085+ """
1086+ client = get_client_or_skip (provider )
1087+
1088+ # Define tools
1089+ add_numbers = FunctionTool (_add_numbers , description = "Add two numbers together" , name = "add_numbers" )
1090+
1091+ chunks : List [str | CreateResult ] = []
1092+ async for chunk in client .create_stream (
1093+ messages = [
1094+ SystemMessage (content = "Use the tools to evaluate calculations" ),
1095+ UserMessage (content = "2 + 2" , source = "user" ),
1096+ ],
1097+ tools = [add_numbers ],
1098+ tool_choice = "required" ,
1099+ ):
1100+ chunks .append (chunk )
1101+
1102+ assert len (chunks ) > 0
1103+ assert isinstance (chunks [- 1 ], CreateResult )
1104+ result : CreateResult = chunks [- 1 ]
1105+ assert len (result .content ) == 1
1106+ content = result .content [- 1 ]
1107+ assert isinstance (content , FunctionCall )
1108+ assert content .name == "add_numbers"
1109+ assert json .loads (content .arguments ) is not None
0 commit comments