diff --git a/airflow-core/src/airflow/dag_processing/processor.py b/airflow-core/src/airflow/dag_processing/processor.py index 6492e4bab777b..627483ef90388 100644 --- a/airflow-core/src/airflow/dag_processing/processor.py +++ b/airflow-core/src/airflow/dag_processing/processor.py @@ -183,10 +183,11 @@ def _parse_file_entrypoint(): import structlog - from airflow.sdk.execution_time import comms, task_runner + from airflow.sdk.execution_time.comms import CommsDecoder + from airflow.sdk.execution_time.task_runner import SupervisorComms # Parse DAG file, send JSON back up! - comms_decoder = comms.CommsDecoder[ToDagProcessor, ToManager]( + comms_decoder = CommsDecoder[ToDagProcessor, ToManager]( body_decoder=TypeAdapter[ToDagProcessor](ToDagProcessor), ) @@ -194,7 +195,7 @@ def _parse_file_entrypoint(): if not isinstance(msg, DagFileParseRequest): raise RuntimeError(f"Required first message to be a DagFileParseRequest, it was {msg}") - task_runner.SUPERVISOR_COMMS = comms_decoder + SupervisorComms().set_comms(comms_decoder) log = structlog.get_logger(logger_name="task") result = _parse_file(msg, log) diff --git a/airflow-core/src/airflow/jobs/triggerer_job_runner.py b/airflow-core/src/airflow/jobs/triggerer_job_runner.py index 2508e4a281ab7..9cb6d4359c133 100644 --- a/airflow-core/src/airflow/jobs/triggerer_job_runner.py +++ b/airflow-core/src/airflow/jobs/triggerer_job_runner.py @@ -930,10 +930,10 @@ async def init_comms(self): """ Set up the communications pipe between this process and the supervisor. - This also sets up the SUPERVISOR_COMMS so that TaskSDK code can work as expected too (but that will + This also sets up the supervisor-comms so that TaskSDK code can work as expected too (but that will need to be wrapped in an ``sync_to_async()`` call) """ - from airflow.sdk.execution_time import task_runner + from airflow.sdk.execution_time.task_runner import SupervisorComms # Yes, we read and write to stdin! It's a socket, not a normal stdin. reader, writer = await asyncio.open_connection(sock=socket(fileno=0)) @@ -943,7 +943,7 @@ async def init_comms(self): async_reader=reader, ) - task_runner.SUPERVISOR_COMMS = self.comms_decoder + SupervisorComms().set_comms(self.comms_decoder) msg = await self.comms_decoder._aget_response(expect_id=0) diff --git a/airflow-core/src/airflow/models/base.py b/airflow-core/src/airflow/models/base.py index cc9853330625a..154c9a4be7324 100644 --- a/airflow-core/src/airflow/models/base.py +++ b/airflow-core/src/airflow/models/base.py @@ -17,6 +17,7 @@ # under the License. from __future__ import annotations +import os from typing import TYPE_CHECKING, Any from sqlalchemy import Integer, MetaData, String, text @@ -98,3 +99,9 @@ class TaskInstanceDependencies(Base): dag_id: Mapped[str] = mapped_column(StringID(), nullable=False) run_id: Mapped[str] = mapped_column(StringID(), nullable=False) map_index: Mapped[int] = mapped_column(Integer, nullable=False, server_default=text("-1")) + + +def is_client_process_context() -> bool: + """Check if we are in an execution context (Task, Dag Parser or Triggerer perhaps).""" + process_context = os.environ.get("_AIRFLOW_PROCESS_CONTEXT", "").lower() + return process_context == "client" diff --git a/airflow-core/src/airflow/models/connection.py b/airflow-core/src/airflow/models/connection.py index 64303f191942a..cc2a56931e1e7 100644 --- a/airflow-core/src/airflow/models/connection.py +++ b/airflow-core/src/airflow/models/connection.py @@ -20,7 +20,6 @@ import json import logging import re -import sys import warnings from contextlib import suppress from json import JSONDecodeError @@ -34,7 +33,7 @@ from airflow._shared.secrets_masker import mask_secret from airflow.configuration import conf, ensure_secrets_loaded from airflow.exceptions import AirflowException, AirflowNotFoundException -from airflow.models.base import ID_LEN, Base +from airflow.models.base import ID_LEN, Base, is_client_process_context from airflow.models.crypto import get_fernet from airflow.utils.helpers import prune_dict from airflow.utils.log.logging_mixin import LoggingMixin @@ -500,13 +499,7 @@ def get_connection_from_secrets(cls, conn_id: str, team_name: str | None = None) :param team_name: Team name associated to the task trying to access the connection (if any) :return: connection """ - # TODO: This is not the best way of having compat, but it's "better than erroring" for now. This still - # means SQLA etc is loaded, but we can't avoid that unless/until we add import shims as a big - # back-compat layer - - # If this is set it means are in some kind of execution context (Task, Dag Parse or Triggerer perhaps) - # and should use the Task SDK API server path - if hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), "SUPERVISOR_COMMS"): + if is_client_process_context(): from airflow.sdk import Connection as TaskSDKConnection from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType @@ -590,7 +583,7 @@ def to_dict(self, *, prune_empty: bool = False, validate: bool = True) -> dict[s @classmethod def from_json(cls, value, conn_id=None) -> Connection: - if hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), "SUPERVISOR_COMMS"): + if is_client_process_context(): from airflow.sdk import Connection as TaskSDKConnection warnings.warn( diff --git a/airflow-core/src/airflow/models/variable.py b/airflow-core/src/airflow/models/variable.py index 65808259fae47..8ddd88daa0254 100644 --- a/airflow-core/src/airflow/models/variable.py +++ b/airflow-core/src/airflow/models/variable.py @@ -20,7 +20,6 @@ import contextlib import json import logging -import sys import warnings from typing import TYPE_CHECKING, Any @@ -30,7 +29,7 @@ from airflow._shared.secrets_masker import mask_secret from airflow.configuration import conf, ensure_secrets_loaded -from airflow.models.base import ID_LEN, Base +from airflow.models.base import ID_LEN, Base, is_client_process_context from airflow.models.crypto import get_fernet from airflow.secrets.metastore import MetastoreBackend from airflow.utils.log.logging_mixin import LoggingMixin @@ -148,13 +147,7 @@ def get( :param deserialize_json: Deserialize the value to a Python dict :param team_name: Team name associated to the task trying to access the variable (if any) """ - # TODO: This is not the best way of having compat, but it's "better than erroring" for now. This still - # means SQLA etc is loaded, but we can't avoid that unless/until we add import shims as a big - # back-compat layer - - # If this is set it means we are in some kind of execution context (Task, Dag Parse or Triggerer perhaps) - # and should use the Task SDK API server path - if hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), "SUPERVISOR_COMMS"): + if is_client_process_context(): warnings.warn( "Using Variable.get from `airflow.models` is deprecated." "Please use `get` on Variable from sdk(`airflow.sdk.Variable`) instead", @@ -208,13 +201,7 @@ def set( :param team_name: Team name associated to the variable (if any) :param session: optional session, use if provided or create a new one """ - # TODO: This is not the best way of having compat, but it's "better than erroring" for now. This still - # means SQLA etc is loaded, but we can't avoid that unless/until we add import shims as a big - # back-compat layer - - # If this is set it means we are in some kind of execution context (Task, Dag Parse or Triggerer perhaps) - # and should use the Task SDK API server path - if hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), "SUPERVISOR_COMMS"): + if is_client_process_context(): warnings.warn( "Using Variable.set from `airflow.models` is deprecated." "Please use `set` on Variable from sdk(`airflow.sdk.Variable`) instead", @@ -339,13 +326,7 @@ def update( :param team_name: Team name associated to the variable (if any) :param session: optional session, use if provided or create a new one """ - # TODO: This is not the best way of having compat, but it's "better than erroring" for now. This still - # means SQLA etc is loaded, but we can't avoid that unless/until we add import shims as a big - # back-compat layer - - # If this is set it means are in some kind of execution context (Task, Dag Parse or Triggerer perhaps) - # and should use the Task SDK API server path - if hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), "SUPERVISOR_COMMS"): + if is_client_process_context(): warnings.warn( "Using Variable.update from `airflow.models` is deprecated." "Please use `set` on Variable from sdk(`airflow.sdk.Variable`) instead as it is an upsert.", @@ -405,13 +386,7 @@ def delete(key: str, team_name: str | None = None, session: Session | None = Non :param team_name: Team name associated to the task trying to delete the variable (if any) :param session: optional session, use if provided or create a new one """ - # TODO: This is not the best way of having compat, but it's "better than erroring" for now. This still - # means SQLA etc is loaded, but we can't avoid that unless/until we add import shims as a big - # back-compat layer - - # If this is set it means are in some kind of execution context (Task, Dag Parse or Triggerer perhaps) - # and should use the Task SDK API server path - if hasattr(sys.modules.get("airflow.sdk.execution_time.task_runner"), "SUPERVISOR_COMMS"): + if is_client_process_context(): warnings.warn( "Using Variable.delete from `airflow.models` is deprecated." "Please use `delete` on Variable from sdk(`airflow.sdk.Variable`) instead", diff --git a/airflow-core/tests/unit/jobs/test_triggerer_job.py b/airflow-core/tests/unit/jobs/test_triggerer_job.py index df31c9225c272..ccbd44ff4c205 100644 --- a/airflow-core/tests/unit/jobs/test_triggerer_job.py +++ b/airflow-core/tests/unit/jobs/test_triggerer_job.py @@ -360,7 +360,7 @@ def fn(moment): ... assert "got an unexpected keyword argument 'not_exists_arg'" in str(err) @pytest.mark.asyncio - @patch("airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True) + @patch("airflow.sdk.execution_time.task_runner.SupervisorComms._comms", create=True) async def test_invalid_trigger(self, supervisor_builder): """Test the behaviour when we try to run an invalid Trigger""" workload = workloads.RunTrigger.model_construct( diff --git a/airflow-core/tests/unit/models/test_connection.py b/airflow-core/tests/unit/models/test_connection.py index 1f14c325ea9e5..de84f27983956 100644 --- a/airflow-core/tests/unit/models/test_connection.py +++ b/airflow-core/tests/unit/models/test_connection.py @@ -17,6 +17,7 @@ # under the License. from __future__ import annotations +import os import re import sys from typing import TYPE_CHECKING @@ -353,6 +354,7 @@ def test_extra_dejson(self): } @mock.patch("airflow.sdk.Connection.get") + @mock.patch.dict(os.environ, {"_AIRFLOW_PROCESS_CONTEXT": "client"}) def test_get_connection_from_secrets_task_sdk_success(self, mock_get): """Test the get_connection_from_secrets method with Task SDK success path.""" from airflow.sdk import Connection as SDKConnection @@ -361,7 +363,6 @@ def test_get_connection_from_secrets_task_sdk_success(self, mock_get): mock_get.return_value = expected_connection mock_task_runner = mock.MagicMock() - mock_task_runner.SUPERVISOR_COMMS = True with mock.patch.dict(sys.modules, {"airflow.sdk.execution_time.task_runner": mock_task_runner}): result = Connection.get_connection_from_secrets("test_conn") @@ -370,10 +371,10 @@ def test_get_connection_from_secrets_task_sdk_success(self, mock_get): assert result.conn_type == "test_type" @mock.patch("airflow.sdk.Connection") + @mock.patch.dict(os.environ, {"_AIRFLOW_PROCESS_CONTEXT": "client"}) def test_get_connection_from_secrets_task_sdk_not_found(self, mock_task_sdk_connection): """Test the get_connection_from_secrets method with Task SDK not found path.""" mock_task_runner = mock.MagicMock() - mock_task_runner.SUPERVISOR_COMMS = True mock_task_sdk_connection.get.side_effect = AirflowRuntimeError( error=ErrorResponse(error=ErrorType.CONNECTION_NOT_FOUND) @@ -383,7 +384,6 @@ def test_get_connection_from_secrets_task_sdk_not_found(self, mock_task_sdk_conn with pytest.raises(AirflowNotFoundException): Connection.get_connection_from_secrets("test_conn") - @mock.patch.dict(sys.modules, {"airflow.sdk.execution_time.task_runner": None}) @mock.patch("airflow.sdk.Connection") @mock.patch("airflow.secrets.environment_variables.EnvironmentVariablesBackend.get_connection") @mock.patch("airflow.secrets.metastore.MetastoreBackend.get_connection") diff --git a/devel-common/src/tests_common/pytest_plugin.py b/devel-common/src/tests_common/pytest_plugin.py index 9dd6e782193f1..f129a871859e6 100644 --- a/devel-common/src/tests_common/pytest_plugin.py +++ b/devel-common/src/tests_common/pytest_plugin.py @@ -2227,7 +2227,7 @@ def override_caplog(request): @pytest.fixture def mock_supervisor_comms(monkeypatch): # for back-compat - from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS + from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_2_PLUS if not AIRFLOW_V_3_0_PLUS: yield None @@ -2239,13 +2239,42 @@ def mock_supervisor_comms(monkeypatch): # core and TaskSDK is finished if CommsDecoder := getattr(comms, "CommsDecoder", None): comms = mock.create_autospec(CommsDecoder) - monkeypatch.setattr(task_runner, "SUPERVISOR_COMMS", comms, raising=False) else: CommsDecoder = getattr(task_runner, "CommsDecoder") comms = mock.create_autospec(CommsDecoder) comms.send = comms.get_message + + if AIRFLOW_V_3_2_PLUS: + svcomms = task_runner.SupervisorComms() + old = svcomms.get_comms() + svcomms.set_comms(comms) + yield comms + svcomms.set_comms(old) + else: monkeypatch.setattr(task_runner, "SUPERVISOR_COMMS", comms, raising=False) - yield comms + yield comms + + +@pytest.fixture +def mock_unset_supervisor_comms(monkeypatch): + # for back-compat + from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_2_PLUS + + if not AIRFLOW_V_3_0_PLUS: + yield None + return + + from airflow.sdk.execution_time import comms, task_runner + + if AIRFLOW_V_3_2_PLUS: + svcomms = task_runner.SupervisorComms() + old = svcomms.get_comms() + svcomms.reset_comms() + yield comms + svcomms.set_comms(old) + else: + monkeypatch.setattr(task_runner, "SUPERVISOR_COMMS", None, raising=False) + yield comms @pytest.fixture diff --git a/task-sdk/src/airflow/sdk/bases/xcom.py b/task-sdk/src/airflow/sdk/bases/xcom.py index 81cbfbed478e6..7b76ef85b59b3 100644 --- a/task-sdk/src/airflow/sdk/bases/xcom.py +++ b/task-sdk/src/airflow/sdk/bases/xcom.py @@ -72,18 +72,13 @@ def set( :param map_index: Optional map index to assign XCom for a mapped task. The default is ``-1`` (set for a non-mapped task). """ - from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + from airflow.sdk.execution_time.task_runner import supervisor_send value = cls.serialize_value( - value=value, - key=key, - task_id=task_id, - dag_id=dag_id, - run_id=run_id, - map_index=map_index, + value=value, key=key, task_id=task_id, dag_id=dag_id, run_id=run_id, map_index=map_index ) - SUPERVISOR_COMMS.send( + supervisor_send( SetXCom( key=key, value=value, @@ -97,14 +92,7 @@ def set( @classmethod def _set_xcom_in_db( - cls, - key: str, - value: Any, - *, - dag_id: str, - task_id: str, - run_id: str, - map_index: int = -1, + cls, key: str, value: Any, *, dag_id: str, task_id: str, run_id: str, map_index: int = -1 ) -> None: """ Store an XCom value directly in the metadata database. @@ -117,17 +105,10 @@ def _set_xcom_in_db( :param map_index: Optional map index to assign XCom for a mapped task. The default is ``-1`` (set for a non-mapped task). """ - from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + from airflow.sdk.execution_time.task_runner import supervisor_send - SUPERVISOR_COMMS.send( - SetXCom( - key=key, - value=value, - dag_id=dag_id, - task_id=task_id, - run_id=run_id, - map_index=map_index, - ), + supervisor_send( + SetXCom(key=key, value=value, dag_id=dag_id, task_id=task_id, run_id=run_id, map_index=map_index), ) @classmethod @@ -160,13 +141,7 @@ def get_value( @classmethod def _get_xcom_db_ref( - cls, - *, - key: str, - dag_id: str, - task_id: str, - run_id: str, - map_index: int | None = None, + cls, *, key: str, dag_id: str, task_id: str, run_id: str, map_index: int | None = None ) -> XComResult: """ Retrieve an XCom value, optionally meeting certain criteria. @@ -190,16 +165,10 @@ def _get_xcom_db_ref( :param key: A key for the XCom. If provided, only XCom with matching keys will be returned. Pass *None* (default) to remove the filter. """ - from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + from airflow.sdk.execution_time.task_runner import supervisor_send - msg = SUPERVISOR_COMMS.send( - GetXCom( - key=key, - dag_id=dag_id, - task_id=task_id, - run_id=run_id, - map_index=map_index, - ), + msg = supervisor_send( + GetXCom(key=key, dag_id=dag_id, task_id=task_id, run_id=run_id, map_index=map_index), ) if not isinstance(msg, XComResult): @@ -243,9 +212,9 @@ def get_one( specified Dag run is returned. If *True*, the latest matching XCom is returned regardless of the run it belongs to. """ - from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + from airflow.sdk.execution_time.task_runner import supervisor_send - msg = SUPERVISOR_COMMS.send( + msg = supervisor_send( GetXCom( key=key, dag_id=dag_id, @@ -299,9 +268,9 @@ def get_all( returned regardless of the run they belong to. :return: List of all XCom values if found. """ - from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + from airflow.sdk.execution_time.task_runner import supervisor_send - msg = SUPERVISOR_COMMS.send( + msg = supervisor_send( msg=GetXComSequenceSlice( key=key, dag_id=dag_id, @@ -351,31 +320,14 @@ def purge(cls, xcom: XComResult, *args) -> None: pass @classmethod - def delete( - cls, - key: str, - task_id: str, - dag_id: str, - run_id: str, - map_index: int | None = None, - ) -> None: + def delete(cls, key: str, task_id: str, dag_id: str, run_id: str, map_index: int | None = None) -> None: """Delete an Xcom entry, for custom xcom backends, it gets the path associated with the data on the backend and purges it.""" - from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + from airflow.sdk.execution_time.task_runner import supervisor_send xcom_result = cls._get_xcom_db_ref( - key=key, - dag_id=dag_id, - task_id=task_id, - run_id=run_id, - map_index=map_index, + key=key, dag_id=dag_id, task_id=task_id, run_id=run_id, map_index=map_index ) cls.purge(xcom_result) - SUPERVISOR_COMMS.send( - DeleteXCom( - key=key, - dag_id=dag_id, - task_id=task_id, - run_id=run_id, - map_index=map_index, - ), + supervisor_send( + DeleteXCom(key=key, dag_id=dag_id, task_id=task_id, run_id=run_id, map_index=map_index) ) diff --git a/task-sdk/src/airflow/sdk/definitions/asset/decorators.py b/task-sdk/src/airflow/sdk/definitions/asset/decorators.py index 26205f0c58335..8facc2df99910 100644 --- a/task-sdk/src/airflow/sdk/definitions/asset/decorators.py +++ b/task-sdk/src/airflow/sdk/definitions/asset/decorators.py @@ -71,10 +71,10 @@ def from_definition(cls, definition: AssetDefinition | MultiAssetDefinition) -> def _iter_kwargs(self, context: Mapping[str, Any]) -> Iterator[tuple[str, Any]]: from airflow.sdk.execution_time.comms import ErrorResponse, GetAssetByName - from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + from airflow.sdk.execution_time.task_runner import supervisor_send def _fetch_asset(name: str) -> Asset: - resp = SUPERVISOR_COMMS.send(GetAssetByName(name=name)) + resp = supervisor_send(GetAssetByName(name=name)) if resp is None: raise RuntimeError("Empty non-error response received") if isinstance(resp, ErrorResponse): diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py b/task-sdk/src/airflow/sdk/execution_time/comms.py index 15755e640d97e..7043127e92203 100644 --- a/task-sdk/src/airflow/sdk/execution_time/comms.py +++ b/task-sdk/src/airflow/sdk/execution_time/comms.py @@ -368,9 +368,9 @@ def xcom_pull(self, *, key: str = "return_value", default: Any = None) -> Any: def _fetch_dag_run(*, dag_id: str, run_id: str) -> DagRun: - from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + from airflow.sdk.execution_time.task_runner import supervisor_send - response = SUPERVISOR_COMMS.send(GetDagRun(dag_id=dag_id, run_id=run_id)) + response = supervisor_send(GetDagRun(dag_id=dag_id, run_id=run_id)) if TYPE_CHECKING: assert isinstance(response, DagRunResult) return response diff --git a/task-sdk/src/airflow/sdk/execution_time/context.py b/task-sdk/src/airflow/sdk/execution_time/context.py index db5a75e10c18d..63b9947f43218 100644 --- a/task-sdk/src/airflow/sdk/execution_time/context.py +++ b/task-sdk/src/airflow/sdk/execution_time/context.py @@ -286,7 +286,7 @@ def _set_variable(key: str, value: Any, description: str | None = None, serializ from airflow.sdk.execution_time.comms import PutVariable from airflow.sdk.execution_time.secrets.execution_api import ExecutionAPISecretsBackend from airflow.sdk.execution_time.supervisor import ensure_secrets_backend_loaded - from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + from airflow.sdk.execution_time.task_runner import supervisor_send # check for write conflicts on the worker for secrets_backend in ensure_secrets_backend_loaded(): @@ -317,7 +317,7 @@ def _set_variable(key: str, value: Any, description: str | None = None, serializ except Exception as e: log.exception(e) - SUPERVISOR_COMMS.send(PutVariable(key=key, value=value, description=description)) + supervisor_send(PutVariable(key=key, value=value, description=description)) # Invalidate cache after setting the variable SecretCache.invalidate_variable(key) @@ -331,9 +331,9 @@ def _delete_variable(key: str) -> None: # keep Task SDK as a separate package than execution time mods. from airflow.sdk.execution_time.cache import SecretCache from airflow.sdk.execution_time.comms import DeleteVariable - from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + from airflow.sdk.execution_time.task_runner import supervisor_send - msg = SUPERVISOR_COMMS.send(DeleteVariable(key=key)) + msg = supervisor_send(DeleteVariable(key=key)) if TYPE_CHECKING: assert isinstance(msg, OKResponse) @@ -458,7 +458,7 @@ def _get_asset_from_db(name: str | None = None, uri: str | None = None) -> Asset GetAssetByUri, ToSupervisor, ) - from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + from airflow.sdk.execution_time.task_runner import supervisor_send msg: ToSupervisor if name: @@ -468,7 +468,7 @@ def _get_asset_from_db(name: str | None = None, uri: str | None = None) -> Asset else: raise ValueError("Either name or uri must be provided") - resp = SUPERVISOR_COMMS.send(msg) + resp = supervisor_send(msg) if isinstance(resp, ErrorResponse): raise AirflowRuntimeError(resp) @@ -619,7 +619,7 @@ def _asset_events(self) -> list[AssetEventResult]: GetAssetEventByAssetAlias, ToSupervisor, ) - from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + from airflow.sdk.execution_time.task_runner import supervisor_send query_dict: dict[str, Any] = { "after": self._after, @@ -635,7 +635,7 @@ def _asset_events(self) -> list[AssetEventResult]: if self._asset_name is None and self._asset_uri is None: raise ValueError("Either asset_name or asset_uri must be provided") msg = GetAssetEventByAsset(name=self._asset_name, uri=self._asset_uri, **query_dict) - resp = SUPERVISOR_COMMS.send(msg) + resp = supervisor_send(msg) if isinstance(resp, ErrorResponse): raise AirflowRuntimeError(resp) @@ -788,7 +788,7 @@ def get_previous_dagrun_success(ti_id: UUID) -> PrevSuccessfulDagRunResponse: PrevSuccessfulDagRunResult, ) - msg = task_runner.SUPERVISOR_COMMS.send(GetPrevSuccessfulDagRun(ti_id=ti_id)) + msg = task_runner.supervisor_send(GetPrevSuccessfulDagRun(ti_id=ti_id)) if TYPE_CHECKING: assert isinstance(msg, PrevSuccessfulDagRunResult) diff --git a/task-sdk/src/airflow/sdk/execution_time/hitl.py b/task-sdk/src/airflow/sdk/execution_time/hitl.py index 07f94e63c0529..ce0cc53030fb7 100644 --- a/task-sdk/src/airflow/sdk/execution_time/hitl.py +++ b/task-sdk/src/airflow/sdk/execution_time/hitl.py @@ -46,9 +46,9 @@ def upsert_hitl_detail( params: dict[str, Any] | None = None, assigned_users: list[HITLUser] | None = None, ) -> None: - from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + from airflow.sdk.execution_time.task_runner import supervisor_send - SUPERVISOR_COMMS.send( + supervisor_send( msg=CreateHITLDetailPayload( ti_id=ti_id, options=options, @@ -71,9 +71,9 @@ def update_hitl_detail_response( chosen_options: list[str], params_input: dict[str, Any], ) -> HITLDetailResponse: - from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + from airflow.sdk.execution_time.task_runner import supervisor_send - response = SUPERVISOR_COMMS.send( + response = supervisor_send( msg=UpdateHITLDetail( ti_id=ti_id, chosen_options=chosen_options, @@ -86,9 +86,9 @@ def update_hitl_detail_response( def get_hitl_detail_content_detail(ti_id: UUID) -> HITLDetailResponse: - from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + from airflow.sdk.execution_time.task_runner import supervisor_send - response = SUPERVISOR_COMMS.send(msg=GetHITLDetailResponse(ti_id=ti_id)) + response = supervisor_send(msg=GetHITLDetailResponse(ti_id=ti_id)) if TYPE_CHECKING: assert isinstance(response, HITLDetailResponse) diff --git a/task-sdk/src/airflow/sdk/execution_time/lazy_sequence.py b/task-sdk/src/airflow/sdk/execution_time/lazy_sequence.py index 4efb0b71368ca..5639fa596d5e7 100644 --- a/task-sdk/src/airflow/sdk/execution_time/lazy_sequence.py +++ b/task-sdk/src/airflow/sdk/execution_time/lazy_sequence.py @@ -90,11 +90,11 @@ def __iter__(self) -> Iterator[T]: def __len__(self) -> int: if self._len is None: from airflow.sdk.execution_time.comms import ErrorResponse, GetXComCount, XComCountResponse - from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + from airflow.sdk.execution_time.task_runner import supervisor_send task = self._xcom_arg.operator - msg = SUPERVISOR_COMMS.send( + msg = supervisor_send( GetXComCount( key=self._xcom_arg.key, dag_id=task.dag_id, @@ -123,13 +123,13 @@ def __getitem__(self, key: int | slice) -> T | Sequence[T]: XComSequenceIndexResult, XComSequenceSliceResult, ) - from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + from airflow.sdk.execution_time.task_runner import supervisor_send from airflow.sdk.execution_time.xcom import XCom if isinstance(key, slice): start, stop, step = _coerce_slice(key) source = (xcom_arg := self._xcom_arg).operator - msg = SUPERVISOR_COMMS.send( + msg = supervisor_send( GetXComSequenceSlice( key=xcom_arg.key, dag_id=source.dag_id, @@ -150,7 +150,7 @@ def __getitem__(self, key: int | slice) -> T | Sequence[T]: raise TypeError(f"Sequence indices must be integers or slices not {type(key).__name__}") source = (xcom_arg := self._xcom_arg).operator - msg = SUPERVISOR_COMMS.send( + msg = supervisor_send( GetXComSequenceItem( key=xcom_arg.key, dag_id=source.dag_id, diff --git a/task-sdk/src/airflow/sdk/execution_time/secrets/execution_api.py b/task-sdk/src/airflow/sdk/execution_time/secrets/execution_api.py index a44b23d06dc6d..0140671083196 100644 --- a/task-sdk/src/airflow/sdk/execution_time/secrets/execution_api.py +++ b/task-sdk/src/airflow/sdk/execution_time/secrets/execution_api.py @@ -30,14 +30,14 @@ class ExecutionAPISecretsBackend(BaseSecretsBackend): """ Secrets backend for client contexts (workers, DAG processors, triggerers). - Routes connection and variable requests through SUPERVISOR_COMMS to the + Routes connection and variable requests through supervisor-comms to the Execution API server. This backend should only be registered in client processes, not in API server/scheduler processes. """ def get_conn_value(self, conn_id: str, team_name: str | None = None) -> str | None: """ - Get connection URI via SUPERVISOR_COMMS. + Get connection URI via supervisor-comms. Not used since we override get_connection directly. """ @@ -45,7 +45,7 @@ def get_conn_value(self, conn_id: str, team_name: str | None = None) -> str | No def get_connection(self, conn_id: str, team_name: str | None = None) -> Connection | None: # type: ignore[override] """ - Return connection object by routing through SUPERVISOR_COMMS. + Return connection object by routing through supervisor-comms. :param conn_id: connection id :param team_name: Name of the team associated to the task trying to access the connection. @@ -54,10 +54,10 @@ def get_connection(self, conn_id: str, team_name: str | None = None) -> Connecti """ from airflow.sdk.execution_time.comms import ErrorResponse, GetConnection from airflow.sdk.execution_time.context import _process_connection_result_conn - from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + from airflow.sdk.execution_time.task_runner import supervisor_send try: - msg = SUPERVISOR_COMMS.send(GetConnection(conn_id=conn_id)) + msg = supervisor_send(GetConnection(conn_id=conn_id)) if isinstance(msg, ErrorResponse): # Connection not found or error occurred @@ -86,13 +86,13 @@ def get_connection(self, conn_id: str, team_name: str | None = None) -> Connecti # Fall through to the general exception handler for other RuntimeErrors return None except Exception: - # If SUPERVISOR_COMMS fails for any reason, return None + # If supervisor-comms fails for any reason, return None # to allow fallback to other backends return None def get_variable(self, key: str, team_name: str | None = None) -> str | None: """ - Return variable value by routing through SUPERVISOR_COMMS. + Return variable value by routing through supervisor-comms. :param key: Variable key :param team_name: Name of the team associated to the task trying to access the variable. @@ -100,10 +100,10 @@ def get_variable(self, key: str, team_name: str | None = None) -> str | None: :return: Variable value or None if not found """ from airflow.sdk.execution_time.comms import ErrorResponse, GetVariable, VariableResult - from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + from airflow.sdk.execution_time.task_runner import supervisor_send try: - msg = SUPERVISOR_COMMS.send(GetVariable(key=key)) + msg = supervisor_send(GetVariable(key=key)) if isinstance(msg, ErrorResponse): # Variable not found or error occurred @@ -114,23 +114,23 @@ def get_variable(self, key: str, team_name: str | None = None) -> str | None: return msg.value # Already a string | None return None except Exception: - # If SUPERVISOR_COMMS fails for any reason, return None + # If supervisor-comms fails for any reason, return None # to allow fallback to other backends return None async def aget_connection(self, conn_id: str) -> Connection | None: # type: ignore[override] """ - Return connection object asynchronously via SUPERVISOR_COMMS. + Return connection object asynchronously via supervisor-comms. :param conn_id: connection id :return: Connection object or None if not found """ from airflow.sdk.execution_time.comms import ErrorResponse, GetConnection from airflow.sdk.execution_time.context import _process_connection_result_conn - from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + from airflow.sdk.execution_time.task_runner import supervisor_asend try: - msg = await SUPERVISOR_COMMS.asend(GetConnection(conn_id=conn_id)) + msg = await supervisor_asend(GetConnection(conn_id=conn_id)) if isinstance(msg, ErrorResponse): # Connection not found or error occurred @@ -139,22 +139,22 @@ async def aget_connection(self, conn_id: str) -> Connection | None: # type: ign # Convert ExecutionAPI response to SDK Connection return _process_connection_result_conn(msg) except Exception: - # If SUPERVISOR_COMMS fails for any reason, return None + # If supervisor-comms fails for any reason, return None # to allow fallback to other backends return None async def aget_variable(self, key: str) -> str | None: """ - Return variable value asynchronously via SUPERVISOR_COMMS. + Return variable value asynchronously via supervisor-comms. :param key: Variable key :return: Variable value or None if not found """ from airflow.sdk.execution_time.comms import ErrorResponse, GetVariable, VariableResult - from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS + from airflow.sdk.execution_time.task_runner import supervisor_asend try: - msg = await SUPERVISOR_COMMS.asend(GetVariable(key=key)) + msg = await supervisor_asend(GetVariable(key=key)) if isinstance(msg, ErrorResponse): # Variable not found or error occurred @@ -165,6 +165,6 @@ async def aget_variable(self, key: str) -> str | None: return msg.value # Already a string | None return None except Exception: - # If SUPERVISOR_COMMS fails for any reason, return None + # If supervisor-comms fails for any reason, return None # to allow fallback to other backends return None diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index b87131aa7336d..6f041f08cb153 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -64,6 +64,7 @@ from airflow.sdk.execution_time.comms import ( AssetEventsResult, AssetResult, + CommsDecoder, ConnectionResult, CreateHITLDetailPayload, DagRunResult, @@ -134,6 +135,7 @@ from airflow.executors.workloads import BundleInfo from airflow.sdk.bases.secrets_backend import BaseSecretsBackend from airflow.sdk.definitions.connection import Connection + from airflow.sdk.execution_time.task_runner import ToTask from airflow.sdk.types import RuntimeTaskInstanceProtocol as RuntimeTI @@ -845,7 +847,7 @@ def _fetch_remote_logging_conn(conn_id: str, client: Client) -> Connection | Non Connection object or None if not found. """ # Since we need to use the API Client directly, we can't use Connection.get as that would try to use - # SUPERVISOR_COMMS + # supervisor-comms # TODO: Store in the SecretsCache if its enabled - see #48858 @@ -888,7 +890,7 @@ def _remote_logging_conn(client: Client): connection it needs, now, directly from the API client, and store it in an env var, so that when the logging hook tries to get the connection it can find it easily from the env vars. - This is needed as the BaseHook.get_connection looks for SUPERVISOR_COMMS, but we are still in the + This is needed as the BaseHook.get_connection looks for supervisor-comms, but we are still in the supervisor process when this is needed, so that doesn't exist yet. The connection details are fetched eagerly on every invocation to avoid retaining @@ -1621,7 +1623,7 @@ def start( # type: ignore[override] from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance, finalize, run supervisor.comms = InProcessSupervisorComms(supervisor=supervisor) - with set_supervisor_comms(supervisor.comms): + with set_supervisor_comms(supervisor.comms): # type: ignore[arg-type] supervisor.ti = what # type: ignore[assignment] # We avoid calling `task_runner.startup()` because we are already inside a @@ -1689,7 +1691,7 @@ def run_trigger_in_process(cls, *, trigger, ti): Run a trigger in-process for testing, similar to how we run tasks. This creates a minimal supervisor instance specifically for trigger execution - and ensures the trigger has access to SUPERVISOR_COMMS for connection access. + and ensures the trigger has access to supervisor-comms for connection access. """ # Create a minimal supervisor instance for trigger execution supervisor = cls( @@ -1721,36 +1723,31 @@ def final_state(self): @contextmanager -def set_supervisor_comms(temp_comms): +def set_supervisor_comms(temp_comms: CommsDecoder[ToTask, ToSupervisor] | None): """ - Temporarily override `SUPERVISOR_COMMS` in the `task_runner` module. + Temporarily override `SupervisorComms()` in the `task_runner` module. This is used to simulate task-runner ↔ supervisor communication in-process, by injecting a test Comms implementation (e.g. `InProcessSupervisorComms`) in place of the real inter-process communication layer. Some parts of the code (e.g. models.Variable.get) check for the presence - of `task_runner.SUPERVISOR_COMMS` to determine if the code is running in a Task SDK execution context. + of `task_runner.SupervisorComms()` to determine if the code is running in a Task SDK execution context. This override ensures those code paths behave correctly during in-process tests. """ - from airflow.sdk.execution_time import task_runner - - sentinel = object() - old = getattr(task_runner, "SUPERVISOR_COMMS", sentinel) + from airflow.sdk.execution_time.task_runner import SupervisorComms - if temp_comms is not None: - task_runner.SUPERVISOR_COMMS = temp_comms - elif old is not sentinel: - delattr(task_runner, "SUPERVISOR_COMMS") + svcomms = SupervisorComms() + old = svcomms.get_comms() + if temp_comms: + svcomms.set_comms(temp_comms) + else: + svcomms.reset_comms() try: yield finally: - if old is sentinel: - if hasattr(task_runner, "SUPERVISOR_COMMS"): - delattr(task_runner, "SUPERVISOR_COMMS") - else: - task_runner.SUPERVISOR_COMMS = old + svcomms.set_comms(old) def run_task_in_process(ti: TaskInstance, task) -> TaskRunResult: @@ -1910,13 +1907,13 @@ def ensure_secrets_backend_loaded() -> list[BaseSecretsBackend]: Initialize secrets backend with auto-detected context. Detection strategy: - 1. SUPERVISOR_COMMS exists and is set → client chain (ExecutionAPISecretsBackend) + 1. supervisor-comms exists and is set → client chain (ExecutionAPISecretsBackend) 2. _AIRFLOW_PROCESS_CONTEXT=server env var → server chain (MetastoreBackend) 3. Neither → fallback chain (only env vars + external backends, no MetastoreBackend) - Client contexts: task runner in worker (has SUPERVISOR_COMMS) + Client contexts: task runner in worker (has supervisor-comms) Server contexts: API server, scheduler (set _AIRFLOW_PROCESS_CONTEXT=server) - Fallback contexts: supervisor, unknown contexts (no SUPERVISOR_COMMS, no env var) + Fallback contexts: supervisor, unknown contexts (no supervisor-comms, no env var) The fallback chain ensures supervisor can use external secrets (AWS Secrets Manager, Vault, etc.) while falling back to API client, without trying MetastoreBackend. @@ -1926,12 +1923,12 @@ def ensure_secrets_backend_loaded() -> list[BaseSecretsBackend]: from airflow.sdk.configuration import ensure_secrets_loaded from airflow.sdk.execution_time.secrets import DEFAULT_SECRETS_SEARCH_PATH_WORKERS - # 1. Check for client context (SUPERVISOR_COMMS) + # 1. Check for client context (supervisor-comms) try: - from airflow.sdk.execution_time import task_runner + from airflow.sdk.execution_time.task_runner import SupervisorComms - if hasattr(task_runner, "SUPERVISOR_COMMS") and task_runner.SUPERVISOR_COMMS is not None: - # Client context: task runner with SUPERVISOR_COMMS + if SupervisorComms().is_initialized(): + # Client context: task runner with supervisor-comms return ensure_secrets_loaded(default_backends=DEFAULT_SECRETS_SEARCH_PATH_WORKERS) except (ImportError, AttributeError): pass diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index 6057706699889..4741de18272d9 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -453,9 +453,7 @@ def get_first_reschedule_date(self, context: Context) -> AwareDatetime | None: log.debug("Requesting first reschedule date from supervisor") - response = SUPERVISOR_COMMS.send( - msg=GetTaskRescheduleStartDate(ti_id=self.id, try_number=first_try_number) - ) + response = supervisor_send(msg=GetTaskRescheduleStartDate(ti_id=self.id, try_number=first_try_number)) if TYPE_CHECKING: assert isinstance(response, TaskRescheduleStartDate) @@ -477,7 +475,7 @@ def get_previous_dagrun(self, state: str | None = None) -> DagRun | None: if dag_run.logical_date is None: return None - response = SUPERVISOR_COMMS.send( + response = supervisor_send( msg=GetPreviousDagRun(dag_id=self.dag_id, logical_date=dag_run.logical_date, state=state) ) @@ -511,7 +509,7 @@ def get_previous_ti( if effective_logical_date is None and dag_run and dag_run.logical_date: effective_logical_date = dag_run.logical_date - response = SUPERVISOR_COMMS.send( + response = supervisor_send( msg=GetPreviousTI( dag_id=self.dag_id, task_id=self.task_id, @@ -537,7 +535,7 @@ def get_ti_count( states: list[str] | None = None, ) -> int: """Return the number of task instances matching the given criteria.""" - response = SUPERVISOR_COMMS.send( + response = supervisor_send( GetTICount( dag_id=dag_id, map_index=map_index, @@ -564,7 +562,7 @@ def get_task_states( run_ids: list[str] | None = None, ) -> dict[str, Any]: """Return the task states matching the given criteria.""" - response = SUPERVISOR_COMMS.send( + response = supervisor_send( GetTaskStates( dag_id=dag_id, map_index=map_index, @@ -583,7 +581,7 @@ def get_task_states( @staticmethod def get_task_breadcrumbs(dag_id: str, run_id: str) -> Iterable[dict[str, Any]]: """Return task breadcrumbs for the given dag run.""" - response = SUPERVISOR_COMMS.send(GetTaskBreadcrumbs(dag_id=dag_id, run_id=run_id)) + response = supervisor_send(GetTaskBreadcrumbs(dag_id=dag_id, run_id=run_id)) if TYPE_CHECKING: assert isinstance(response, TaskBreadcrumbsResult) return response.breadcrumbs @@ -596,7 +594,7 @@ def get_dr_count( states: list[str] | None = None, ) -> int: """Return the number of Dag runs matching the given criteria.""" - response = SUPERVISOR_COMMS.send( + response = supervisor_send( GetDRCount( dag_id=dag_id, logical_dates=logical_dates, @@ -613,7 +611,7 @@ def get_dr_count( @staticmethod def get_dagrun_state(dag_id: str, run_id: str) -> str: """Return the state of the Dag run with the given Run ID.""" - response = SUPERVISOR_COMMS.send(msg=GetDagRunState(dag_id=dag_id, run_id=run_id)) + response = supervisor_send(msg=GetDagRunState(dag_id=dag_id, run_id=run_id)) if TYPE_CHECKING: assert isinstance(response, DagRunStateResult) @@ -732,17 +730,74 @@ def parse(what: StartupDetails, log: Logger) -> RuntimeTaskInstance: ) -# This global variable will be used by Connection/Variable/XCom classes, or other parts of the task's execution, +# This global class will be used by Connection/Variable/XCom classes, or other parts of the task's execution, # to send requests back to the supervisor process. # -# Why it needs to be a global: +# Why it needs to be a global singleton class: # - Many parts of Airflow's codebase (e.g., connections, variables, and XComs) may rely on making dynamic requests # to the parent process during task execution. # - These calls occur in various locations and cannot easily pass the `CommsDecoder` instance through the # deeply nested execution stack. -# - By defining `SUPERVISOR_COMMS` as a global, it ensures that this communication mechanism is readily +# - By defining this as a singleton class with accessors, it ensures that this communication mechanism is readily # accessible wherever needed during task execution without modifying every layer of the call stack. -SUPERVISOR_COMMS: CommsDecoder[ToTask, ToSupervisor] +# Not perfect but getter than a global variable. +class _UnsetComms(CommsDecoder[ToTask, ToSupervisor]): + def __init__(self): + self.id_counter = self.socket = None # type: ignore[assignment] + + def send(self, msg: ToSupervisor) -> None: + raise RuntimeError( + "Supervisor comms not initialized yet. Call SupervisorComms().set_comms() before using." + ) + + async def asend(self, msg: ToSupervisor) -> None: + raise RuntimeError( + "Supervisor comms not initialized yet. Call SupervisorComms().set_comms() before using." + ) + + +class SupervisorComms: + _instance = None + + def __new__(cls, *args, **kwargs): + if not cls._instance: + cls._instance = super().__new__(cls) + cls._instance._comms = _UnsetComms() + return cls._instance + + def get_comms(self) -> CommsDecoder: + """Get the global supervisor comms instance.""" + return self._comms + + def set_comms(self, comms: CommsDecoder) -> None: + """Set the global supervisor comms instance.""" + self._comms = comms + + def reset_comms(self) -> None: + """Reset the global supervisor comms instance to initial state.""" + self._comms = _UnsetComms() + + def is_initialized(self) -> bool: + """Check if the global supervisor comms instance is initialized.""" + return type(self._comms) is not _UnsetComms + + def send(self, msg: ToSupervisor) -> ToTask | None: + """Send a message to the supervisor.""" + return self._comms.send(msg) + + async def asend(self, msg: ToSupervisor) -> ToTask | None: + """Send a message to the supervisor asynchronously.""" + return await self._comms.asend(msg) + + +def supervisor_send(msg: ToSupervisor) -> ToTask | None: + """Send a message to the supervisor as convenience for SupervisorComms().send().""" + return SupervisorComms().send(msg) + + +async def supervisor_asend(msg: ToSupervisor) -> ToTask | None: + """Send a message to the supervisor as convenience for SupervisorComms().asend().""" + return await SupervisorComms().asend(msg) # State machine! @@ -807,7 +862,7 @@ def startup() -> tuple[RuntimeTaskInstance, Context, Logger]: log.debug("Using serialized startup message from environment", msg=msg) else: # normal entry point - msg = SUPERVISOR_COMMS._get_response() # type: ignore[assignment] + msg = SupervisorComms().get_comms()._get_response() # type: ignore[assignment] if not isinstance(msg, StartupDetails): raise RuntimeError(f"Unhandled startup message {type(msg)} {msg}") @@ -839,12 +894,12 @@ def startup() -> tuple[RuntimeTaskInstance, Context, Logger]: os.environ["_AIRFLOW__REEXECUTED_PROCESS"] = "1" # store startup message in environment for re-exec process os.environ["_AIRFLOW__STARTUP_MSG"] = msg.model_dump_json() - os.set_inheritable(SUPERVISOR_COMMS.socket.fileno(), True) + os.set_inheritable(SupervisorComms().get_comms().socket.fileno(), True) # Import main directly from the module instead of re-executing the file. # This ensures that when other parts modules import # airflow.sdk.execution_time.task_runner, they get the same module instance - # with the properly initialized SUPERVISOR_COMMS global variable. + # with the properly initialized supervisor_comms instance. # If we re-executed the module with `python -m`, it would load as __main__ and future # imports would get a fresh copy without the initialized globals. rexec_python_code = "from airflow.sdk.execution_time.task_runner import main; main()" @@ -987,7 +1042,7 @@ def _prepare(ti: RuntimeTaskInstance, log: Logger, context: Context) -> ToSuperv if rendered_fields := _serialize_rendered_fields(ti.task): # so that we do not call the API unnecessarily - SUPERVISOR_COMMS.send(msg=SetRenderedFields(rendered_fields=rendered_fields)) + supervisor_send(msg=SetRenderedFields(rendered_fields=rendered_fields)) # Try to render map_index_template early with available context (will be re-rendered after execution) # This provides a partial label during task execution for templates using pre-execution context @@ -996,7 +1051,7 @@ def _prepare(ti: RuntimeTaskInstance, log: Logger, context: Context) -> ToSuperv if rendered_map_index := _render_map_index(context, ti=ti, log=log): ti.rendered_map_index = rendered_map_index log.debug("Sending early rendered map index", length=len(rendered_map_index)) - SUPERVISOR_COMMS.send(msg=SetRenderedMapIndex(rendered_map_index=rendered_map_index)) + supervisor_send(msg=SetRenderedMapIndex(rendered_map_index=rendered_map_index)) except Exception: log.debug( "Early rendering of map_index_template failed, will retry after task execution", exc_info=True @@ -1020,7 +1075,7 @@ def _validate_task_inlets_and_outlets(*, ti: RuntimeTaskInstance, log: Logger) - if not ti.task.inlets and not ti.task.outlets: return - inactive_assets_resp = SUPERVISOR_COMMS.send(msg=ValidateInletsAndOutlets(ti_id=ti.id)) + inactive_assets_resp = supervisor_send(msg=ValidateInletsAndOutlets(ti_id=ti.id)) if TYPE_CHECKING: assert isinstance(inactive_assets_resp, InactiveAssetsResult) if inactive_assets := inactive_assets_resp.inactive_assets: @@ -1139,16 +1194,14 @@ def _on_term(signum, frame): ti.rendered_map_index = _render_map_index(context, ti=ti, log=log) # Send update only if value changed (e.g., user set context variables during execution) if ti.rendered_map_index and ti.rendered_map_index != previous_rendered_map_index: - SUPERVISOR_COMMS.send( - msg=SetRenderedMapIndex(rendered_map_index=ti.rendered_map_index) - ) + supervisor_send(msg=SetRenderedMapIndex(rendered_map_index=ti.rendered_map_index)) raise else: # If the task succeeded, render normally to let rendering error bubble up. previous_rendered_map_index = ti.rendered_map_index ti.rendered_map_index = _render_map_index(context, ti=ti, log=log) # Send update only if value changed (e.g., user set context variables during execution) if ti.rendered_map_index and ti.rendered_map_index != previous_rendered_map_index: - SUPERVISOR_COMMS.send(msg=SetRenderedMapIndex(rendered_map_index=ti.rendered_map_index)) + supervisor_send(msg=SetRenderedMapIndex(rendered_map_index=ti.rendered_map_index)) _push_xcom_if_needed(result, ti, log) @@ -1156,7 +1209,7 @@ def _on_term(signum, frame): except DownstreamTasksSkipped as skip: log.info("Skipping downstream tasks.") tasks_to_skip = skip.tasks if isinstance(skip.tasks, list) else [skip.tasks] - SUPERVISOR_COMMS.send(msg=SkipDownstreamTasks(tasks=tasks_to_skip)) + supervisor_send(msg=SkipDownstreamTasks(tasks=tasks_to_skip)) msg, state = _handle_current_task_success(context, ti) except DagRunTriggerException as drte: msg, state = _handle_trigger_dag_run(drte, context, ti, log) @@ -1218,7 +1271,7 @@ def _on_term(signum, frame): error = e finally: if msg: - SUPERVISOR_COMMS.send(msg=msg) + supervisor_send(msg=msg) # Return the message to make unit tests easier too ti.state = state @@ -1282,7 +1335,7 @@ def _handle_trigger_dag_run( ) -> tuple[ToSupervisor, TaskInstanceState]: """Handle exception from TriggerDagRunOperator.""" log.info("Triggering Dag Run.", trigger_dag_id=drte.trigger_dag_id) - comms_msg = SUPERVISOR_COMMS.send( + comms_msg = supervisor_send( TriggerDagRun( dag_id=drte.trigger_dag_id, run_id=drte.dag_run_id, @@ -1349,9 +1402,7 @@ def _handle_trigger_dag_run( ) time.sleep(drte.poke_interval) - comms_msg = SUPERVISOR_COMMS.send( - GetDagRunState(dag_id=drte.trigger_dag_id, run_id=drte.dag_run_id) - ) + comms_msg = supervisor_send(GetDagRunState(dag_id=drte.trigger_dag_id, run_id=drte.dag_run_id)) if TYPE_CHECKING: assert isinstance(comms_msg, DagRunStateResult) if comms_msg.state in drte.failed_states: @@ -1635,7 +1686,7 @@ def finalize( if getattr(ti.task, "overwrite_rtif_after_execution", False): log.debug("Overwriting Rendered template fields.") if ti.task.template_fields: - SUPERVISOR_COMMS.send(SetRenderedFields(rendered_fields=_serialize_rendered_fields(ti.task))) + supervisor_send(SetRenderedFields(rendered_fields=_serialize_rendered_fields(ti.task))) log.debug("Running finalizers", ti=ti) if state == TaskInstanceState.SUCCESS: @@ -1684,8 +1735,7 @@ def finalize( def main(): log = structlog.get_logger(logger_name="task") - global SUPERVISOR_COMMS - SUPERVISOR_COMMS = CommsDecoder[ToTask, ToSupervisor](log=log) + SupervisorComms().set_comms(CommsDecoder[ToTask, ToSupervisor](log=log)) Stats.initialize( is_statsd_datadog_enabled=conf.getboolean("metrics", "statsd_datadog_enabled"), @@ -1711,9 +1761,10 @@ def main(): finally: # Ensure the request socket is closed on the child side in all circumstances # before the process fully terminates. - if SUPERVISOR_COMMS and SUPERVISOR_COMMS.socket: + svcomms = SupervisorComms() + if svcomms.is_initialized() and svcomms.get_comms().socket: with suppress(Exception): - SUPERVISOR_COMMS.socket.close() + svcomms.get_comms().socket.close() def reinit_supervisor_comms() -> None: @@ -1726,15 +1777,16 @@ def reinit_supervisor_comms() -> None: """ import socket - if "SUPERVISOR_COMMS" not in globals(): - global SUPERVISOR_COMMS + if not SupervisorComms().is_initialized(): log = structlog.get_logger(logger_name="task") fd = int(os.environ.get("__AIRFLOW_SUPERVISOR_FD", "0")) - SUPERVISOR_COMMS = CommsDecoder[ToTask, ToSupervisor](log=log, socket=socket.socket(fileno=fd)) + SupervisorComms().set_comms( + CommsDecoder[ToTask, ToSupervisor](log=log, socket=socket.socket(fileno=fd)) + ) - logs = SUPERVISOR_COMMS.send(ResendLoggingFD()) + logs = supervisor_send(ResendLoggingFD()) if isinstance(logs, SentFDs): from airflow.sdk.log import configure_logging diff --git a/task-sdk/src/airflow/sdk/log.py b/task-sdk/src/airflow/sdk/log.py index 9e0835a1efb44..ee2125d0ca8d1 100644 --- a/task-sdk/src/airflow/sdk/log.py +++ b/task-sdk/src/airflow/sdk/log.py @@ -258,13 +258,12 @@ def mask_secret(secret: JsonValue, name: str | None = None) -> None: _secrets_masker().add_mask(secret, name) - with suppress(Exception): + with suppress(Exception, RuntimeError): # Try to tell supervisor (only if in task execution context) from airflow.sdk.execution_time import task_runner from airflow.sdk.execution_time.comms import MaskSecret - if comms := getattr(task_runner, "SUPERVISOR_COMMS", None): - comms.send(MaskSecret(value=secret, name=name)) + task_runner.supervisor_send(MaskSecret(value=secret, name=name)) def reset_logging(): diff --git a/task-sdk/tests/task_sdk/execution_time/test_context.py b/task-sdk/tests/task_sdk/execution_time/test_context.py index a512062ff05cc..1cd7790c10569 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_context.py +++ b/task-sdk/tests/task_sdk/execution_time/test_context.py @@ -891,7 +891,7 @@ def get_connection(self, conn_id: str) -> Connection | None: result = await _async_get_connection("test_conn") assert result == sample_connection - # Should not have tried SUPERVISOR_COMMS since secrets backend had the connection + # Should not have tried supervisor-comms since secrets backend had the connection mock_supervisor_comms.send.assert_not_called() mock_supervisor_comms.asend.assert_not_called() diff --git a/task-sdk/tests/task_sdk/execution_time/test_secrets.py b/task-sdk/tests/task_sdk/execution_time/test_secrets.py index 8f9745b0ffeca..7b12e74d8c5d0 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_secrets.py +++ b/task-sdk/tests/task_sdk/execution_time/test_secrets.py @@ -27,7 +27,7 @@ class TestExecutionAPISecretsBackend: """Test ExecutionAPISecretsBackend.""" def test_get_connection_via_supervisor_comms(self, mock_supervisor_comms): - """Test that connection is retrieved via SUPERVISOR_COMMS.""" + """Test that connection is retrieved via supervisor-comms.""" from airflow.sdk.api.datamodels._generated import ConnectionResponse from airflow.sdk.execution_time.comms import ConnectionResult @@ -67,7 +67,7 @@ def test_get_connection_not_found(self, mock_supervisor_comms): mock_supervisor_comms.send.assert_called_once() def test_get_variable_via_supervisor_comms(self, mock_supervisor_comms): - """Test that variable is retrieved via SUPERVISOR_COMMS.""" + """Test that variable is retrieved via supervisor-comms.""" from airflow.sdk.execution_time.comms import VariableResult # Mock variable response @@ -125,7 +125,7 @@ def test_runtime_error_triggers_greenback_fallback(self, mocker, mock_supervisor """ Test that RuntimeError from async_to_sync triggers greenback fallback. - This test verifies the fix for issue #57145: when SUPERVISOR_COMMS.send() + This test verifies the fix for issue #57145: when supervisor_send() raises the specific RuntimeError about async_to_sync in an event loop, the backend catches it and uses greenback to call aget_connection(). """ @@ -161,7 +161,7 @@ def greenback_await_side_effect(coro): # Mock aget_connection to return the expected connection directly. # We need to mock this because the real aget_connection would try to - # use SUPERVISOR_COMMS.asend which is not set up for this test. + # use supervisor_asend which is not set up for this test. async def mock_aget_connection(self, conn_id): return expected_conn @@ -183,7 +183,7 @@ class TestContextDetection: """Test context detection in ensure_secrets_backend_loaded.""" def test_client_context_with_supervisor_comms(self, mock_supervisor_comms): - """Client context: SUPERVISOR_COMMS set → uses worker chain.""" + """Client context: get_supervisor_comms() set → uses worker chain.""" from airflow.sdk.execution_time.supervisor import ensure_secrets_backend_loaded backends = ensure_secrets_backend_loaded() @@ -191,32 +191,21 @@ def test_client_context_with_supervisor_comms(self, mock_supervisor_comms): assert "ExecutionAPISecretsBackend" in backend_classes assert "MetastoreBackend" not in backend_classes - def test_server_context_with_env_var(self, monkeypatch): + def test_server_context_with_env_var(self, monkeypatch, mock_unset_supervisor_comms): """Server context: env var set → uses server chain.""" - import sys - from airflow.sdk.execution_time.supervisor import ensure_secrets_backend_loaded monkeypatch.setenv("_AIRFLOW_PROCESS_CONTEXT", "server") - # Ensure SUPERVISOR_COMMS is not available - if "airflow.sdk.execution_time.task_runner" in sys.modules: - monkeypatch.delitem(sys.modules, "airflow.sdk.execution_time.task_runner") backends = ensure_secrets_backend_loaded() backend_classes = [type(b).__name__ for b in backends] assert "MetastoreBackend" in backend_classes assert "ExecutionAPISecretsBackend" not in backend_classes - def test_fallback_context_no_markers(self, monkeypatch): - """Fallback context: no SUPERVISOR_COMMS, no env var → only env vars + external.""" - import sys - + def test_fallback_context_no_markers(self, monkeypatch, mock_unset_supervisor_comms): + """Fallback context: no supervisor-comms, no env var → only env vars + external.""" from airflow.sdk.execution_time.supervisor import ensure_secrets_backend_loaded - # Ensure no SUPERVISOR_COMMS - if "airflow.sdk.execution_time.task_runner" in sys.modules: - monkeypatch.delitem(sys.modules, "airflow.sdk.execution_time.task_runner") - # Ensure no env var monkeypatch.delenv("_AIRFLOW_PROCESS_CONTEXT", raising=False) diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py index b45f57cd4256e..72945e67bf0df 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py +++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py @@ -293,6 +293,14 @@ def subprocess_main(): instant = timezone.datetime(2024, 11, 7, 12, 34, 56, 78901) time_machine.move_to(instant, tick=False) + # Depending on if another test made the env dirty we need to suppress the fork() warning + import warnings + + warnings.filterwarnings( + "ignore", + message="This process .* is multi-threaded, use of fork.* may lead to deadlocks in the child.", + ) + proc = ActivitySubprocess.start( dag_rel_path=os.devnull, bundle_info=FAKE_BUNDLE, @@ -2636,40 +2644,24 @@ class TestSetSupervisorComms: class DummyComms: pass - @pytest.fixture(autouse=True) - def cleanup_supervisor_comms(self): - # Ensure clean state before/after test - if hasattr(task_runner, "SUPERVISOR_COMMS"): - delattr(task_runner, "SUPERVISOR_COMMS") - yield - if hasattr(task_runner, "SUPERVISOR_COMMS"): - delattr(task_runner, "SUPERVISOR_COMMS") - - def test_set_supervisor_comms_overrides_and_restores(self): - task_runner.SUPERVISOR_COMMS = self.DummyComms() - original = task_runner.SUPERVISOR_COMMS + def test_set_supervisor_comms_overrides_and_restores(self, mock_unset_supervisor_comms): + svcomm = task_runner.SupervisorComms() + original = svcomm.get_comms() replacement = self.DummyComms() with set_supervisor_comms(replacement): - assert task_runner.SUPERVISOR_COMMS is replacement - assert task_runner.SUPERVISOR_COMMS is original + assert svcomm.get_comms() is replacement + assert svcomm.get_comms() is original - def test_set_supervisor_comms_sets_temporarily_when_not_set(self): - assert not hasattr(task_runner, "SUPERVISOR_COMMS") + def test_set_supervisor_comms_sets_temporarily_when_not_set(self, mock_unset_supervisor_comms): + svcomm = task_runner.SupervisorComms() + assert svcomm.is_initialized() is False replacement = self.DummyComms() with set_supervisor_comms(replacement): - assert task_runner.SUPERVISOR_COMMS is replacement - assert not hasattr(task_runner, "SUPERVISOR_COMMS") - - def test_set_supervisor_comms_unsets_temporarily_when_not_set(self): - assert not hasattr(task_runner, "SUPERVISOR_COMMS") - - # This will delete an attribute that isn't set, and restore it likewise - with set_supervisor_comms(None): - assert not hasattr(task_runner, "SUPERVISOR_COMMS") - - assert not hasattr(task_runner, "SUPERVISOR_COMMS") + assert svcomm.get_comms() is replacement + assert svcomm.is_initialized() is True + assert svcomm.is_initialized() is False class TestInProcessTestSupervisor: diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index 9cf63c654aaf5..bb33e2ffef2fa 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -2563,7 +2563,7 @@ def execute(self, context): assert ( call( - msg=SetRenderedFields( + SetRenderedFields( rendered_fields={ "env_vars": "Truncated. You can change this behaviour in [core]max_templated_field_length. \"{'TEST_URL_0': '***', 'TEST_URL_1': '***', 'TEST_URL_10': '***', 'TEST_URL_11': '***', 'TEST_URL_12': '***', 'TEST_URL_13': '***', 'TEST_URL_14': '***', 'TEST_URL_15': '***', 'TEST_URL_16': '***', 'TEST_URL_17': '***', 'TEST_URL_18': '***', 'TEST_URL_19': '***', 'TEST_URL_2': '***', 'TEST_URL_20': '***', 'TEST_URL_21': '***', 'TEST_URL_22': '***', 'TEST_URL_23': '***', 'TEST_URL_24': '***', 'TEST_URL_25': '***', 'TEST_URL_26': '***', 'TEST_URL_27': '***', 'TEST_URL_28': '***', 'TEST_URL_29': '***', 'TEST_URL_3': '***', 'TEST_URL_30': '***', 'TEST_URL_31': '***', 'TEST_URL_32': '***', 'TEST_URL_33': '***', 'TEST_URL_34': '***', 'TEST_URL_35': '***', 'TEST_URL_36': '***', 'TEST_URL_37': '***', 'TEST_URL_38': '***', 'TEST_URL_39': '***', 'TEST_URL_4': '***', 'TEST_URL_40': '***', 'TEST_URL_41': '***', 'TEST_URL_42': '***', 'TEST_URL_43': '***', 'TEST_URL_44': '***', 'TEST_URL_45': '***', 'TEST_URL_46': '***', 'TEST_URL_47': '***', 'TEST_URL_48': '***', 'TEST_URL_49': '***', 'TEST_URL_5': '***', 'TEST_URL_6': '***', 'TEST_URL_7': '***', 'TEST_URL_8': '***', 'TEST_URL_9': '***'}\"... ", "region": "us-west-2", @@ -3073,7 +3073,7 @@ def execute(self, context): assert ( call( - msg=SetRenderedFields( + SetRenderedFields( rendered_fields={"username": "***", "region": "us-west-2"}, type="SetRenderedFields", )