Skip to content

Commit 34ec681

Browse files
committed
Iteration on error handling ... more to do
1 parent 6b9feb6 commit 34ec681

File tree

4 files changed

+114
-29
lines changed

4 files changed

+114
-29
lines changed

pydantic_ai_slim/pydantic_ai/retries.py

Lines changed: 78 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,25 +13,85 @@
1313

1414
from __future__ import annotations
1515

16+
from dataclasses import dataclass
17+
1618
from httpx import AsyncBaseTransport, AsyncHTTPTransport, BaseTransport, HTTPTransport, Request, Response
19+
from pydantic_core import PydanticUndefinedType as Undefined
1720

1821
try:
19-
from tenacity import AsyncRetrying, Retrying
22+
from tenacity import AsyncRetrying, Retrying, WrappedFn
2023
except ImportError as _import_error:
2124
raise ImportError(
2225
'Please install `tenacity` to use the retries utilities, '
2326
'you can use the `retries` optional group — `pip install "pydantic-ai-slim[retries]"`'
2427
) from _import_error
2528

26-
27-
__all__ = ['TenacityTransport', 'AsyncTenacityTransport', 'wait_retry_after']
28-
29+
from collections.abc import Awaitable
2930
from datetime import datetime, timezone
3031
from email.utils import parsedate_to_datetime
31-
from typing import Callable, cast
32+
from typing import TYPE_CHECKING, Any, Callable, cast
3233

3334
from httpx import HTTPStatusError
34-
from tenacity import RetryCallState, wait_exponential
35+
from tenacity import RetryCallState, RetryError, retry, wait_exponential
36+
37+
if TYPE_CHECKING:
38+
from tenacity.asyncio.retry import RetryBaseT
39+
from tenacity.retry import RetryBaseT as SyncRetryBaseT
40+
from tenacity.stop import StopBaseT
41+
from tenacity.wait import WaitBaseT
42+
43+
__all__ = ['RetryConfig', 'TenacityTransport', 'AsyncTenacityTransport', 'wait_retry_after']
44+
45+
UNDEFINED = Undefined()
46+
47+
48+
@dataclass
49+
class RetryConfig:
50+
"""These are the arguments to the tenacity retry function and AsyncRetrying/Retrying classes."""
51+
52+
# The following arguments cannot be None in tenacity but have private default values, so we use None as a sentinel
53+
sleep: Callable[[int | float], None | Awaitable[None]] | None = None
54+
stop: StopBaseT | None = None
55+
wait: WaitBaseT | None = None
56+
retry: SyncRetryBaseT | RetryBaseT | None = None
57+
before: Callable[[RetryCallState], None | Awaitable[None]] | None = None
58+
after: Callable[[RetryCallState], None | Awaitable[None]] | None = None
59+
60+
# The following have public types and default values in tenacity, so we just repeat them verbatim here
61+
before_sleep: Callable[[RetryCallState], None | Awaitable[None]] | None = None
62+
reraise: bool = False
63+
retry_error_cls: type[RetryError] = RetryError
64+
retry_error_callback: Callable[[RetryCallState], Any | Awaitable[Any]] | None = None
65+
66+
def tenacity_kwargs(self) -> dict[str, Any]:
67+
kwargs: dict[str, Any] = {
68+
'before_sleep': self.before_sleep,
69+
'reraise': self.reraise,
70+
'retry_error_cls': self.retry_error_cls,
71+
'retry_error_callback': self.retry_error_callback,
72+
}
73+
if self.sleep is not None:
74+
kwargs['sleep'] = self.sleep
75+
if self.stop is not None:
76+
kwargs['stop'] = self.stop
77+
if self.wait is not None:
78+
kwargs['wait'] = self.wait
79+
if self.retry is not None:
80+
kwargs['retry'] = self.retry
81+
if self.before is not None:
82+
kwargs['before'] = self.before
83+
if self.after is not None:
84+
kwargs['after'] = self.after
85+
86+
return kwargs
87+
88+
def tenacity_decorator(self, function: WrappedFn) -> WrappedFn:
89+
"""Wrap the provided function using this config to populate the tenacity `retry` decorator.
90+
91+
Returns:
92+
A wrapped version of the function that will use this configuration for tenacity-based retrying when called.
93+
"""
94+
return retry(**self.tenacity_kwargs())(function)
3595

3696

3797
class TenacityTransport(BaseTransport):
@@ -76,7 +136,7 @@ class TenacityTransport(BaseTransport):
76136

77137
def __init__(
78138
self,
79-
controller: Retrying,
139+
controller: RetryConfig | Retrying,
80140
wrapped: BaseTransport | None = None,
81141
validate_response: Callable[[Response], None] | None = None,
82142
):
@@ -97,7 +157,10 @@ def handle_request(self, request: Request) -> Response:
97157
RuntimeError: If the retry controller did not make any attempts.
98158
Exception: Any exception raised by the wrapped transport or validation function.
99159
"""
100-
for attempt in self.controller:
160+
controller = (
161+
self.controller if isinstance(self.controller, Retrying) else Retrying(**self.controller.tenacity_kwargs())
162+
)
163+
for attempt in controller:
101164
with attempt:
102165
response = self.wrapped.handle_request(request)
103166
if self.validate_response:
@@ -147,7 +210,7 @@ class AsyncTenacityTransport(AsyncBaseTransport):
147210

148211
def __init__(
149212
self,
150-
controller: AsyncRetrying,
213+
controller: RetryConfig | AsyncRetrying,
151214
wrapped: AsyncBaseTransport | None = None,
152215
validate_response: Callable[[Response], None] | None = None,
153216
):
@@ -168,7 +231,12 @@ async def handle_async_request(self, request: Request) -> Response:
168231
RuntimeError: If the retry controller did not make any attempts.
169232
Exception: Any exception raised by the wrapped transport or validation function.
170233
"""
171-
async for attempt in self.controller:
234+
controller = (
235+
self.controller
236+
if isinstance(self.controller, AsyncRetrying)
237+
else AsyncRetrying(**self.controller.tenacity_kwargs())
238+
)
239+
async for attempt in controller:
172240
with attempt:
173241
response = await self.wrapped.handle_async_request(request)
174242
if self.validate_response:

pydantic_evals/pydantic_evals/dataset.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
from .reporting import EvaluationReport, ReportCase, ReportCaseAggregate, ReportCaseFailure
4949

5050
if TYPE_CHECKING:
51-
from tenacity import AsyncRetrying
51+
from pydantic_ai.retries import RetryConfig
5252

5353
if sys.version_info < (3, 11):
5454
from exceptiongroup import ExceptionGroup # pragma: lax no cover
@@ -264,7 +264,8 @@ async def evaluate(
264264
name: str | None = None,
265265
max_concurrency: int | None = None,
266266
progress: bool = True,
267-
retry: AsyncRetrying | None = None,
267+
retry_task: RetryConfig | None = None,
268+
retry_evaluators: RetryConfig | None = None,
268269
) -> EvaluationReport[InputsT, OutputT, MetadataT]:
269270
"""Evaluates the test cases in the dataset using the given task.
270271
@@ -279,7 +280,8 @@ async def evaluate(
279280
max_concurrency: The maximum number of concurrent evaluations of the task to allow.
280281
If None, all cases will be evaluated concurrently.
281282
progress: Whether to show a progress bar for the evaluation. Defaults to `True`.
282-
retry: Optional retry configuration for the task execution.
283+
retry_task: Optional retry configuration for the task execution.
284+
retry_evaluators: Optional retry configuration for evaluator execution.
283285
284286
Returns:
285287
A report containing the results of the evaluation.
@@ -295,7 +297,9 @@ async def evaluate(
295297

296298
async def _handle_case(case: Case[InputsT, OutputT, MetadataT], report_case_name: str):
297299
async with limiter:
298-
result = await _run_task_and_evaluators(task, case, report_case_name, self.evaluators, retry)
300+
result = await _run_task_and_evaluators(
301+
task, case, report_case_name, self.evaluators, retry_task, retry_evaluators
302+
)
299303
if progress_bar and task_id is not None: # pragma: no branch
300304
progress_bar.update(task_id, advance=1)
301305
return result
@@ -828,14 +832,14 @@ def record_attribute(self, name: str, value: Any) -> None:
828832
async def _run_task(
829833
task: Callable[[InputsT], Awaitable[OutputT] | OutputT],
830834
case: Case[InputsT, OutputT, MetadataT],
831-
retry: AsyncRetrying | None = None,
835+
retry: RetryConfig | None = None,
832836
) -> EvaluatorContext[InputsT, OutputT, MetadataT]:
833837
"""Run a task on a case and return the context for evaluators.
834838
835839
Args:
836840
task: The task to run.
837841
case: The case to run the task on.
838-
retry: The retry strategy to use.
842+
retry: The retry config to use.
839843
840844
Returns:
841845
An EvaluatorContext containing the inputs, actual output, expected output, and metadata.
@@ -868,11 +872,10 @@ async def _run_once():
868872

869873
async def _run_with_retries():
870874
if retry:
871-
async for attempt in retry:
872-
with attempt:
873-
return await _run_once()
874-
# Note: the following line will be unreachable if retry is not None
875-
return await _run_once()
875+
return await retry.decorator(_run_once)()
876+
else:
877+
# Note: the following line will be unreachable if retry is not None
878+
return await _run_once()
876879

877880
task_run, task_output, duration, span_tree = await _run_with_retries()
878881

@@ -913,7 +916,8 @@ async def _run_task_and_evaluators(
913916
case: Case[InputsT, OutputT, MetadataT],
914917
report_case_name: str,
915918
dataset_evaluators: list[Evaluator[InputsT, OutputT, MetadataT]],
916-
retry: AsyncRetrying | None,
919+
retry_task: RetryConfig | None,
920+
retry_evaluators: RetryConfig | None,
917921
) -> ReportCase[InputsT, OutputT, MetadataT] | ReportCaseFailure[InputsT, OutputT, MetadataT]:
918922
"""Run a task on a case and evaluate the results.
919923
@@ -922,7 +926,7 @@ async def _run_task_and_evaluators(
922926
case: The case to run the task on.
923927
report_case_name: The name to use for this case in the report.
924928
dataset_evaluators: Evaluators from the dataset to apply to this case.
925-
retry: The retry strategy to use for running the task.
929+
retry_task: The retry config to use for running the task.
926930
927931
Returns:
928932
A ReportCase containing the evaluation results.
@@ -944,7 +948,7 @@ async def _run_task_and_evaluators(
944948
span_id = f'{context.span_id:016x}'
945949

946950
t0 = time.time()
947-
scoring_context = await _run_task(task, case, retry)
951+
scoring_context = await _run_task(task, case, retry_task)
948952

949953
case_span.set_attribute('output', scoring_context.output)
950954
case_span.set_attribute('task_duration', scoring_context.duration)
@@ -956,7 +960,7 @@ async def _run_task_and_evaluators(
956960
evaluator_failures: list[EvaluatorFailure] = []
957961
if evaluators:
958962
evaluator_outputs_by_task = await task_group_gather(
959-
[lambda ev=ev: run_evaluator(ev, scoring_context) for ev in evaluators]
963+
[lambda ev=ev: run_evaluator(ev, scoring_context, retry_evaluators) for ev in evaluators]
960964
)
961965
for outputs in evaluator_outputs_by_task:
962966
if isinstance(outputs, EvaluatorFailure):

pydantic_evals/pydantic_evals/evaluators/_run_evaluator.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import traceback
44
from collections.abc import Mapping
5-
from typing import Any
5+
from typing import TYPE_CHECKING, Any
66

77
import logfire_api
88
from pydantic import (
@@ -21,6 +21,12 @@
2121
EvaluatorOutput,
2222
)
2323

24+
if TYPE_CHECKING:
25+
# TODO: pydantic_evals should not import from pydantic_ai...
26+
# Need to figure out a good way to sneak retry behavior into the evaluators..
27+
# Well, the problem is that we probably want to use the retry stuff in both pydantic_ai and pydantic_evals ... ugh.
28+
from pydantic_ai.retries import RetryConfig
29+
2430
# while waiting for https://github.com/pydantic/logfire/issues/745
2531
try:
2632
import logfire._internal.stack_info
@@ -40,7 +46,9 @@
4046

4147

4248
async def run_evaluator(
43-
evaluator: Evaluator[InputsT, OutputT, MetadataT], ctx: EvaluatorContext[InputsT, OutputT, MetadataT]
49+
evaluator: Evaluator[InputsT, OutputT, MetadataT],
50+
ctx: EvaluatorContext[InputsT, OutputT, MetadataT],
51+
retry: RetryConfig | None = None,
4452
) -> list[EvaluationResult] | EvaluatorFailure:
4553
"""Run an evaluator and return the results.
4654
@@ -50,19 +58,24 @@ async def run_evaluator(
5058
Args:
5159
evaluator: The evaluator to run.
5260
ctx: The context containing the inputs, outputs, and metadata for evaluation.
61+
retry: The retry configuration to use for running the evaluator.
5362
5463
Returns:
5564
A list of evaluation results, or an evaluator failure if an exception is raised during its execution.
5665
5766
Raises:
5867
ValueError: If the evaluator returns a value of an invalid type.
5968
"""
69+
evaluate = evaluator.evaluate_async
70+
if retry is not None:
71+
evaluate = retry.tenacity_decorator(evaluate)
72+
6073
try:
6174
with _logfire.span(
6275
'evaluator: {evaluator_name}',
6376
evaluator_name=evaluator.get_default_evaluation_name(),
6477
):
65-
raw_results = await evaluator.evaluate_async(ctx)
78+
raw_results = await evaluate(ctx)
6679

6780
try:
6881
results = _EVALUATOR_OUTPUT_ADAPTER.validate_python(raw_results)

tests/evals/test_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,7 @@ async def mock_async_task(inputs: TaskInput) -> TaskOutput:
316316
return TaskOutput(answer='Paris')
317317
return TaskOutput(answer='Unknown') # pragma: no cover
318318

319-
report = await example_dataset.evaluate(mock_async_task, retry=AsyncRetrying(stop=stop_after_attempt(3)))
319+
report = await example_dataset.evaluate(mock_async_task, retry_task=AsyncRetrying(stop=stop_after_attempt(3)))
320320

321321
assert attempt == 3
322322

0 commit comments

Comments
 (0)