2020from dataclasses import dataclass , field
2121from inspect import iscoroutinefunction
2222from pathlib import Path
23- from typing import Any , Callable , Generic , Literal , Union , cast
23+ from typing import TYPE_CHECKING , Any , Callable , Generic , Literal , Union , cast
2424
2525import anyio
2626import logfire_api
4141from .evaluators ._spec import EvaluatorSpec
4242from .evaluators .common import DEFAULT_EVALUATORS
4343from .evaluators .context import EvaluatorContext
44+ from .evaluators .evaluator import EvaluatorFailure
4445from .otel import SpanTree
4546from .otel ._context_subtree import context_subtree
46- from .reporting import EvaluationReport , ReportCase , ReportCaseAggregate
47+ from .reporting import EvaluationReport , ReportCase , ReportCaseAggregate , ReportCaseFailure
48+
49+ if TYPE_CHECKING :
50+ from tenacity import AsyncRetrying
4751
4852if sys .version_info < (3 , 11 ):
4953 from exceptiongroup import ExceptionGroup # pragma: lax no cover
8488
8589
8690_REPORT_CASES_ADAPTER = TypeAdapter (list [ReportCase ])
91+ _REPORT_CASE_FAILURES_ADAPTER = TypeAdapter (list [ReportCaseFailure ])
8792_REPORT_CASE_AGGREGATE_ADAPTER = TypeAdapter (ReportCaseAggregate )
8893
8994
@@ -171,11 +176,6 @@ def __init__(
171176 self .evaluators = list (evaluators )
172177
173178
174- # TODO: Consider making one or more of the following changes to this type:
175- # * Add `task: Callable[[InputsT], Awaitable[OutputT]` as a field
176- # * Add `inputs_type`, `output_type`, etc. as kwargs on `__init__`
177- # * Rename to `Evaluation`
178- # TODO: Allow `task` to be sync _or_ async
179179class Dataset (BaseModel , Generic [InputsT , OutputT , MetadataT ], extra = 'forbid' , arbitrary_types_allowed = True ):
180180 """A dataset of test [cases][pydantic_evals.Case].
181181
@@ -263,6 +263,7 @@ async def evaluate(
263263 name : str | None = None ,
264264 max_concurrency : int | None = None ,
265265 progress : bool = True ,
266+ retry : AsyncRetrying | None = None ,
266267 ) -> EvaluationReport [InputsT , OutputT , MetadataT ]:
267268 """Evaluates the test cases in the dataset using the given task.
268269
@@ -292,24 +293,30 @@ async def evaluate(
292293
293294 async def _handle_case (case : Case [InputsT , OutputT , MetadataT ], report_case_name : str ):
294295 async with limiter :
295- result = await _run_task_and_evaluators (task , case , report_case_name , self .evaluators )
296+ result = await _run_task_and_evaluators (task , case , report_case_name , self .evaluators , retry )
296297 if progress_bar and task_id is not None : # pragma: no branch
297298 progress_bar .update (task_id , advance = 1 )
298299 return result
299300
301+ cases_and_failures = await task_group_gather (
302+ [
303+ lambda case = case , i = i : _handle_case (case , case .name or f'Case { i } ' )
304+ for i , case in enumerate (self .cases , 1 )
305+ ]
306+ )
300307 report = EvaluationReport (
301308 name = name ,
302- cases = await task_group_gather (
303- [
304- lambda case = case , i = i : _handle_case (case , case .name or f'Case { i } ' )
305- for i , case in enumerate (self .cases , 1 )
306- ]
307- ),
309+ cases = [x for x in cases_and_failures if isinstance (x , ReportCase )],
310+ failures = [x for x in cases_and_failures if isinstance (x , ReportCaseFailure )],
308311 )
309312 # TODO(DavidM): This attribute will be too big in general; remove it once we can use child spans in details panel:
310313 eval_span .set_attribute ('cases' , _REPORT_CASES_ADAPTER .dump_python (report .cases ))
314+ # TODO(DavidM): This attribute will be too big in general; remove it once we can use child spans in details panel:
315+ eval_span .set_attribute ('failures' , _REPORT_CASE_FAILURES_ADAPTER .dump_python (report .failures ))
311316 # TODO(DavidM): Remove this 'averages' attribute once we compute it in the details panel
312- eval_span .set_attribute ('averages' , _REPORT_CASE_AGGREGATE_ADAPTER .dump_python (report .averages ()))
317+ averages = report .averages ()
318+ if averages :
319+ eval_span .set_attribute ('averages' , _REPORT_CASE_AGGREGATE_ADAPTER .dump_python (averages ))
313320 return report
314321
315322 def evaluate_sync (
@@ -817,38 +824,55 @@ def record_attribute(self, name: str, value: Any) -> None:
817824
818825
819826async def _run_task (
820- task : Callable [[InputsT ], Awaitable [OutputT ] | OutputT ], case : Case [InputsT , OutputT , MetadataT ]
827+ task : Callable [[InputsT ], Awaitable [OutputT ] | OutputT ],
828+ case : Case [InputsT , OutputT , MetadataT ],
829+ retry : AsyncRetrying | None = None ,
821830) -> EvaluatorContext [InputsT , OutputT , MetadataT ]:
822831 """Run a task on a case and return the context for evaluators.
823832
824833 Args:
825834 task: The task to run.
826835 case: The case to run the task on.
836+ retry: The retry strategy to use.
827837
828838 Returns:
829839 An EvaluatorContext containing the inputs, actual output, expected output, and metadata.
830840
831841 Raises:
832842 Exception: Any exception raised by the task.
833843 """
834- task_run = _TaskRun ()
835- if _CURRENT_TASK_RUN .get () is not None : # pragma: no cover
836- raise RuntimeError ('A task run has already been entered. Task runs should not be nested' )
837844
838- # Note: the current behavior is for task execution errors to just bubble up all the way and kill the evaluation.
839- # Should we handle them for the user in some way? If so, I guess we'd want to do that here.
840- token = _CURRENT_TASK_RUN .set (task_run )
841- try :
842- with _logfire .span ('execute {task}' , task = get_unwrapped_function_name (task )) as task_span :
843- with context_subtree () as span_tree :
845+ async def _run_once ():
846+ task_run_ = _TaskRun ()
847+ if _CURRENT_TASK_RUN .get () is not None : # pragma: no cover
848+ raise RuntimeError ('A task run has already been entered. Task runs should not be nested' )
849+
850+ token = _CURRENT_TASK_RUN .set (task_run_ )
851+ try :
852+ with (
853+ _logfire .span ('execute {task}' , task = get_unwrapped_function_name (task )) as task_span ,
854+ context_subtree () as span_tree_ ,
855+ ):
844856 t0 = time .perf_counter ()
845857 if iscoroutinefunction (task ):
846- task_output = cast (OutputT , await task (case .inputs ))
858+ task_output_ = cast (OutputT , await task (case .inputs ))
847859 else :
848- task_output = cast (OutputT , await to_thread .run_sync (task , case .inputs ))
860+ task_output_ = cast (OutputT , await to_thread .run_sync (task , case .inputs ))
849861 fallback_duration = time .perf_counter () - t0
850- finally :
851- _CURRENT_TASK_RUN .reset (token )
862+ duration_ = _get_span_duration (task_span , fallback_duration )
863+ return task_run_ , task_output_ , duration_ , span_tree_
864+ finally :
865+ _CURRENT_TASK_RUN .reset (token )
866+
867+ async def _run_with_retries ():
868+ if retry :
869+ async for attempt in retry :
870+ with attempt :
871+ return await _run_once ()
872+ # Note: the following line will be unreachable if retry is not None
873+ return await _run_once ()
874+
875+ task_run , task_output , duration , span_tree = await _run_with_retries ()
852876
853877 if isinstance (span_tree , SpanTree ): # pragma: no branch
854878 # TODO: Question: Should we make this metric-attributes functionality more user-configurable in some way before merging?
@@ -863,6 +887,7 @@ async def _run_task(
863887 if not isinstance (v , (int , float )):
864888 continue
865889 # TODO: Revisit this choice to strip the prefix..
890+ # TODO: Use the span-tracking-of-metrics functionality to simplify this implementation
866891 if k .startswith ('gen_ai.usage.details.' ):
867892 task_run .increment_metric (k .removeprefix ('gen_ai.usage.details.' ), v )
868893 elif k .startswith ('gen_ai.usage.' ):
@@ -874,7 +899,7 @@ async def _run_task(
874899 metadata = case .metadata ,
875900 expected_output = case .expected_output ,
876901 output = task_output ,
877- duration = _get_span_duration ( task_span , fallback_duration ) ,
902+ duration = duration ,
878903 _span_tree = span_tree ,
879904 attributes = task_run .attributes ,
880905 metrics = task_run .metrics ,
@@ -886,7 +911,8 @@ async def _run_task_and_evaluators(
886911 case : Case [InputsT , OutputT , MetadataT ],
887912 report_case_name : str ,
888913 dataset_evaluators : list [Evaluator [InputsT , OutputT , MetadataT ]],
889- ) -> ReportCase [InputsT , OutputT , MetadataT ]:
914+ retry : AsyncRetrying | None ,
915+ ) -> ReportCase [InputsT , OutputT , MetadataT ] | ReportCaseFailure [InputsT , OutputT , MetadataT ]:
890916 """Run a task on a case and evaluate the results.
891917
892918 Args:
@@ -898,60 +924,75 @@ async def _run_task_and_evaluators(
898924 Returns:
899925 A ReportCase containing the evaluation results.
900926 """
901- with _logfire .span (
902- 'case: {case_name}' ,
903- task_name = get_unwrapped_function_name (task ),
904- case_name = report_case_name ,
905- inputs = case .inputs ,
906- metadata = case .metadata ,
907- expected_output = case .expected_output ,
908- ) as case_span :
909- t0 = time .time ()
910- scoring_context = await _run_task (task , case )
911-
912- case_span .set_attribute ('output' , scoring_context .output )
913- case_span .set_attribute ('task_duration' , scoring_context .duration )
914- case_span .set_attribute ('metrics' , scoring_context .metrics )
915- case_span .set_attribute ('attributes' , scoring_context .attributes )
916-
917- evaluators = case .evaluators + dataset_evaluators
918- evaluator_outputs : list [EvaluationResult ] = []
919- if evaluators :
920- evaluator_outputs_by_task = await task_group_gather (
921- [lambda ev = ev : run_evaluator (ev , scoring_context ) for ev in evaluators ]
922- )
923- evaluator_outputs += [out for outputs in evaluator_outputs_by_task for out in outputs ]
924-
925- assertions , scores , labels = _group_evaluator_outputs_by_type (evaluator_outputs )
926- case_span .set_attribute ('assertions' , _evaluation_results_adapter .dump_python (assertions ))
927- case_span .set_attribute ('scores' , _evaluation_results_adapter .dump_python (scores ))
928- case_span .set_attribute ('labels' , _evaluation_results_adapter .dump_python (labels ))
929-
930- context = case_span .context
931- if context is None : # pragma: no cover
932- trace_id = ''
933- span_id = ''
934- else :
935- trace_id = f'{ context .trace_id :032x} '
936- span_id = f'{ context .span_id :016x} '
937- fallback_duration = time .time () - t0
938-
939- return ReportCase [InputsT , OutputT , MetadataT ](
940- name = report_case_name ,
941- inputs = case .inputs ,
942- metadata = case .metadata ,
943- expected_output = case .expected_output ,
944- output = scoring_context .output ,
945- metrics = scoring_context .metrics ,
946- attributes = scoring_context .attributes ,
947- scores = scores ,
948- labels = labels ,
949- assertions = assertions ,
950- task_duration = scoring_context .duration ,
951- total_duration = _get_span_duration (case_span , fallback_duration ),
952- trace_id = trace_id ,
953- span_id = span_id ,
954- )
927+ trace_id = ''
928+ span_id = ''
929+ try :
930+ with _logfire .span (
931+ 'case: {case_name}' ,
932+ task_name = get_unwrapped_function_name (task ),
933+ case_name = report_case_name ,
934+ inputs = case .inputs ,
935+ metadata = case .metadata ,
936+ expected_output = case .expected_output ,
937+ ) as case_span :
938+ context = case_span .context
939+ if context is not None : # pragma: no cover
940+ trace_id = f'{ context .trace_id :032x} '
941+ span_id = f'{ context .span_id :016x} '
942+
943+ t0 = time .time ()
944+ scoring_context = await _run_task (task , case , retry )
945+
946+ case_span .set_attribute ('output' , scoring_context .output )
947+ case_span .set_attribute ('task_duration' , scoring_context .duration )
948+ case_span .set_attribute ('metrics' , scoring_context .metrics )
949+ case_span .set_attribute ('attributes' , scoring_context .attributes )
950+
951+ evaluators = case .evaluators + dataset_evaluators
952+ evaluator_outputs : list [EvaluationResult ] = []
953+ evaluator_failures : list [EvaluatorFailure ] = []
954+ if evaluators :
955+ evaluator_outputs_by_task = await task_group_gather (
956+ [lambda ev = ev : run_evaluator (ev , scoring_context ) for ev in evaluators ]
957+ )
958+ flattened = [out for outputs in evaluator_outputs_by_task for out in outputs ]
959+ evaluator_outputs += [o for o in flattened if not isinstance (o , EvaluatorFailure )]
960+ evaluator_failures += [o for o in flattened if isinstance (o , EvaluatorFailure )]
961+
962+ assertions , scores , labels = _group_evaluator_outputs_by_type (evaluator_outputs )
963+ case_span .set_attribute ('assertions' , _evaluation_results_adapter .dump_python (assertions ))
964+ case_span .set_attribute ('scores' , _evaluation_results_adapter .dump_python (scores ))
965+ case_span .set_attribute ('labels' , _evaluation_results_adapter .dump_python (labels ))
966+
967+ fallback_duration = time .time () - t0
968+
969+ return ReportCase [InputsT , OutputT , MetadataT ](
970+ name = report_case_name ,
971+ inputs = case .inputs ,
972+ metadata = case .metadata ,
973+ expected_output = case .expected_output ,
974+ output = scoring_context .output ,
975+ metrics = scoring_context .metrics ,
976+ attributes = scoring_context .attributes ,
977+ scores = scores ,
978+ labels = labels ,
979+ assertions = assertions ,
980+ task_duration = scoring_context .duration ,
981+ total_duration = _get_span_duration (case_span , fallback_duration ),
982+ trace_id = trace_id ,
983+ span_id = span_id ,
984+ evaluator_failures = evaluator_failures ,
985+ )
986+ except Exception as exc :
987+ return ReportCaseFailure [InputsT , OutputT , MetadataT ](
988+ name = report_case_name ,
989+ inputs = case .inputs ,
990+ metadata = case .metadata ,
991+ expected_output = case .expected_output ,
992+ error_msg = f'{ type (exc ).__name__ } : { exc } ' ,
993+ trace_id = trace_id ,
994+ span_id = span_id ,
995+ )
955996
956997
957998_evaluation_results_adapter = TypeAdapter (Mapping [str , EvaluationResult ])
0 commit comments