diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index 33583ffe9c7b8..a515c1a7212ee 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -1292,7 +1292,8 @@ def main(): bundle_name=ti.bundle_instance.name, bundle_version=ti.bundle_instance.version, ): - state, msg, error = run(ti, context, log) + state, _, error = run(ti, context, log) + context["exception"] = error finalize(ti, state, context, log, error) except KeyboardInterrupt: log = structlog.get_logger(logger_name="task") diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index 09d4d9709e789..4b47fb57aa558 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -2277,6 +2277,9 @@ def execute(self, context): @pytest.mark.usefixtures("mock_supervisor_comms") class TestTaskRunnerCallsCallbacks: + class _Failure(Exception): + """Exception raised in a failed execution and received by the failure callback.""" + def _execute_success(self, context): self.results.append("execute success") @@ -2288,7 +2291,7 @@ def _execute_skipped(self, context): def _execute_failure(self, context): self.results.append("execute failure") - raise Exception("sorry!") + raise self._Failure("sorry!") @pytest.mark.parametrize( "execute_impl, should_retry, expected_state, expected_results", @@ -2336,6 +2339,10 @@ def test_task_runner_calls_callback( def custom_callback(context, *, kind): collected_results.append(f"on-{kind} callback") + def failure_callback(context): + custom_callback(context, kind="failure") + assert isinstance(context["exception"], self._Failure) + class CustomOperator(BaseOperator): results = collected_results execute = execute_impl @@ -2345,7 +2352,7 @@ class CustomOperator(BaseOperator): on_execute_callback=functools.partial(custom_callback, kind="execute"), on_skipped_callback=functools.partial(custom_callback, kind="skipped"), on_success_callback=functools.partial(custom_callback, kind="success"), - on_failure_callback=functools.partial(custom_callback, kind="failure"), + on_failure_callback=failure_callback, on_retry_callback=functools.partial(custom_callback, kind="retry"), ) runtime_ti = create_runtime_ti(dag_id="dag", task=task, should_retry=should_retry)