diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index 17a3451155e37..a441a0e275ac5 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -392,6 +392,7 @@ def get_first_reschedule_date(self, context: Context) -> datetime | None: def _xcom_push(ti: RuntimeTaskInstance, key: str, value: Any, mapped_length: int | None = None) -> None: + """Push a XCom through XCom.set, which pushes to XCom Backend if configured.""" # Private function, as we don't want to expose the ability to manually set `mapped_length` to SDK # consumers XCom.set( @@ -405,6 +406,18 @@ def _xcom_push(ti: RuntimeTaskInstance, key: str, value: Any, mapped_length: int ) +def _xcom_push_to_db(ti: RuntimeTaskInstance, key: str, value: Any) -> None: + """Push a XCom directly to metadata DB, bypassing custom xcom_backend.""" + XCom._set_xcom_in_db( + key=key, + value=value, + dag_id=ti.dag_id, + task_id=ti.task_id, + run_id=ti.run_id, + map_index=ti.map_index, + ) + + def parse(what: StartupDetails) -> RuntimeTaskInstance: # TODO: Task-SDK: # Using DagBag here is about 98% wrong, but it'll do for now @@ -944,7 +957,7 @@ def finalize( for oe in task.operator_extra_links: link, xcom_key = oe.get_link(operator=task, ti_key=ti), oe.xcom_key # type: ignore[arg-type] log.debug("Setting xcom for operator extra link", link=link, xcom_key=xcom_key) - _xcom_push(ti, key=xcom_key, value=link) + _xcom_push_to_db(ti, key=xcom_key, value=link) if getattr(ti.task, "overwrite_rtif_after_execution", False): log.debug("Overwriting Rendered template fields.") diff --git a/task-sdk/src/airflow/sdk/execution_time/xcom.py b/task-sdk/src/airflow/sdk/execution_time/xcom.py index c7bf4a9be3793..a10d831e2ee01 100644 --- a/task-sdk/src/airflow/sdk/execution_time/xcom.py +++ b/task-sdk/src/airflow/sdk/execution_time/xcom.py @@ -77,6 +77,42 @@ def set( ), ) + @classmethod + def _set_xcom_in_db( + cls, + key: str, + value: Any, + *, + dag_id: str, + task_id: str, + run_id: str, + map_index: int = -1, + ) -> None: + """ + Store an XCom value directly in the metadata database. + + :param key: Key to store the XCom. + :param value: XCom value to store. + :param dag_id: DAG ID. + :param task_id: Task ID. + :param run_id: DAG run ID for the task. + :param map_index: Optional map index to assign XCom for a mapped task. + The default is ``-1`` (set for a non-mapped task). + """ + from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + + SUPERVISOR_COMMS.send_request( + log=log, + msg=SetXCom( + key=key, + value=value, + dag_id=dag_id, + task_id=task_id, + run_id=run_id, + map_index=map_index, + ), + ) + @classmethod def get_value( cls, diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index 8e6180dd11aa2..941804c9acad2 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -1384,21 +1384,21 @@ def execute(self, context): log=mock.ANY, ) - finalize(runtime_ti, context=context, log=mock.MagicMock(), state=TerminalTIState.SUCCESS) - - mock_supervisor_comms.send_request.assert_any_call( - msg=SetXCom( + with mock.patch.object(XCom, "_set_xcom_in_db") as mock_xcom_set: + finalize( + runtime_ti, + log=mock.MagicMock(), + state=TerminalTIState.SUCCESS, + context=runtime_ti.get_template_context(), + ) + mock_xcom_set.assert_called_once_with( key="_link_AirflowLink", value="https://airflow.apache.org", - dag_id="test_dag", - run_id="test_run", - task_id="task_with_operator_extra_links", - map_index=-1, - mapped_length=None, - type="SetXCom", - ), - log=mock.ANY, - ) + dag_id=runtime_ti.dag_id, + task_id=runtime_ti.task_id, + run_id=runtime_ti.run_id, + map_index=runtime_ti.map_index, + ) def test_overwrite_rtif_after_execution_sets_rtif(self, create_runtime_ti, mock_supervisor_comms): """Test that the RTIF is overwritten after execution for certain operators."""