Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions pydantic_ai_slim/pydantic_ai/_agent_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,12 @@ async def _make_request(
message_history = await _process_message_history(
ctx.state.message_history, ctx.deps.history_processors, build_run_context(ctx)
)

if ctx.deps.usage_limits and ctx.deps.usage_limits.pre_request_token_check_with_overhead:
token_count = await ctx.deps.model.count_tokens(message_history)

ctx.deps.usage_limits.check_tokens(_usage.Usage(request_tokens=token_count.total_tokens))

model_response = await ctx.deps.model.request(message_history, model_settings, model_request_parameters)
ctx.state.usage.incr(_usage.Usage())

Expand Down
27 changes: 27 additions & 0 deletions pydantic_ai_slim/pydantic_ai/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -1143,3 +1143,30 @@ def tool_call_id(self) -> str:
HandleResponseEvent = Annotated[
Union[FunctionToolCallEvent, FunctionToolResultEvent], pydantic.Discriminator('event_kind')
]


@dataclass(repr=False)
class BaseCountTokensResponse:
"""Structured response for token count API calls from various model providers."""

total_tokens: int | None = field(
default=None, metadata={'description': 'Total number of tokens counted in the messages.'}
)
"""Total number of tokens counted in the messages."""

model_name: str | None = field(default=None)
"""Name of the model that provided the token count."""

vendor_details: dict[str, Any] | None = field(default=None)
"""Vendor-specific token count details (e.g., cached_content_token_count for Gemini)."""

vendor_id: str | None = field(default=None)
"""Vendor request ID for tracking the token count request."""

timestamp: datetime = field(default_factory=_now_utc)
"""Timestamp of the token count response."""

error: str | None = field(default=None)
"""Error message if the token count request failed."""

__repr__ = _utils.dataclasses_no_defaults_repr
18 changes: 17 additions & 1 deletion pydantic_ai_slim/pydantic_ai/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,15 @@
from .._output import OutputObjectDefinition
from .._parts_manager import ModelResponsePartsManager
from ..exceptions import UserError
from ..messages import FileUrl, ModelMessage, ModelRequest, ModelResponse, ModelResponseStreamEvent, VideoUrl
from ..messages import (
BaseCountTokensResponse,
FileUrl,
ModelMessage,
ModelRequest,
ModelResponse,
ModelResponseStreamEvent,
VideoUrl,
)
from ..output import OutputMode
from ..profiles._json_schema import JsonSchemaTransformer
from ..settings import ModelSettings
Expand Down Expand Up @@ -382,6 +390,14 @@ async def request(
"""Make a request to the model."""
raise NotImplementedError()

@abstractmethod
async def count_tokens(
self,
messages: list[ModelMessage],
) -> BaseCountTokensResponse:
"""Make a request to the model."""
raise NotImplementedError()

@asynccontextmanager
async def request_stream(
self,
Expand Down
8 changes: 8 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
from .._utils import guard_tool_call_id as _guard_tool_call_id
from ..messages import (
BaseCountTokensResponse,
BinaryContent,
DocumentUrl,
ImageUrl,
Expand Down Expand Up @@ -165,6 +166,13 @@ async def request(
model_response.usage.requests = 1
return model_response

async def count_tokens(
self,
messages: list[ModelMessage],
) -> BaseCountTokensResponse:
"""Token counting is not supported by the AnthropicModel."""
raise NotImplementedError('Token counting is not supported by AnthropicModel')

@asynccontextmanager
async def request_stream(
self,
Expand Down
8 changes: 8 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from pydantic_ai import _utils, usage
from pydantic_ai.messages import (
AudioUrl,
BaseCountTokensResponse,
BinaryContent,
DocumentUrl,
ImageUrl,
Expand Down Expand Up @@ -258,6 +259,13 @@ async def request(
model_response.usage.requests = 1
return model_response

async def count_tokens(
self,
messages: list[ModelMessage],
) -> BaseCountTokensResponse:
"""Token counting is not supported by the BedrockConverseModel."""
raise NotImplementedError('Token counting is not supported by BedrockConverseModel')

@asynccontextmanager
async def request_stream(
self,
Expand Down
8 changes: 8 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .. import ModelHTTPError, usage
from .._utils import generate_tool_call_id as _generate_tool_call_id, guard_tool_call_id as _guard_tool_call_id
from ..messages import (
BaseCountTokensResponse,
ModelMessage,
ModelRequest,
ModelResponse,
Expand Down Expand Up @@ -149,6 +150,13 @@ async def request(
model_response.usage.requests = 1
return model_response

async def count_tokens(
self,
messages: list[ModelMessage],
) -> BaseCountTokensResponse:
"""Token counting is not supported by the CohereModel."""
raise NotImplementedError('Token counting is not supported by CohereModel')

@property
def model_name(self) -> CohereModelName:
"""The model name."""
Expand Down
9 changes: 8 additions & 1 deletion pydantic_ai_slim/pydantic_ai/models/fallback.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from . import KnownModelName, Model, ModelRequestParameters, StreamedResponse, infer_model

if TYPE_CHECKING:
from ..messages import ModelMessage, ModelResponse
from ..messages import BaseCountTokensResponse, ModelMessage, ModelResponse
from ..settings import ModelSettings


Expand Down Expand Up @@ -77,6 +77,13 @@ async def request(

raise FallbackExceptionGroup('All models from FallbackModel failed', exceptions)

async def count_tokens(
self,
messages: list[ModelMessage],
) -> BaseCountTokensResponse:
"""Token counting is not supported by the FallbackModel."""
raise NotImplementedError('Token counting is not supported by FallbackModel')

@asynccontextmanager
async def request_stream(
self,
Expand Down
9 changes: 9 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from .. import _utils, usage
from .._utils import PeekableAsyncStream
from ..messages import (
AudioUrl,
BaseCountTokensResponse,
BinaryContent,
ModelMessage,
ModelRequest,
Expand Down Expand Up @@ -139,6 +141,13 @@ async def request(
response.usage.requests = 1
return response

async def count_tokens(
self,
messages: list[ModelMessage],
) -> BaseCountTokensResponse:
"""Token counting is not supported by the FunctionModel."""
raise NotImplementedError('Token counting is not supported by FunctionModel')

@asynccontextmanager
async def request_stream(
self,
Expand Down
8 changes: 8 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .._output import OutputObjectDefinition
from ..exceptions import UserError
from ..messages import (
BaseCountTokensResponse,
BinaryContent,
FileUrl,
ModelMessage,
Expand Down Expand Up @@ -158,6 +159,13 @@ async def request(
response = _gemini_response_ta.validate_json(data)
return self._process_response(response)

async def count_tokens(
self,
messages: list[ModelMessage],
) -> BaseCountTokensResponse:
"""Token counting is not supported by the GeminiModel."""
raise NotImplementedError('Token counting is not supported by GeminiModel')

@asynccontextmanager
async def request_stream(
self,
Expand Down
34 changes: 34 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/google.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .._output import OutputObjectDefinition
from ..exceptions import UserError
from ..messages import (
BaseCountTokensResponse,
BinaryContent,
FileUrl,
ModelMessage,
Expand Down Expand Up @@ -48,6 +49,7 @@
from google.genai.types import (
ContentDict,
ContentUnionDict,
CountTokensResponse,
FunctionCallDict,
FunctionCallingConfigDict,
FunctionCallingConfigMode,
Expand Down Expand Up @@ -181,6 +183,18 @@ async def request(
response = await self._generate_content(messages, False, model_settings, model_request_parameters)
return self._process_response(response)

async def count_tokens(
self,
messages: list[ModelMessage],
) -> BaseCountTokensResponse:
check_allow_model_requests()
_, contents = await self._map_messages(messages)
response = self.client.models.count_tokens(
model=self._model_name,
contents=contents,
)
return self._process_count_tokens_response(response)

@asynccontextmanager
async def request_stream(
self,
Expand Down Expand Up @@ -338,6 +352,26 @@ async def _process_streamed_response(self, response: AsyncIterator[GenerateConte
_timestamp=first_chunk.create_time or _utils.now_utc(),
)

def _process_count_tokens_response(
self,
response: CountTokensResponse,
) -> BaseCountTokensResponse:
"""Process Gemini token count response into BaseCountTokensResponse."""
if not hasattr(response, 'total_tokens') or response.total_tokens is None:
raise UnexpectedModelBehavior('Total tokens missing from Gemini response', str(response))

vendor_details: dict[str, Any] | None = None
if hasattr(response, 'cached_content_token_count'):
vendor_details = {}
vendor_details['cached_content_token_count'] = response.cached_content_token_count

return BaseCountTokensResponse(
total_tokens=response.total_tokens,
model_name=self._model_name,
vendor_details=vendor_details if vendor_details else None,
vendor_id=getattr(response, 'request_id', None),
)

async def _map_messages(self, messages: list[ModelMessage]) -> tuple[ContentDict | None, list[ContentUnionDict]]:
contents: list[ContentUnionDict] = []
system_parts: list[PartDict] = []
Expand Down
8 changes: 8 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/groq.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage
from .._utils import guard_tool_call_id as _guard_tool_call_id, number_to_datetime
from ..messages import (
BaseCountTokensResponse,
BinaryContent,
DocumentUrl,
ImageUrl,
Expand Down Expand Up @@ -160,6 +161,13 @@ async def request(
model_response.usage.requests = 1
return model_response

async def count_tokens(
self,
messages: list[ModelMessage],
) -> BaseCountTokensResponse:
"""Token counting is not supported by the GroqModel."""
raise NotImplementedError('Token counting is not supported by GroqModel')

@asynccontextmanager
async def request_stream(
self,
Expand Down
8 changes: 8 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .._utils import guard_tool_call_id as _guard_tool_call_id, now_utc as _now_utc
from ..messages import (
AudioUrl,
BaseCountTokensResponse,
BinaryContent,
DocumentUrl,
ImageUrl,
Expand Down Expand Up @@ -140,6 +141,13 @@ async def request(
model_response.usage.requests = 1
return model_response

async def count_tokens(
self,
messages: list[ModelMessage],
) -> BaseCountTokensResponse:
"""Token counting is not supported by the HuggingFaceModel."""
raise NotImplementedError('Token counting is not supported by HuggingFaceModel')

@asynccontextmanager
async def request_stream(
self,
Expand Down
9 changes: 8 additions & 1 deletion pydantic_ai_slim/pydantic_ai/models/mcp_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import TYPE_CHECKING, cast

from .. import _mcp, exceptions, usage
from ..messages import ModelMessage, ModelResponse
from ..messages import BaseCountTokensResponse, ModelMessage, ModelResponse
from ..settings import ModelSettings
from . import Model, ModelRequestParameters, StreamedResponse

Expand Down Expand Up @@ -70,6 +70,13 @@ async def request(
f'Unexpected result from MCP sampling, expected "assistant" role, got {result.role}.'
)

async def count_tokens(
self,
messages: list[ModelMessage],
) -> BaseCountTokensResponse:
"""Token counting is not supported by the MCPSamplingModel."""
raise NotImplementedError('Token counting is not supported by MCPSamplingModel')

@asynccontextmanager
async def request_stream(
self,
Expand Down
8 changes: 8 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .. import ModelHTTPError, UnexpectedModelBehavior, _utils
from .._utils import generate_tool_call_id as _generate_tool_call_id, now_utc as _now_utc, number_to_datetime
from ..messages import (
BaseCountTokensResponse,
BinaryContent,
DocumentUrl,
ImageUrl,
Expand Down Expand Up @@ -167,6 +168,13 @@ async def request(
model_response.usage.requests = 1
return model_response

async def count_tokens(
self,
messages: list[ModelMessage],
) -> BaseCountTokensResponse:
"""Token counting is not supported by the MistralModel."""
raise NotImplementedError('Token counting is not supported by MistralModel')

@asynccontextmanager
async def request_stream(
self,
Expand Down
15 changes: 15 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from .._utils import guard_tool_call_id as _guard_tool_call_id, now_utc as _now_utc, number_to_datetime
from ..messages import (
AudioUrl,
BaseCountTokensResponse,
BinaryContent,
DocumentUrl,
ImageUrl,
Expand Down Expand Up @@ -248,6 +249,13 @@ async def request(
model_response.usage.requests = 1
return model_response

async def count_tokens(
self,
messages: list[ModelMessage],
) -> BaseCountTokensResponse:
"""Token counting is not supported by the OpenAIModel."""
raise NotImplementedError('Token counting is not supported by OpenAIModel')

@asynccontextmanager
async def request_stream(
self,
Expand Down Expand Up @@ -672,6 +680,13 @@ async def request(
)
return self._process_response(response)

async def count_tokens(
self,
messages: list[ModelMessage],
) -> BaseCountTokensResponse:
"""Token counting is not supported by the OpenAIResponsesModel."""
raise NotImplementedError('Token counting is not supported by the OpenAIResponsesModel')

@asynccontextmanager
async def request_stream(
self,
Expand Down
Loading
Loading