From a2e08d9ba9cc03b3328b6fe4944f50d931581783 Mon Sep 17 00:00:00 2001 From: Emmett McFaralne Date: Fri, 3 Jan 2025 13:38:35 -0500 Subject: [PATCH 1/3] Bug fix for anthropic clients running on AWS bedrock --- instructor/utils.py | 50 +++++++++++++++++++++++++----------- tests/llm/test_new_client.py | 2 -- 2 files changed, 35 insertions(+), 17 deletions(-) diff --git a/instructor/utils.py b/instructor/utils.py index 088cebbaf..0792a7a84 100644 --- a/instructor/utils.py +++ b/instructor/utils.py @@ -167,20 +167,36 @@ def update_total_usage( if isinstance(response_usage, AnthropicUsage) and isinstance( total_usage, AnthropicUsage ): - if not total_usage.cache_creation_input_tokens: - total_usage.cache_creation_input_tokens = 0 + # update input_tokens / output_tokens + if hasattr(total_usage, "input_tokens") and hasattr( + response_usage, "input_tokens" + ): + total_usage.input_tokens += response_usage.input_tokens or 0 + if hasattr(total_usage, "output_tokens") and hasattr( + response_usage, "output_tokens" + ): + total_usage.output_tokens += response_usage.output_tokens or 0 + + # Update cache_creation_input_tokens if both have that field + if hasattr(total_usage, "cache_creation_input_tokens") and hasattr( + response_usage, "cache_creation_input_tokens" + ): + if not total_usage.cache_creation_input_tokens: + total_usage.cache_creation_input_tokens = 0 + total_usage.cache_creation_input_tokens += ( + response_usage.cache_creation_input_tokens or 0 + ) + + # Update cache_read_input_tokens if both have that field + if hasattr(total_usage, "cache_read_input_tokens") and hasattr( + response_usage, "cache_read_input_tokens" + ): + if not total_usage.cache_read_input_tokens: + total_usage.cache_read_input_tokens = 0 + total_usage.cache_read_input_tokens += ( + response_usage.cache_read_input_tokens or 0 + ) - if not total_usage.cache_read_input_tokens: - total_usage.cache_read_input_tokens = 0 - - total_usage.input_tokens += response_usage.input_tokens or 0 - total_usage.output_tokens += response_usage.output_tokens or 0 - total_usage.cache_creation_input_tokens += ( - response_usage.cache_creation_input_tokens or 0 - ) - total_usage.cache_read_input_tokens += ( - response_usage.cache_read_input_tokens or 0 - ) response.usage = total_usage return response except ImportError: @@ -429,7 +445,9 @@ def combine_system_messages( def extract_system_messages(messages: list[dict[str, Any]]) -> list[SystemMessage]: - def convert_message(content: Union[str, dict[str, Any]]) -> SystemMessage: # noqa: UP007 + def convert_message( + content: Union[str, dict[str, Any]] + ) -> SystemMessage: # noqa: UP007 if isinstance(content, str): return SystemMessage(type="text", text=content) elif isinstance(content, dict): @@ -441,7 +459,9 @@ def convert_message(content: Union[str, dict[str, Any]]) -> SystemMessage: # no for m in messages: if m["role"] == "system": # System message must always be a string or list of dictionaries - content = cast(Union[str, list[dict[str, Any]]], m["content"]) # noqa: UP007 + content = cast( + Union[str, list[dict[str, Any]]], m["content"] + ) # noqa: UP007 if isinstance(content, list): result.extend(convert_message(item) for item in content) else: diff --git a/tests/llm/test_new_client.py b/tests/llm/test_new_client.py index 5acf4e6a7..d20eb64d4 100644 --- a/tests/llm/test_new_client.py +++ b/tests/llm/test_new_client.py @@ -180,7 +180,6 @@ def test_client_anthropic_response(): assert user.age == 10 -@pytest.mark.skip(reason="Skip for now") def test_client_anthropic_bedrock_response(): client = anthropic.AnthropicBedrock( aws_access_key=os.getenv("AWS_ACCESS_KEY_ID"), @@ -222,7 +221,6 @@ async def test_async_client_anthropic_response(): assert user.age == 10 -@pytest.mark.skip(reason="Skip for now") @pytest.mark.asyncio async def test_async_client_anthropic_bedrock_response(): client = anthropic.AsyncAnthropicBedrock( From eb79153dd579a40b231fee674f588c05c1d894fd Mon Sep 17 00:00:00 2001 From: Emmett McFaralne Date: Fri, 3 Jan 2025 13:50:56 -0500 Subject: [PATCH 2/3] Undo formatting changes in utils --- instructor/utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/instructor/utils.py b/instructor/utils.py index 0792a7a84..b01ed3cfc 100644 --- a/instructor/utils.py +++ b/instructor/utils.py @@ -445,9 +445,7 @@ def combine_system_messages( def extract_system_messages(messages: list[dict[str, Any]]) -> list[SystemMessage]: - def convert_message( - content: Union[str, dict[str, Any]] - ) -> SystemMessage: # noqa: UP007 + def convert_message(content: Union[str, dict[str, Any]]) -> SystemMessage: if isinstance(content, str): return SystemMessage(type="text", text=content) elif isinstance(content, dict): From 12135f6cd0608b4c8f2fecbd78e5b4d490d6d161 Mon Sep 17 00:00:00 2001 From: Emmett McFarlane Date: Fri, 3 Jan 2025 13:57:24 -0500 Subject: [PATCH 3/3] Added noqa back into utils.py --- instructor/utils.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/instructor/utils.py b/instructor/utils.py index b01ed3cfc..29effb9df 100644 --- a/instructor/utils.py +++ b/instructor/utils.py @@ -445,7 +445,7 @@ def combine_system_messages( def extract_system_messages(messages: list[dict[str, Any]]) -> list[SystemMessage]: - def convert_message(content: Union[str, dict[str, Any]]) -> SystemMessage: + def convert_message(content: Union[str, dict[str, Any]]) -> SystemMessage: # noqa: UP007 if isinstance(content, str): return SystemMessage(type="text", text=content) elif isinstance(content, dict): @@ -457,9 +457,7 @@ def convert_message(content: Union[str, dict[str, Any]]) -> SystemMessage: for m in messages: if m["role"] == "system": # System message must always be a string or list of dictionaries - content = cast( - Union[str, list[dict[str, Any]]], m["content"] - ) # noqa: UP007 + content = cast(Union[str, list[dict[str, Any]]], m["content"]) # noqa: UP007 if isinstance(content, list): result.extend(convert_message(item) for item in content) else: