4848from .reporting import EvaluationReport , ReportCase , ReportCaseAggregate , ReportCaseFailure
4949
5050if TYPE_CHECKING :
51- from tenacity import AsyncRetrying
51+ from pydantic_ai . retries import RetryConfig
5252
5353if sys .version_info < (3 , 11 ):
5454 from exceptiongroup import ExceptionGroup # pragma: lax no cover
@@ -264,7 +264,8 @@ async def evaluate(
264264 name : str | None = None ,
265265 max_concurrency : int | None = None ,
266266 progress : bool = True ,
267- retry : AsyncRetrying | None = None ,
267+ retry_task : RetryConfig | None = None ,
268+ retry_evaluators : RetryConfig | None = None ,
268269 ) -> EvaluationReport [InputsT , OutputT , MetadataT ]:
269270 """Evaluates the test cases in the dataset using the given task.
270271
@@ -279,7 +280,8 @@ async def evaluate(
279280 max_concurrency: The maximum number of concurrent evaluations of the task to allow.
280281 If None, all cases will be evaluated concurrently.
281282 progress: Whether to show a progress bar for the evaluation. Defaults to `True`.
282- retry: Optional retry configuration for the task execution.
283+ retry_task: Optional retry configuration for the task execution.
284+ retry_evaluators: Optional retry configuration for evaluator execution.
283285
284286 Returns:
285287 A report containing the results of the evaluation.
@@ -295,7 +297,9 @@ async def evaluate(
295297
296298 async def _handle_case (case : Case [InputsT , OutputT , MetadataT ], report_case_name : str ):
297299 async with limiter :
298- result = await _run_task_and_evaluators (task , case , report_case_name , self .evaluators , retry )
300+ result = await _run_task_and_evaluators (
301+ task , case , report_case_name , self .evaluators , retry_task , retry_evaluators
302+ )
299303 if progress_bar and task_id is not None : # pragma: no branch
300304 progress_bar .update (task_id , advance = 1 )
301305 return result
@@ -828,14 +832,14 @@ def record_attribute(self, name: str, value: Any) -> None:
828832async def _run_task (
829833 task : Callable [[InputsT ], Awaitable [OutputT ] | OutputT ],
830834 case : Case [InputsT , OutputT , MetadataT ],
831- retry : AsyncRetrying | None = None ,
835+ retry : RetryConfig | None = None ,
832836) -> EvaluatorContext [InputsT , OutputT , MetadataT ]:
833837 """Run a task on a case and return the context for evaluators.
834838
835839 Args:
836840 task: The task to run.
837841 case: The case to run the task on.
838- retry: The retry strategy to use.
842+ retry: The retry config to use.
839843
840844 Returns:
841845 An EvaluatorContext containing the inputs, actual output, expected output, and metadata.
@@ -868,11 +872,10 @@ async def _run_once():
868872
869873 async def _run_with_retries ():
870874 if retry :
871- async for attempt in retry :
872- with attempt :
873- return await _run_once ()
874- # Note: the following line will be unreachable if retry is not None
875- return await _run_once ()
875+ return await retry .decorator (_run_once )()
876+ else :
877+ # Note: the following line will be unreachable if retry is not None
878+ return await _run_once ()
876879
877880 task_run , task_output , duration , span_tree = await _run_with_retries ()
878881
@@ -913,7 +916,8 @@ async def _run_task_and_evaluators(
913916 case : Case [InputsT , OutputT , MetadataT ],
914917 report_case_name : str ,
915918 dataset_evaluators : list [Evaluator [InputsT , OutputT , MetadataT ]],
916- retry : AsyncRetrying | None ,
919+ retry_task : RetryConfig | None ,
920+ retry_evaluators : RetryConfig | None ,
917921) -> ReportCase [InputsT , OutputT , MetadataT ] | ReportCaseFailure [InputsT , OutputT , MetadataT ]:
918922 """Run a task on a case and evaluate the results.
919923
@@ -922,7 +926,7 @@ async def _run_task_and_evaluators(
922926 case: The case to run the task on.
923927 report_case_name: The name to use for this case in the report.
924928 dataset_evaluators: Evaluators from the dataset to apply to this case.
925- retry : The retry strategy to use for running the task.
929+ retry_task : The retry config to use for running the task.
926930
927931 Returns:
928932 A ReportCase containing the evaluation results.
@@ -944,7 +948,7 @@ async def _run_task_and_evaluators(
944948 span_id = f'{ context .span_id :016x} '
945949
946950 t0 = time .time ()
947- scoring_context = await _run_task (task , case , retry )
951+ scoring_context = await _run_task (task , case , retry_task )
948952
949953 case_span .set_attribute ('output' , scoring_context .output )
950954 case_span .set_attribute ('task_duration' , scoring_context .duration )
@@ -956,7 +960,7 @@ async def _run_task_and_evaluators(
956960 evaluator_failures : list [EvaluatorFailure ] = []
957961 if evaluators :
958962 evaluator_outputs_by_task = await task_group_gather (
959- [lambda ev = ev : run_evaluator (ev , scoring_context ) for ev in evaluators ]
963+ [lambda ev = ev : run_evaluator (ev , scoring_context , retry_evaluators ) for ev in evaluators ]
960964 )
961965 for outputs in evaluator_outputs_by_task :
962966 if isinstance (outputs , EvaluatorFailure ):
0 commit comments