diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index 312a8a2fca..87a49a468f 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -356,6 +356,12 @@ async def _make_request( message_history = await _process_message_history( ctx.state.message_history, ctx.deps.history_processors, build_run_context(ctx) ) + + if ctx.deps.usage_limits and ctx.deps.usage_limits.pre_request_token_check_with_overhead: + token_count = await ctx.deps.model.count_tokens(message_history) + + ctx.deps.usage_limits.check_tokens(_usage.Usage(request_tokens=token_count.total_tokens)) + model_response = await ctx.deps.model.request(message_history, model_settings, model_request_parameters) ctx.state.usage.incr(_usage.Usage()) diff --git a/pydantic_ai_slim/pydantic_ai/messages.py b/pydantic_ai_slim/pydantic_ai/messages.py index 379d70efd7..afa0ec7f17 100644 --- a/pydantic_ai_slim/pydantic_ai/messages.py +++ b/pydantic_ai_slim/pydantic_ai/messages.py @@ -1143,3 +1143,30 @@ def tool_call_id(self) -> str: HandleResponseEvent = Annotated[ Union[FunctionToolCallEvent, FunctionToolResultEvent], pydantic.Discriminator('event_kind') ] + + +@dataclass(repr=False) +class BaseCountTokensResponse: + """Structured response for token count API calls from various model providers.""" + + total_tokens: int | None = field( + default=None, metadata={'description': 'Total number of tokens counted in the messages.'} + ) + """Total number of tokens counted in the messages.""" + + model_name: str | None = field(default=None) + """Name of the model that provided the token count.""" + + vendor_details: dict[str, Any] | None = field(default=None) + """Vendor-specific token count details (e.g., cached_content_token_count for Gemini).""" + + vendor_id: str | None = field(default=None) + """Vendor request ID for tracking the token count request.""" + + timestamp: datetime = field(default_factory=_now_utc) + """Timestamp of the token count response.""" + + error: str | None = field(default=None) + """Error message if the token count request failed.""" + + __repr__ = _utils.dataclasses_no_defaults_repr diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index 6cdcbfbd64..fa8e5b7266 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -24,7 +24,15 @@ from .._output import OutputObjectDefinition from .._parts_manager import ModelResponsePartsManager from ..exceptions import UserError -from ..messages import FileUrl, ModelMessage, ModelRequest, ModelResponse, ModelResponseStreamEvent, VideoUrl +from ..messages import ( + BaseCountTokensResponse, + FileUrl, + ModelMessage, + ModelRequest, + ModelResponse, + ModelResponseStreamEvent, + VideoUrl, +) from ..output import OutputMode from ..profiles._json_schema import JsonSchemaTransformer from ..settings import ModelSettings @@ -382,6 +390,14 @@ async def request( """Make a request to the model.""" raise NotImplementedError() + @abstractmethod + async def count_tokens( + self, + messages: list[ModelMessage], + ) -> BaseCountTokensResponse: + """Make a request to the model.""" + raise NotImplementedError() + @asynccontextmanager async def request_stream( self, diff --git a/pydantic_ai_slim/pydantic_ai/models/anthropic.py b/pydantic_ai_slim/pydantic_ai/models/anthropic.py index 02f9111c2d..361ee01f47 100644 --- a/pydantic_ai_slim/pydantic_ai/models/anthropic.py +++ b/pydantic_ai_slim/pydantic_ai/models/anthropic.py @@ -13,6 +13,7 @@ from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage from .._utils import guard_tool_call_id as _guard_tool_call_id from ..messages import ( + BaseCountTokensResponse, BinaryContent, DocumentUrl, ImageUrl, @@ -165,6 +166,13 @@ async def request( model_response.usage.requests = 1 return model_response + async def count_tokens( + self, + messages: list[ModelMessage], + ) -> BaseCountTokensResponse: + """Token counting is not supported by the AnthropicModel.""" + raise NotImplementedError('Token counting is not supported by AnthropicModel') + @asynccontextmanager async def request_stream( self, diff --git a/pydantic_ai_slim/pydantic_ai/models/bedrock.py b/pydantic_ai_slim/pydantic_ai/models/bedrock.py index b63ed4e1f9..746e8d66d4 100644 --- a/pydantic_ai_slim/pydantic_ai/models/bedrock.py +++ b/pydantic_ai_slim/pydantic_ai/models/bedrock.py @@ -17,6 +17,7 @@ from pydantic_ai import _utils, usage from pydantic_ai.messages import ( AudioUrl, + BaseCountTokensResponse, BinaryContent, DocumentUrl, ImageUrl, @@ -258,6 +259,13 @@ async def request( model_response.usage.requests = 1 return model_response + async def count_tokens( + self, + messages: list[ModelMessage], + ) -> BaseCountTokensResponse: + """Token counting is not supported by the BedrockConverseModel.""" + raise NotImplementedError('Token counting is not supported by BedrockConverseModel') + @asynccontextmanager async def request_stream( self, diff --git a/pydantic_ai_slim/pydantic_ai/models/cohere.py b/pydantic_ai_slim/pydantic_ai/models/cohere.py index 4243ef492a..087d47941a 100644 --- a/pydantic_ai_slim/pydantic_ai/models/cohere.py +++ b/pydantic_ai_slim/pydantic_ai/models/cohere.py @@ -11,6 +11,7 @@ from .. import ModelHTTPError, usage from .._utils import generate_tool_call_id as _generate_tool_call_id, guard_tool_call_id as _guard_tool_call_id from ..messages import ( + BaseCountTokensResponse, ModelMessage, ModelRequest, ModelResponse, @@ -149,6 +150,13 @@ async def request( model_response.usage.requests = 1 return model_response + async def count_tokens( + self, + messages: list[ModelMessage], + ) -> BaseCountTokensResponse: + """Token counting is not supported by the CohereModel.""" + raise NotImplementedError('Token counting is not supported by CohereModel') + @property def model_name(self) -> CohereModelName: """The model name.""" diff --git a/pydantic_ai_slim/pydantic_ai/models/fallback.py b/pydantic_ai_slim/pydantic_ai/models/fallback.py index 4455defce3..80957d4520 100644 --- a/pydantic_ai_slim/pydantic_ai/models/fallback.py +++ b/pydantic_ai_slim/pydantic_ai/models/fallback.py @@ -13,7 +13,7 @@ from . import KnownModelName, Model, ModelRequestParameters, StreamedResponse, infer_model if TYPE_CHECKING: - from ..messages import ModelMessage, ModelResponse + from ..messages import BaseCountTokensResponse, ModelMessage, ModelResponse from ..settings import ModelSettings @@ -77,6 +77,13 @@ async def request( raise FallbackExceptionGroup('All models from FallbackModel failed', exceptions) + async def count_tokens( + self, + messages: list[ModelMessage], + ) -> BaseCountTokensResponse: + """Token counting is not supported by the FallbackModel.""" + raise NotImplementedError('Token counting is not supported by FallbackModel') + @asynccontextmanager async def request_stream( self, diff --git a/pydantic_ai_slim/pydantic_ai/models/function.py b/pydantic_ai_slim/pydantic_ai/models/function.py index c48873f046..f7e5831730 100644 --- a/pydantic_ai_slim/pydantic_ai/models/function.py +++ b/pydantic_ai_slim/pydantic_ai/models/function.py @@ -16,6 +16,8 @@ from .. import _utils, usage from .._utils import PeekableAsyncStream from ..messages import ( + AudioUrl, + BaseCountTokensResponse, BinaryContent, ModelMessage, ModelRequest, @@ -139,6 +141,13 @@ async def request( response.usage.requests = 1 return response + async def count_tokens( + self, + messages: list[ModelMessage], + ) -> BaseCountTokensResponse: + """Token counting is not supported by the FunctionModel.""" + raise NotImplementedError('Token counting is not supported by FunctionModel') + @asynccontextmanager async def request_stream( self, diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index 4ac07f8ada..ee4694c911 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -19,6 +19,7 @@ from .._output import OutputObjectDefinition from ..exceptions import UserError from ..messages import ( + BaseCountTokensResponse, BinaryContent, FileUrl, ModelMessage, @@ -158,6 +159,13 @@ async def request( response = _gemini_response_ta.validate_json(data) return self._process_response(response) + async def count_tokens( + self, + messages: list[ModelMessage], + ) -> BaseCountTokensResponse: + """Token counting is not supported by the GeminiModel.""" + raise NotImplementedError('Token counting is not supported by GeminiModel') + @asynccontextmanager async def request_stream( self, diff --git a/pydantic_ai_slim/pydantic_ai/models/google.py b/pydantic_ai_slim/pydantic_ai/models/google.py index 082f5ba566..517d9cad45 100644 --- a/pydantic_ai_slim/pydantic_ai/models/google.py +++ b/pydantic_ai_slim/pydantic_ai/models/google.py @@ -14,6 +14,7 @@ from .._output import OutputObjectDefinition from ..exceptions import UserError from ..messages import ( + BaseCountTokensResponse, BinaryContent, FileUrl, ModelMessage, @@ -48,6 +49,7 @@ from google.genai.types import ( ContentDict, ContentUnionDict, + CountTokensResponse, FunctionCallDict, FunctionCallingConfigDict, FunctionCallingConfigMode, @@ -181,6 +183,18 @@ async def request( response = await self._generate_content(messages, False, model_settings, model_request_parameters) return self._process_response(response) + async def count_tokens( + self, + messages: list[ModelMessage], + ) -> BaseCountTokensResponse: + check_allow_model_requests() + _, contents = await self._map_messages(messages) + response = self.client.models.count_tokens( + model=self._model_name, + contents=contents, + ) + return self._process_count_tokens_response(response) + @asynccontextmanager async def request_stream( self, @@ -338,6 +352,26 @@ async def _process_streamed_response(self, response: AsyncIterator[GenerateConte _timestamp=first_chunk.create_time or _utils.now_utc(), ) + def _process_count_tokens_response( + self, + response: CountTokensResponse, + ) -> BaseCountTokensResponse: + """Process Gemini token count response into BaseCountTokensResponse.""" + if not hasattr(response, 'total_tokens') or response.total_tokens is None: + raise UnexpectedModelBehavior('Total tokens missing from Gemini response', str(response)) + + vendor_details: dict[str, Any] | None = None + if hasattr(response, 'cached_content_token_count'): + vendor_details = {} + vendor_details['cached_content_token_count'] = response.cached_content_token_count + + return BaseCountTokensResponse( + total_tokens=response.total_tokens, + model_name=self._model_name, + vendor_details=vendor_details if vendor_details else None, + vendor_id=getattr(response, 'request_id', None), + ) + async def _map_messages(self, messages: list[ModelMessage]) -> tuple[ContentDict | None, list[ContentUnionDict]]: contents: list[ContentUnionDict] = [] system_parts: list[PartDict] = [] diff --git a/pydantic_ai_slim/pydantic_ai/models/groq.py b/pydantic_ai_slim/pydantic_ai/models/groq.py index ffca84b447..f6f28e554d 100644 --- a/pydantic_ai_slim/pydantic_ai/models/groq.py +++ b/pydantic_ai_slim/pydantic_ai/models/groq.py @@ -14,6 +14,7 @@ from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage from .._utils import guard_tool_call_id as _guard_tool_call_id, number_to_datetime from ..messages import ( + BaseCountTokensResponse, BinaryContent, DocumentUrl, ImageUrl, @@ -160,6 +161,13 @@ async def request( model_response.usage.requests = 1 return model_response + async def count_tokens( + self, + messages: list[ModelMessage], + ) -> BaseCountTokensResponse: + """Token counting is not supported by the GroqModel.""" + raise NotImplementedError('Token counting is not supported by GroqModel') + @asynccontextmanager async def request_stream( self, diff --git a/pydantic_ai_slim/pydantic_ai/models/huggingface.py b/pydantic_ai_slim/pydantic_ai/models/huggingface.py index 4b3c2ff404..4ad9da66cf 100644 --- a/pydantic_ai_slim/pydantic_ai/models/huggingface.py +++ b/pydantic_ai_slim/pydantic_ai/models/huggingface.py @@ -16,6 +16,7 @@ from .._utils import guard_tool_call_id as _guard_tool_call_id, now_utc as _now_utc from ..messages import ( AudioUrl, + BaseCountTokensResponse, BinaryContent, DocumentUrl, ImageUrl, @@ -140,6 +141,13 @@ async def request( model_response.usage.requests = 1 return model_response + async def count_tokens( + self, + messages: list[ModelMessage], + ) -> BaseCountTokensResponse: + """Token counting is not supported by the HuggingFaceModel.""" + raise NotImplementedError('Token counting is not supported by HuggingFaceModel') + @asynccontextmanager async def request_stream( self, diff --git a/pydantic_ai_slim/pydantic_ai/models/mcp_sampling.py b/pydantic_ai_slim/pydantic_ai/models/mcp_sampling.py index ebfaac92d0..bead387c07 100644 --- a/pydantic_ai_slim/pydantic_ai/models/mcp_sampling.py +++ b/pydantic_ai_slim/pydantic_ai/models/mcp_sampling.py @@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, cast from .. import _mcp, exceptions, usage -from ..messages import ModelMessage, ModelResponse +from ..messages import BaseCountTokensResponse, ModelMessage, ModelResponse from ..settings import ModelSettings from . import Model, ModelRequestParameters, StreamedResponse @@ -70,6 +70,13 @@ async def request( f'Unexpected result from MCP sampling, expected "assistant" role, got {result.role}.' ) + async def count_tokens( + self, + messages: list[ModelMessage], + ) -> BaseCountTokensResponse: + """Token counting is not supported by the MCPSamplingModel.""" + raise NotImplementedError('Token counting is not supported by MCPSamplingModel') + @asynccontextmanager async def request_stream( self, diff --git a/pydantic_ai_slim/pydantic_ai/models/mistral.py b/pydantic_ai_slim/pydantic_ai/models/mistral.py index ca73558bca..800333efe8 100644 --- a/pydantic_ai_slim/pydantic_ai/models/mistral.py +++ b/pydantic_ai_slim/pydantic_ai/models/mistral.py @@ -16,6 +16,7 @@ from .. import ModelHTTPError, UnexpectedModelBehavior, _utils from .._utils import generate_tool_call_id as _generate_tool_call_id, now_utc as _now_utc, number_to_datetime from ..messages import ( + BaseCountTokensResponse, BinaryContent, DocumentUrl, ImageUrl, @@ -167,6 +168,13 @@ async def request( model_response.usage.requests = 1 return model_response + async def count_tokens( + self, + messages: list[ModelMessage], + ) -> BaseCountTokensResponse: + """Token counting is not supported by the MistralModel.""" + raise NotImplementedError('Token counting is not supported by MistralModel') + @asynccontextmanager async def request_stream( self, diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index 35dca2e03d..7d46579888 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -20,6 +20,7 @@ from .._utils import guard_tool_call_id as _guard_tool_call_id, now_utc as _now_utc, number_to_datetime from ..messages import ( AudioUrl, + BaseCountTokensResponse, BinaryContent, DocumentUrl, ImageUrl, @@ -248,6 +249,13 @@ async def request( model_response.usage.requests = 1 return model_response + async def count_tokens( + self, + messages: list[ModelMessage], + ) -> BaseCountTokensResponse: + """Token counting is not supported by the OpenAIModel.""" + raise NotImplementedError('Token counting is not supported by OpenAIModel') + @asynccontextmanager async def request_stream( self, @@ -672,6 +680,13 @@ async def request( ) return self._process_response(response) + async def count_tokens( + self, + messages: list[ModelMessage], + ) -> BaseCountTokensResponse: + """Token counting is not supported by the OpenAIResponsesModel.""" + raise NotImplementedError('Token counting is not supported by the OpenAIResponsesModel') + @asynccontextmanager async def request_stream( self, diff --git a/pydantic_ai_slim/pydantic_ai/models/test.py b/pydantic_ai_slim/pydantic_ai/models/test.py index eebe00d440..05061ba38f 100644 --- a/pydantic_ai_slim/pydantic_ai/models/test.py +++ b/pydantic_ai_slim/pydantic_ai/models/test.py @@ -13,6 +13,7 @@ from .. import _utils from ..messages import ( + BaseCountTokensResponse, ModelMessage, ModelRequest, ModelResponse, @@ -112,6 +113,12 @@ async def request( model_response.usage.requests = 1 return model_response + async def count_tokens( + self, + messages: list[ModelMessage], + ) -> BaseCountTokensResponse: + return BaseCountTokensResponse(total_tokens=1) + @asynccontextmanager async def request_stream( self, diff --git a/pydantic_ai_slim/pydantic_ai/models/wrapper.py b/pydantic_ai_slim/pydantic_ai/models/wrapper.py index cc91f9c725..d05b390dc2 100644 --- a/pydantic_ai_slim/pydantic_ai/models/wrapper.py +++ b/pydantic_ai_slim/pydantic_ai/models/wrapper.py @@ -6,7 +6,7 @@ from functools import cached_property from typing import Any -from ..messages import ModelMessage, ModelResponse +from ..messages import BaseCountTokensResponse, ModelMessage, ModelResponse from ..profiles import ModelProfile from ..settings import ModelSettings from . import KnownModelName, Model, ModelRequestParameters, StreamedResponse, infer_model @@ -29,6 +29,9 @@ def __init__(self, wrapped: Model | KnownModelName): async def request(self, *args: Any, **kwargs: Any) -> ModelResponse: return await self.wrapped.request(*args, **kwargs) + async def count_tokens(self, *args: Any, **kwargs: Any) -> BaseCountTokensResponse: + return await self.wrapped.count_tokens(*args, **kwargs) + @asynccontextmanager async def request_stream( self, diff --git a/pydantic_ai_slim/pydantic_ai/usage.py b/pydantic_ai_slim/pydantic_ai/usage.py index c3f4c1885b..39e7a15d2c 100644 --- a/pydantic_ai_slim/pydantic_ai/usage.py +++ b/pydantic_ai_slim/pydantic_ai/usage.py @@ -28,6 +28,8 @@ class Usage: """Total tokens used in the whole run, should generally be equal to `request_tokens + response_tokens`.""" details: dict[str, int] | None = None """Any extra details returned by the model.""" + eager_request_tokens_check: bool = False + """Any extra details returned by the model.""" def incr(self, incr_usage: Usage) -> None: """Increment the usage in place. @@ -96,6 +98,10 @@ class UsageLimits: """The maximum number of tokens allowed in responses from the model.""" total_tokens_limit: int | None = None """The maximum number of tokens allowed in requests and responses combined.""" + pre_request_token_check_with_overhead: bool = False + """If True, perform a token counting pass before sending the request to the model, + to enforce `request_tokens_limit` ahead of time. This may incur additional overhead + (from calling the model's `count_tokens` method) and is disabled by default.""" def has_token_limits(self) -> bool: """Returns `True` if this instance places any limits on token counts. diff --git a/tests/evals/test_evaluators.py b/tests/evals/test_evaluators.py index 235296c4ad..cfe373b97f 100644 --- a/tests/evals/test_evaluators.py +++ b/tests/evals/test_evaluators.py @@ -8,7 +8,7 @@ from pydantic import BaseModel, TypeAdapter from pydantic_core import to_jsonable_python -from pydantic_ai.messages import ModelMessage, ModelResponse +from pydantic_ai.messages import BaseCountTokensResponse, ModelMessage, ModelResponse from pydantic_ai.models import Model, ModelRequestParameters from pydantic_ai.settings import ModelSettings @@ -125,6 +125,12 @@ async def request( ) -> ModelResponse: raise NotImplementedError + async def count_tokens( + self, + messages: list[ModelMessage], + ) -> BaseCountTokensResponse: + raise NotImplementedError + @property def model_name(self) -> str: return 'my-model' diff --git a/tests/models/test_google.py b/tests/models/test_google.py index 7e1f372bcc..80016c5307 100644 --- a/tests/models/test_google.py +++ b/tests/models/test_google.py @@ -1393,3 +1393,30 @@ class CountryLanguage(BaseModel): ), ] ) + + +async def test_google_model_count_tokens(allow_model_requests: None, google_provider: GoogleProvider): + model = GoogleModel('gemini-1.5-flash', provider=google_provider) + + messages = [ + ModelRequest( + parts=[ + SystemPromptPart(content='You are a helpful chatbot.', timestamp=IsDatetime()), + UserPromptPart(content='What was the temperature in London 1st January 2022?', timestamp=IsDatetime()), + ] + ), + ModelResponse( + parts=[ + ToolCallPart( + tool_name='temperature', + args={'date': '2022-01-01', 'city': 'London'}, + tool_call_id='test_id', + ) + ], + model_name='gemini-1.5-flash', + timestamp=IsDatetime(), + vendor_details={'finish_reason': 'STOP'}, + ), + ] + result = await model.count_tokens(messages) + assert result.total_tokens == snapshot(7)