diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index e39d254e8e..726a534a98 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -463,13 +463,12 @@ async def _get_gemini_responses(self) -> AsyncIterator[_GeminiResponse]: responses_to_yield = gemini_responses[:-1] for r in responses_to_yield[current_gemini_response_index:]: current_gemini_response_index += 1 - self._usage += _metadata_as_usage(r) yield r # Now yield the final response, which should be complete if gemini_responses: # pragma: no branch r = gemini_responses[-1] - self._usage += _metadata_as_usage(r) + self._usage = _metadata_as_usage(r) yield r @property @@ -770,8 +769,17 @@ class _GeminiCandidates(TypedDict): safety_ratings: NotRequired[Annotated[list[_GeminiSafetyRating], pydantic.Field(alias='safetyRatings')]] +class _GeminiModalityTokenCount(TypedDict): + """See .""" + + modality: Annotated[ + Literal['MODALITY_UNSPECIFIED', 'TEXT', 'IMAGE', 'VIDEO', 'AUDIO', 'DOCUMENT'], pydantic.Field(alias='modality') + ] + token_count: Annotated[int, pydantic.Field(alias='tokenCount', default=0)] + + class _GeminiUsageMetaData(TypedDict, total=False): - """See . + """See . The docs suggest all fields are required, but some are actually not required, so we assume they are all optional. """ @@ -780,6 +788,20 @@ class _GeminiUsageMetaData(TypedDict, total=False): candidates_token_count: NotRequired[Annotated[int, pydantic.Field(alias='candidatesTokenCount')]] total_token_count: Annotated[int, pydantic.Field(alias='totalTokenCount')] cached_content_token_count: NotRequired[Annotated[int, pydantic.Field(alias='cachedContentTokenCount')]] + thoughts_token_count: NotRequired[Annotated[int, pydantic.Field(alias='thoughtsTokenCount')]] + tool_use_prompt_token_count: NotRequired[Annotated[int, pydantic.Field(alias='toolUsePromptTokenCount')]] + prompt_tokens_details: NotRequired[ + Annotated[list[_GeminiModalityTokenCount], pydantic.Field(alias='promptTokensDetails')] + ] + cache_tokens_details: NotRequired[ + Annotated[list[_GeminiModalityTokenCount], pydantic.Field(alias='cacheTokensDetails')] + ] + candidates_tokens_details: NotRequired[ + Annotated[list[_GeminiModalityTokenCount], pydantic.Field(alias='candidatesTokensDetails')] + ] + tool_use_prompt_tokens_details: NotRequired[ + Annotated[list[_GeminiModalityTokenCount], pydantic.Field(alias='toolUsePromptTokensDetails')] + ] def _metadata_as_usage(response: _GeminiResponse) -> usage.Usage: @@ -788,7 +810,21 @@ def _metadata_as_usage(response: _GeminiResponse) -> usage.Usage: return usage.Usage() # pragma: no cover details: dict[str, int] = {} if cached_content_token_count := metadata.get('cached_content_token_count'): - details['cached_content_token_count'] = cached_content_token_count # pragma: no cover + details['cached_content_tokens'] = cached_content_token_count # pragma: no cover + + if thoughts_token_count := metadata.get('thoughts_token_count'): + details['thoughts_tokens'] = thoughts_token_count + + if tool_use_prompt_token_count := metadata.get('tool_use_prompt_token_count'): + details['tool_use_prompt_tokens'] = tool_use_prompt_token_count # pragma: no cover + + for key, metadata_details in metadata.items(): + if key.endswith('_details') and metadata_details: + metadata_details = cast(list[_GeminiModalityTokenCount], metadata_details) + suffix = key.removesuffix('_details') + for detail in metadata_details: + details[f'{detail["modality"].lower()}_{suffix}'] = detail['token_count'] + return usage.Usage( request_tokens=metadata.get('prompt_token_count', 0), response_tokens=metadata.get('candidates_token_count', 0), diff --git a/pydantic_ai_slim/pydantic_ai/models/google.py b/pydantic_ai_slim/pydantic_ai/models/google.py index 73243f020e..9fcf6e185a 100644 --- a/pydantic_ai_slim/pydantic_ai/models/google.py +++ b/pydantic_ai_slim/pydantic_ai/models/google.py @@ -399,7 +399,7 @@ class GeminiStreamedResponse(StreamedResponse): async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: async for chunk in self._response: - self._usage += _metadata_as_usage(chunk) + self._usage = _metadata_as_usage(chunk) assert chunk.candidates is not None candidate = chunk.candidates[0] @@ -490,17 +490,28 @@ def _metadata_as_usage(response: GenerateContentResponse) -> usage.Usage: metadata = response.usage_metadata if metadata is None: return usage.Usage() # pragma: no cover - # TODO(Marcelo): We exclude the `prompt_tokens_details` and `candidate_token_details` fields because on - # `usage.Usage.incr``, it will try to sum non-integer values with integers, which will fail. We should probably - # handle this in the `Usage` class. - details = metadata.model_dump( - exclude={'prompt_tokens_details', 'candidates_tokens_details', 'traffic_type'}, - exclude_defaults=True, - ) + metadata = metadata.model_dump(exclude_defaults=True) + + details: dict[str, int] = {} + if cached_content_token_count := metadata.get('cached_content_token_count'): + details['cached_content_tokens'] = cached_content_token_count # pragma: no cover + + if thoughts_token_count := metadata.get('thoughts_token_count'): + details['thoughts_tokens'] = thoughts_token_count + + if tool_use_prompt_token_count := metadata.get('tool_use_prompt_token_count'): + details['tool_use_prompt_tokens'] = tool_use_prompt_token_count # pragma: no cover + + for key, metadata_details in metadata.items(): + if key.endswith('_details') and metadata_details: + suffix = key.removesuffix('_details') + for detail in metadata_details: + details[f'{detail["modality"].lower()}_{suffix}'] = detail['token_count'] + return usage.Usage( - request_tokens=details.pop('prompt_token_count', 0), - response_tokens=details.pop('candidates_token_count', 0), - total_tokens=details.pop('total_token_count', 0), + request_tokens=metadata.get('prompt_token_count', 0), + response_tokens=metadata.get('candidates_token_count', 0), + total_tokens=metadata.get('total_token_count', 0), details=details, ) diff --git a/tests/models/test_gemini.py b/tests/models/test_gemini.py index 17b63c316f..13d831e4bb 100644 --- a/tests/models/test_gemini.py +++ b/tests/models/test_gemini.py @@ -739,12 +739,12 @@ async def test_stream_text(get_gemini_client: GetGeminiClient): 'Hello world', ] ) - assert result.usage() == snapshot(Usage(requests=1, request_tokens=2, response_tokens=4, total_tokens=6)) + assert result.usage() == snapshot(Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3)) async with agent.run_stream('Hello') as result: chunks = [chunk async for chunk in result.stream_text(delta=True, debounce_by=None)] assert chunks == snapshot(['Hello ', 'world']) - assert result.usage() == snapshot(Usage(requests=1, request_tokens=2, response_tokens=4, total_tokens=6)) + assert result.usage() == snapshot(Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3)) async def test_stream_invalid_unicode_text(get_gemini_client: GetGeminiClient): @@ -776,7 +776,7 @@ async def test_stream_invalid_unicode_text(get_gemini_client: GetGeminiClient): async with agent.run_stream('Hello') as result: chunks = [chunk async for chunk in result.stream(debounce_by=None)] assert chunks == snapshot(['abc', 'abc€def', 'abc€def']) - assert result.usage() == snapshot(Usage(requests=1, request_tokens=2, response_tokens=4, total_tokens=6)) + assert result.usage() == snapshot(Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3)) async def test_stream_text_no_data(get_gemini_client: GetGeminiClient): @@ -847,7 +847,7 @@ async def bar(y: str) -> str: async with agent.run_stream('Hello') as result: response = await result.get_output() assert response == snapshot((1, 2)) - assert result.usage() == snapshot(Usage(requests=2, request_tokens=3, response_tokens=6, total_tokens=9)) + assert result.usage() == snapshot(Usage(requests=2, request_tokens=2, response_tokens=4, total_tokens=6)) assert result.all_messages() == snapshot( [ ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), @@ -856,7 +856,7 @@ async def bar(y: str) -> str: ToolCallPart(tool_name='foo', args={'x': 'a'}, tool_call_id=IsStr()), ToolCallPart(tool_name='bar', args={'y': 'b'}, tool_call_id=IsStr()), ], - usage=Usage(request_tokens=2, response_tokens=4, total_tokens=6), + usage=Usage(request_tokens=1, response_tokens=2, total_tokens=3, details={}), model_name='gemini-1.5-flash', timestamp=IsNow(tz=timezone.utc), ), @@ -872,7 +872,7 @@ async def bar(y: str) -> str: ), ModelResponse( parts=[ToolCallPart(tool_name='final_result', args={'response': [1, 2]}, tool_call_id=IsStr())], - usage=Usage(request_tokens=1, response_tokens=2, total_tokens=3), + usage=Usage(request_tokens=1, response_tokens=2, total_tokens=3, details={}), model_name='gemini-1.5-flash', timestamp=IsNow(tz=timezone.utc), ), @@ -1103,7 +1103,13 @@ async def get_image() -> BinaryContent: ), ToolCallPart(tool_name='get_image', args={}, tool_call_id=IsStr()), ], - usage=Usage(requests=1, request_tokens=38, response_tokens=28, total_tokens=427, details={}), + usage=Usage( + requests=1, + request_tokens=38, + response_tokens=28, + total_tokens=427, + details={'thoughts_tokens': 361, 'text_prompt_tokens': 38}, + ), model_name='gemini-2.5-pro-preview-03-25', timestamp=IsDatetime(), vendor_details={'finish_reason': 'STOP'}, @@ -1127,7 +1133,13 @@ async def get_image() -> BinaryContent: ), ModelResponse( parts=[TextPart(content='The image shows a kiwi fruit, sliced in half.')], - usage=Usage(requests=1, request_tokens=360, response_tokens=11, total_tokens=572, details={}), + usage=Usage( + requests=1, + request_tokens=360, + response_tokens=11, + total_tokens=572, + details={'thoughts_tokens': 201, 'text_prompt_tokens': 102, 'image_prompt_tokens': 258}, + ), model_name='gemini-2.5-pro-preview-03-25', timestamp=IsDatetime(), vendor_details={'finish_reason': 'STOP'}, @@ -1250,7 +1262,13 @@ async def test_gemini_model_instructions(allow_model_requests: None, gemini_api_ ), ModelResponse( parts=[TextPart(content='The capital of France is Paris.\n')], - usage=Usage(requests=1, request_tokens=13, response_tokens=8, total_tokens=21, details={}), + usage=Usage( + requests=1, + request_tokens=13, + response_tokens=8, + total_tokens=21, + details={'text_prompt_tokens': 13, 'text_candidates_tokens': 8}, + ), model_name='gemini-1.5-flash', timestamp=IsDatetime(), vendor_details={'finish_reason': 'STOP'}, diff --git a/tests/models/test_google.py b/tests/models/test_google.py index 5231519d89..ca8a82a731 100644 --- a/tests/models/test_google.py +++ b/tests/models/test_google.py @@ -65,7 +65,15 @@ async def test_google_model(allow_model_requests: None, google_provider: GoogleP result = await agent.run('Hello!') assert result.output == snapshot('Hello there! How can I help you today?\n') - assert result.usage() == snapshot(Usage(requests=1, request_tokens=7, response_tokens=11, total_tokens=18)) + assert result.usage() == snapshot( + Usage( + requests=1, + request_tokens=7, + response_tokens=11, + total_tokens=18, + details={'text_prompt_tokens': 7, 'text_candidates_tokens': 11}, + ) + ) assert result.all_messages() == snapshot( [ ModelRequest( @@ -82,7 +90,13 @@ async def test_google_model(allow_model_requests: None, google_provider: GoogleP ), ModelResponse( parts=[TextPart(content='Hello there! How can I help you today?\n')], - usage=Usage(requests=1, request_tokens=7, response_tokens=11, total_tokens=18, details={}), + usage=Usage( + requests=1, + request_tokens=7, + response_tokens=11, + total_tokens=18, + details={'text_prompt_tokens': 7, 'text_candidates_tokens': 11}, + ), model_name='gemini-1.5-flash', timestamp=IsDatetime(), vendor_details={'finish_reason': 'STOP'}, @@ -115,7 +129,15 @@ async def temperature(city: str, date: datetime.date) -> str: result = await agent.run('What was the temperature in London 1st January 2022?', output_type=Response) assert result.output == snapshot({'temperature': '30°C', 'date': datetime.date(2022, 1, 1), 'city': 'London'}) - assert result.usage() == snapshot(Usage(requests=2, request_tokens=224, response_tokens=35, total_tokens=259)) + assert result.usage() == snapshot( + Usage( + requests=2, + request_tokens=224, + response_tokens=35, + total_tokens=259, + details={'text_prompt_tokens': 224, 'text_candidates_tokens': 35}, + ) + ) assert result.all_messages() == snapshot( [ ModelRequest( @@ -136,7 +158,13 @@ async def temperature(city: str, date: datetime.date) -> str: tool_name='temperature', args={'date': '2022-01-01', 'city': 'London'}, tool_call_id=IsStr() ) ], - usage=Usage(requests=1, request_tokens=101, response_tokens=14, total_tokens=115, details={}), + usage=Usage( + requests=1, + request_tokens=101, + response_tokens=14, + total_tokens=115, + details={'text_prompt_tokens': 101, 'text_candidates_tokens': 14}, + ), model_name='gemini-1.5-flash', timestamp=IsDatetime(), vendor_details={'finish_reason': 'STOP'}, @@ -156,7 +184,13 @@ async def temperature(city: str, date: datetime.date) -> str: tool_call_id=IsStr(), ) ], - usage=Usage(requests=1, request_tokens=123, response_tokens=21, total_tokens=144, details={}), + usage=Usage( + requests=1, + request_tokens=123, + response_tokens=21, + total_tokens=144, + details={'text_prompt_tokens': 123, 'text_candidates_tokens': 21}, + ), model_name='gemini-1.5-flash', timestamp=IsDatetime(), vendor_details={'finish_reason': 'STOP'}, @@ -214,7 +248,7 @@ async def get_capital(country: str) -> str: request_tokens=57, response_tokens=15, total_tokens=173, - details={'thoughts_token_count': 101}, + details={'thoughts_tokens': 101, 'text_prompt_tokens': 57}, ), model_name='models/gemini-2.5-pro-preview-05-06', timestamp=IsDatetime(), @@ -236,7 +270,13 @@ async def get_capital(country: str) -> str: content='I am sorry, I cannot fulfill this request. The country you provided is not supported.' ) ], - usage=Usage(requests=1, request_tokens=104, response_tokens=18, total_tokens=122, details={}), + usage=Usage( + requests=1, + request_tokens=104, + response_tokens=18, + total_tokens=122, + details={'text_prompt_tokens': 104}, + ), model_name='models/gemini-2.5-pro-preview-05-06', timestamp=IsDatetime(), vendor_details={'finish_reason': 'STOP'}, @@ -493,7 +533,13 @@ def instructions() -> str: ), ModelResponse( parts=[TextPart(content='The capital of France is Paris.\n')], - usage=Usage(requests=1, request_tokens=13, response_tokens=8, total_tokens=21, details={}), + usage=Usage( + requests=1, + request_tokens=13, + response_tokens=8, + total_tokens=21, + details={'text_prompt_tokens': 13, 'text_candidates_tokens': 8}, + ), model_name='gemini-2.0-flash', timestamp=IsDatetime(), vendor_details={'finish_reason': 'STOP'},