33from collections .abc import AsyncIterator
44from contextlib import asynccontextmanager
55from datetime import datetime , timezone
6- from unittest .mock import AsyncMock , patch
76
87import pytest
9- from aiolimiter import AsyncLimiter
108from tenacity import AsyncRetrying , RetryError , retry_if_exception_type , stop_after_attempt , wait_fixed
9+ from throttled .asyncio import RateLimiterType , Throttled , rate_limiter , store
1110
1211from pydantic_ai .messages import (
1312 ModelMessage ,
@@ -368,59 +367,56 @@ async def test_rate_limited_model_limiter_only():
368367 # Create a simple model with rate limiting
369368 simple_model = SimpleModel ()
370369
371- # Create a rate limiter with a mocked acquire method for easy testing
372- limiter = AsyncLimiter (1 , 1 ) # 1 request per second
373- with patch .object (limiter , 'acquire' , new = AsyncMock ()) as mock_acquire :
374- rate_limited_model = RateLimitedModel (simple_model , limiter = limiter )
375-
376- messages : list [ModelMessage ] = [
377- ModelRequest (
378- parts = [
379- UserPromptPart ('user prompt with rate limiting' ),
380- ]
381- ),
382- ]
383-
384- # Make a request
385- response = await rate_limited_model .request (
386- messages ,
387- model_settings = None ,
388- model_request_parameters = ModelRequestParameters (
389- function_tools = [],
390- allow_text_output = True ,
391- output_tools = [],
392- ),
393- )
370+ # Create a rate limiter - 10 requests per second to make tests fast
371+ throttle = Throttled (
372+ using = RateLimiterType .GCRA .value ,
373+ quota = rate_limiter .per_sec (10 , burst = 10 ),
374+ store = store .MemoryStore (),
375+ )
376+ rate_limited_model = RateLimitedModel (simple_model , limiter = throttle )
394377
395- # Check response is correct
396- assert response .model_name == 'simple_model'
397- assert len (response .parts ) == 1
398- assert isinstance (response .parts [0 ], TextPart )
399- assert response .parts [0 ].content == 'simple response'
378+ messages : list [ModelMessage ] = [
379+ ModelRequest (
380+ parts = [
381+ UserPromptPart ('user prompt with rate limiting' ),
382+ ]
383+ ),
384+ ]
400385
401- # Check that limiter.acquire was called once
402- mock_acquire .assert_called_once ()
386+ # Make a request
387+ response = await rate_limited_model .request (
388+ messages ,
389+ model_settings = None ,
390+ model_request_parameters = ModelRequestParameters (
391+ function_tools = [],
392+ allow_text_output = True ,
393+ output_tools = [],
394+ ),
395+ )
403396
404- # Now test with streaming
405- async with rate_limited_model .request_stream (
406- messages ,
407- model_settings = None ,
408- model_request_parameters = ModelRequestParameters (
409- function_tools = [],
410- allow_text_output = True ,
411- output_tools = [],
412- ),
413- ) as response_stream :
414- events = [event async for event in response_stream ]
397+ # Check response is correct
398+ assert response .model_name == 'simple_model'
399+ assert len (response .parts ) == 1
400+ assert isinstance (response .parts [0 ], TextPart )
401+ assert response .parts [0 ].content == 'simple response'
415402
416- # Verify events
417- assert len (events ) == 2
418- assert isinstance (events [0 ], PartStartEvent )
419- assert isinstance (events [0 ].part , TextPart )
420- assert events [0 ].part .content == 'stream part 1'
403+ # Now test with streaming
404+ async with rate_limited_model .request_stream (
405+ messages ,
406+ model_settings = None ,
407+ model_request_parameters = ModelRequestParameters (
408+ function_tools = [],
409+ allow_text_output = True ,
410+ output_tools = [],
411+ ),
412+ ) as response_stream :
413+ events = [event async for event in response_stream ]
421414
422- # Check acquire was called again
423- assert mock_acquire .call_count == 2
415+ # Verify events
416+ assert len (events ) == 2
417+ assert isinstance (events [0 ], PartStartEvent )
418+ assert isinstance (events [0 ].part , TextPart )
419+ assert events [0 ].part .content == 'stream part 1'
424420
425421
426422async def test_rate_limited_model_both_limiter_and_retryer ():
@@ -429,83 +425,82 @@ async def test_rate_limited_model_both_limiter_and_retryer():
429425 failing_model = FailingModel (fail_count = 2 )
430426
431427 # Configure the limiter and retryer
432- limiter = AsyncLimiter (1 , 1 ) # 1 request per second
428+ throttle = Throttled (
429+ using = RateLimiterType .GCRA .value ,
430+ quota = rate_limiter .per_sec (10 , burst = 10 ),
431+ store = store .MemoryStore (),
432+ )
433433 retry_config = AsyncRetrying (
434434 retry = retry_if_exception_type (ValueError ),
435435 stop = stop_after_attempt (3 ),
436436 wait = wait_fixed (0.1 ),
437437 )
438438
439- # Mock the limiter's acquire method
440- with patch .object (limiter , 'acquire' , new = AsyncMock ()) as mock_acquire :
441- rate_limited_model = RateLimitedModel (failing_model , limiter = limiter , retryer = retry_config )
442-
443- messages : list [ModelMessage ] = [
444- ModelRequest (
445- parts = [
446- UserPromptPart ('user prompt with rate limiting and retries' ),
447- ]
448- ),
449- ]
450-
451- # Make a request - should succeed after retries
452- response = await rate_limited_model .request (
453- messages ,
454- model_settings = None ,
455- model_request_parameters = ModelRequestParameters (
456- function_tools = [],
457- allow_text_output = True ,
458- output_tools = [],
459- ),
460- )
439+ rate_limited_model = RateLimitedModel (failing_model , limiter = throttle , retryer = retry_config )
461440
462- # Check response is correct
463- assert isinstance (response .parts [0 ], TextPart )
464- assert response .parts [0 ].content == 'success after retries'
441+ messages : list [ModelMessage ] = [
442+ ModelRequest (
443+ parts = [
444+ UserPromptPart ('user prompt with rate limiting and retries' ),
445+ ]
446+ ),
447+ ]
465448
466- # Limiter should be acquired for each retry attempt
467- assert mock_acquire .call_count == 3
449+ # Make a request - should succeed after retries
450+ response = await rate_limited_model .request (
451+ messages ,
452+ model_settings = None ,
453+ model_request_parameters = ModelRequestParameters (
454+ function_tools = [],
455+ allow_text_output = True ,
456+ output_tools = [],
457+ ),
458+ )
468459
469- # Verify the model was called the right number of times
470- assert failing_model .attempt_count == 3
460+ # Check response is correct
461+ assert isinstance (response .parts [0 ], TextPart )
462+ assert response .parts [0 ].content == 'success after retries'
471463
472- # Reset for streaming test
473- failing_model .attempt_count = 0
474- mock_acquire .reset_mock ()
464+ # Verify the model was called the right number of times
465+ assert failing_model .attempt_count == 3
475466
476- # Test with streaming - should also succeed after retries
477- async with rate_limited_model .request_stream (
478- messages ,
479- model_settings = None ,
480- model_request_parameters = ModelRequestParameters (
481- function_tools = [],
482- allow_text_output = True ,
483- output_tools = [],
484- ),
485- ) as response_stream :
486- events = [event async for event in response_stream ]
467+ # Reset for streaming test
468+ failing_model .attempt_count = 0
487469
488- # Verify events exist
489- assert len (events ) > 0
470+ # Test with streaming - should also succeed after retries
471+ async with rate_limited_model .request_stream (
472+ messages ,
473+ model_settings = None ,
474+ model_request_parameters = ModelRequestParameters (
475+ function_tools = [],
476+ allow_text_output = True ,
477+ output_tools = [],
478+ ),
479+ ) as response_stream :
480+ events = [event async for event in response_stream ]
490481
491- # Limiter should be acquired for each retry attempt
492- assert mock_acquire . call_count == 3
482+ # Verify events exist
483+ assert len ( events ) > 0
493484
494- # Verify the model was called the right number of times
495- assert failing_model .attempt_count == 3
485+ # Verify the model was called the right number of times
486+ assert failing_model .attempt_count == 3
496487
497488
498489async def test_rate_limited_model_concurrent_requests ():
499490 """Test RateLimitedModel with concurrent requests."""
500- import asyncio
491+ import time
501492
502493 # Create several simple model instances
503494 simple_model = SimpleModel ()
504495
505496 # Create a real rate limiter that will allow 2 requests per second
506- limiter = AsyncLimiter (2 , 1 ) # 2 requests per second
497+ throttle = Throttled (
498+ using = RateLimiterType .GCRA .value ,
499+ quota = rate_limiter .per_sec (2 ), # 2 requests per second
500+ store = store .MemoryStore (),
501+ )
507502
508- rate_limited_model = RateLimitedModel (simple_model , limiter = limiter )
503+ rate_limited_model = RateLimitedModel (simple_model , limiter = throttle )
509504
510505 # Create the message for all requests
511506 messages : list [ModelMessage ] = [
@@ -516,10 +511,10 @@ async def test_rate_limited_model_concurrent_requests():
516511 ),
517512 ]
518513
519- # Define a function to make a request and record the time
520- async def make_request ():
521- start_time = asyncio . get_event_loop (). time ()
522-
514+ # Make 5 sequential requests and measure time
515+ start_time = time . time ()
516+
517+ for i in range ( 5 ):
523518 response = await rate_limited_model .request (
524519 messages ,
525520 model_settings = None ,
@@ -529,30 +524,19 @@ async def make_request():
529524 output_tools = [],
530525 ),
531526 )
532-
533- end_time = asyncio .get_event_loop ().time ()
534527 assert response .model_name == 'simple_model'
535- return end_time - start_time
536-
537- # Run 4 requests concurrently
538- durations = await asyncio .gather (
539- make_request (),
540- make_request (),
541- make_request (),
542- make_request (),
543- )
544-
545- # The first two requests should complete quickly,
546- # but the next two should be delayed due to rate limiting
547- durations_sorted = sorted (durations )
548-
549- # Make a lenient assertion to avoid test flakiness:
550- # - The faster two requests should complete quickly (under 0.3s is a reasonable expectation)
551- # - The slower two requests should be noticeably delayed due to the rate limit
552- assert durations_sorted [0 ] < 0.3
553- assert durations_sorted [1 ] < 0.3
554- assert durations_sorted [2 ] > durations_sorted [0 ]
555- assert durations_sorted [3 ] > durations_sorted [1 ]
528+
529+ total_time = time .time () - start_time
530+
531+ # With 2 requests per second and 5 requests total, with GCRA algorithm:
532+ # - First 2 requests go immediately (burst allowed)
533+ # - Wait ~0.5s, next 2 requests go
534+ # - Wait ~0.5s, last request goes
535+ # Total should be at least 1 second
536+ assert total_time >= 1.0 , f'Expected at least 1 second for 5 requests at 2/sec, but took { total_time } s'
537+
538+ # But it shouldn't take too much longer (allow some margin for processing)
539+ assert total_time < 2.5 , f'Expected less than 2.5 seconds for 5 requests at 2/sec, but took { total_time } s'
556540
557541
558542async def test_rate_limited_model_neither_limiter_nor_retryer ():
0 commit comments