diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index d4d09c8e20c48..10043b5baad4f 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -957,9 +957,10 @@ def _on_term(signum, frame): # If AirflowFailException is raised, task should not retry. # If a sensor in reschedule mode reaches timeout, task should not retry. log.exception("Task failed with exception") + ti.end_date = datetime.now(tz=timezone.utc) msg = TaskState( state=TaskInstanceState.FAILED, - end_date=datetime.now(tz=timezone.utc), + end_date=ti.end_date, rendered_map_index=ti.rendered_map_index, ) state = TaskInstanceState.FAILED @@ -974,9 +975,10 @@ def _on_term(signum, frame): # updated already be another UI API. So, these exceptions should ideally never be thrown. # If these are thrown, we should mark the TI state as failed. log.exception("Task failed with exception") + ti.end_date = datetime.now(tz=timezone.utc) msg = TaskState( state=TaskInstanceState.FAILED, - end_date=datetime.now(tz=timezone.utc), + end_date=ti.end_date, rendered_map_index=ti.rendered_map_index, ) state = TaskInstanceState.FAILED @@ -1003,10 +1005,12 @@ def _handle_current_task_success( context: Context, ti: RuntimeTaskInstance, ) -> tuple[SucceedTask, TaskInstanceState]: + end_date = datetime.now(tz=timezone.utc) + ti.end_date = end_date task_outlets = list(_build_asset_profiles(ti.task.outlets)) outlet_events = list(_serialize_outlet_events(context["outlet_events"])) msg = SucceedTask( - end_date=datetime.now(tz=timezone.utc), + end_date=end_date, task_outlets=task_outlets, outlet_events=outlet_events, rendered_map_index=ti.rendered_map_index, @@ -1018,11 +1022,15 @@ def _handle_current_task_failed( ti: RuntimeTaskInstance, ) -> tuple[RetryTask, TaskInstanceState] | tuple[TaskState, TaskInstanceState]: end_date = datetime.now(tz=timezone.utc) + ti.end_date = end_date if ti._ti_context_from_server and ti._ti_context_from_server.should_retry: return RetryTask(end_date=end_date), TaskInstanceState.UP_FOR_RETRY - return TaskState( - state=TaskInstanceState.FAILED, end_date=end_date, rendered_map_index=ti.rendered_map_index - ), TaskInstanceState.FAILED + return ( + TaskState( + state=TaskInstanceState.FAILED, end_date=end_date, rendered_map_index=ti.rendered_map_index + ), + TaskInstanceState.FAILED, + ) def _handle_trigger_dag_run( diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index 2556845dac897..fc64d02fbc72a 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -2566,6 +2566,133 @@ class CustomOperator(BaseOperator): assert state == expected_state assert collected_results == expected_results + def test_task_runner_on_failure_callback_context(self, create_runtime_ti): + """Test that on_failure_callback context has end_date and duration.""" + from airflow.exceptions import AirflowException + + def failure_callback(context): + ti = context["task_instance"] + assert isinstance(ti.end_date, datetime) + duration = (ti.end_date - ti.start_date).total_seconds() + assert duration is not None + assert duration >= 0 + + class FailingOperator(BaseOperator): + def execute(self, context): + raise AirflowException("Failing task") + + task = FailingOperator(task_id="failing_task", on_failure_callback=failure_callback) + runtime_ti = create_runtime_ti(dag_id="dag", task=task) + log = mock.MagicMock() + context = runtime_ti.get_template_context() + state, _, error = run(runtime_ti, context, log) + finalize(runtime_ti, state, context, log, error) + + assert state == TaskInstanceState.FAILED + + def test_task_runner_on_success_callback_context(self, create_runtime_ti): + """Test that on_success_callback context has end_date and duration.""" + callback_data = {} # Store callback data for inspection + + def success_callback(context): + ti = context["task_instance"] + callback_data["end_date"] = ti.end_date + callback_data["duration"] = (ti.end_date - ti.start_date).total_seconds() if ti.end_date else None + callback_data["start_date"] = ti.start_date + + class SuccessOperator(BaseOperator): + def execute(self, context): + return "success" + + task = SuccessOperator(task_id="success_task", on_success_callback=success_callback) + runtime_ti = create_runtime_ti(dag_id="dag", task=task) + log = mock.MagicMock() + context = runtime_ti.get_template_context() + + state, _, error = run(runtime_ti, context, log) + finalize(runtime_ti, state, context, log, error) + + assert state == TaskInstanceState.SUCCESS + + # Verify callback was called and data was captured + assert "end_date" in callback_data, "Success callback should have been called" + assert isinstance(callback_data["end_date"], datetime), ( + f"end_date should be datetime, got {type(callback_data['end_date'])}" + ) + assert callback_data["duration"] is not None, ( + f"duration should not be None, got {callback_data['duration']}" + ) + assert callback_data["duration"] >= 0, f"duration should be >= 0, got {callback_data['duration']}" + + def test_task_runner_both_callbacks_have_timing_info(self, create_runtime_ti): + """Test that both success and failure callbacks receive accurate timing information.""" + import time + + from airflow.exceptions import AirflowException + + success_data = {} + failure_data = {} + + def success_callback(context): + ti = context["task_instance"] + success_data["end_date"] = ti.end_date + success_data["start_date"] = ti.start_date + success_data["duration"] = (ti.end_date - ti.start_date).total_seconds() if ti.end_date else None + + def failure_callback(context): + ti = context["task_instance"] + failure_data["end_date"] = ti.end_date + failure_data["start_date"] = ti.start_date + failure_data["duration"] = (ti.end_date - ti.start_date).total_seconds() if ti.end_date else None + + # Test success callback + class SuccessOperator(BaseOperator): + def execute(self, context): + time.sleep(0.01) # Add small delay to ensure measurable duration + return "success" + + success_task = SuccessOperator(task_id="success_task", on_success_callback=success_callback) + success_runtime_ti = create_runtime_ti(dag_id="dag", task=success_task) + success_log = mock.MagicMock() + success_context = success_runtime_ti.get_template_context() + + success_state, _, success_error = run(success_runtime_ti, success_context, success_log) + finalize(success_runtime_ti, success_state, success_context, success_log, success_error) + + # Test failure callback + class FailureOperator(BaseOperator): + def execute(self, context): + time.sleep(0.01) # Add small delay to ensure measurable duration + raise AirflowException("Test failure") + + failure_task = FailureOperator(task_id="failure_task", on_failure_callback=failure_callback) + failure_runtime_ti = create_runtime_ti(dag_id="dag", task=failure_task) + failure_log = mock.MagicMock() + failure_context = failure_runtime_ti.get_template_context() + + failure_state, _, failure_error = run(failure_runtime_ti, failure_context, failure_log) + finalize(failure_runtime_ti, failure_state, failure_context, failure_log, failure_error) + + # Assertions for success callback + assert success_state == TaskInstanceState.SUCCESS + assert "end_date" in success_data, "Success callback should have been called" + assert isinstance(success_data["end_date"], datetime) + assert isinstance(success_data["start_date"], datetime) + assert success_data["duration"] is not None + assert success_data["duration"] >= 0.01, ( + f"Success duration should be >= 0.01, got {success_data['duration']}" + ) + + # Assertions for failure callback + assert failure_state == TaskInstanceState.FAILED + assert "end_date" in failure_data, "Failure callback should have been called" + assert isinstance(failure_data["end_date"], datetime) + assert isinstance(failure_data["start_date"], datetime) + assert failure_data["duration"] is not None + assert failure_data["duration"] >= 0.01, ( + f"Failure duration should be >= 0.01, got {failure_data['duration']}" + ) + @pytest.mark.parametrize( "callback_to_test, execute_impl, should_retry, expected_state, expected_results, extra_exceptions", [