From 654bf758a9e44deb4f8d09f2efd25c7d56432d03 Mon Sep 17 00:00:00 2001 From: Jens Scheffler Date: Sun, 28 Dec 2025 23:25:22 +0100 Subject: [PATCH 01/10] Remove global from task_runner supervisor-comms --- .../src/airflow/dag_processing/processor.py | 2 +- .../src/airflow/jobs/triggerer_job_runner.py | 4 +- airflow-core/src/airflow/models/connection.py | 9 +- airflow-core/src/airflow/models/variable.py | 17 ++-- .../tests/unit/jobs/test_triggerer_job.py | 2 +- .../tests/unit/models/test_connection.py | 4 +- .../src/tests_common/pytest_plugin.py | 4 +- task-sdk/src/airflow/sdk/bases/xcom.py | 24 +++--- .../sdk/definitions/asset/decorators.py | 4 +- .../src/airflow/sdk/execution_time/comms.py | 4 +- .../src/airflow/sdk/execution_time/context.py | 18 ++-- .../src/airflow/sdk/execution_time/hitl.py | 12 +-- .../sdk/execution_time/lazy_sequence.py | 10 +-- .../execution_time/secrets/execution_api.py | 36 ++++---- .../airflow/sdk/execution_time/supervisor.py | 34 ++++---- .../airflow/sdk/execution_time/task_runner.py | 83 ++++++++++++------- task-sdk/src/airflow/sdk/log.py | 5 +- .../task_sdk/execution_time/test_context.py | 2 +- .../task_sdk/execution_time/test_secrets.py | 16 ++-- .../execution_time/test_supervisor.py | 28 +++---- 20 files changed, 173 insertions(+), 145 deletions(-) diff --git a/airflow-core/src/airflow/dag_processing/processor.py b/airflow-core/src/airflow/dag_processing/processor.py index 6492e4bab777b..476bd02f69196 100644 --- a/airflow-core/src/airflow/dag_processing/processor.py +++ b/airflow-core/src/airflow/dag_processing/processor.py @@ -194,7 +194,7 @@ def _parse_file_entrypoint(): if not isinstance(msg, DagFileParseRequest): raise RuntimeError(f"Required first message to be a DagFileParseRequest, it was {msg}") - task_runner.SUPERVISOR_COMMS = comms_decoder + task_runner.set_supervisor_comms(comms_decoder) log = structlog.get_logger(logger_name="task") result = _parse_file(msg, log) diff --git a/airflow-core/src/airflow/jobs/triggerer_job_runner.py b/airflow-core/src/airflow/jobs/triggerer_job_runner.py index 2508e4a281ab7..3b87e168ad9bc 100644 --- a/airflow-core/src/airflow/jobs/triggerer_job_runner.py +++ b/airflow-core/src/airflow/jobs/triggerer_job_runner.py @@ -930,7 +930,7 @@ async def init_comms(self): """ Set up the communications pipe between this process and the supervisor. - This also sets up the SUPERVISOR_COMMS so that TaskSDK code can work as expected too (but that will + This also sets up the supervisor-comms so that TaskSDK code can work as expected too (but that will need to be wrapped in an ``sync_to_async()`` call) """ from airflow.sdk.execution_time import task_runner @@ -943,7 +943,7 @@ async def init_comms(self): async_reader=reader, ) - task_runner.SUPERVISOR_COMMS = self.comms_decoder + task_runner.set_supervisor_comms(self.comms_decoder) msg = await self.comms_decoder._aget_response(expect_id=0) diff --git a/airflow-core/src/airflow/models/connection.py b/airflow-core/src/airflow/models/connection.py index 64303f191942a..c4b70b15bcc4d 100644 --- a/airflow-core/src/airflow/models/connection.py +++ b/airflow-core/src/airflow/models/connection.py @@ -20,7 +20,6 @@ import json import logging import re -import sys import warnings from contextlib import suppress from json import JSONDecodeError @@ -506,7 +505,9 @@ def get_connection_from_secrets(cls, conn_id: str, team_name: str | None = None) # If this is set it means are in some kind of execution context (Task, Dag Parse or Triggerer perhaps) # and should use the Task SDK API server path - if hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), "SUPERVISOR_COMMS"): + from airflow.sdk.execution_time.task_runner import is_supervisor_comms_initialized + + if is_supervisor_comms_initialized(): from airflow.sdk import Connection as TaskSDKConnection from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType @@ -590,7 +591,9 @@ def to_dict(self, *, prune_empty: bool = False, validate: bool = True) -> dict[s @classmethod def from_json(cls, value, conn_id=None) -> Connection: - if hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), "SUPERVISOR_COMMS"): + from airflow.sdk.execution_time.task_runner import is_supervisor_comms_initialized + + if is_supervisor_comms_initialized(): from airflow.sdk import Connection as TaskSDKConnection warnings.warn( diff --git a/airflow-core/src/airflow/models/variable.py b/airflow-core/src/airflow/models/variable.py index 65808259fae47..0dee21bfdc5dc 100644 --- a/airflow-core/src/airflow/models/variable.py +++ b/airflow-core/src/airflow/models/variable.py @@ -20,7 +20,6 @@ import contextlib import json import logging -import sys import warnings from typing import TYPE_CHECKING, Any @@ -154,7 +153,9 @@ def get( # If this is set it means we are in some kind of execution context (Task, Dag Parse or Triggerer perhaps) # and should use the Task SDK API server path - if hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), "SUPERVISOR_COMMS"): + from airflow.sdk.execution_time.task_runner import is_supervisor_comms_initialized + + if is_supervisor_comms_initialized(): warnings.warn( "Using Variable.get from `airflow.models` is deprecated." "Please use `get` on Variable from sdk(`airflow.sdk.Variable`) instead", @@ -214,7 +215,9 @@ def set( # If this is set it means we are in some kind of execution context (Task, Dag Parse or Triggerer perhaps) # and should use the Task SDK API server path - if hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), "SUPERVISOR_COMMS"): + from airflow.sdk.execution_time.task_runner import is_supervisor_comms_initialized + + if is_supervisor_comms_initialized(): warnings.warn( "Using Variable.set from `airflow.models` is deprecated." "Please use `set` on Variable from sdk(`airflow.sdk.Variable`) instead", @@ -345,7 +348,9 @@ def update( # If this is set it means are in some kind of execution context (Task, Dag Parse or Triggerer perhaps) # and should use the Task SDK API server path - if hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), "SUPERVISOR_COMMS"): + from airflow.sdk.execution_time.task_runner import is_supervisor_comms_initialized + + if is_supervisor_comms_initialized(): warnings.warn( "Using Variable.update from `airflow.models` is deprecated." "Please use `set` on Variable from sdk(`airflow.sdk.Variable`) instead as it is an upsert.", @@ -411,7 +416,9 @@ def delete(key: str, team_name: str | None = None, session: Session | None = Non # If this is set it means are in some kind of execution context (Task, Dag Parse or Triggerer perhaps) # and should use the Task SDK API server path - if hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), "SUPERVISOR_COMMS"): + from airflow.sdk.execution_time.task_runner import is_supervisor_comms_initialized + + if is_supervisor_comms_initialized(): warnings.warn( "Using Variable.delete from `airflow.models` is deprecated." "Please use `delete` on Variable from sdk(`airflow.sdk.Variable`) instead", diff --git a/airflow-core/tests/unit/jobs/test_triggerer_job.py b/airflow-core/tests/unit/jobs/test_triggerer_job.py index df31c9225c272..b8ec067340035 100644 --- a/airflow-core/tests/unit/jobs/test_triggerer_job.py +++ b/airflow-core/tests/unit/jobs/test_triggerer_job.py @@ -360,7 +360,7 @@ def fn(moment): ... assert "got an unexpected keyword argument 'not_exists_arg'" in str(err) @pytest.mark.asyncio - @patch("airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True) + @patch("airflow.sdk.execution_time.task_runner._SupervisorCommsHolder.comms", create=True) async def test_invalid_trigger(self, supervisor_builder): """Test the behaviour when we try to run an invalid Trigger""" workload = workloads.RunTrigger.model_construct( diff --git a/airflow-core/tests/unit/models/test_connection.py b/airflow-core/tests/unit/models/test_connection.py index 1f14c325ea9e5..54054c3832429 100644 --- a/airflow-core/tests/unit/models/test_connection.py +++ b/airflow-core/tests/unit/models/test_connection.py @@ -361,7 +361,7 @@ def test_get_connection_from_secrets_task_sdk_success(self, mock_get): mock_get.return_value = expected_connection mock_task_runner = mock.MagicMock() - mock_task_runner.SUPERVISOR_COMMS = True + mock_task_runner._SupervisorCommsHolder.comms = True with mock.patch.dict(sys.modules, {"airflow.sdk.execution_time.task_runner": mock_task_runner}): result = Connection.get_connection_from_secrets("test_conn") @@ -373,7 +373,7 @@ def test_get_connection_from_secrets_task_sdk_success(self, mock_get): def test_get_connection_from_secrets_task_sdk_not_found(self, mock_task_sdk_connection): """Test the get_connection_from_secrets method with Task SDK not found path.""" mock_task_runner = mock.MagicMock() - mock_task_runner.SUPERVISOR_COMMS = True + mock_task_runner._SupervisorCommsHolder.comms = True mock_task_sdk_connection.get.side_effect = AirflowRuntimeError( error=ErrorResponse(error=ErrorType.CONNECTION_NOT_FOUND) diff --git a/devel-common/src/tests_common/pytest_plugin.py b/devel-common/src/tests_common/pytest_plugin.py index 9dd6e782193f1..6e9096dc9590a 100644 --- a/devel-common/src/tests_common/pytest_plugin.py +++ b/devel-common/src/tests_common/pytest_plugin.py @@ -2239,12 +2239,12 @@ def mock_supervisor_comms(monkeypatch): # core and TaskSDK is finished if CommsDecoder := getattr(comms, "CommsDecoder", None): comms = mock.create_autospec(CommsDecoder) - monkeypatch.setattr(task_runner, "SUPERVISOR_COMMS", comms, raising=False) + monkeypatch.setattr(task_runner._SupervisorCommsHolder, "comms", comms, raising=False) else: CommsDecoder = getattr(task_runner, "CommsDecoder") comms = mock.create_autospec(CommsDecoder) comms.send = comms.get_message - monkeypatch.setattr(task_runner, "SUPERVISOR_COMMS", comms, raising=False) + monkeypatch.setattr(task_runner._SupervisorCommsHolder, "comms", comms, raising=False) yield comms diff --git a/task-sdk/src/airflow/sdk/bases/xcom.py b/task-sdk/src/airflow/sdk/bases/xcom.py index 81cbfbed478e6..6567b12198a97 100644 --- a/task-sdk/src/airflow/sdk/bases/xcom.py +++ b/task-sdk/src/airflow/sdk/bases/xcom.py @@ -72,7 +72,7 @@ def set( :param map_index: Optional map index to assign XCom for a mapped task. The default is ``-1`` (set for a non-mapped task). """ - from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + from airflow.sdk.execution_time.task_runner import supervisor_comms value = cls.serialize_value( value=value, @@ -83,7 +83,7 @@ def set( map_index=map_index, ) - SUPERVISOR_COMMS.send( + supervisor_comms().send( SetXCom( key=key, value=value, @@ -117,9 +117,9 @@ def _set_xcom_in_db( :param map_index: Optional map index to assign XCom for a mapped task. The default is ``-1`` (set for a non-mapped task). """ - from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + from airflow.sdk.execution_time.task_runner import supervisor_comms - SUPERVISOR_COMMS.send( + supervisor_comms().send( SetXCom( key=key, value=value, @@ -190,9 +190,9 @@ def _get_xcom_db_ref( :param key: A key for the XCom. If provided, only XCom with matching keys will be returned. Pass *None* (default) to remove the filter. """ - from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + from airflow.sdk.execution_time.task_runner import supervisor_comms - msg = SUPERVISOR_COMMS.send( + msg = supervisor_comms().send( GetXCom( key=key, dag_id=dag_id, @@ -243,9 +243,9 @@ def get_one( specified Dag run is returned. If *True*, the latest matching XCom is returned regardless of the run it belongs to. """ - from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + from airflow.sdk.execution_time.task_runner import supervisor_comms - msg = SUPERVISOR_COMMS.send( + msg = supervisor_comms().send( GetXCom( key=key, dag_id=dag_id, @@ -299,9 +299,9 @@ def get_all( returned regardless of the run they belong to. :return: List of all XCom values if found. """ - from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + from airflow.sdk.execution_time.task_runner import supervisor_comms - msg = SUPERVISOR_COMMS.send( + msg = supervisor_comms().send( msg=GetXComSequenceSlice( key=key, dag_id=dag_id, @@ -360,7 +360,7 @@ def delete( map_index: int | None = None, ) -> None: """Delete an Xcom entry, for custom xcom backends, it gets the path associated with the data on the backend and purges it.""" - from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + from airflow.sdk.execution_time.task_runner import supervisor_comms xcom_result = cls._get_xcom_db_ref( key=key, @@ -370,7 +370,7 @@ def delete( map_index=map_index, ) cls.purge(xcom_result) - SUPERVISOR_COMMS.send( + supervisor_comms().send( DeleteXCom( key=key, dag_id=dag_id, diff --git a/task-sdk/src/airflow/sdk/definitions/asset/decorators.py b/task-sdk/src/airflow/sdk/definitions/asset/decorators.py index 26205f0c58335..172784d10a9b3 100644 --- a/task-sdk/src/airflow/sdk/definitions/asset/decorators.py +++ b/task-sdk/src/airflow/sdk/definitions/asset/decorators.py @@ -71,10 +71,10 @@ def from_definition(cls, definition: AssetDefinition | MultiAssetDefinition) -> def _iter_kwargs(self, context: Mapping[str, Any]) -> Iterator[tuple[str, Any]]: from airflow.sdk.execution_time.comms import ErrorResponse, GetAssetByName - from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + from airflow.sdk.execution_time.task_runner import supervisor_comms def _fetch_asset(name: str) -> Asset: - resp = SUPERVISOR_COMMS.send(GetAssetByName(name=name)) + resp = supervisor_comms().send(GetAssetByName(name=name)) if resp is None: raise RuntimeError("Empty non-error response received") if isinstance(resp, ErrorResponse): diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py b/task-sdk/src/airflow/sdk/execution_time/comms.py index 15755e640d97e..5904ff317b2c9 100644 --- a/task-sdk/src/airflow/sdk/execution_time/comms.py +++ b/task-sdk/src/airflow/sdk/execution_time/comms.py @@ -368,9 +368,9 @@ def xcom_pull(self, *, key: str = "return_value", default: Any = None) -> Any: def _fetch_dag_run(*, dag_id: str, run_id: str) -> DagRun: - from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + from airflow.sdk.execution_time.task_runner import supervisor_comms - response = SUPERVISOR_COMMS.send(GetDagRun(dag_id=dag_id, run_id=run_id)) + response = supervisor_comms().send(GetDagRun(dag_id=dag_id, run_id=run_id)) if TYPE_CHECKING: assert isinstance(response, DagRunResult) return response diff --git a/task-sdk/src/airflow/sdk/execution_time/context.py b/task-sdk/src/airflow/sdk/execution_time/context.py index db5a75e10c18d..a27f7d3f81164 100644 --- a/task-sdk/src/airflow/sdk/execution_time/context.py +++ b/task-sdk/src/airflow/sdk/execution_time/context.py @@ -286,7 +286,7 @@ def _set_variable(key: str, value: Any, description: str | None = None, serializ from airflow.sdk.execution_time.comms import PutVariable from airflow.sdk.execution_time.secrets.execution_api import ExecutionAPISecretsBackend from airflow.sdk.execution_time.supervisor import ensure_secrets_backend_loaded - from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + from airflow.sdk.execution_time.task_runner import supervisor_comms # check for write conflicts on the worker for secrets_backend in ensure_secrets_backend_loaded(): @@ -317,7 +317,7 @@ def _set_variable(key: str, value: Any, description: str | None = None, serializ except Exception as e: log.exception(e) - SUPERVISOR_COMMS.send(PutVariable(key=key, value=value, description=description)) + supervisor_comms().send(PutVariable(key=key, value=value, description=description)) # Invalidate cache after setting the variable SecretCache.invalidate_variable(key) @@ -331,9 +331,9 @@ def _delete_variable(key: str) -> None: # keep Task SDK as a separate package than execution time mods. from airflow.sdk.execution_time.cache import SecretCache from airflow.sdk.execution_time.comms import DeleteVariable - from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + from airflow.sdk.execution_time.task_runner import supervisor_comms - msg = SUPERVISOR_COMMS.send(DeleteVariable(key=key)) + msg = supervisor_comms().send(DeleteVariable(key=key)) if TYPE_CHECKING: assert isinstance(msg, OKResponse) @@ -458,7 +458,7 @@ def _get_asset_from_db(name: str | None = None, uri: str | None = None) -> Asset GetAssetByUri, ToSupervisor, ) - from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + from airflow.sdk.execution_time.task_runner import supervisor_comms msg: ToSupervisor if name: @@ -468,7 +468,7 @@ def _get_asset_from_db(name: str | None = None, uri: str | None = None) -> Asset else: raise ValueError("Either name or uri must be provided") - resp = SUPERVISOR_COMMS.send(msg) + resp = supervisor_comms().send(msg) if isinstance(resp, ErrorResponse): raise AirflowRuntimeError(resp) @@ -619,7 +619,7 @@ def _asset_events(self) -> list[AssetEventResult]: GetAssetEventByAssetAlias, ToSupervisor, ) - from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + from airflow.sdk.execution_time.task_runner import supervisor_comms query_dict: dict[str, Any] = { "after": self._after, @@ -635,7 +635,7 @@ def _asset_events(self) -> list[AssetEventResult]: if self._asset_name is None and self._asset_uri is None: raise ValueError("Either asset_name or asset_uri must be provided") msg = GetAssetEventByAsset(name=self._asset_name, uri=self._asset_uri, **query_dict) - resp = SUPERVISOR_COMMS.send(msg) + resp = supervisor_comms().send(msg) if isinstance(resp, ErrorResponse): raise AirflowRuntimeError(resp) @@ -788,7 +788,7 @@ def get_previous_dagrun_success(ti_id: UUID) -> PrevSuccessfulDagRunResponse: PrevSuccessfulDagRunResult, ) - msg = task_runner.SUPERVISOR_COMMS.send(GetPrevSuccessfulDagRun(ti_id=ti_id)) + msg = task_runner.supervisor_comms().send(GetPrevSuccessfulDagRun(ti_id=ti_id)) if TYPE_CHECKING: assert isinstance(msg, PrevSuccessfulDagRunResult) diff --git a/task-sdk/src/airflow/sdk/execution_time/hitl.py b/task-sdk/src/airflow/sdk/execution_time/hitl.py index 07f94e63c0529..6240a518b6654 100644 --- a/task-sdk/src/airflow/sdk/execution_time/hitl.py +++ b/task-sdk/src/airflow/sdk/execution_time/hitl.py @@ -46,9 +46,9 @@ def upsert_hitl_detail( params: dict[str, Any] | None = None, assigned_users: list[HITLUser] | None = None, ) -> None: - from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + from airflow.sdk.execution_time.task_runner import supervisor_comms - SUPERVISOR_COMMS.send( + supervisor_comms().send( msg=CreateHITLDetailPayload( ti_id=ti_id, options=options, @@ -71,9 +71,9 @@ def update_hitl_detail_response( chosen_options: list[str], params_input: dict[str, Any], ) -> HITLDetailResponse: - from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + from airflow.sdk.execution_time.task_runner import supervisor_comms - response = SUPERVISOR_COMMS.send( + response = supervisor_comms().send( msg=UpdateHITLDetail( ti_id=ti_id, chosen_options=chosen_options, @@ -86,9 +86,9 @@ def update_hitl_detail_response( def get_hitl_detail_content_detail(ti_id: UUID) -> HITLDetailResponse: - from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + from airflow.sdk.execution_time.task_runner import supervisor_comms - response = SUPERVISOR_COMMS.send(msg=GetHITLDetailResponse(ti_id=ti_id)) + response = supervisor_comms().send(msg=GetHITLDetailResponse(ti_id=ti_id)) if TYPE_CHECKING: assert isinstance(response, HITLDetailResponse) diff --git a/task-sdk/src/airflow/sdk/execution_time/lazy_sequence.py b/task-sdk/src/airflow/sdk/execution_time/lazy_sequence.py index 4efb0b71368ca..689b95d4bdc5a 100644 --- a/task-sdk/src/airflow/sdk/execution_time/lazy_sequence.py +++ b/task-sdk/src/airflow/sdk/execution_time/lazy_sequence.py @@ -90,11 +90,11 @@ def __iter__(self) -> Iterator[T]: def __len__(self) -> int: if self._len is None: from airflow.sdk.execution_time.comms import ErrorResponse, GetXComCount, XComCountResponse - from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + from airflow.sdk.execution_time.task_runner import supervisor_comms task = self._xcom_arg.operator - msg = SUPERVISOR_COMMS.send( + msg = supervisor_comms().send( GetXComCount( key=self._xcom_arg.key, dag_id=task.dag_id, @@ -123,13 +123,13 @@ def __getitem__(self, key: int | slice) -> T | Sequence[T]: XComSequenceIndexResult, XComSequenceSliceResult, ) - from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + from airflow.sdk.execution_time.task_runner import supervisor_comms from airflow.sdk.execution_time.xcom import XCom if isinstance(key, slice): start, stop, step = _coerce_slice(key) source = (xcom_arg := self._xcom_arg).operator - msg = SUPERVISOR_COMMS.send( + msg = supervisor_comms().send( GetXComSequenceSlice( key=xcom_arg.key, dag_id=source.dag_id, @@ -150,7 +150,7 @@ def __getitem__(self, key: int | slice) -> T | Sequence[T]: raise TypeError(f"Sequence indices must be integers or slices not {type(key).__name__}") source = (xcom_arg := self._xcom_arg).operator - msg = SUPERVISOR_COMMS.send( + msg = supervisor_comms().send( GetXComSequenceItem( key=xcom_arg.key, dag_id=source.dag_id, diff --git a/task-sdk/src/airflow/sdk/execution_time/secrets/execution_api.py b/task-sdk/src/airflow/sdk/execution_time/secrets/execution_api.py index a44b23d06dc6d..1de2d2d3393fd 100644 --- a/task-sdk/src/airflow/sdk/execution_time/secrets/execution_api.py +++ b/task-sdk/src/airflow/sdk/execution_time/secrets/execution_api.py @@ -30,14 +30,14 @@ class ExecutionAPISecretsBackend(BaseSecretsBackend): """ Secrets backend for client contexts (workers, DAG processors, triggerers). - Routes connection and variable requests through SUPERVISOR_COMMS to the + Routes connection and variable requests through supervisor-comms to the Execution API server. This backend should only be registered in client processes, not in API server/scheduler processes. """ def get_conn_value(self, conn_id: str, team_name: str | None = None) -> str | None: """ - Get connection URI via SUPERVISOR_COMMS. + Get connection URI via supervisor-comms. Not used since we override get_connection directly. """ @@ -45,7 +45,7 @@ def get_conn_value(self, conn_id: str, team_name: str | None = None) -> str | No def get_connection(self, conn_id: str, team_name: str | None = None) -> Connection | None: # type: ignore[override] """ - Return connection object by routing through SUPERVISOR_COMMS. + Return connection object by routing through supervisor-comms. :param conn_id: connection id :param team_name: Name of the team associated to the task trying to access the connection. @@ -54,10 +54,10 @@ def get_connection(self, conn_id: str, team_name: str | None = None) -> Connecti """ from airflow.sdk.execution_time.comms import ErrorResponse, GetConnection from airflow.sdk.execution_time.context import _process_connection_result_conn - from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + from airflow.sdk.execution_time.task_runner import supervisor_comms try: - msg = SUPERVISOR_COMMS.send(GetConnection(conn_id=conn_id)) + msg = supervisor_comms().send(GetConnection(conn_id=conn_id)) if isinstance(msg, ErrorResponse): # Connection not found or error occurred @@ -86,13 +86,13 @@ def get_connection(self, conn_id: str, team_name: str | None = None) -> Connecti # Fall through to the general exception handler for other RuntimeErrors return None except Exception: - # If SUPERVISOR_COMMS fails for any reason, return None + # If supervisor-comms fails for any reason, return None # to allow fallback to other backends return None def get_variable(self, key: str, team_name: str | None = None) -> str | None: """ - Return variable value by routing through SUPERVISOR_COMMS. + Return variable value by routing through supervisor-comms. :param key: Variable key :param team_name: Name of the team associated to the task trying to access the variable. @@ -100,10 +100,10 @@ def get_variable(self, key: str, team_name: str | None = None) -> str | None: :return: Variable value or None if not found """ from airflow.sdk.execution_time.comms import ErrorResponse, GetVariable, VariableResult - from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + from airflow.sdk.execution_time.task_runner import supervisor_comms try: - msg = SUPERVISOR_COMMS.send(GetVariable(key=key)) + msg = supervisor_comms().send(GetVariable(key=key)) if isinstance(msg, ErrorResponse): # Variable not found or error occurred @@ -114,23 +114,23 @@ def get_variable(self, key: str, team_name: str | None = None) -> str | None: return msg.value # Already a string | None return None except Exception: - # If SUPERVISOR_COMMS fails for any reason, return None + # If supervisor-comms fails for any reason, return None # to allow fallback to other backends return None async def aget_connection(self, conn_id: str) -> Connection | None: # type: ignore[override] """ - Return connection object asynchronously via SUPERVISOR_COMMS. + Return connection object asynchronously via supervisor-comms. :param conn_id: connection id :return: Connection object or None if not found """ from airflow.sdk.execution_time.comms import ErrorResponse, GetConnection from airflow.sdk.execution_time.context import _process_connection_result_conn - from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + from airflow.sdk.execution_time.task_runner import supervisor_comms try: - msg = await SUPERVISOR_COMMS.asend(GetConnection(conn_id=conn_id)) + msg = await supervisor_comms().asend(GetConnection(conn_id=conn_id)) if isinstance(msg, ErrorResponse): # Connection not found or error occurred @@ -139,22 +139,22 @@ async def aget_connection(self, conn_id: str) -> Connection | None: # type: ign # Convert ExecutionAPI response to SDK Connection return _process_connection_result_conn(msg) except Exception: - # If SUPERVISOR_COMMS fails for any reason, return None + # If supervisor-comms fails for any reason, return None # to allow fallback to other backends return None async def aget_variable(self, key: str) -> str | None: """ - Return variable value asynchronously via SUPERVISOR_COMMS. + Return variable value asynchronously via supervisor-comms. :param key: Variable key :return: Variable value or None if not found """ from airflow.sdk.execution_time.comms import ErrorResponse, GetVariable, VariableResult - from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + from airflow.sdk.execution_time.task_runner import supervisor_comms try: - msg = await SUPERVISOR_COMMS.asend(GetVariable(key=key)) + msg = await supervisor_comms().asend(GetVariable(key=key)) if isinstance(msg, ErrorResponse): # Variable not found or error occurred @@ -165,6 +165,6 @@ async def aget_variable(self, key: str) -> str | None: return msg.value # Already a string | None return None except Exception: - # If SUPERVISOR_COMMS fails for any reason, return None + # If supervisor-comms fails for any reason, return None # to allow fallback to other backends return None diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index b87131aa7336d..8b6992d5898b4 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -845,7 +845,7 @@ def _fetch_remote_logging_conn(conn_id: str, client: Client) -> Connection | Non Connection object or None if not found. """ # Since we need to use the API Client directly, we can't use Connection.get as that would try to use - # SUPERVISOR_COMMS + # supervisor-comms # TODO: Store in the SecretsCache if its enabled - see #48858 @@ -888,7 +888,7 @@ def _remote_logging_conn(client: Client): connection it needs, now, directly from the API client, and store it in an env var, so that when the logging hook tries to get the connection it can find it easily from the env vars. - This is needed as the BaseHook.get_connection looks for SUPERVISOR_COMMS, but we are still in the + This is needed as the BaseHook.get_connection looks for supervisor-comms, but we are still in the supervisor process when this is needed, so that doesn't exist yet. The connection details are fetched eagerly on every invocation to avoid retaining @@ -1689,7 +1689,7 @@ def run_trigger_in_process(cls, *, trigger, ti): Run a trigger in-process for testing, similar to how we run tasks. This creates a minimal supervisor instance specifically for trigger execution - and ensures the trigger has access to SUPERVISOR_COMMS for connection access. + and ensures the trigger has access to supervisor-comms for connection access. """ # Create a minimal supervisor instance for trigger execution supervisor = cls( @@ -1723,34 +1723,34 @@ def final_state(self): @contextmanager def set_supervisor_comms(temp_comms): """ - Temporarily override `SUPERVISOR_COMMS` in the `task_runner` module. + Temporarily override `supervisor_comms()` in the `task_runner` module. This is used to simulate task-runner ↔ supervisor communication in-process, by injecting a test Comms implementation (e.g. `InProcessSupervisorComms`) in place of the real inter-process communication layer. Some parts of the code (e.g. models.Variable.get) check for the presence - of `task_runner.SUPERVISOR_COMMS` to determine if the code is running in a Task SDK execution context. + of `task_runner.supervisor_comms()` to determine if the code is running in a Task SDK execution context. This override ensures those code paths behave correctly during in-process tests. """ from airflow.sdk.execution_time import task_runner sentinel = object() - old = getattr(task_runner, "SUPERVISOR_COMMS", sentinel) + old = getattr(task_runner._SupervisorCommsHolder, "comms", sentinel) if temp_comms is not None: - task_runner.SUPERVISOR_COMMS = temp_comms + task_runner._SupervisorCommsHolder.comms = temp_comms elif old is not sentinel: - delattr(task_runner, "SUPERVISOR_COMMS") + task_runner._SupervisorCommsHolder.comms = None try: yield finally: if old is sentinel: - if hasattr(task_runner, "SUPERVISOR_COMMS"): - delattr(task_runner, "SUPERVISOR_COMMS") + if task_runner._SupervisorCommsHolder.comms is not None: + task_runner._SupervisorCommsHolder.comms = None else: - task_runner.SUPERVISOR_COMMS = old + task_runner._SupervisorCommsHolder.comms = old def run_task_in_process(ti: TaskInstance, task) -> TaskRunResult: @@ -1910,13 +1910,13 @@ def ensure_secrets_backend_loaded() -> list[BaseSecretsBackend]: Initialize secrets backend with auto-detected context. Detection strategy: - 1. SUPERVISOR_COMMS exists and is set → client chain (ExecutionAPISecretsBackend) + 1. supervisor-comms exists and is set → client chain (ExecutionAPISecretsBackend) 2. _AIRFLOW_PROCESS_CONTEXT=server env var → server chain (MetastoreBackend) 3. Neither → fallback chain (only env vars + external backends, no MetastoreBackend) - Client contexts: task runner in worker (has SUPERVISOR_COMMS) + Client contexts: task runner in worker (has supervisor-comms) Server contexts: API server, scheduler (set _AIRFLOW_PROCESS_CONTEXT=server) - Fallback contexts: supervisor, unknown contexts (no SUPERVISOR_COMMS, no env var) + Fallback contexts: supervisor, unknown contexts (no supervisor-comms, no env var) The fallback chain ensures supervisor can use external secrets (AWS Secrets Manager, Vault, etc.) while falling back to API client, without trying MetastoreBackend. @@ -1926,12 +1926,12 @@ def ensure_secrets_backend_loaded() -> list[BaseSecretsBackend]: from airflow.sdk.configuration import ensure_secrets_loaded from airflow.sdk.execution_time.secrets import DEFAULT_SECRETS_SEARCH_PATH_WORKERS - # 1. Check for client context (SUPERVISOR_COMMS) + # 1. Check for client context (supervisor-comms) try: from airflow.sdk.execution_time import task_runner - if hasattr(task_runner, "SUPERVISOR_COMMS") and task_runner.SUPERVISOR_COMMS is not None: - # Client context: task runner with SUPERVISOR_COMMS + if task_runner.is_supervisor_comms_initialized(): + # Client context: task runner with supervisor-comms return ensure_secrets_loaded(default_backends=DEFAULT_SECRETS_SEARCH_PATH_WORKERS) except (ImportError, AttributeError): pass diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index 6057706699889..8a909ab95a9ba 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -453,7 +453,7 @@ def get_first_reschedule_date(self, context: Context) -> AwareDatetime | None: log.debug("Requesting first reschedule date from supervisor") - response = SUPERVISOR_COMMS.send( + response = supervisor_comms().send( msg=GetTaskRescheduleStartDate(ti_id=self.id, try_number=first_try_number) ) @@ -477,7 +477,7 @@ def get_previous_dagrun(self, state: str | None = None) -> DagRun | None: if dag_run.logical_date is None: return None - response = SUPERVISOR_COMMS.send( + response = supervisor_comms().send( msg=GetPreviousDagRun(dag_id=self.dag_id, logical_date=dag_run.logical_date, state=state) ) @@ -511,7 +511,7 @@ def get_previous_ti( if effective_logical_date is None and dag_run and dag_run.logical_date: effective_logical_date = dag_run.logical_date - response = SUPERVISOR_COMMS.send( + response = supervisor_comms().send( msg=GetPreviousTI( dag_id=self.dag_id, task_id=self.task_id, @@ -537,7 +537,7 @@ def get_ti_count( states: list[str] | None = None, ) -> int: """Return the number of task instances matching the given criteria.""" - response = SUPERVISOR_COMMS.send( + response = supervisor_comms().send( GetTICount( dag_id=dag_id, map_index=map_index, @@ -564,7 +564,7 @@ def get_task_states( run_ids: list[str] | None = None, ) -> dict[str, Any]: """Return the task states matching the given criteria.""" - response = SUPERVISOR_COMMS.send( + response = supervisor_comms().send( GetTaskStates( dag_id=dag_id, map_index=map_index, @@ -583,7 +583,7 @@ def get_task_states( @staticmethod def get_task_breadcrumbs(dag_id: str, run_id: str) -> Iterable[dict[str, Any]]: """Return task breadcrumbs for the given dag run.""" - response = SUPERVISOR_COMMS.send(GetTaskBreadcrumbs(dag_id=dag_id, run_id=run_id)) + response = supervisor_comms().send(GetTaskBreadcrumbs(dag_id=dag_id, run_id=run_id)) if TYPE_CHECKING: assert isinstance(response, TaskBreadcrumbsResult) return response.breadcrumbs @@ -596,7 +596,7 @@ def get_dr_count( states: list[str] | None = None, ) -> int: """Return the number of Dag runs matching the given criteria.""" - response = SUPERVISOR_COMMS.send( + response = supervisor_comms().send( GetDRCount( dag_id=dag_id, logical_dates=logical_dates, @@ -613,7 +613,7 @@ def get_dr_count( @staticmethod def get_dagrun_state(dag_id: str, run_id: str) -> str: """Return the state of the Dag run with the given Run ID.""" - response = SUPERVISOR_COMMS.send(msg=GetDagRunState(dag_id=dag_id, run_id=run_id)) + response = supervisor_comms().send(msg=GetDagRunState(dag_id=dag_id, run_id=run_id)) if TYPE_CHECKING: assert isinstance(response, DagRunStateResult) @@ -732,7 +732,7 @@ def parse(what: StartupDetails, log: Logger) -> RuntimeTaskInstance: ) -# This global variable will be used by Connection/Variable/XCom classes, or other parts of the task's execution, +# This global class will be used by Connection/Variable/XCom classes, or other parts of the task's execution, # to send requests back to the supervisor process. # # Why it needs to be a global: @@ -740,9 +740,30 @@ def parse(what: StartupDetails, log: Logger) -> RuntimeTaskInstance: # to the parent process during task execution. # - These calls occur in various locations and cannot easily pass the `CommsDecoder` instance through the # deeply nested execution stack. -# - By defining `SUPERVISOR_COMMS` as a global, it ensures that this communication mechanism is readily +# - By defining this as a static class with accessors, it ensures that this communication mechanism is readily # accessible wherever needed during task execution without modifying every layer of the call stack. -SUPERVISOR_COMMS: CommsDecoder[ToTask, ToSupervisor] +# Not perfect but getter than a global variable. +class _SupervisorCommsHolder: + comms: CommsDecoder[ToTask, ToSupervisor] | None = None + + +def supervisor_comms() -> CommsDecoder[ToTask, ToSupervisor]: + """Get the global supervisor comms instance.""" + if _SupervisorCommsHolder.comms is None: + raise RuntimeError("Supervisor comms not initialized yet") + return _SupervisorCommsHolder.comms + + +def set_supervisor_comms(comms: CommsDecoder[ToTask, ToSupervisor]) -> None: + """Set the global supervisor comms instance.""" + if _SupervisorCommsHolder.comms is not None: + raise RuntimeError("Supervisor comms already initialized") + _SupervisorCommsHolder.comms = comms + + +def is_supervisor_comms_initialized() -> bool: + """Check if the global supervisor comms instance is initialized.""" + return _SupervisorCommsHolder.comms is not None # State machine! @@ -807,7 +828,7 @@ def startup() -> tuple[RuntimeTaskInstance, Context, Logger]: log.debug("Using serialized startup message from environment", msg=msg) else: # normal entry point - msg = SUPERVISOR_COMMS._get_response() # type: ignore[assignment] + msg = supervisor_comms()._get_response() # type: ignore[assignment] if not isinstance(msg, StartupDetails): raise RuntimeError(f"Unhandled startup message {type(msg)} {msg}") @@ -839,12 +860,12 @@ def startup() -> tuple[RuntimeTaskInstance, Context, Logger]: os.environ["_AIRFLOW__REEXECUTED_PROCESS"] = "1" # store startup message in environment for re-exec process os.environ["_AIRFLOW__STARTUP_MSG"] = msg.model_dump_json() - os.set_inheritable(SUPERVISOR_COMMS.socket.fileno(), True) + os.set_inheritable(supervisor_comms().socket.fileno(), True) # Import main directly from the module instead of re-executing the file. # This ensures that when other parts modules import # airflow.sdk.execution_time.task_runner, they get the same module instance - # with the properly initialized SUPERVISOR_COMMS global variable. + # with the properly initialized supervisor_comms instance. # If we re-executed the module with `python -m`, it would load as __main__ and future # imports would get a fresh copy without the initialized globals. rexec_python_code = "from airflow.sdk.execution_time.task_runner import main; main()" @@ -987,7 +1008,7 @@ def _prepare(ti: RuntimeTaskInstance, log: Logger, context: Context) -> ToSuperv if rendered_fields := _serialize_rendered_fields(ti.task): # so that we do not call the API unnecessarily - SUPERVISOR_COMMS.send(msg=SetRenderedFields(rendered_fields=rendered_fields)) + supervisor_comms().send(msg=SetRenderedFields(rendered_fields=rendered_fields)) # Try to render map_index_template early with available context (will be re-rendered after execution) # This provides a partial label during task execution for templates using pre-execution context @@ -996,7 +1017,7 @@ def _prepare(ti: RuntimeTaskInstance, log: Logger, context: Context) -> ToSuperv if rendered_map_index := _render_map_index(context, ti=ti, log=log): ti.rendered_map_index = rendered_map_index log.debug("Sending early rendered map index", length=len(rendered_map_index)) - SUPERVISOR_COMMS.send(msg=SetRenderedMapIndex(rendered_map_index=rendered_map_index)) + supervisor_comms().send(msg=SetRenderedMapIndex(rendered_map_index=rendered_map_index)) except Exception: log.debug( "Early rendering of map_index_template failed, will retry after task execution", exc_info=True @@ -1020,7 +1041,7 @@ def _validate_task_inlets_and_outlets(*, ti: RuntimeTaskInstance, log: Logger) - if not ti.task.inlets and not ti.task.outlets: return - inactive_assets_resp = SUPERVISOR_COMMS.send(msg=ValidateInletsAndOutlets(ti_id=ti.id)) + inactive_assets_resp = supervisor_comms().send(msg=ValidateInletsAndOutlets(ti_id=ti.id)) if TYPE_CHECKING: assert isinstance(inactive_assets_resp, InactiveAssetsResult) if inactive_assets := inactive_assets_resp.inactive_assets: @@ -1139,7 +1160,7 @@ def _on_term(signum, frame): ti.rendered_map_index = _render_map_index(context, ti=ti, log=log) # Send update only if value changed (e.g., user set context variables during execution) if ti.rendered_map_index and ti.rendered_map_index != previous_rendered_map_index: - SUPERVISOR_COMMS.send( + supervisor_comms().send( msg=SetRenderedMapIndex(rendered_map_index=ti.rendered_map_index) ) raise @@ -1148,7 +1169,7 @@ def _on_term(signum, frame): ti.rendered_map_index = _render_map_index(context, ti=ti, log=log) # Send update only if value changed (e.g., user set context variables during execution) if ti.rendered_map_index and ti.rendered_map_index != previous_rendered_map_index: - SUPERVISOR_COMMS.send(msg=SetRenderedMapIndex(rendered_map_index=ti.rendered_map_index)) + supervisor_comms().send(msg=SetRenderedMapIndex(rendered_map_index=ti.rendered_map_index)) _push_xcom_if_needed(result, ti, log) @@ -1156,7 +1177,7 @@ def _on_term(signum, frame): except DownstreamTasksSkipped as skip: log.info("Skipping downstream tasks.") tasks_to_skip = skip.tasks if isinstance(skip.tasks, list) else [skip.tasks] - SUPERVISOR_COMMS.send(msg=SkipDownstreamTasks(tasks=tasks_to_skip)) + supervisor_comms().send(msg=SkipDownstreamTasks(tasks=tasks_to_skip)) msg, state = _handle_current_task_success(context, ti) except DagRunTriggerException as drte: msg, state = _handle_trigger_dag_run(drte, context, ti, log) @@ -1218,7 +1239,7 @@ def _on_term(signum, frame): error = e finally: if msg: - SUPERVISOR_COMMS.send(msg=msg) + supervisor_comms().send(msg=msg) # Return the message to make unit tests easier too ti.state = state @@ -1282,7 +1303,7 @@ def _handle_trigger_dag_run( ) -> tuple[ToSupervisor, TaskInstanceState]: """Handle exception from TriggerDagRunOperator.""" log.info("Triggering Dag Run.", trigger_dag_id=drte.trigger_dag_id) - comms_msg = SUPERVISOR_COMMS.send( + comms_msg = supervisor_comms().send( TriggerDagRun( dag_id=drte.trigger_dag_id, run_id=drte.dag_run_id, @@ -1349,7 +1370,7 @@ def _handle_trigger_dag_run( ) time.sleep(drte.poke_interval) - comms_msg = SUPERVISOR_COMMS.send( + comms_msg = supervisor_comms().send( GetDagRunState(dag_id=drte.trigger_dag_id, run_id=drte.dag_run_id) ) if TYPE_CHECKING: @@ -1635,7 +1656,7 @@ def finalize( if getattr(ti.task, "overwrite_rtif_after_execution", False): log.debug("Overwriting Rendered template fields.") if ti.task.template_fields: - SUPERVISOR_COMMS.send(SetRenderedFields(rendered_fields=_serialize_rendered_fields(ti.task))) + supervisor_comms().send(SetRenderedFields(rendered_fields=_serialize_rendered_fields(ti.task))) log.debug("Running finalizers", ti=ti) if state == TaskInstanceState.SUCCESS: @@ -1684,8 +1705,7 @@ def finalize( def main(): log = structlog.get_logger(logger_name="task") - global SUPERVISOR_COMMS - SUPERVISOR_COMMS = CommsDecoder[ToTask, ToSupervisor](log=log) + set_supervisor_comms(CommsDecoder[ToTask, ToSupervisor](log=log)) Stats.initialize( is_statsd_datadog_enabled=conf.getboolean("metrics", "statsd_datadog_enabled"), @@ -1711,9 +1731,9 @@ def main(): finally: # Ensure the request socket is closed on the child side in all circumstances # before the process fully terminates. - if SUPERVISOR_COMMS and SUPERVISOR_COMMS.socket: + if is_supervisor_comms_initialized() and supervisor_comms().socket: with suppress(Exception): - SUPERVISOR_COMMS.socket.close() + supervisor_comms().socket.close() def reinit_supervisor_comms() -> None: @@ -1726,15 +1746,14 @@ def reinit_supervisor_comms() -> None: """ import socket - if "SUPERVISOR_COMMS" not in globals(): - global SUPERVISOR_COMMS + if not is_supervisor_comms_initialized(): log = structlog.get_logger(logger_name="task") fd = int(os.environ.get("__AIRFLOW_SUPERVISOR_FD", "0")) - SUPERVISOR_COMMS = CommsDecoder[ToTask, ToSupervisor](log=log, socket=socket.socket(fileno=fd)) + set_supervisor_comms(CommsDecoder[ToTask, ToSupervisor](log=log, socket=socket.socket(fileno=fd))) - logs = SUPERVISOR_COMMS.send(ResendLoggingFD()) + logs = supervisor_comms().send(ResendLoggingFD()) if isinstance(logs, SentFDs): from airflow.sdk.log import configure_logging diff --git a/task-sdk/src/airflow/sdk/log.py b/task-sdk/src/airflow/sdk/log.py index 9e0835a1efb44..13ba908ebe627 100644 --- a/task-sdk/src/airflow/sdk/log.py +++ b/task-sdk/src/airflow/sdk/log.py @@ -258,13 +258,12 @@ def mask_secret(secret: JsonValue, name: str | None = None) -> None: _secrets_masker().add_mask(secret, name) - with suppress(Exception): + with suppress(Exception, RuntimeError): # Try to tell supervisor (only if in task execution context) from airflow.sdk.execution_time import task_runner from airflow.sdk.execution_time.comms import MaskSecret - if comms := getattr(task_runner, "SUPERVISOR_COMMS", None): - comms.send(MaskSecret(value=secret, name=name)) + task_runner.supervisor_comms().send(MaskSecret(value=secret, name=name)) def reset_logging(): diff --git a/task-sdk/tests/task_sdk/execution_time/test_context.py b/task-sdk/tests/task_sdk/execution_time/test_context.py index a512062ff05cc..1cd7790c10569 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_context.py +++ b/task-sdk/tests/task_sdk/execution_time/test_context.py @@ -891,7 +891,7 @@ def get_connection(self, conn_id: str) -> Connection | None: result = await _async_get_connection("test_conn") assert result == sample_connection - # Should not have tried SUPERVISOR_COMMS since secrets backend had the connection + # Should not have tried supervisor-comms since secrets backend had the connection mock_supervisor_comms.send.assert_not_called() mock_supervisor_comms.asend.assert_not_called() diff --git a/task-sdk/tests/task_sdk/execution_time/test_secrets.py b/task-sdk/tests/task_sdk/execution_time/test_secrets.py index 8f9745b0ffeca..3a8ef10cd0e10 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_secrets.py +++ b/task-sdk/tests/task_sdk/execution_time/test_secrets.py @@ -27,7 +27,7 @@ class TestExecutionAPISecretsBackend: """Test ExecutionAPISecretsBackend.""" def test_get_connection_via_supervisor_comms(self, mock_supervisor_comms): - """Test that connection is retrieved via SUPERVISOR_COMMS.""" + """Test that connection is retrieved via supervisor-comms.""" from airflow.sdk.api.datamodels._generated import ConnectionResponse from airflow.sdk.execution_time.comms import ConnectionResult @@ -67,7 +67,7 @@ def test_get_connection_not_found(self, mock_supervisor_comms): mock_supervisor_comms.send.assert_called_once() def test_get_variable_via_supervisor_comms(self, mock_supervisor_comms): - """Test that variable is retrieved via SUPERVISOR_COMMS.""" + """Test that variable is retrieved via supervisor-comms.""" from airflow.sdk.execution_time.comms import VariableResult # Mock variable response @@ -125,7 +125,7 @@ def test_runtime_error_triggers_greenback_fallback(self, mocker, mock_supervisor """ Test that RuntimeError from async_to_sync triggers greenback fallback. - This test verifies the fix for issue #57145: when SUPERVISOR_COMMS.send() + This test verifies the fix for issue #57145: when supervisor_comms().send() raises the specific RuntimeError about async_to_sync in an event loop, the backend catches it and uses greenback to call aget_connection(). """ @@ -161,7 +161,7 @@ def greenback_await_side_effect(coro): # Mock aget_connection to return the expected connection directly. # We need to mock this because the real aget_connection would try to - # use SUPERVISOR_COMMS.asend which is not set up for this test. + # use supervisor_comms().asend which is not set up for this test. async def mock_aget_connection(self, conn_id): return expected_conn @@ -183,7 +183,7 @@ class TestContextDetection: """Test context detection in ensure_secrets_backend_loaded.""" def test_client_context_with_supervisor_comms(self, mock_supervisor_comms): - """Client context: SUPERVISOR_COMMS set → uses worker chain.""" + """Client context: supervisor_comms() set → uses worker chain.""" from airflow.sdk.execution_time.supervisor import ensure_secrets_backend_loaded backends = ensure_secrets_backend_loaded() @@ -198,7 +198,7 @@ def test_server_context_with_env_var(self, monkeypatch): from airflow.sdk.execution_time.supervisor import ensure_secrets_backend_loaded monkeypatch.setenv("_AIRFLOW_PROCESS_CONTEXT", "server") - # Ensure SUPERVISOR_COMMS is not available + # Ensure supervisor-comms is not available if "airflow.sdk.execution_time.task_runner" in sys.modules: monkeypatch.delitem(sys.modules, "airflow.sdk.execution_time.task_runner") @@ -208,12 +208,12 @@ def test_server_context_with_env_var(self, monkeypatch): assert "ExecutionAPISecretsBackend" not in backend_classes def test_fallback_context_no_markers(self, monkeypatch): - """Fallback context: no SUPERVISOR_COMMS, no env var → only env vars + external.""" + """Fallback context: no supervisor-comms, no env var → only env vars + external.""" import sys from airflow.sdk.execution_time.supervisor import ensure_secrets_backend_loaded - # Ensure no SUPERVISOR_COMMS + # Ensure no supervisor-comms if "airflow.sdk.execution_time.task_runner" in sys.modules: monkeypatch.delitem(sys.modules, "airflow.sdk.execution_time.task_runner") diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py index b45f57cd4256e..438de3c22591e 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py +++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py @@ -2639,37 +2639,37 @@ class DummyComms: @pytest.fixture(autouse=True) def cleanup_supervisor_comms(self): # Ensure clean state before/after test - if hasattr(task_runner, "SUPERVISOR_COMMS"): - delattr(task_runner, "SUPERVISOR_COMMS") + if task_runner._SupervisorCommsHolder.comms: + task_runner._SupervisorCommsHolder.comms = None yield - if hasattr(task_runner, "SUPERVISOR_COMMS"): - delattr(task_runner, "SUPERVISOR_COMMS") + if task_runner._SupervisorCommsHolder.comms: + task_runner._SupervisorCommsHolder.comms = None def test_set_supervisor_comms_overrides_and_restores(self): - task_runner.SUPERVISOR_COMMS = self.DummyComms() - original = task_runner.SUPERVISOR_COMMS + task_runner._SupervisorCommsHolder.comms = self.DummyComms() + original = task_runner._SupervisorCommsHolder.comms replacement = self.DummyComms() with set_supervisor_comms(replacement): - assert task_runner.SUPERVISOR_COMMS is replacement - assert task_runner.SUPERVISOR_COMMS is original + assert task_runner._SupervisorCommsHolder.comms is replacement + assert task_runner._SupervisorCommsHolder.comms is original def test_set_supervisor_comms_sets_temporarily_when_not_set(self): - assert not hasattr(task_runner, "SUPERVISOR_COMMS") + assert task_runner._SupervisorCommsHolder.comms is None replacement = self.DummyComms() with set_supervisor_comms(replacement): - assert task_runner.SUPERVISOR_COMMS is replacement - assert not hasattr(task_runner, "SUPERVISOR_COMMS") + assert task_runner._SupervisorCommsHolder.comms is replacement + assert task_runner._SupervisorCommsHolder.comms is None def test_set_supervisor_comms_unsets_temporarily_when_not_set(self): - assert not hasattr(task_runner, "SUPERVISOR_COMMS") + assert task_runner._SupervisorCommsHolder.comms is None # This will delete an attribute that isn't set, and restore it likewise with set_supervisor_comms(None): - assert not hasattr(task_runner, "SUPERVISOR_COMMS") + assert task_runner._SupervisorCommsHolder.comms is None - assert not hasattr(task_runner, "SUPERVISOR_COMMS") + assert task_runner._SupervisorCommsHolder.comms is None class TestInProcessTestSupervisor: From 665d640a0d358aa39ccea32e8558884f570eca98 Mon Sep 17 00:00:00 2001 From: Jens Scheffler Date: Mon, 29 Dec 2025 09:43:51 +0100 Subject: [PATCH 02/10] Fix pytests --- airflow-core/tests/unit/models/test_connection.py | 1 - devel-common/src/tests_common/pytest_plugin.py | 7 +++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/airflow-core/tests/unit/models/test_connection.py b/airflow-core/tests/unit/models/test_connection.py index 54054c3832429..b493144285e63 100644 --- a/airflow-core/tests/unit/models/test_connection.py +++ b/airflow-core/tests/unit/models/test_connection.py @@ -383,7 +383,6 @@ def test_get_connection_from_secrets_task_sdk_not_found(self, mock_task_sdk_conn with pytest.raises(AirflowNotFoundException): Connection.get_connection_from_secrets("test_conn") - @mock.patch.dict(sys.modules, {"airflow.sdk.execution_time.task_runner": None}) @mock.patch("airflow.sdk.Connection") @mock.patch("airflow.secrets.environment_variables.EnvironmentVariablesBackend.get_connection") @mock.patch("airflow.secrets.metastore.MetastoreBackend.get_connection") diff --git a/devel-common/src/tests_common/pytest_plugin.py b/devel-common/src/tests_common/pytest_plugin.py index 6e9096dc9590a..fea49e45d6d4e 100644 --- a/devel-common/src/tests_common/pytest_plugin.py +++ b/devel-common/src/tests_common/pytest_plugin.py @@ -2227,7 +2227,7 @@ def override_caplog(request): @pytest.fixture def mock_supervisor_comms(monkeypatch): # for back-compat - from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS + from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_2_PLUS if not AIRFLOW_V_3_0_PLUS: yield None @@ -2239,12 +2239,15 @@ def mock_supervisor_comms(monkeypatch): # core and TaskSDK is finished if CommsDecoder := getattr(comms, "CommsDecoder", None): comms = mock.create_autospec(CommsDecoder) - monkeypatch.setattr(task_runner._SupervisorCommsHolder, "comms", comms, raising=False) else: CommsDecoder = getattr(task_runner, "CommsDecoder") comms = mock.create_autospec(CommsDecoder) comms.send = comms.get_message + + if AIRFLOW_V_3_2_PLUS: monkeypatch.setattr(task_runner._SupervisorCommsHolder, "comms", comms, raising=False) + else: + monkeypatch.setattr(task_runner, "SUPERVISOR_COMMS", comms, raising=False) yield comms From ea79dbd109c7daecb11d2b440d79e0e37c90dd39 Mon Sep 17 00:00:00 2001 From: Jens Scheffler Date: Wed, 31 Dec 2025 16:27:29 +0100 Subject: [PATCH 03/10] Rework supervisor_comms().send() to supervisor_send() --- task-sdk/src/airflow/sdk/bases/xcom.py | 88 +++++-------------- .../sdk/definitions/asset/decorators.py | 4 +- .../src/airflow/sdk/execution_time/comms.py | 4 +- .../src/airflow/sdk/execution_time/context.py | 18 ++-- .../src/airflow/sdk/execution_time/hitl.py | 12 +-- .../sdk/execution_time/lazy_sequence.py | 10 +-- .../execution_time/secrets/execution_api.py | 16 ++-- .../airflow/sdk/execution_time/task_runner.py | 68 +++++++------- task-sdk/src/airflow/sdk/log.py | 2 +- .../task_sdk/execution_time/test_secrets.py | 6 +- .../execution_time/test_task_runner.py | 4 +- 11 files changed, 96 insertions(+), 136 deletions(-) diff --git a/task-sdk/src/airflow/sdk/bases/xcom.py b/task-sdk/src/airflow/sdk/bases/xcom.py index 6567b12198a97..7b76ef85b59b3 100644 --- a/task-sdk/src/airflow/sdk/bases/xcom.py +++ b/task-sdk/src/airflow/sdk/bases/xcom.py @@ -72,18 +72,13 @@ def set( :param map_index: Optional map index to assign XCom for a mapped task. The default is ``-1`` (set for a non-mapped task). """ - from airflow.sdk.execution_time.task_runner import supervisor_comms + from airflow.sdk.execution_time.task_runner import supervisor_send value = cls.serialize_value( - value=value, - key=key, - task_id=task_id, - dag_id=dag_id, - run_id=run_id, - map_index=map_index, + value=value, key=key, task_id=task_id, dag_id=dag_id, run_id=run_id, map_index=map_index ) - supervisor_comms().send( + supervisor_send( SetXCom( key=key, value=value, @@ -97,14 +92,7 @@ def set( @classmethod def _set_xcom_in_db( - cls, - key: str, - value: Any, - *, - dag_id: str, - task_id: str, - run_id: str, - map_index: int = -1, + cls, key: str, value: Any, *, dag_id: str, task_id: str, run_id: str, map_index: int = -1 ) -> None: """ Store an XCom value directly in the metadata database. @@ -117,17 +105,10 @@ def _set_xcom_in_db( :param map_index: Optional map index to assign XCom for a mapped task. The default is ``-1`` (set for a non-mapped task). """ - from airflow.sdk.execution_time.task_runner import supervisor_comms + from airflow.sdk.execution_time.task_runner import supervisor_send - supervisor_comms().send( - SetXCom( - key=key, - value=value, - dag_id=dag_id, - task_id=task_id, - run_id=run_id, - map_index=map_index, - ), + supervisor_send( + SetXCom(key=key, value=value, dag_id=dag_id, task_id=task_id, run_id=run_id, map_index=map_index), ) @classmethod @@ -160,13 +141,7 @@ def get_value( @classmethod def _get_xcom_db_ref( - cls, - *, - key: str, - dag_id: str, - task_id: str, - run_id: str, - map_index: int | None = None, + cls, *, key: str, dag_id: str, task_id: str, run_id: str, map_index: int | None = None ) -> XComResult: """ Retrieve an XCom value, optionally meeting certain criteria. @@ -190,16 +165,10 @@ def _get_xcom_db_ref( :param key: A key for the XCom. If provided, only XCom with matching keys will be returned. Pass *None* (default) to remove the filter. """ - from airflow.sdk.execution_time.task_runner import supervisor_comms + from airflow.sdk.execution_time.task_runner import supervisor_send - msg = supervisor_comms().send( - GetXCom( - key=key, - dag_id=dag_id, - task_id=task_id, - run_id=run_id, - map_index=map_index, - ), + msg = supervisor_send( + GetXCom(key=key, dag_id=dag_id, task_id=task_id, run_id=run_id, map_index=map_index), ) if not isinstance(msg, XComResult): @@ -243,9 +212,9 @@ def get_one( specified Dag run is returned. If *True*, the latest matching XCom is returned regardless of the run it belongs to. """ - from airflow.sdk.execution_time.task_runner import supervisor_comms + from airflow.sdk.execution_time.task_runner import supervisor_send - msg = supervisor_comms().send( + msg = supervisor_send( GetXCom( key=key, dag_id=dag_id, @@ -299,9 +268,9 @@ def get_all( returned regardless of the run they belong to. :return: List of all XCom values if found. """ - from airflow.sdk.execution_time.task_runner import supervisor_comms + from airflow.sdk.execution_time.task_runner import supervisor_send - msg = supervisor_comms().send( + msg = supervisor_send( msg=GetXComSequenceSlice( key=key, dag_id=dag_id, @@ -351,31 +320,14 @@ def purge(cls, xcom: XComResult, *args) -> None: pass @classmethod - def delete( - cls, - key: str, - task_id: str, - dag_id: str, - run_id: str, - map_index: int | None = None, - ) -> None: + def delete(cls, key: str, task_id: str, dag_id: str, run_id: str, map_index: int | None = None) -> None: """Delete an Xcom entry, for custom xcom backends, it gets the path associated with the data on the backend and purges it.""" - from airflow.sdk.execution_time.task_runner import supervisor_comms + from airflow.sdk.execution_time.task_runner import supervisor_send xcom_result = cls._get_xcom_db_ref( - key=key, - dag_id=dag_id, - task_id=task_id, - run_id=run_id, - map_index=map_index, + key=key, dag_id=dag_id, task_id=task_id, run_id=run_id, map_index=map_index ) cls.purge(xcom_result) - supervisor_comms().send( - DeleteXCom( - key=key, - dag_id=dag_id, - task_id=task_id, - run_id=run_id, - map_index=map_index, - ), + supervisor_send( + DeleteXCom(key=key, dag_id=dag_id, task_id=task_id, run_id=run_id, map_index=map_index) ) diff --git a/task-sdk/src/airflow/sdk/definitions/asset/decorators.py b/task-sdk/src/airflow/sdk/definitions/asset/decorators.py index 172784d10a9b3..8facc2df99910 100644 --- a/task-sdk/src/airflow/sdk/definitions/asset/decorators.py +++ b/task-sdk/src/airflow/sdk/definitions/asset/decorators.py @@ -71,10 +71,10 @@ def from_definition(cls, definition: AssetDefinition | MultiAssetDefinition) -> def _iter_kwargs(self, context: Mapping[str, Any]) -> Iterator[tuple[str, Any]]: from airflow.sdk.execution_time.comms import ErrorResponse, GetAssetByName - from airflow.sdk.execution_time.task_runner import supervisor_comms + from airflow.sdk.execution_time.task_runner import supervisor_send def _fetch_asset(name: str) -> Asset: - resp = supervisor_comms().send(GetAssetByName(name=name)) + resp = supervisor_send(GetAssetByName(name=name)) if resp is None: raise RuntimeError("Empty non-error response received") if isinstance(resp, ErrorResponse): diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py b/task-sdk/src/airflow/sdk/execution_time/comms.py index 5904ff317b2c9..7043127e92203 100644 --- a/task-sdk/src/airflow/sdk/execution_time/comms.py +++ b/task-sdk/src/airflow/sdk/execution_time/comms.py @@ -368,9 +368,9 @@ def xcom_pull(self, *, key: str = "return_value", default: Any = None) -> Any: def _fetch_dag_run(*, dag_id: str, run_id: str) -> DagRun: - from airflow.sdk.execution_time.task_runner import supervisor_comms + from airflow.sdk.execution_time.task_runner import supervisor_send - response = supervisor_comms().send(GetDagRun(dag_id=dag_id, run_id=run_id)) + response = supervisor_send(GetDagRun(dag_id=dag_id, run_id=run_id)) if TYPE_CHECKING: assert isinstance(response, DagRunResult) return response diff --git a/task-sdk/src/airflow/sdk/execution_time/context.py b/task-sdk/src/airflow/sdk/execution_time/context.py index a27f7d3f81164..63b9947f43218 100644 --- a/task-sdk/src/airflow/sdk/execution_time/context.py +++ b/task-sdk/src/airflow/sdk/execution_time/context.py @@ -286,7 +286,7 @@ def _set_variable(key: str, value: Any, description: str | None = None, serializ from airflow.sdk.execution_time.comms import PutVariable from airflow.sdk.execution_time.secrets.execution_api import ExecutionAPISecretsBackend from airflow.sdk.execution_time.supervisor import ensure_secrets_backend_loaded - from airflow.sdk.execution_time.task_runner import supervisor_comms + from airflow.sdk.execution_time.task_runner import supervisor_send # check for write conflicts on the worker for secrets_backend in ensure_secrets_backend_loaded(): @@ -317,7 +317,7 @@ def _set_variable(key: str, value: Any, description: str | None = None, serializ except Exception as e: log.exception(e) - supervisor_comms().send(PutVariable(key=key, value=value, description=description)) + supervisor_send(PutVariable(key=key, value=value, description=description)) # Invalidate cache after setting the variable SecretCache.invalidate_variable(key) @@ -331,9 +331,9 @@ def _delete_variable(key: str) -> None: # keep Task SDK as a separate package than execution time mods. from airflow.sdk.execution_time.cache import SecretCache from airflow.sdk.execution_time.comms import DeleteVariable - from airflow.sdk.execution_time.task_runner import supervisor_comms + from airflow.sdk.execution_time.task_runner import supervisor_send - msg = supervisor_comms().send(DeleteVariable(key=key)) + msg = supervisor_send(DeleteVariable(key=key)) if TYPE_CHECKING: assert isinstance(msg, OKResponse) @@ -458,7 +458,7 @@ def _get_asset_from_db(name: str | None = None, uri: str | None = None) -> Asset GetAssetByUri, ToSupervisor, ) - from airflow.sdk.execution_time.task_runner import supervisor_comms + from airflow.sdk.execution_time.task_runner import supervisor_send msg: ToSupervisor if name: @@ -468,7 +468,7 @@ def _get_asset_from_db(name: str | None = None, uri: str | None = None) -> Asset else: raise ValueError("Either name or uri must be provided") - resp = supervisor_comms().send(msg) + resp = supervisor_send(msg) if isinstance(resp, ErrorResponse): raise AirflowRuntimeError(resp) @@ -619,7 +619,7 @@ def _asset_events(self) -> list[AssetEventResult]: GetAssetEventByAssetAlias, ToSupervisor, ) - from airflow.sdk.execution_time.task_runner import supervisor_comms + from airflow.sdk.execution_time.task_runner import supervisor_send query_dict: dict[str, Any] = { "after": self._after, @@ -635,7 +635,7 @@ def _asset_events(self) -> list[AssetEventResult]: if self._asset_name is None and self._asset_uri is None: raise ValueError("Either asset_name or asset_uri must be provided") msg = GetAssetEventByAsset(name=self._asset_name, uri=self._asset_uri, **query_dict) - resp = supervisor_comms().send(msg) + resp = supervisor_send(msg) if isinstance(resp, ErrorResponse): raise AirflowRuntimeError(resp) @@ -788,7 +788,7 @@ def get_previous_dagrun_success(ti_id: UUID) -> PrevSuccessfulDagRunResponse: PrevSuccessfulDagRunResult, ) - msg = task_runner.supervisor_comms().send(GetPrevSuccessfulDagRun(ti_id=ti_id)) + msg = task_runner.supervisor_send(GetPrevSuccessfulDagRun(ti_id=ti_id)) if TYPE_CHECKING: assert isinstance(msg, PrevSuccessfulDagRunResult) diff --git a/task-sdk/src/airflow/sdk/execution_time/hitl.py b/task-sdk/src/airflow/sdk/execution_time/hitl.py index 6240a518b6654..ce0cc53030fb7 100644 --- a/task-sdk/src/airflow/sdk/execution_time/hitl.py +++ b/task-sdk/src/airflow/sdk/execution_time/hitl.py @@ -46,9 +46,9 @@ def upsert_hitl_detail( params: dict[str, Any] | None = None, assigned_users: list[HITLUser] | None = None, ) -> None: - from airflow.sdk.execution_time.task_runner import supervisor_comms + from airflow.sdk.execution_time.task_runner import supervisor_send - supervisor_comms().send( + supervisor_send( msg=CreateHITLDetailPayload( ti_id=ti_id, options=options, @@ -71,9 +71,9 @@ def update_hitl_detail_response( chosen_options: list[str], params_input: dict[str, Any], ) -> HITLDetailResponse: - from airflow.sdk.execution_time.task_runner import supervisor_comms + from airflow.sdk.execution_time.task_runner import supervisor_send - response = supervisor_comms().send( + response = supervisor_send( msg=UpdateHITLDetail( ti_id=ti_id, chosen_options=chosen_options, @@ -86,9 +86,9 @@ def update_hitl_detail_response( def get_hitl_detail_content_detail(ti_id: UUID) -> HITLDetailResponse: - from airflow.sdk.execution_time.task_runner import supervisor_comms + from airflow.sdk.execution_time.task_runner import supervisor_send - response = supervisor_comms().send(msg=GetHITLDetailResponse(ti_id=ti_id)) + response = supervisor_send(msg=GetHITLDetailResponse(ti_id=ti_id)) if TYPE_CHECKING: assert isinstance(response, HITLDetailResponse) diff --git a/task-sdk/src/airflow/sdk/execution_time/lazy_sequence.py b/task-sdk/src/airflow/sdk/execution_time/lazy_sequence.py index 689b95d4bdc5a..5639fa596d5e7 100644 --- a/task-sdk/src/airflow/sdk/execution_time/lazy_sequence.py +++ b/task-sdk/src/airflow/sdk/execution_time/lazy_sequence.py @@ -90,11 +90,11 @@ def __iter__(self) -> Iterator[T]: def __len__(self) -> int: if self._len is None: from airflow.sdk.execution_time.comms import ErrorResponse, GetXComCount, XComCountResponse - from airflow.sdk.execution_time.task_runner import supervisor_comms + from airflow.sdk.execution_time.task_runner import supervisor_send task = self._xcom_arg.operator - msg = supervisor_comms().send( + msg = supervisor_send( GetXComCount( key=self._xcom_arg.key, dag_id=task.dag_id, @@ -123,13 +123,13 @@ def __getitem__(self, key: int | slice) -> T | Sequence[T]: XComSequenceIndexResult, XComSequenceSliceResult, ) - from airflow.sdk.execution_time.task_runner import supervisor_comms + from airflow.sdk.execution_time.task_runner import supervisor_send from airflow.sdk.execution_time.xcom import XCom if isinstance(key, slice): start, stop, step = _coerce_slice(key) source = (xcom_arg := self._xcom_arg).operator - msg = supervisor_comms().send( + msg = supervisor_send( GetXComSequenceSlice( key=xcom_arg.key, dag_id=source.dag_id, @@ -150,7 +150,7 @@ def __getitem__(self, key: int | slice) -> T | Sequence[T]: raise TypeError(f"Sequence indices must be integers or slices not {type(key).__name__}") source = (xcom_arg := self._xcom_arg).operator - msg = supervisor_comms().send( + msg = supervisor_send( GetXComSequenceItem( key=xcom_arg.key, dag_id=source.dag_id, diff --git a/task-sdk/src/airflow/sdk/execution_time/secrets/execution_api.py b/task-sdk/src/airflow/sdk/execution_time/secrets/execution_api.py index 1de2d2d3393fd..0140671083196 100644 --- a/task-sdk/src/airflow/sdk/execution_time/secrets/execution_api.py +++ b/task-sdk/src/airflow/sdk/execution_time/secrets/execution_api.py @@ -54,10 +54,10 @@ def get_connection(self, conn_id: str, team_name: str | None = None) -> Connecti """ from airflow.sdk.execution_time.comms import ErrorResponse, GetConnection from airflow.sdk.execution_time.context import _process_connection_result_conn - from airflow.sdk.execution_time.task_runner import supervisor_comms + from airflow.sdk.execution_time.task_runner import supervisor_send try: - msg = supervisor_comms().send(GetConnection(conn_id=conn_id)) + msg = supervisor_send(GetConnection(conn_id=conn_id)) if isinstance(msg, ErrorResponse): # Connection not found or error occurred @@ -100,10 +100,10 @@ def get_variable(self, key: str, team_name: str | None = None) -> str | None: :return: Variable value or None if not found """ from airflow.sdk.execution_time.comms import ErrorResponse, GetVariable, VariableResult - from airflow.sdk.execution_time.task_runner import supervisor_comms + from airflow.sdk.execution_time.task_runner import supervisor_send try: - msg = supervisor_comms().send(GetVariable(key=key)) + msg = supervisor_send(GetVariable(key=key)) if isinstance(msg, ErrorResponse): # Variable not found or error occurred @@ -127,10 +127,10 @@ async def aget_connection(self, conn_id: str) -> Connection | None: # type: ign """ from airflow.sdk.execution_time.comms import ErrorResponse, GetConnection from airflow.sdk.execution_time.context import _process_connection_result_conn - from airflow.sdk.execution_time.task_runner import supervisor_comms + from airflow.sdk.execution_time.task_runner import supervisor_asend try: - msg = await supervisor_comms().asend(GetConnection(conn_id=conn_id)) + msg = await supervisor_asend(GetConnection(conn_id=conn_id)) if isinstance(msg, ErrorResponse): # Connection not found or error occurred @@ -151,10 +151,10 @@ async def aget_variable(self, key: str) -> str | None: :return: Variable value or None if not found """ from airflow.sdk.execution_time.comms import ErrorResponse, GetVariable, VariableResult - from airflow.sdk.execution_time.task_runner import supervisor_comms + from airflow.sdk.execution_time.task_runner import supervisor_asend try: - msg = await supervisor_comms().asend(GetVariable(key=key)) + msg = await supervisor_asend(GetVariable(key=key)) if isinstance(msg, ErrorResponse): # Variable not found or error occurred diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index 8a909ab95a9ba..a4c8458810119 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -453,9 +453,7 @@ def get_first_reschedule_date(self, context: Context) -> AwareDatetime | None: log.debug("Requesting first reschedule date from supervisor") - response = supervisor_comms().send( - msg=GetTaskRescheduleStartDate(ti_id=self.id, try_number=first_try_number) - ) + response = supervisor_send(msg=GetTaskRescheduleStartDate(ti_id=self.id, try_number=first_try_number)) if TYPE_CHECKING: assert isinstance(response, TaskRescheduleStartDate) @@ -477,7 +475,7 @@ def get_previous_dagrun(self, state: str | None = None) -> DagRun | None: if dag_run.logical_date is None: return None - response = supervisor_comms().send( + response = supervisor_send( msg=GetPreviousDagRun(dag_id=self.dag_id, logical_date=dag_run.logical_date, state=state) ) @@ -511,7 +509,7 @@ def get_previous_ti( if effective_logical_date is None and dag_run and dag_run.logical_date: effective_logical_date = dag_run.logical_date - response = supervisor_comms().send( + response = supervisor_send( msg=GetPreviousTI( dag_id=self.dag_id, task_id=self.task_id, @@ -537,7 +535,7 @@ def get_ti_count( states: list[str] | None = None, ) -> int: """Return the number of task instances matching the given criteria.""" - response = supervisor_comms().send( + response = supervisor_send( GetTICount( dag_id=dag_id, map_index=map_index, @@ -564,7 +562,7 @@ def get_task_states( run_ids: list[str] | None = None, ) -> dict[str, Any]: """Return the task states matching the given criteria.""" - response = supervisor_comms().send( + response = supervisor_send( GetTaskStates( dag_id=dag_id, map_index=map_index, @@ -583,7 +581,7 @@ def get_task_states( @staticmethod def get_task_breadcrumbs(dag_id: str, run_id: str) -> Iterable[dict[str, Any]]: """Return task breadcrumbs for the given dag run.""" - response = supervisor_comms().send(GetTaskBreadcrumbs(dag_id=dag_id, run_id=run_id)) + response = supervisor_send(GetTaskBreadcrumbs(dag_id=dag_id, run_id=run_id)) if TYPE_CHECKING: assert isinstance(response, TaskBreadcrumbsResult) return response.breadcrumbs @@ -596,7 +594,7 @@ def get_dr_count( states: list[str] | None = None, ) -> int: """Return the number of Dag runs matching the given criteria.""" - response = supervisor_comms().send( + response = supervisor_send( GetDRCount( dag_id=dag_id, logical_dates=logical_dates, @@ -613,7 +611,7 @@ def get_dr_count( @staticmethod def get_dagrun_state(dag_id: str, run_id: str) -> str: """Return the state of the Dag run with the given Run ID.""" - response = supervisor_comms().send(msg=GetDagRunState(dag_id=dag_id, run_id=run_id)) + response = supervisor_send(msg=GetDagRunState(dag_id=dag_id, run_id=run_id)) if TYPE_CHECKING: assert isinstance(response, DagRunStateResult) @@ -747,7 +745,21 @@ class _SupervisorCommsHolder: comms: CommsDecoder[ToTask, ToSupervisor] | None = None -def supervisor_comms() -> CommsDecoder[ToTask, ToSupervisor]: +def supervisor_send(msg: ToSupervisor) -> ToTask | None: + """Send a message to the supervisor as convenience for get_supervisor_comms().send().""" + if _SupervisorCommsHolder.comms is None: + raise RuntimeError("Supervisor comms not initialized yet") + return _SupervisorCommsHolder.comms.send(msg) + + +async def supervisor_asend(msg: ToSupervisor) -> ToTask | None: + """Send a message to the supervisor as convenience for get_supervisor_comms().asend().""" + if _SupervisorCommsHolder.comms is None: + raise RuntimeError("Supervisor comms not initialized yet") + return await _SupervisorCommsHolder.comms.asend(msg) + + +def get_supervisor_comms() -> CommsDecoder[ToTask, ToSupervisor]: """Get the global supervisor comms instance.""" if _SupervisorCommsHolder.comms is None: raise RuntimeError("Supervisor comms not initialized yet") @@ -828,7 +840,7 @@ def startup() -> tuple[RuntimeTaskInstance, Context, Logger]: log.debug("Using serialized startup message from environment", msg=msg) else: # normal entry point - msg = supervisor_comms()._get_response() # type: ignore[assignment] + msg = get_supervisor_comms()._get_response() # type: ignore[assignment] if not isinstance(msg, StartupDetails): raise RuntimeError(f"Unhandled startup message {type(msg)} {msg}") @@ -860,7 +872,7 @@ def startup() -> tuple[RuntimeTaskInstance, Context, Logger]: os.environ["_AIRFLOW__REEXECUTED_PROCESS"] = "1" # store startup message in environment for re-exec process os.environ["_AIRFLOW__STARTUP_MSG"] = msg.model_dump_json() - os.set_inheritable(supervisor_comms().socket.fileno(), True) + os.set_inheritable(get_supervisor_comms().socket.fileno(), True) # Import main directly from the module instead of re-executing the file. # This ensures that when other parts modules import @@ -1008,7 +1020,7 @@ def _prepare(ti: RuntimeTaskInstance, log: Logger, context: Context) -> ToSuperv if rendered_fields := _serialize_rendered_fields(ti.task): # so that we do not call the API unnecessarily - supervisor_comms().send(msg=SetRenderedFields(rendered_fields=rendered_fields)) + supervisor_send(msg=SetRenderedFields(rendered_fields=rendered_fields)) # Try to render map_index_template early with available context (will be re-rendered after execution) # This provides a partial label during task execution for templates using pre-execution context @@ -1017,7 +1029,7 @@ def _prepare(ti: RuntimeTaskInstance, log: Logger, context: Context) -> ToSuperv if rendered_map_index := _render_map_index(context, ti=ti, log=log): ti.rendered_map_index = rendered_map_index log.debug("Sending early rendered map index", length=len(rendered_map_index)) - supervisor_comms().send(msg=SetRenderedMapIndex(rendered_map_index=rendered_map_index)) + supervisor_send(msg=SetRenderedMapIndex(rendered_map_index=rendered_map_index)) except Exception: log.debug( "Early rendering of map_index_template failed, will retry after task execution", exc_info=True @@ -1041,7 +1053,7 @@ def _validate_task_inlets_and_outlets(*, ti: RuntimeTaskInstance, log: Logger) - if not ti.task.inlets and not ti.task.outlets: return - inactive_assets_resp = supervisor_comms().send(msg=ValidateInletsAndOutlets(ti_id=ti.id)) + inactive_assets_resp = supervisor_send(msg=ValidateInletsAndOutlets(ti_id=ti.id)) if TYPE_CHECKING: assert isinstance(inactive_assets_resp, InactiveAssetsResult) if inactive_assets := inactive_assets_resp.inactive_assets: @@ -1160,16 +1172,14 @@ def _on_term(signum, frame): ti.rendered_map_index = _render_map_index(context, ti=ti, log=log) # Send update only if value changed (e.g., user set context variables during execution) if ti.rendered_map_index and ti.rendered_map_index != previous_rendered_map_index: - supervisor_comms().send( - msg=SetRenderedMapIndex(rendered_map_index=ti.rendered_map_index) - ) + supervisor_send(msg=SetRenderedMapIndex(rendered_map_index=ti.rendered_map_index)) raise else: # If the task succeeded, render normally to let rendering error bubble up. previous_rendered_map_index = ti.rendered_map_index ti.rendered_map_index = _render_map_index(context, ti=ti, log=log) # Send update only if value changed (e.g., user set context variables during execution) if ti.rendered_map_index and ti.rendered_map_index != previous_rendered_map_index: - supervisor_comms().send(msg=SetRenderedMapIndex(rendered_map_index=ti.rendered_map_index)) + supervisor_send(msg=SetRenderedMapIndex(rendered_map_index=ti.rendered_map_index)) _push_xcom_if_needed(result, ti, log) @@ -1177,7 +1187,7 @@ def _on_term(signum, frame): except DownstreamTasksSkipped as skip: log.info("Skipping downstream tasks.") tasks_to_skip = skip.tasks if isinstance(skip.tasks, list) else [skip.tasks] - supervisor_comms().send(msg=SkipDownstreamTasks(tasks=tasks_to_skip)) + supervisor_send(msg=SkipDownstreamTasks(tasks=tasks_to_skip)) msg, state = _handle_current_task_success(context, ti) except DagRunTriggerException as drte: msg, state = _handle_trigger_dag_run(drte, context, ti, log) @@ -1239,7 +1249,7 @@ def _on_term(signum, frame): error = e finally: if msg: - supervisor_comms().send(msg=msg) + supervisor_send(msg=msg) # Return the message to make unit tests easier too ti.state = state @@ -1303,7 +1313,7 @@ def _handle_trigger_dag_run( ) -> tuple[ToSupervisor, TaskInstanceState]: """Handle exception from TriggerDagRunOperator.""" log.info("Triggering Dag Run.", trigger_dag_id=drte.trigger_dag_id) - comms_msg = supervisor_comms().send( + comms_msg = supervisor_send( TriggerDagRun( dag_id=drte.trigger_dag_id, run_id=drte.dag_run_id, @@ -1370,9 +1380,7 @@ def _handle_trigger_dag_run( ) time.sleep(drte.poke_interval) - comms_msg = supervisor_comms().send( - GetDagRunState(dag_id=drte.trigger_dag_id, run_id=drte.dag_run_id) - ) + comms_msg = supervisor_send(GetDagRunState(dag_id=drte.trigger_dag_id, run_id=drte.dag_run_id)) if TYPE_CHECKING: assert isinstance(comms_msg, DagRunStateResult) if comms_msg.state in drte.failed_states: @@ -1656,7 +1664,7 @@ def finalize( if getattr(ti.task, "overwrite_rtif_after_execution", False): log.debug("Overwriting Rendered template fields.") if ti.task.template_fields: - supervisor_comms().send(SetRenderedFields(rendered_fields=_serialize_rendered_fields(ti.task))) + supervisor_send(SetRenderedFields(rendered_fields=_serialize_rendered_fields(ti.task))) log.debug("Running finalizers", ti=ti) if state == TaskInstanceState.SUCCESS: @@ -1731,9 +1739,9 @@ def main(): finally: # Ensure the request socket is closed on the child side in all circumstances # before the process fully terminates. - if is_supervisor_comms_initialized() and supervisor_comms().socket: + if is_supervisor_comms_initialized() and get_supervisor_comms().socket: with suppress(Exception): - supervisor_comms().socket.close() + get_supervisor_comms().socket.close() def reinit_supervisor_comms() -> None: @@ -1753,7 +1761,7 @@ def reinit_supervisor_comms() -> None: set_supervisor_comms(CommsDecoder[ToTask, ToSupervisor](log=log, socket=socket.socket(fileno=fd))) - logs = supervisor_comms().send(ResendLoggingFD()) + logs = supervisor_send(ResendLoggingFD()) if isinstance(logs, SentFDs): from airflow.sdk.log import configure_logging diff --git a/task-sdk/src/airflow/sdk/log.py b/task-sdk/src/airflow/sdk/log.py index 13ba908ebe627..ee2125d0ca8d1 100644 --- a/task-sdk/src/airflow/sdk/log.py +++ b/task-sdk/src/airflow/sdk/log.py @@ -263,7 +263,7 @@ def mask_secret(secret: JsonValue, name: str | None = None) -> None: from airflow.sdk.execution_time import task_runner from airflow.sdk.execution_time.comms import MaskSecret - task_runner.supervisor_comms().send(MaskSecret(value=secret, name=name)) + task_runner.supervisor_send(MaskSecret(value=secret, name=name)) def reset_logging(): diff --git a/task-sdk/tests/task_sdk/execution_time/test_secrets.py b/task-sdk/tests/task_sdk/execution_time/test_secrets.py index 3a8ef10cd0e10..fbddcffa55f10 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_secrets.py +++ b/task-sdk/tests/task_sdk/execution_time/test_secrets.py @@ -125,7 +125,7 @@ def test_runtime_error_triggers_greenback_fallback(self, mocker, mock_supervisor """ Test that RuntimeError from async_to_sync triggers greenback fallback. - This test verifies the fix for issue #57145: when supervisor_comms().send() + This test verifies the fix for issue #57145: when supervisor_send() raises the specific RuntimeError about async_to_sync in an event loop, the backend catches it and uses greenback to call aget_connection(). """ @@ -161,7 +161,7 @@ def greenback_await_side_effect(coro): # Mock aget_connection to return the expected connection directly. # We need to mock this because the real aget_connection would try to - # use supervisor_comms().asend which is not set up for this test. + # use supervisor_asend which is not set up for this test. async def mock_aget_connection(self, conn_id): return expected_conn @@ -183,7 +183,7 @@ class TestContextDetection: """Test context detection in ensure_secrets_backend_loaded.""" def test_client_context_with_supervisor_comms(self, mock_supervisor_comms): - """Client context: supervisor_comms() set → uses worker chain.""" + """Client context: get_supervisor_comms() set → uses worker chain.""" from airflow.sdk.execution_time.supervisor import ensure_secrets_backend_loaded backends = ensure_secrets_backend_loaded() diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index 9cf63c654aaf5..bb33e2ffef2fa 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -2563,7 +2563,7 @@ def execute(self, context): assert ( call( - msg=SetRenderedFields( + SetRenderedFields( rendered_fields={ "env_vars": "Truncated. You can change this behaviour in [core]max_templated_field_length. \"{'TEST_URL_0': '***', 'TEST_URL_1': '***', 'TEST_URL_10': '***', 'TEST_URL_11': '***', 'TEST_URL_12': '***', 'TEST_URL_13': '***', 'TEST_URL_14': '***', 'TEST_URL_15': '***', 'TEST_URL_16': '***', 'TEST_URL_17': '***', 'TEST_URL_18': '***', 'TEST_URL_19': '***', 'TEST_URL_2': '***', 'TEST_URL_20': '***', 'TEST_URL_21': '***', 'TEST_URL_22': '***', 'TEST_URL_23': '***', 'TEST_URL_24': '***', 'TEST_URL_25': '***', 'TEST_URL_26': '***', 'TEST_URL_27': '***', 'TEST_URL_28': '***', 'TEST_URL_29': '***', 'TEST_URL_3': '***', 'TEST_URL_30': '***', 'TEST_URL_31': '***', 'TEST_URL_32': '***', 'TEST_URL_33': '***', 'TEST_URL_34': '***', 'TEST_URL_35': '***', 'TEST_URL_36': '***', 'TEST_URL_37': '***', 'TEST_URL_38': '***', 'TEST_URL_39': '***', 'TEST_URL_4': '***', 'TEST_URL_40': '***', 'TEST_URL_41': '***', 'TEST_URL_42': '***', 'TEST_URL_43': '***', 'TEST_URL_44': '***', 'TEST_URL_45': '***', 'TEST_URL_46': '***', 'TEST_URL_47': '***', 'TEST_URL_48': '***', 'TEST_URL_49': '***', 'TEST_URL_5': '***', 'TEST_URL_6': '***', 'TEST_URL_7': '***', 'TEST_URL_8': '***', 'TEST_URL_9': '***'}\"... ", "region": "us-west-2", @@ -3073,7 +3073,7 @@ def execute(self, context): assert ( call( - msg=SetRenderedFields( + SetRenderedFields( rendered_fields={"username": "***", "region": "us-west-2"}, type="SetRenderedFields", ) From 14056f21d77f7b6077ef8b66edd41a8f06bd8852 Mon Sep 17 00:00:00 2001 From: Jens Scheffler Date: Wed, 31 Dec 2025 16:48:36 +0100 Subject: [PATCH 04/10] Review feedback Kaxil and Amogh --- airflow-core/src/airflow/models/base.py | 12 ++++++ airflow-core/src/airflow/models/connection.py | 16 ++----- airflow-core/src/airflow/models/variable.py | 42 +++---------------- .../airflow/sdk/execution_time/task_runner.py | 6 +-- 4 files changed, 23 insertions(+), 53 deletions(-) diff --git a/airflow-core/src/airflow/models/base.py b/airflow-core/src/airflow/models/base.py index cc9853330625a..2fa69abfb4ce9 100644 --- a/airflow-core/src/airflow/models/base.py +++ b/airflow-core/src/airflow/models/base.py @@ -17,6 +17,7 @@ # under the License. from __future__ import annotations +import sys from typing import TYPE_CHECKING, Any from sqlalchemy import Integer, MetaData, String, text @@ -98,3 +99,14 @@ class TaskInstanceDependencies(Base): dag_id: Mapped[str] = mapped_column(StringID(), nullable=False) run_id: Mapped[str] = mapped_column(StringID(), nullable=False) map_index: Mapped[int] = mapped_column(Integer, nullable=False, server_default=text("-1")) + + +def has_execution_context() -> bool: + """Check if we are in an execution context (Task, Dag Parse or Triggerer perhaps).""" + # TODO: This is not the best way of having compat, but it's "better than erroring" for now. This still + # means SQLA etc is loaded, but we can't avoid that unless/until we add import shims as a big + # back-compat layer + + # If this is set it means are in some kind of execution context (Task, Dag Parse or Triggerer perhaps) + # and should use the Task SDK API server path + return hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), "_SupervisorCommsHolder") diff --git a/airflow-core/src/airflow/models/connection.py b/airflow-core/src/airflow/models/connection.py index c4b70b15bcc4d..4ce9bb0791404 100644 --- a/airflow-core/src/airflow/models/connection.py +++ b/airflow-core/src/airflow/models/connection.py @@ -33,7 +33,7 @@ from airflow._shared.secrets_masker import mask_secret from airflow.configuration import conf, ensure_secrets_loaded from airflow.exceptions import AirflowException, AirflowNotFoundException -from airflow.models.base import ID_LEN, Base +from airflow.models.base import ID_LEN, Base, has_execution_context from airflow.models.crypto import get_fernet from airflow.utils.helpers import prune_dict from airflow.utils.log.logging_mixin import LoggingMixin @@ -499,15 +499,7 @@ def get_connection_from_secrets(cls, conn_id: str, team_name: str | None = None) :param team_name: Team name associated to the task trying to access the connection (if any) :return: connection """ - # TODO: This is not the best way of having compat, but it's "better than erroring" for now. This still - # means SQLA etc is loaded, but we can't avoid that unless/until we add import shims as a big - # back-compat layer - - # If this is set it means are in some kind of execution context (Task, Dag Parse or Triggerer perhaps) - # and should use the Task SDK API server path - from airflow.sdk.execution_time.task_runner import is_supervisor_comms_initialized - - if is_supervisor_comms_initialized(): + if has_execution_context(): from airflow.sdk import Connection as TaskSDKConnection from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType @@ -591,9 +583,7 @@ def to_dict(self, *, prune_empty: bool = False, validate: bool = True) -> dict[s @classmethod def from_json(cls, value, conn_id=None) -> Connection: - from airflow.sdk.execution_time.task_runner import is_supervisor_comms_initialized - - if is_supervisor_comms_initialized(): + if has_execution_context(): from airflow.sdk import Connection as TaskSDKConnection warnings.warn( diff --git a/airflow-core/src/airflow/models/variable.py b/airflow-core/src/airflow/models/variable.py index 0dee21bfdc5dc..641b9804c68c3 100644 --- a/airflow-core/src/airflow/models/variable.py +++ b/airflow-core/src/airflow/models/variable.py @@ -29,7 +29,7 @@ from airflow._shared.secrets_masker import mask_secret from airflow.configuration import conf, ensure_secrets_loaded -from airflow.models.base import ID_LEN, Base +from airflow.models.base import ID_LEN, Base, has_execution_context from airflow.models.crypto import get_fernet from airflow.secrets.metastore import MetastoreBackend from airflow.utils.log.logging_mixin import LoggingMixin @@ -147,15 +147,7 @@ def get( :param deserialize_json: Deserialize the value to a Python dict :param team_name: Team name associated to the task trying to access the variable (if any) """ - # TODO: This is not the best way of having compat, but it's "better than erroring" for now. This still - # means SQLA etc is loaded, but we can't avoid that unless/until we add import shims as a big - # back-compat layer - - # If this is set it means we are in some kind of execution context (Task, Dag Parse or Triggerer perhaps) - # and should use the Task SDK API server path - from airflow.sdk.execution_time.task_runner import is_supervisor_comms_initialized - - if is_supervisor_comms_initialized(): + if has_execution_context(): warnings.warn( "Using Variable.get from `airflow.models` is deprecated." "Please use `get` on Variable from sdk(`airflow.sdk.Variable`) instead", @@ -209,15 +201,7 @@ def set( :param team_name: Team name associated to the variable (if any) :param session: optional session, use if provided or create a new one """ - # TODO: This is not the best way of having compat, but it's "better than erroring" for now. This still - # means SQLA etc is loaded, but we can't avoid that unless/until we add import shims as a big - # back-compat layer - - # If this is set it means we are in some kind of execution context (Task, Dag Parse or Triggerer perhaps) - # and should use the Task SDK API server path - from airflow.sdk.execution_time.task_runner import is_supervisor_comms_initialized - - if is_supervisor_comms_initialized(): + if has_execution_context(): warnings.warn( "Using Variable.set from `airflow.models` is deprecated." "Please use `set` on Variable from sdk(`airflow.sdk.Variable`) instead", @@ -342,15 +326,7 @@ def update( :param team_name: Team name associated to the variable (if any) :param session: optional session, use if provided or create a new one """ - # TODO: This is not the best way of having compat, but it's "better than erroring" for now. This still - # means SQLA etc is loaded, but we can't avoid that unless/until we add import shims as a big - # back-compat layer - - # If this is set it means are in some kind of execution context (Task, Dag Parse or Triggerer perhaps) - # and should use the Task SDK API server path - from airflow.sdk.execution_time.task_runner import is_supervisor_comms_initialized - - if is_supervisor_comms_initialized(): + if has_execution_context(): warnings.warn( "Using Variable.update from `airflow.models` is deprecated." "Please use `set` on Variable from sdk(`airflow.sdk.Variable`) instead as it is an upsert.", @@ -410,15 +386,7 @@ def delete(key: str, team_name: str | None = None, session: Session | None = Non :param team_name: Team name associated to the task trying to delete the variable (if any) :param session: optional session, use if provided or create a new one """ - # TODO: This is not the best way of having compat, but it's "better than erroring" for now. This still - # means SQLA etc is loaded, but we can't avoid that unless/until we add import shims as a big - # back-compat layer - - # If this is set it means are in some kind of execution context (Task, Dag Parse or Triggerer perhaps) - # and should use the Task SDK API server path - from airflow.sdk.execution_time.task_runner import is_supervisor_comms_initialized - - if is_supervisor_comms_initialized(): + if has_execution_context(): warnings.warn( "Using Variable.delete from `airflow.models` is deprecated." "Please use `delete` on Variable from sdk(`airflow.sdk.Variable`) instead", diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index a4c8458810119..cfb1afc01892f 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -748,21 +748,21 @@ class _SupervisorCommsHolder: def supervisor_send(msg: ToSupervisor) -> ToTask | None: """Send a message to the supervisor as convenience for get_supervisor_comms().send().""" if _SupervisorCommsHolder.comms is None: - raise RuntimeError("Supervisor comms not initialized yet") + raise RuntimeError("Supervisor comms not initialized yet. Call set_supervisor_comms() instead.") return _SupervisorCommsHolder.comms.send(msg) async def supervisor_asend(msg: ToSupervisor) -> ToTask | None: """Send a message to the supervisor as convenience for get_supervisor_comms().asend().""" if _SupervisorCommsHolder.comms is None: - raise RuntimeError("Supervisor comms not initialized yet") + raise RuntimeError("Supervisor comms not initialized yet. Call set_supervisor_comms() instead.") return await _SupervisorCommsHolder.comms.asend(msg) def get_supervisor_comms() -> CommsDecoder[ToTask, ToSupervisor]: """Get the global supervisor comms instance.""" if _SupervisorCommsHolder.comms is None: - raise RuntimeError("Supervisor comms not initialized yet") + raise RuntimeError("Supervisor comms not initialized yet. Call set_supervisor_comms() instead.") return _SupervisorCommsHolder.comms From c37728b7f61a6bb46ec6703c26d379dac4127bf1 Mon Sep 17 00:00:00 2001 From: Jens Scheffler Date: Wed, 31 Dec 2025 17:15:58 +0100 Subject: [PATCH 05/10] Harden check for execution context --- airflow-core/src/airflow/models/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow-core/src/airflow/models/base.py b/airflow-core/src/airflow/models/base.py index 2fa69abfb4ce9..b0f53adc8aae4 100644 --- a/airflow-core/src/airflow/models/base.py +++ b/airflow-core/src/airflow/models/base.py @@ -109,4 +109,4 @@ def has_execution_context() -> bool: # If this is set it means are in some kind of execution context (Task, Dag Parse or Triggerer perhaps) # and should use the Task SDK API server path - return hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), "_SupervisorCommsHolder") + return hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), "_SupervisorCommsHolder.comms") From 172be4e0f178be018394c97af0ca79102566df3d Mon Sep 17 00:00:00 2001 From: Jens Scheffler Date: Sun, 11 Jan 2026 12:01:02 +0100 Subject: [PATCH 06/10] Rework supervisor comms to a singleton pattern --- airflow-core/src/airflow/models/base.py | 6 +- .../tests/unit/jobs/test_triggerer_job.py | 2 +- .../tests/unit/models/test_connection.py | 4 +- .../src/tests_common/pytest_plugin.py | 30 ++++++- .../airflow/sdk/execution_time/supervisor.py | 35 ++++---- .../airflow/sdk/execution_time/task_runner.py | 85 ++++++++++++------- .../task_sdk/execution_time/test_secrets.py | 15 +--- .../execution_time/test_supervisor.py | 46 +++++----- 8 files changed, 126 insertions(+), 97 deletions(-) diff --git a/airflow-core/src/airflow/models/base.py b/airflow-core/src/airflow/models/base.py index b0f53adc8aae4..ca6b146ee014c 100644 --- a/airflow-core/src/airflow/models/base.py +++ b/airflow-core/src/airflow/models/base.py @@ -109,4 +109,8 @@ def has_execution_context() -> bool: # If this is set it means are in some kind of execution context (Task, Dag Parse or Triggerer perhaps) # and should use the Task SDK API server path - return hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), "_SupervisorCommsHolder.comms") + return ( + hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), "SupervisorComms") + and sys.modules["airflow.sdk.execution_time.task_runner"].SupervisorComms._comms + is not sys.modules["airflow.sdk.execution_time.task_runner"]._UnsetComms + ) diff --git a/airflow-core/tests/unit/jobs/test_triggerer_job.py b/airflow-core/tests/unit/jobs/test_triggerer_job.py index b8ec067340035..ccbd44ff4c205 100644 --- a/airflow-core/tests/unit/jobs/test_triggerer_job.py +++ b/airflow-core/tests/unit/jobs/test_triggerer_job.py @@ -360,7 +360,7 @@ def fn(moment): ... assert "got an unexpected keyword argument 'not_exists_arg'" in str(err) @pytest.mark.asyncio - @patch("airflow.sdk.execution_time.task_runner._SupervisorCommsHolder.comms", create=True) + @patch("airflow.sdk.execution_time.task_runner.SupervisorComms._comms", create=True) async def test_invalid_trigger(self, supervisor_builder): """Test the behaviour when we try to run an invalid Trigger""" workload = workloads.RunTrigger.model_construct( diff --git a/airflow-core/tests/unit/models/test_connection.py b/airflow-core/tests/unit/models/test_connection.py index b493144285e63..eb330ff4b5636 100644 --- a/airflow-core/tests/unit/models/test_connection.py +++ b/airflow-core/tests/unit/models/test_connection.py @@ -361,7 +361,7 @@ def test_get_connection_from_secrets_task_sdk_success(self, mock_get): mock_get.return_value = expected_connection mock_task_runner = mock.MagicMock() - mock_task_runner._SupervisorCommsHolder.comms = True + mock_task_runner.SupervisorComms._comms = True with mock.patch.dict(sys.modules, {"airflow.sdk.execution_time.task_runner": mock_task_runner}): result = Connection.get_connection_from_secrets("test_conn") @@ -373,7 +373,7 @@ def test_get_connection_from_secrets_task_sdk_success(self, mock_get): def test_get_connection_from_secrets_task_sdk_not_found(self, mock_task_sdk_connection): """Test the get_connection_from_secrets method with Task SDK not found path.""" mock_task_runner = mock.MagicMock() - mock_task_runner._SupervisorCommsHolder.comms = True + mock_task_runner.SupervisorComms._comms = True mock_task_sdk_connection.get.side_effect = AirflowRuntimeError( error=ErrorResponse(error=ErrorType.CONNECTION_NOT_FOUND) diff --git a/devel-common/src/tests_common/pytest_plugin.py b/devel-common/src/tests_common/pytest_plugin.py index fea49e45d6d4e..f129a871859e6 100644 --- a/devel-common/src/tests_common/pytest_plugin.py +++ b/devel-common/src/tests_common/pytest_plugin.py @@ -2245,10 +2245,36 @@ def mock_supervisor_comms(monkeypatch): comms.send = comms.get_message if AIRFLOW_V_3_2_PLUS: - monkeypatch.setattr(task_runner._SupervisorCommsHolder, "comms", comms, raising=False) + svcomms = task_runner.SupervisorComms() + old = svcomms.get_comms() + svcomms.set_comms(comms) + yield comms + svcomms.set_comms(old) else: monkeypatch.setattr(task_runner, "SUPERVISOR_COMMS", comms, raising=False) - yield comms + yield comms + + +@pytest.fixture +def mock_unset_supervisor_comms(monkeypatch): + # for back-compat + from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_2_PLUS + + if not AIRFLOW_V_3_0_PLUS: + yield None + return + + from airflow.sdk.execution_time import comms, task_runner + + if AIRFLOW_V_3_2_PLUS: + svcomms = task_runner.SupervisorComms() + old = svcomms.get_comms() + svcomms.reset_comms() + yield comms + svcomms.set_comms(old) + else: + monkeypatch.setattr(task_runner, "SUPERVISOR_COMMS", None, raising=False) + yield comms @pytest.fixture diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index 8b6992d5898b4..6f041f08cb153 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -64,6 +64,7 @@ from airflow.sdk.execution_time.comms import ( AssetEventsResult, AssetResult, + CommsDecoder, ConnectionResult, CreateHITLDetailPayload, DagRunResult, @@ -134,6 +135,7 @@ from airflow.executors.workloads import BundleInfo from airflow.sdk.bases.secrets_backend import BaseSecretsBackend from airflow.sdk.definitions.connection import Connection + from airflow.sdk.execution_time.task_runner import ToTask from airflow.sdk.types import RuntimeTaskInstanceProtocol as RuntimeTI @@ -1621,7 +1623,7 @@ def start( # type: ignore[override] from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance, finalize, run supervisor.comms = InProcessSupervisorComms(supervisor=supervisor) - with set_supervisor_comms(supervisor.comms): + with set_supervisor_comms(supervisor.comms): # type: ignore[arg-type] supervisor.ti = what # type: ignore[assignment] # We avoid calling `task_runner.startup()` because we are already inside a @@ -1721,36 +1723,31 @@ def final_state(self): @contextmanager -def set_supervisor_comms(temp_comms): +def set_supervisor_comms(temp_comms: CommsDecoder[ToTask, ToSupervisor] | None): """ - Temporarily override `supervisor_comms()` in the `task_runner` module. + Temporarily override `SupervisorComms()` in the `task_runner` module. This is used to simulate task-runner ↔ supervisor communication in-process, by injecting a test Comms implementation (e.g. `InProcessSupervisorComms`) in place of the real inter-process communication layer. Some parts of the code (e.g. models.Variable.get) check for the presence - of `task_runner.supervisor_comms()` to determine if the code is running in a Task SDK execution context. + of `task_runner.SupervisorComms()` to determine if the code is running in a Task SDK execution context. This override ensures those code paths behave correctly during in-process tests. """ - from airflow.sdk.execution_time import task_runner - - sentinel = object() - old = getattr(task_runner._SupervisorCommsHolder, "comms", sentinel) + from airflow.sdk.execution_time.task_runner import SupervisorComms - if temp_comms is not None: - task_runner._SupervisorCommsHolder.comms = temp_comms - elif old is not sentinel: - task_runner._SupervisorCommsHolder.comms = None + svcomms = SupervisorComms() + old = svcomms.get_comms() + if temp_comms: + svcomms.set_comms(temp_comms) + else: + svcomms.reset_comms() try: yield finally: - if old is sentinel: - if task_runner._SupervisorCommsHolder.comms is not None: - task_runner._SupervisorCommsHolder.comms = None - else: - task_runner._SupervisorCommsHolder.comms = old + svcomms.set_comms(old) def run_task_in_process(ti: TaskInstance, task) -> TaskRunResult: @@ -1928,9 +1925,9 @@ def ensure_secrets_backend_loaded() -> list[BaseSecretsBackend]: # 1. Check for client context (supervisor-comms) try: - from airflow.sdk.execution_time import task_runner + from airflow.sdk.execution_time.task_runner import SupervisorComms - if task_runner.is_supervisor_comms_initialized(): + if SupervisorComms().is_initialized(): # Client context: task runner with supervisor-comms return ensure_secrets_loaded(default_backends=DEFAULT_SECRETS_SEARCH_PATH_WORKERS) except (ImportError, AttributeError): diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index cfb1afc01892f..7cbab3bb6853c 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -741,41 +741,59 @@ def parse(what: StartupDetails, log: Logger) -> RuntimeTaskInstance: # - By defining this as a static class with accessors, it ensures that this communication mechanism is readily # accessible wherever needed during task execution without modifying every layer of the call stack. # Not perfect but getter than a global variable. -class _SupervisorCommsHolder: - comms: CommsDecoder[ToTask, ToSupervisor] | None = None +class _UnsetComms(CommsDecoder[ToTask, ToSupervisor]): + def __init__(self): + self.id_counter = self.socket = None # type: ignore[assignment] + def send(self, msg: ToSupervisor) -> None: + raise RuntimeError("Supervisor comms not initialized yet. Call set_supervisor_comms() before using.") -def supervisor_send(msg: ToSupervisor) -> ToTask | None: - """Send a message to the supervisor as convenience for get_supervisor_comms().send().""" - if _SupervisorCommsHolder.comms is None: - raise RuntimeError("Supervisor comms not initialized yet. Call set_supervisor_comms() instead.") - return _SupervisorCommsHolder.comms.send(msg) + async def asend(self, msg: ToSupervisor) -> None: + raise RuntimeError("Supervisor comms not initialized yet. Call set_supervisor_comms() before using.") -async def supervisor_asend(msg: ToSupervisor) -> ToTask | None: - """Send a message to the supervisor as convenience for get_supervisor_comms().asend().""" - if _SupervisorCommsHolder.comms is None: - raise RuntimeError("Supervisor comms not initialized yet. Call set_supervisor_comms() instead.") - return await _SupervisorCommsHolder.comms.asend(msg) +class SupervisorComms: + _instance = None + + def __new__(cls, *args, **kwargs): + if not cls._instance: + cls._instance = super().__new__(cls) + cls._instance._comms = _UnsetComms() + return cls._instance + + def get_comms(self) -> CommsDecoder[ToTask, ToSupervisor]: + """Get the global supervisor comms instance.""" + return self._comms + + def set_comms(self, comms: CommsDecoder[ToTask, ToSupervisor]) -> None: + """Set the global supervisor comms instance.""" + self._comms = comms + def reset_comms(self) -> None: + """Reset the global supervisor comms instance to initial state.""" + self._comms = _UnsetComms() -def get_supervisor_comms() -> CommsDecoder[ToTask, ToSupervisor]: - """Get the global supervisor comms instance.""" - if _SupervisorCommsHolder.comms is None: - raise RuntimeError("Supervisor comms not initialized yet. Call set_supervisor_comms() instead.") - return _SupervisorCommsHolder.comms + def is_initialized(self) -> bool: + """Check if the global supervisor comms instance is initialized.""" + return type(self._comms) is not _UnsetComms + def send(self, msg: ToSupervisor) -> ToTask | None: + """Send a message to the supervisor.""" + return self._comms.send(msg) -def set_supervisor_comms(comms: CommsDecoder[ToTask, ToSupervisor]) -> None: - """Set the global supervisor comms instance.""" - if _SupervisorCommsHolder.comms is not None: - raise RuntimeError("Supervisor comms already initialized") - _SupervisorCommsHolder.comms = comms + async def asend(self, msg: ToSupervisor) -> ToTask | None: + """Send a message to the supervisor asynchronously.""" + return await self._comms.asend(msg) -def is_supervisor_comms_initialized() -> bool: - """Check if the global supervisor comms instance is initialized.""" - return _SupervisorCommsHolder.comms is not None +def supervisor_send(msg: ToSupervisor) -> ToTask | None: + """Send a message to the supervisor as convenience for SupervisorComms().send().""" + return SupervisorComms().send(msg) + + +async def supervisor_asend(msg: ToSupervisor) -> ToTask | None: + """Send a message to the supervisor as convenience for SupervisorComms().asend().""" + return await SupervisorComms().asend(msg) # State machine! @@ -840,7 +858,7 @@ def startup() -> tuple[RuntimeTaskInstance, Context, Logger]: log.debug("Using serialized startup message from environment", msg=msg) else: # normal entry point - msg = get_supervisor_comms()._get_response() # type: ignore[assignment] + msg = SupervisorComms().get_comms()._get_response() # type: ignore[assignment] if not isinstance(msg, StartupDetails): raise RuntimeError(f"Unhandled startup message {type(msg)} {msg}") @@ -872,7 +890,7 @@ def startup() -> tuple[RuntimeTaskInstance, Context, Logger]: os.environ["_AIRFLOW__REEXECUTED_PROCESS"] = "1" # store startup message in environment for re-exec process os.environ["_AIRFLOW__STARTUP_MSG"] = msg.model_dump_json() - os.set_inheritable(get_supervisor_comms().socket.fileno(), True) + os.set_inheritable(SupervisorComms().get_comms().socket.fileno(), True) # Import main directly from the module instead of re-executing the file. # This ensures that when other parts modules import @@ -1713,7 +1731,7 @@ def finalize( def main(): log = structlog.get_logger(logger_name="task") - set_supervisor_comms(CommsDecoder[ToTask, ToSupervisor](log=log)) + SupervisorComms().set_comms(CommsDecoder[ToTask, ToSupervisor](log=log)) Stats.initialize( is_statsd_datadog_enabled=conf.getboolean("metrics", "statsd_datadog_enabled"), @@ -1739,9 +1757,10 @@ def main(): finally: # Ensure the request socket is closed on the child side in all circumstances # before the process fully terminates. - if is_supervisor_comms_initialized() and get_supervisor_comms().socket: + svcomms = SupervisorComms() + if svcomms.is_initialized() and svcomms.get_comms().socket: with suppress(Exception): - get_supervisor_comms().socket.close() + svcomms.get_comms().socket.close() def reinit_supervisor_comms() -> None: @@ -1754,12 +1773,14 @@ def reinit_supervisor_comms() -> None: """ import socket - if not is_supervisor_comms_initialized(): + if not SupervisorComms().is_initialized(): log = structlog.get_logger(logger_name="task") fd = int(os.environ.get("__AIRFLOW_SUPERVISOR_FD", "0")) - set_supervisor_comms(CommsDecoder[ToTask, ToSupervisor](log=log, socket=socket.socket(fileno=fd))) + SupervisorComms().set_comms( + CommsDecoder[ToTask, ToSupervisor](log=log, socket=socket.socket(fileno=fd)) + ) logs = supervisor_send(ResendLoggingFD()) if isinstance(logs, SentFDs): diff --git a/task-sdk/tests/task_sdk/execution_time/test_secrets.py b/task-sdk/tests/task_sdk/execution_time/test_secrets.py index fbddcffa55f10..7b12e74d8c5d0 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_secrets.py +++ b/task-sdk/tests/task_sdk/execution_time/test_secrets.py @@ -191,32 +191,21 @@ def test_client_context_with_supervisor_comms(self, mock_supervisor_comms): assert "ExecutionAPISecretsBackend" in backend_classes assert "MetastoreBackend" not in backend_classes - def test_server_context_with_env_var(self, monkeypatch): + def test_server_context_with_env_var(self, monkeypatch, mock_unset_supervisor_comms): """Server context: env var set → uses server chain.""" - import sys - from airflow.sdk.execution_time.supervisor import ensure_secrets_backend_loaded monkeypatch.setenv("_AIRFLOW_PROCESS_CONTEXT", "server") - # Ensure supervisor-comms is not available - if "airflow.sdk.execution_time.task_runner" in sys.modules: - monkeypatch.delitem(sys.modules, "airflow.sdk.execution_time.task_runner") backends = ensure_secrets_backend_loaded() backend_classes = [type(b).__name__ for b in backends] assert "MetastoreBackend" in backend_classes assert "ExecutionAPISecretsBackend" not in backend_classes - def test_fallback_context_no_markers(self, monkeypatch): + def test_fallback_context_no_markers(self, monkeypatch, mock_unset_supervisor_comms): """Fallback context: no supervisor-comms, no env var → only env vars + external.""" - import sys - from airflow.sdk.execution_time.supervisor import ensure_secrets_backend_loaded - # Ensure no supervisor-comms - if "airflow.sdk.execution_time.task_runner" in sys.modules: - monkeypatch.delitem(sys.modules, "airflow.sdk.execution_time.task_runner") - # Ensure no env var monkeypatch.delenv("_AIRFLOW_PROCESS_CONTEXT", raising=False) diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py index 438de3c22591e..72945e67bf0df 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py +++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py @@ -293,6 +293,14 @@ def subprocess_main(): instant = timezone.datetime(2024, 11, 7, 12, 34, 56, 78901) time_machine.move_to(instant, tick=False) + # Depending on if another test made the env dirty we need to suppress the fork() warning + import warnings + + warnings.filterwarnings( + "ignore", + message="This process .* is multi-threaded, use of fork.* may lead to deadlocks in the child.", + ) + proc = ActivitySubprocess.start( dag_rel_path=os.devnull, bundle_info=FAKE_BUNDLE, @@ -2636,40 +2644,24 @@ class TestSetSupervisorComms: class DummyComms: pass - @pytest.fixture(autouse=True) - def cleanup_supervisor_comms(self): - # Ensure clean state before/after test - if task_runner._SupervisorCommsHolder.comms: - task_runner._SupervisorCommsHolder.comms = None - yield - if task_runner._SupervisorCommsHolder.comms: - task_runner._SupervisorCommsHolder.comms = None - - def test_set_supervisor_comms_overrides_and_restores(self): - task_runner._SupervisorCommsHolder.comms = self.DummyComms() - original = task_runner._SupervisorCommsHolder.comms + def test_set_supervisor_comms_overrides_and_restores(self, mock_unset_supervisor_comms): + svcomm = task_runner.SupervisorComms() + original = svcomm.get_comms() replacement = self.DummyComms() with set_supervisor_comms(replacement): - assert task_runner._SupervisorCommsHolder.comms is replacement - assert task_runner._SupervisorCommsHolder.comms is original + assert svcomm.get_comms() is replacement + assert svcomm.get_comms() is original - def test_set_supervisor_comms_sets_temporarily_when_not_set(self): - assert task_runner._SupervisorCommsHolder.comms is None + def test_set_supervisor_comms_sets_temporarily_when_not_set(self, mock_unset_supervisor_comms): + svcomm = task_runner.SupervisorComms() + assert svcomm.is_initialized() is False replacement = self.DummyComms() with set_supervisor_comms(replacement): - assert task_runner._SupervisorCommsHolder.comms is replacement - assert task_runner._SupervisorCommsHolder.comms is None - - def test_set_supervisor_comms_unsets_temporarily_when_not_set(self): - assert task_runner._SupervisorCommsHolder.comms is None - - # This will delete an attribute that isn't set, and restore it likewise - with set_supervisor_comms(None): - assert task_runner._SupervisorCommsHolder.comms is None - - assert task_runner._SupervisorCommsHolder.comms is None + assert svcomm.get_comms() is replacement + assert svcomm.is_initialized() is True + assert svcomm.is_initialized() is False class TestInProcessTestSupervisor: From 21d73e207112283c513ab0f257c3af62c8207df4 Mon Sep 17 00:00:00 2001 From: Jens Scheffler Date: Sun, 11 Jan 2026 12:11:32 +0100 Subject: [PATCH 07/10] Rework has_execution_context() to use env _AIRFLOW_PROCESS_CONTEXT --- airflow-core/src/airflow/models/base.py | 19 +++++-------------- airflow-core/src/airflow/models/connection.py | 6 +++--- airflow-core/src/airflow/models/variable.py | 10 +++++----- 3 files changed, 13 insertions(+), 22 deletions(-) diff --git a/airflow-core/src/airflow/models/base.py b/airflow-core/src/airflow/models/base.py index ca6b146ee014c..154c9a4be7324 100644 --- a/airflow-core/src/airflow/models/base.py +++ b/airflow-core/src/airflow/models/base.py @@ -17,7 +17,7 @@ # under the License. from __future__ import annotations -import sys +import os from typing import TYPE_CHECKING, Any from sqlalchemy import Integer, MetaData, String, text @@ -101,16 +101,7 @@ class TaskInstanceDependencies(Base): map_index: Mapped[int] = mapped_column(Integer, nullable=False, server_default=text("-1")) -def has_execution_context() -> bool: - """Check if we are in an execution context (Task, Dag Parse or Triggerer perhaps).""" - # TODO: This is not the best way of having compat, but it's "better than erroring" for now. This still - # means SQLA etc is loaded, but we can't avoid that unless/until we add import shims as a big - # back-compat layer - - # If this is set it means are in some kind of execution context (Task, Dag Parse or Triggerer perhaps) - # and should use the Task SDK API server path - return ( - hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), "SupervisorComms") - and sys.modules["airflow.sdk.execution_time.task_runner"].SupervisorComms._comms - is not sys.modules["airflow.sdk.execution_time.task_runner"]._UnsetComms - ) +def is_client_process_context() -> bool: + """Check if we are in an execution context (Task, Dag Parser or Triggerer perhaps).""" + process_context = os.environ.get("_AIRFLOW_PROCESS_CONTEXT", "").lower() + return process_context == "client" diff --git a/airflow-core/src/airflow/models/connection.py b/airflow-core/src/airflow/models/connection.py index 4ce9bb0791404..cc2a56931e1e7 100644 --- a/airflow-core/src/airflow/models/connection.py +++ b/airflow-core/src/airflow/models/connection.py @@ -33,7 +33,7 @@ from airflow._shared.secrets_masker import mask_secret from airflow.configuration import conf, ensure_secrets_loaded from airflow.exceptions import AirflowException, AirflowNotFoundException -from airflow.models.base import ID_LEN, Base, has_execution_context +from airflow.models.base import ID_LEN, Base, is_client_process_context from airflow.models.crypto import get_fernet from airflow.utils.helpers import prune_dict from airflow.utils.log.logging_mixin import LoggingMixin @@ -499,7 +499,7 @@ def get_connection_from_secrets(cls, conn_id: str, team_name: str | None = None) :param team_name: Team name associated to the task trying to access the connection (if any) :return: connection """ - if has_execution_context(): + if is_client_process_context(): from airflow.sdk import Connection as TaskSDKConnection from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType @@ -583,7 +583,7 @@ def to_dict(self, *, prune_empty: bool = False, validate: bool = True) -> dict[s @classmethod def from_json(cls, value, conn_id=None) -> Connection: - if has_execution_context(): + if is_client_process_context(): from airflow.sdk import Connection as TaskSDKConnection warnings.warn( diff --git a/airflow-core/src/airflow/models/variable.py b/airflow-core/src/airflow/models/variable.py index 641b9804c68c3..8ddd88daa0254 100644 --- a/airflow-core/src/airflow/models/variable.py +++ b/airflow-core/src/airflow/models/variable.py @@ -29,7 +29,7 @@ from airflow._shared.secrets_masker import mask_secret from airflow.configuration import conf, ensure_secrets_loaded -from airflow.models.base import ID_LEN, Base, has_execution_context +from airflow.models.base import ID_LEN, Base, is_client_process_context from airflow.models.crypto import get_fernet from airflow.secrets.metastore import MetastoreBackend from airflow.utils.log.logging_mixin import LoggingMixin @@ -147,7 +147,7 @@ def get( :param deserialize_json: Deserialize the value to a Python dict :param team_name: Team name associated to the task trying to access the variable (if any) """ - if has_execution_context(): + if is_client_process_context(): warnings.warn( "Using Variable.get from `airflow.models` is deprecated." "Please use `get` on Variable from sdk(`airflow.sdk.Variable`) instead", @@ -201,7 +201,7 @@ def set( :param team_name: Team name associated to the variable (if any) :param session: optional session, use if provided or create a new one """ - if has_execution_context(): + if is_client_process_context(): warnings.warn( "Using Variable.set from `airflow.models` is deprecated." "Please use `set` on Variable from sdk(`airflow.sdk.Variable`) instead", @@ -326,7 +326,7 @@ def update( :param team_name: Team name associated to the variable (if any) :param session: optional session, use if provided or create a new one """ - if has_execution_context(): + if is_client_process_context(): warnings.warn( "Using Variable.update from `airflow.models` is deprecated." "Please use `set` on Variable from sdk(`airflow.sdk.Variable`) instead as it is an upsert.", @@ -386,7 +386,7 @@ def delete(key: str, team_name: str | None = None, session: Session | None = Non :param team_name: Team name associated to the task trying to delete the variable (if any) :param session: optional session, use if provided or create a new one """ - if has_execution_context(): + if is_client_process_context(): warnings.warn( "Using Variable.delete from `airflow.models` is deprecated." "Please use `delete` on Variable from sdk(`airflow.sdk.Variable`) instead", From 6522fb7d836622caf40fb169b047b587bb77658e Mon Sep 17 00:00:00 2001 From: Jens Scheffler Date: Sun, 11 Jan 2026 12:18:15 +0100 Subject: [PATCH 08/10] Fix code comments --- task-sdk/src/airflow/sdk/execution_time/task_runner.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index 7cbab3bb6853c..baba0e201bfbc 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -733,12 +733,12 @@ def parse(what: StartupDetails, log: Logger) -> RuntimeTaskInstance: # This global class will be used by Connection/Variable/XCom classes, or other parts of the task's execution, # to send requests back to the supervisor process. # -# Why it needs to be a global: +# Why it needs to be a global singleton class: # - Many parts of Airflow's codebase (e.g., connections, variables, and XComs) may rely on making dynamic requests # to the parent process during task execution. # - These calls occur in various locations and cannot easily pass the `CommsDecoder` instance through the # deeply nested execution stack. -# - By defining this as a static class with accessors, it ensures that this communication mechanism is readily +# - By defining this as a singleton class with accessors, it ensures that this communication mechanism is readily # accessible wherever needed during task execution without modifying every layer of the call stack. # Not perfect but getter than a global variable. class _UnsetComms(CommsDecoder[ToTask, ToSupervisor]): From 4e940bec459a1c71b4ccf330efe06d3af13ac3b9 Mon Sep 17 00:00:00 2001 From: Jens Scheffler Date: Sun, 11 Jan 2026 13:33:22 +0100 Subject: [PATCH 09/10] Fix references in Airflow core --- airflow-core/src/airflow/dag_processing/processor.py | 7 ++++--- .../src/airflow/jobs/triggerer_job_runner.py | 4 ++-- .../src/airflow/sdk/execution_time/task_runner.py | 12 ++++++++---- 3 files changed, 14 insertions(+), 9 deletions(-) diff --git a/airflow-core/src/airflow/dag_processing/processor.py b/airflow-core/src/airflow/dag_processing/processor.py index 476bd02f69196..627483ef90388 100644 --- a/airflow-core/src/airflow/dag_processing/processor.py +++ b/airflow-core/src/airflow/dag_processing/processor.py @@ -183,10 +183,11 @@ def _parse_file_entrypoint(): import structlog - from airflow.sdk.execution_time import comms, task_runner + from airflow.sdk.execution_time.comms import CommsDecoder + from airflow.sdk.execution_time.task_runner import SupervisorComms # Parse DAG file, send JSON back up! - comms_decoder = comms.CommsDecoder[ToDagProcessor, ToManager]( + comms_decoder = CommsDecoder[ToDagProcessor, ToManager]( body_decoder=TypeAdapter[ToDagProcessor](ToDagProcessor), ) @@ -194,7 +195,7 @@ def _parse_file_entrypoint(): if not isinstance(msg, DagFileParseRequest): raise RuntimeError(f"Required first message to be a DagFileParseRequest, it was {msg}") - task_runner.set_supervisor_comms(comms_decoder) + SupervisorComms().set_comms(comms_decoder) log = structlog.get_logger(logger_name="task") result = _parse_file(msg, log) diff --git a/airflow-core/src/airflow/jobs/triggerer_job_runner.py b/airflow-core/src/airflow/jobs/triggerer_job_runner.py index 3b87e168ad9bc..9cb6d4359c133 100644 --- a/airflow-core/src/airflow/jobs/triggerer_job_runner.py +++ b/airflow-core/src/airflow/jobs/triggerer_job_runner.py @@ -933,7 +933,7 @@ async def init_comms(self): This also sets up the supervisor-comms so that TaskSDK code can work as expected too (but that will need to be wrapped in an ``sync_to_async()`` call) """ - from airflow.sdk.execution_time import task_runner + from airflow.sdk.execution_time.task_runner import SupervisorComms # Yes, we read and write to stdin! It's a socket, not a normal stdin. reader, writer = await asyncio.open_connection(sock=socket(fileno=0)) @@ -943,7 +943,7 @@ async def init_comms(self): async_reader=reader, ) - task_runner.set_supervisor_comms(self.comms_decoder) + SupervisorComms().set_comms(self.comms_decoder) msg = await self.comms_decoder._aget_response(expect_id=0) diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index baba0e201bfbc..4741de18272d9 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -746,10 +746,14 @@ def __init__(self): self.id_counter = self.socket = None # type: ignore[assignment] def send(self, msg: ToSupervisor) -> None: - raise RuntimeError("Supervisor comms not initialized yet. Call set_supervisor_comms() before using.") + raise RuntimeError( + "Supervisor comms not initialized yet. Call SupervisorComms().set_comms() before using." + ) async def asend(self, msg: ToSupervisor) -> None: - raise RuntimeError("Supervisor comms not initialized yet. Call set_supervisor_comms() before using.") + raise RuntimeError( + "Supervisor comms not initialized yet. Call SupervisorComms().set_comms() before using." + ) class SupervisorComms: @@ -761,11 +765,11 @@ def __new__(cls, *args, **kwargs): cls._instance._comms = _UnsetComms() return cls._instance - def get_comms(self) -> CommsDecoder[ToTask, ToSupervisor]: + def get_comms(self) -> CommsDecoder: """Get the global supervisor comms instance.""" return self._comms - def set_comms(self, comms: CommsDecoder[ToTask, ToSupervisor]) -> None: + def set_comms(self, comms: CommsDecoder) -> None: """Set the global supervisor comms instance.""" self._comms = comms From c103ddf892bb26c18f17d3b7690f13945a5458ca Mon Sep 17 00:00:00 2001 From: Jens Scheffler Date: Sun, 11 Jan 2026 18:59:21 +0100 Subject: [PATCH 10/10] Fix pytest --- airflow-core/tests/unit/models/test_connection.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/airflow-core/tests/unit/models/test_connection.py b/airflow-core/tests/unit/models/test_connection.py index eb330ff4b5636..de84f27983956 100644 --- a/airflow-core/tests/unit/models/test_connection.py +++ b/airflow-core/tests/unit/models/test_connection.py @@ -17,6 +17,7 @@ # under the License. from __future__ import annotations +import os import re import sys from typing import TYPE_CHECKING @@ -353,6 +354,7 @@ def test_extra_dejson(self): } @mock.patch("airflow.sdk.Connection.get") + @mock.patch.dict(os.environ, {"_AIRFLOW_PROCESS_CONTEXT": "client"}) def test_get_connection_from_secrets_task_sdk_success(self, mock_get): """Test the get_connection_from_secrets method with Task SDK success path.""" from airflow.sdk import Connection as SDKConnection @@ -361,7 +363,6 @@ def test_get_connection_from_secrets_task_sdk_success(self, mock_get): mock_get.return_value = expected_connection mock_task_runner = mock.MagicMock() - mock_task_runner.SupervisorComms._comms = True with mock.patch.dict(sys.modules, {"airflow.sdk.execution_time.task_runner": mock_task_runner}): result = Connection.get_connection_from_secrets("test_conn") @@ -370,10 +371,10 @@ def test_get_connection_from_secrets_task_sdk_success(self, mock_get): assert result.conn_type == "test_type" @mock.patch("airflow.sdk.Connection") + @mock.patch.dict(os.environ, {"_AIRFLOW_PROCESS_CONTEXT": "client"}) def test_get_connection_from_secrets_task_sdk_not_found(self, mock_task_sdk_connection): """Test the get_connection_from_secrets method with Task SDK not found path.""" mock_task_runner = mock.MagicMock() - mock_task_runner.SupervisorComms._comms = True mock_task_sdk_connection.get.side_effect = AirflowRuntimeError( error=ErrorResponse(error=ErrorType.CONNECTION_NOT_FOUND)