diff --git a/python/packages/autogen-ext/src/autogen_ext/models/anthropic/__init__.py b/python/packages/autogen-ext/src/autogen_ext/models/anthropic/__init__.py index f31e7b1c0b72..d1c29a9c82f4 100644 --- a/python/packages/autogen-ext/src/autogen_ext/models/anthropic/__init__.py +++ b/python/packages/autogen-ext/src/autogen_ext/models/anthropic/__init__.py @@ -1,6 +1,7 @@ from ._anthropic_client import ( AnthropicBedrockChatCompletionClient, AnthropicChatCompletionClient, + AnthropicVertexChatCompletionClient, BaseAnthropicChatCompletionClient, ) from .config import ( @@ -8,18 +9,25 @@ AnthropicBedrockClientConfigurationConfigModel, AnthropicClientConfiguration, AnthropicClientConfigurationConfigModel, + AnthropicVertexClientConfiguration, + AnthropicVertexClientConfigurationConfigModel, BedrockInfo, CreateArgumentsConfigModel, + VertexInfo, ) __all__ = [ "AnthropicChatCompletionClient", "AnthropicBedrockChatCompletionClient", + "AnthropicVertexChatCompletionClient", "BaseAnthropicChatCompletionClient", "AnthropicClientConfiguration", "AnthropicBedrockClientConfiguration", + "AnthropicVertexClientConfiguration", "AnthropicClientConfigurationConfigModel", "AnthropicBedrockClientConfigurationConfigModel", + "AnthropicVertexClientConfigurationConfigModel", "CreateArgumentsConfigModel", "BedrockInfo", + "VertexInfo", ] diff --git a/python/packages/autogen-ext/src/autogen_ext/models/anthropic/_anthropic_client.py b/python/packages/autogen-ext/src/autogen_ext/models/anthropic/_anthropic_client.py index 6f68cf7b8cbc..f9eb41f94b41 100644 --- a/python/packages/autogen-ext/src/autogen_ext/models/anthropic/_anthropic_client.py +++ b/python/packages/autogen-ext/src/autogen_ext/models/anthropic/_anthropic_client.py @@ -23,7 +23,7 @@ ) import tiktoken -from anthropic import AsyncAnthropic, AsyncAnthropicBedrock, AsyncStream +from anthropic import AsyncAnthropic, AsyncAnthropicBedrock, AsyncAnthropicVertex, AsyncStream from anthropic.types import ( Base64ImageSourceParam, ContentBlock, @@ -71,7 +71,10 @@ AnthropicBedrockClientConfigurationConfigModel, AnthropicClientConfiguration, AnthropicClientConfigurationConfigModel, + AnthropicVertexClientConfiguration, + AnthropicVertexClientConfigurationConfigModel, BedrockInfo, + VertexInfo, ) logger = logging.getLogger(EVENT_LOGGER_NAME) @@ -1414,3 +1417,130 @@ def _from_config(cls, config: AnthropicBedrockClientConfigurationConfigModel) -> } return cls(**copied_config) + + +class AnthropicVertexChatCompletionClient( + BaseAnthropicChatCompletionClient, Component[AnthropicVertexClientConfigurationConfigModel] +): + """ + Chat completion client for Anthropic's Claude models via Google Vertex AI. + + Args: + model (str): The Claude model to use (e.g., "claude-3-sonnet-20240229", "claude-3-opus-20240229") + vertex_info (VertexInfo): Configuration for Vertex AI including project_id and region + max_tokens (int, optional): Maximum tokens in the response. Default is 4096. + temperature (float, optional): Controls randomness. Lower is more deterministic. Default is 1.0. + top_p (float, optional): Controls diversity via nucleus sampling. Default is 1.0. + top_k (int, optional): Controls diversity via top-k sampling. Default is -1 (disabled). + model_info (ModelInfo, optional): The capabilities of the model. Required if using a custom model. + + To use this client, you must install the Anthropic extension: + + .. code-block:: bash + + pip install "autogen-ext[anthropic]" + + Example: + + .. code-block:: python + + import asyncio + from autogen_ext.models.anthropic import AnthropicVertexChatCompletionClient, VertexInfo + from autogen_core.models import UserMessage + + + async def main(): + vertex_client = AnthropicVertexChatCompletionClient( + model="claude-3-sonnet-20240229", + vertex_info=VertexInfo(project_id="your-gcp-project-id", region="us-east5"), + ) + + result = await vertex_client.create([UserMessage(content="What is the capital of France?", source="user")]) # type: ignore + print(result) + + + if __name__ == "__main__": + asyncio.run(main()) + + To load the client from a configuration: + + .. code-block:: python + + from autogen_core.models import ChatCompletionClient + + config = { + "provider": "AnthropicVertexChatCompletionClient", + "config": { + "model": "claude-3-sonnet-20240229", + "vertex_info": {"project_id": "your-gcp-project-id", "region": "us-east5"}, + }, + } + + client = ChatCompletionClient.load_component(config) + """ + + component_type = "model" + component_config_schema = AnthropicVertexClientConfigurationConfigModel + component_provider_override = "autogen_ext.models.anthropic.AnthropicVertexChatCompletionClient" + + def __init__(self, **kwargs: Unpack[AnthropicVertexClientConfiguration]): + if "model" not in kwargs: + raise ValueError("model is required for AnthropicVertexChatCompletionClient") + + self._raw_config: Dict[str, Any] = dict(kwargs).copy() + copied_args = dict(kwargs).copy() + + model_info: Optional[ModelInfo] = None + if "model_info" in kwargs: + model_info = kwargs["model_info"] + del copied_args["model_info"] + + vertex_info: Optional[VertexInfo] = None + if "vertex_info" in kwargs: + vertex_info = kwargs["vertex_info"] + + if vertex_info is None: + raise ValueError("vertex_info is required for AnthropicVertexChatCompletionClient") + + # Handle vertex_info + project_id = vertex_info["project_id"] + region = vertex_info["region"] + + client = AsyncAnthropicVertex( + project_id=project_id, + region=region, + ) + create_args = _create_args_from_config(copied_args) + + super().__init__( + client=client, + create_args=create_args, + model_info=model_info, + ) + + def __getstate__(self) -> Dict[str, Any]: + state = self.__dict__.copy() + state["_client"] = None + return state + + def __setstate__(self, state: Dict[str, Any]) -> None: + self.__dict__.update(state) + # Recreate the client from raw config + vertex_info = state["_raw_config"]["vertex_info"] + self._client = AsyncAnthropicVertex( + project_id=vertex_info["project_id"], + region=vertex_info["region"], + ) + + def _to_config(self) -> AnthropicVertexClientConfigurationConfigModel: + copied_config = self._raw_config.copy() + return AnthropicVertexClientConfigurationConfigModel(**copied_config) + + @classmethod + def _from_config(cls, config: AnthropicVertexClientConfigurationConfigModel) -> Self: + copied_config = config.model_copy().model_dump(exclude_none=True) + + # Handle vertex_info properly - no secret values to extract like bedrock + # vertex_info contains project_id and region which are not secret + + return cls(**copied_config) diff --git a/python/packages/autogen-ext/src/autogen_ext/models/anthropic/config/__init__.py b/python/packages/autogen-ext/src/autogen_ext/models/anthropic/config/__init__.py index 10b46b6a6b00..a1e1b9b0cdb6 100644 --- a/python/packages/autogen-ext/src/autogen_ext/models/anthropic/config/__init__.py +++ b/python/packages/autogen-ext/src/autogen_ext/models/anthropic/config/__init__.py @@ -44,6 +44,18 @@ class BedrockInfo(TypedDict): """aws region for the aws account to gain bedrock model access""" +class VertexInfo(TypedDict): + """VertexInfo is a dictionary that contains information about Vertex AI configuration. + It is expected to be used in the vertex_info property of a model client. + + """ + + project_id: Required[str] + """GCP project ID for Vertex AI access""" + region: Required[str] + """GCP region for Vertex AI access""" + + class BaseAnthropicClientConfiguration(CreateArguments, total=False): api_key: str base_url: Optional[str] @@ -64,6 +76,10 @@ class AnthropicBedrockClientConfiguration(AnthropicClientConfiguration, total=Fa bedrock_info: BedrockInfo +class AnthropicVertexClientConfiguration(AnthropicClientConfiguration, total=False): + vertex_info: VertexInfo + + # Pydantic equivalents of the above TypedDicts class ThinkingConfigModel(BaseModel): """Configuration for thinking mode.""" @@ -111,3 +127,14 @@ class BedrockInfoConfigModel(TypedDict): class AnthropicBedrockClientConfigurationConfigModel(AnthropicClientConfigurationConfigModel): bedrock_info: BedrockInfoConfigModel | None = None + + +class VertexInfoConfigModel(TypedDict): + project_id: Required[str] + """GCP project ID for Vertex AI access""" + region: Required[str] + """GCP region for Vertex AI access""" + + +class AnthropicVertexClientConfigurationConfigModel(AnthropicClientConfigurationConfigModel): + vertex_info: VertexInfoConfigModel | None = None diff --git a/python/packages/autogen-ext/tests/models/test_anthropic_model_client.py b/python/packages/autogen-ext/tests/models/test_anthropic_model_client.py index 6238eee3bbe5..66ba71d50877 100644 --- a/python/packages/autogen-ext/tests/models/test_anthropic_model_client.py +++ b/python/packages/autogen-ext/tests/models/test_anthropic_model_client.py @@ -21,8 +21,10 @@ from autogen_ext.models.anthropic import ( AnthropicBedrockChatCompletionClient, AnthropicChatCompletionClient, + AnthropicVertexChatCompletionClient, BaseAnthropicChatCompletionClient, BedrockInfo, + VertexInfo, ) @@ -1109,6 +1111,114 @@ async def test_streaming_tool_usage_with_arguments(provider: str) -> None: assert json.loads(content.arguments) is not None +def test_vertex_client_initialization_success() -> None: + """Test successful initialization of AnthropicVertexChatCompletionClient.""" + vertex_info = VertexInfo(project_id="test-project-123", region="us-east5") + + client = AnthropicVertexChatCompletionClient(model="claude-3-sonnet-20240229", vertex_info=vertex_info) + + assert client._create_args["model"] == "claude-3-sonnet-20240229" # pyright: ignore[reportPrivateUsage] + assert client._raw_config["vertex_info"]["project_id"] == "test-project-123" # pyright: ignore[reportPrivateUsage] + assert client._raw_config["vertex_info"]["region"] == "us-east5" # pyright: ignore[reportPrivateUsage] + + +def test_vertex_client_initialization_missing_model() -> None: + """Test that missing model raises ValueError.""" + vertex_info = VertexInfo(project_id="test-project-123", region="us-east5") + + with pytest.raises(ValueError, match="model is required for AnthropicVertexChatCompletionClient"): + AnthropicVertexChatCompletionClient(vertex_info=vertex_info) + + +def test_vertex_client_initialization_missing_vertex_info() -> None: + """Test that missing vertex_info raises ValueError.""" + with pytest.raises(ValueError, match="vertex_info is required for AnthropicVertexChatCompletionClient"): + AnthropicVertexChatCompletionClient(model="claude-3-sonnet-20240229") + + +def test_vertex_client_with_model_info() -> None: + """Test initialization with custom model_info.""" + vertex_info = VertexInfo(project_id="test-project-123", region="us-east5") + + custom_model_info = ModelInfo( + vision=True, function_calling=True, json_output=True, family="test-family", structured_output=False + ) + + client = AnthropicVertexChatCompletionClient( + model="custom-claude-model", vertex_info=vertex_info, model_info=custom_model_info + ) + + assert client.model_info == custom_model_info + + +@pytest.mark.asyncio +async def test_vertex_client_mock_create() -> None: + """Test basic create functionality with mocked client.""" + from anthropic.types import TextBlock + + # Create mock client and response + mock_client = AsyncMock() + mock_message = MagicMock() + mock_message.content = [TextBlock(type="text", text="Hello! I'm Claude.")] + mock_message.usage.input_tokens = 10 + mock_message.usage.output_tokens = 5 + mock_message.stop_reason = "end_turn" + + mock_client.messages.create.return_value = mock_message + + # Create real client but patch the underlying Anthropic client + vertex_info = VertexInfo(project_id="test-project", region="us-east5") + client = AnthropicVertexChatCompletionClient(model="claude-3-sonnet-20240229", vertex_info=vertex_info) + + messages: List[LLMMessage] = [ + UserMessage(content="Hello", source="user"), + ] + + with patch.object(client, "_client", mock_client): + result = await client.create(messages=messages) + + # Verify the call was made + mock_client.messages.create.assert_called_once() + + # Verify the result + assert result.content == "Hello! I'm Claude." + assert result.usage.prompt_tokens == 10 + assert result.usage.completion_tokens == 5 + + +@pytest.mark.asyncio +async def test_vertex_client_mock_tool_usage() -> None: + """Test tool usage with mocked client.""" + from anthropic.types import ToolUseBlock + + # Create mock client and response + mock_client = AsyncMock() + mock_message = MagicMock() + mock_message.content = [ToolUseBlock(type="tool_use", name="add_numbers", input={"a": 1, "b": 2}, id="call_123")] + mock_message.usage.input_tokens = 15 + mock_message.usage.output_tokens = 8 + mock_message.stop_reason = "tool_use" + + mock_client.messages.create.return_value = mock_message + + # Create real client + vertex_info = VertexInfo(project_id="test-project", region="us-east5") + client = AnthropicVertexChatCompletionClient( + model="claude-3-sonnet-20240229", + vertex_info=vertex_info, + ) + + # Patch the client and make a call + messages = [UserMessage(content="Calculate 1 + 2", source="user")] + with patch.object(client, "_client", mock_client): + result = await client.create(messages=messages) + + # Verify tool use was returned + assert len(result.content) == 1 + assert isinstance(result.content[0], FunctionCall) + assert result.content[0].arguments == '{"a": 1, "b": 2}' # pyright: ignore[reportUnknownMemberType] + + def test_mock_thinking_config_validation() -> None: """Test thinking configuration handling logic.""" client = AnthropicChatCompletionClient( @@ -1237,20 +1347,76 @@ async def test_anthropic_thinking_mode_with_tools() -> None: # Define tool add_tool = FunctionTool(_add_numbers, description="Add two numbers together", name="add_numbers") - messages = [ - UserMessage(content="I need to add 25 and 17. Use the add tool after thinking about it.", source="test") + messages: List[LLMMessage] = [ + UserMessage(content="What is 1 + 2? Think step by step.", source="user"), ] thinking_config = {"thinking": {"type": "enabled", "budget_tokens": 2000}} - result = await client.create(messages, tools=[add_tool], extra_create_args=thinking_config) + result = await client.create( + messages=messages, tools=[add_tool], tool_choice="auto", extra_create_args=thinking_config + ) - # Should get tool calls + # Verify tool call result assert isinstance(result.content, list) assert len(result.content) >= 1 - assert isinstance(result.content[0], FunctionCall) - assert result.content[0].name == "add_numbers" + tool_call = result.content[0] + assert isinstance(tool_call, FunctionCall) + assert tool_call.name == "add_numbers" + + +def test_vertex_client_serialization() -> None: + """Test that client can be serialized and deserialized.""" + import pickle + + vertex_info = VertexInfo(project_id="test-project-456", region="us-central1") + + client = AnthropicVertexChatCompletionClient( + model="claude-3-haiku-20240307", vertex_info=vertex_info, temperature=0.7, max_tokens=1000 + ) + + # Serialize and deserialize + serialized = pickle.dumps(client) + deserialized = pickle.loads(serialized) + + # Verify the deserialized client has same config + assert deserialized._create_args["model"] == "claude-3-haiku-20240307" + assert deserialized._create_args["temperature"] == 0.7 + assert deserialized._create_args["max_tokens"] == 1000 + assert deserialized._raw_config["vertex_info"]["project_id"] == "test-project-456" + assert deserialized._raw_config["vertex_info"]["region"] == "us-central1" + + +def test_vertex_client_config_conversion() -> None: + """Test _to_config and _from_config methods.""" + from autogen_ext.models.anthropic import AnthropicVertexClientConfigurationConfigModel + + vertex_info = VertexInfo(project_id="config-test-project", region="europe-west1") + + client = AnthropicVertexChatCompletionClient( + model="claude-3-opus-20240229", vertex_info=vertex_info, temperature=0.5 + ) + + # Convert to config + config = client._to_config() # pyright: ignore[reportPrivateUsage] + assert isinstance(config, AnthropicVertexClientConfigurationConfigModel) + + # Create new client from config + new_client = AnthropicVertexChatCompletionClient._from_config(config) # pyright: ignore[reportPrivateUsage] + + # Verify equivalence + assert new_client._create_args["model"] == "claude-3-opus-20240229" # pyright: ignore[reportPrivateUsage] + assert new_client._create_args["temperature"] == 0.5 # pyright: ignore[reportPrivateUsage] + assert new_client._raw_config["vertex_info"]["project_id"] == "config-test-project" # pyright: ignore[reportPrivateUsage] + assert new_client._raw_config["vertex_info"]["region"] == "europe-west1" # pyright: ignore[reportPrivateUsage] + + +def test_vertex_client_component_attributes() -> None: + """Test that component attributes are set correctly.""" + vertex_info = VertexInfo(project_id="test", region="us-east5") + client = AnthropicVertexChatCompletionClient(model="claude-3-sonnet-20240229", vertex_info=vertex_info) - # Should have thinking content even with tool calls - assert result.thought is not None - assert len(result.thought) > 10 + assert client.component_type == "model" + assert hasattr(client, "component_config_schema") + assert hasattr(client, "component_provider_override") + assert client.component_provider_override == "autogen_ext.models.anthropic.AnthropicVertexChatCompletionClient"