diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index 822ed9bb50ae8..461e770c42501 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -1395,9 +1395,17 @@ def finalize( task = ti.task # Pushing xcom for each operator extra links defined on the operator only. for oe in task.operator_extra_links: - link, xcom_key = oe.get_link(operator=task, ti_key=ti), oe.xcom_key # type: ignore[arg-type] - log.debug("Setting xcom for operator extra link", link=link, xcom_key=xcom_key) - _xcom_push_to_db(ti, key=xcom_key, value=link) + try: + link, xcom_key = oe.get_link(operator=task, ti_key=ti), oe.xcom_key # type: ignore[arg-type] + log.debug("Setting xcom for operator extra link", link=link, xcom_key=xcom_key) + _xcom_push_to_db(ti, key=xcom_key, value=link) + except Exception: + log.exception( + "Failed to push an xcom for task operator extra link", + link_name=oe.name, + xcom_key=oe.xcom_key, + ti=ti, + ) if getattr(ti.task, "overwrite_rtif_after_execution", False): log.debug("Overwriting Rendered template fields.") diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index 58dda3371ef1d..4adf3c23723fb 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -27,7 +27,7 @@ from pathlib import Path from typing import TYPE_CHECKING from unittest import mock -from unittest.mock import patch +from unittest.mock import call, patch import pandas as pd import pytest @@ -48,6 +48,7 @@ from airflow.sdk import ( DAG, BaseOperator, + BaseOperatorLink, Connection, dag as dag_decorator, get_current_context, @@ -1723,6 +1724,93 @@ def execute(self, context): map_index=runtime_ti.map_index, ) + def test_task_failed_with_operator_extra_links( + self, create_runtime_ti, mock_supervisor_comms, time_machine + ): + """Test that operator extra links are pushed to xcoms even when task fails.""" + instant = timezone.datetime(2024, 12, 3, 10, 0) + time_machine.move_to(instant, tick=False) + + class DummyTestOperator(BaseOperator): + operator_extra_links = (AirflowLink(),) + + def execute(self, context): + raise ValueError("Task failed intentionally") + + task = DummyTestOperator(task_id="task_with_operator_extra_links") + runtime_ti = create_runtime_ti(task=task) + context = runtime_ti.get_template_context() + runtime_ti.start_date = instant + runtime_ti.end_date = instant + + state, _, error = run(runtime_ti, context=context, log=mock.MagicMock()) + assert state == TaskInstanceState.FAILED + assert error is not None + + with mock.patch.object(XCom, "_set_xcom_in_db") as mock_xcom_set: + finalize( + runtime_ti, + log=mock.MagicMock(), + state=TaskInstanceState.FAILED, + context=context, + error=error, + ) + assert mock_xcom_set.mock_calls == [ + call( + key="_link_AirflowLink", + value="https://airflow.apache.org", + dag_id=runtime_ti.dag_id, + task_id=runtime_ti.task_id, + run_id=runtime_ti.run_id, + map_index=runtime_ti.map_index, + ) + ] + + def test_operator_extra_links_exception_handling( + self, create_runtime_ti, mock_supervisor_comms, time_machine + ): + """Test that exceptions in get_link() don't prevent other links from being pushed.""" + instant = timezone.datetime(2024, 12, 3, 10, 0) + time_machine.move_to(instant, tick=False) + + class FailingLink(BaseOperatorLink): + """A link that raises an exception when get_link is called.""" + + name = "failing_link" + + def get_link(self, operator, *, ti_key): + raise ValueError("Link generation failed") + + class DummyTestOperator(BaseOperator): + operator_extra_links = (FailingLink(), AirflowLink()) + + def execute(self, context): + pass + + task = DummyTestOperator(task_id="task_with_multiple_links") + runtime_ti = create_runtime_ti(task=task) + context = runtime_ti.get_template_context() + runtime_ti.start_date = instant + runtime_ti.end_date = instant + + with mock.patch.object(XCom, "_set_xcom_in_db") as mock_xcom_set: + finalize( + runtime_ti, + log=mock.MagicMock(), + state=TaskInstanceState.SUCCESS, + context=context, + ) + assert mock_xcom_set.mock_calls == [ + call( + key="_link_AirflowLink", + value="https://airflow.apache.org", + dag_id=runtime_ti.dag_id, + task_id=runtime_ti.task_id, + run_id=runtime_ti.run_id, + map_index=runtime_ti.map_index, + ) + ] + @pytest.mark.parametrize( ["cmd", "rendered_cmd"], [