-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Rate Limit and Retry for Models #1734
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
a772aff
dc32fb0
7fbf88f
b006aa4
ba2cb2c
415ae25
b34b8d8
330997d
c624f13
5d07625
7aa6da6
33796b4
99e332a
7fdc132
39e8339
6bb5bfb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,165 @@ | ||
| from __future__ import annotations | ||
|
|
||
| from collections.abc import AsyncIterator | ||
| from contextlib import asynccontextmanager | ||
| from dataclasses import dataclass | ||
| from typing import TYPE_CHECKING | ||
|
|
||
| from tenacity import AsyncRetrying | ||
|
|
||
| if TYPE_CHECKING: | ||
| from throttled.asyncio import Throttled | ||
|
|
||
| from . import ( | ||
| KnownModelName, | ||
| Model, | ||
| ModelMessage, | ||
| ModelRequestParameters, | ||
| ModelResponse, | ||
| ModelSettings, | ||
| StreamedResponse, | ||
| ) | ||
| from .wrapper import WrapperModel | ||
|
|
||
|
|
||
| @dataclass | ||
| class RateLimitedModel(WrapperModel): | ||
| """Model which wraps another model such that requests are rate limited with throttled. | ||
| If retryer is provided it also retries requests with tenacity. | ||
| Usage: | ||
| ```python | ||
| from tenacity import AsyncRetrying, stop_after_attempt | ||
| from throttled.asyncio import Throttled, rate_limiter | ||
| from throttled.asyncio.store import MemoryStore | ||
| from pydantic_ai import Agent | ||
| from pydantic_ai.models.rate_limited import RateLimitedModel | ||
| throttle = Throttled( | ||
| using='gcra', | ||
| quota=rate_limiter.per_sec(1_000, burst=1_000), | ||
| store=MemoryStore(), | ||
| ) | ||
| model = RateLimitedModel( | ||
| 'anthropic:claude-3-7-sonnet-latest', | ||
| limiter=throttle, | ||
| retryer=AsyncRetrying(stop=stop_after_attempt(3)), | ||
| ) | ||
| agent = Agent(model=model) | ||
| ``` | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| wrapped: Model | KnownModelName, | ||
| limiter: Throttled | None = None, | ||
| retryer: AsyncRetrying | None = None, | ||
| ) -> None: | ||
| super().__init__(wrapped) | ||
| self.limiter = limiter | ||
| self.retryer = retryer | ||
|
|
||
| async def request( | ||
| self, | ||
| messages: list[ModelMessage], | ||
| model_settings: ModelSettings | None, | ||
| model_request_parameters: ModelRequestParameters, | ||
| key: str = 'default', | ||
| cost: int = 1, | ||
| timeout: float | None = 30.0, | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would it make sense to move these to the initializer? Or could they be values on the limiter that's passed? Users don't call There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe we can move the parameters to the initialization stage: throttle = Throttled(
using='gcra',
quota=rate_limiter.per_sec(1_000, burst=1_000),
# store can be omitted, the global MemoryStore is provided by default.
# store=MemoryStore(),
key='default',
cost=1,
timeout=30
) |
||
| ) -> ModelResponse: | ||
| """Make a request to the model. | ||
| Args: | ||
| messages: The messages to send to the model. | ||
| model_settings: The settings to use for the model. | ||
| model_request_parameters: The parameters to use for the model. | ||
| key: The key to use in the rate limiter store. | ||
| cost: The cost to use for the rate limiter. | ||
| timeout: The timeout to use for the rate limiter. Important: if timeout is | ||
| not provided or set to -1, the rate limiter will return immediately. | ||
| """ | ||
| if self.retryer: | ||
| async for attempt in self.retryer: | ||
| with attempt: | ||
| if self.limiter: | ||
| await self.limiter.limit(key, cost, timeout) | ||
| return await super().request( | ||
| messages, | ||
| model_settings, | ||
| model_request_parameters, | ||
| ) | ||
| else: | ||
| return await super().request( | ||
| messages, | ||
| model_settings, | ||
| model_request_parameters, | ||
| ) | ||
| raise RuntimeError('Model request failed after all retries') | ||
| else: | ||
| if self.limiter: | ||
| await self.limiter.limit(key, cost, timeout) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could we consider verifying the rate limit result( For example: if self.limiter:
await result = self.limiter.limit(key, cost, timeout)
# 💡 Check if the limit has been exceeded.
if result.limited:
raise RuntimeError('Rate limit exceeded.')
return await super().request(...)I am concerned that skipping the check and still executing the request after exceeding the rate limit may cause the model to encounter unpredictable third-party exception errors. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes good point I will adjust |
||
| return await super().request( | ||
| messages, | ||
| model_settings, | ||
| model_request_parameters, | ||
| ) | ||
| else: | ||
| return await super().request( | ||
| messages, | ||
| model_settings, | ||
| model_request_parameters, | ||
| ) | ||
|
|
||
| @asynccontextmanager | ||
| async def request_stream( | ||
| self, | ||
| messages: list[ModelMessage], | ||
| model_settings: ModelSettings | None, | ||
| model_request_parameters: ModelRequestParameters, | ||
| key: str = 'default', | ||
| cost: int = 1, | ||
| timeout: float | None = 30.0, | ||
| ) -> AsyncIterator[StreamedResponse]: | ||
| """Make a streamed request to the model. | ||
| Args: | ||
| messages: The messages to send to the model. | ||
| model_settings: The settings to use for the model. | ||
| model_request_parameters: The parameters to use for the model. | ||
| key: The key to use in the rate limiter store. | ||
| cost: The cost to use for the rate limiter. | ||
| timeout: The timeout to use for the rate limiter. Important: if timeout is | ||
| not provided or set to -1, the rate limiter will return immediately. | ||
| """ | ||
| if self.retryer: | ||
| async for attempt in self.retryer: | ||
| with attempt: | ||
| if self.limiter: | ||
| await self.limiter.limit(key, cost, timeout) | ||
| async with super().request_stream( | ||
| messages, model_settings, model_request_parameters | ||
| ) as response_stream: | ||
| yield response_stream | ||
| else: | ||
| async with super().request_stream( | ||
| messages, model_settings, model_request_parameters | ||
| ) as response_stream: | ||
| yield response_stream | ||
| else: | ||
| if self.limiter: | ||
| await self.limiter.limit(key, cost, timeout) | ||
| async with super().request_stream( | ||
| messages, model_settings, model_request_parameters | ||
| ) as response_stream: | ||
| yield response_stream | ||
| else: | ||
| async with super().request_stream( | ||
| messages, model_settings, model_request_parameters | ||
| ) as response_stream: | ||
| yield response_stream | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -55,6 +55,8 @@ dependencies = [ | |
| "exceptiongroup; python_version < '3.11'", | ||
| "opentelemetry-api>=1.28.0", | ||
| "typing-inspection>=0.4.0", | ||
| "tenacity>=9.1.2", | ||
| "throttled-py>=2.2.0" | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @Kludex What do you think of these being included by default? Should we put them in an optional dependency group? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's add these to a new optional dependency group called |
||
| ] | ||
|
|
||
| [tool.hatch.metadata.hooks.uv-dynamic-versioning.optional-dependencies] | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the implication of this value being the same for all models (unless it's overwritten)? Should we use a different value for each model/agent?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@grll @DouweM The target of rate limiting is the LLMs API, which is different for each model and needs to be independent. Can we use
self.model_nameas the default key unless a specific key already exists for the limiter?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ZhuoZhuoCrayon I think that makes sense, but let's add in the provider as well:
{self.system}:{self.model_name}