diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index 0e7b7f54cd1f5..90ab13305bc72 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -1320,15 +1320,22 @@ def set_supervisor_comms(temp_comms): """ from airflow.sdk.execution_time import task_runner - old = getattr(task_runner, "SUPERVISOR_COMMS", None) - task_runner.SUPERVISOR_COMMS = temp_comms + sentinel = object() + old = getattr(task_runner, "SUPERVISOR_COMMS", sentinel) + + if temp_comms is not None: + task_runner.SUPERVISOR_COMMS = temp_comms + elif old is not sentinel: + delattr(task_runner, "SUPERVISOR_COMMS") + try: yield finally: - if old is not None: - task_runner.SUPERVISOR_COMMS = old + if old is sentinel: + if hasattr(task_runner, "SUPERVISOR_COMMS"): + delattr(task_runner, "SUPERVISOR_COMMS") else: - delattr(task_runner, "SUPERVISOR_COMMS") + task_runner.SUPERVISOR_COMMS = old def run_task_in_process(ti: TaskInstance, task) -> TaskRunResult: diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py index 5690f9b418ec6..1e6aec6bed06d 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py +++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py @@ -51,6 +51,7 @@ TaskInstanceState, ) from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType +from airflow.sdk.execution_time import task_runner from airflow.sdk.execution_time.comms import ( AssetEventsResult, AssetResult, @@ -92,7 +93,15 @@ XComResult, ) from airflow.sdk.execution_time.secrets_masker import SecretsMasker -from airflow.sdk.execution_time.supervisor import BUFFER_SIZE, ActivitySubprocess, mkpipe, supervise +from airflow.sdk.execution_time.supervisor import ( + BUFFER_SIZE, + ActivitySubprocess, + InProcessSupervisorComms, + InProcessTestSupervisor, + mkpipe, + set_supervisor_comms, + supervise, +) from airflow.sdk.execution_time.task_runner import CommsDecoder from airflow.utils import timezone, timezone as tz @@ -1600,3 +1609,84 @@ def test_handle_requests_api_server_error(self, watched_subprocess, mocker): "message": str(error), "detail": error.response.json(), } + + +class TestSetSupervisorComms: + class DummyComms: + pass + + @pytest.fixture(autouse=True) + def cleanup_supervisor_comms(self): + # Ensure clean state before/after test + if hasattr(task_runner, "SUPERVISOR_COMMS"): + delattr(task_runner, "SUPERVISOR_COMMS") + yield + if hasattr(task_runner, "SUPERVISOR_COMMS"): + delattr(task_runner, "SUPERVISOR_COMMS") + + def test_set_supervisor_comms_overrides_and_restores(self): + task_runner.SUPERVISOR_COMMS = self.DummyComms() + original = task_runner.SUPERVISOR_COMMS + replacement = self.DummyComms() + + with set_supervisor_comms(replacement): + assert task_runner.SUPERVISOR_COMMS is replacement + assert task_runner.SUPERVISOR_COMMS is original + + def test_set_supervisor_comms_sets_temporarily_when_not_set(self): + assert not hasattr(task_runner, "SUPERVISOR_COMMS") + replacement = self.DummyComms() + + with set_supervisor_comms(replacement): + assert task_runner.SUPERVISOR_COMMS is replacement + assert not hasattr(task_runner, "SUPERVISOR_COMMS") + + def test_set_supervisor_comms_unsets_temporarily_when_not_set(self): + assert not hasattr(task_runner, "SUPERVISOR_COMMS") + + # This will delete an attribute that isn't set, and restore it likewise + with set_supervisor_comms(None): + assert not hasattr(task_runner, "SUPERVISOR_COMMS") + + assert not hasattr(task_runner, "SUPERVISOR_COMMS") + + +class TestInProcessTestSupervisor: + def test_inprocess_supervisor_comms_roundtrip(self): + """ + Test that InProcessSupervisorComms correctly sends a message to the supervisor, + and that the supervisor's response is received via the message queue. + + This verifies the end-to-end communication flow: + - send_request() dispatches a message to the supervisor + - the supervisor handles the request and appends a response via send_msg() + - get_message() returns the enqueued response + + This test mocks the supervisor's `_handle_request()` method to simulate + a simple echo-style response, avoiding full task execution. + """ + + class MinimalSupervisor(InProcessTestSupervisor): + def _handle_request(self, msg, log): + resp = VariableResult(key=msg.key, value="value") + self.send_msg(resp) + + supervisor = MinimalSupervisor( + id="test", + pid=123, + requests_fd=-1, + process=MagicMock(), + process_log=MagicMock(), + client=MagicMock(), + ) + comms = InProcessSupervisorComms(supervisor=supervisor) + supervisor.comms = comms + + test_msg = GetVariable(key="test_key") + + comms.send_request(log=MagicMock(), msg=test_msg) + + # Ensure we got back what we expect + response = comms.get_message() + assert isinstance(response, VariableResult) + assert response.value == "value"