diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index 7a6ea889c892a..428c50aa29d6a 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -33,7 +33,6 @@ from collections.abc import Callable, Generator from contextlib import contextmanager, suppress from datetime import datetime, timezone -from functools import lru_cache from http import HTTPStatus from socket import socket, socketpair from typing import ( @@ -827,8 +826,10 @@ def _check_subprocess_exit( return self._exit_code -@lru_cache -def _get_remote_logging_conn(conn_id: str, client: Client) -> Connection | None: +_REMOTE_LOGGING_CONN_CACHE: dict[str, Connection | None] = {} + + +def _fetch_remote_logging_conn(conn_id: str, client: Client) -> Connection | None: """ Fetch and cache connection for remote logging. @@ -837,18 +838,22 @@ def _get_remote_logging_conn(conn_id: str, client: Client) -> Connection | None: client: API client for making requests Returns: - Connection object or None if not found + Connection object or None if not found. """ # Since we need to use the API Client directly, we can't use Connection.get as that would try to use # SUPERVISOR_COMMS # TODO: Store in the SecretsCache if its enabled - see #48858 + if conn_id in _REMOTE_LOGGING_CONN_CACHE: + return _REMOTE_LOGGING_CONN_CACHE[conn_id] + backends = ensure_secrets_backend_loaded() for secrets_backend in backends: try: conn = secrets_backend.get_connection(conn_id=conn_id) if conn: + _REMOTE_LOGGING_CONN_CACHE[conn_id] = conn return conn except Exception: log.exception( @@ -862,8 +867,12 @@ def _get_remote_logging_conn(conn_id: str, client: Client) -> Connection | None: conn_result = ConnectionResult.from_conn_response(conn) from airflow.sdk.definitions.connection import Connection - return Connection(**conn_result.model_dump(exclude={"type"}, by_alias=True)) - return None + result: Connection | None = Connection(**conn_result.model_dump(exclude={"type"}, by_alias=True)) + else: + result = None + + _REMOTE_LOGGING_CONN_CACHE[conn_id] = result + return result @contextlib.contextmanager @@ -878,7 +887,8 @@ def _remote_logging_conn(client: Client): This is needed as the BaseHook.get_connection looks for SUPERVISOR_COMMS, but we are still in the supervisor process when this is needed, so that doesn't exist yet. - This function uses @lru_cache for connection caching to avoid repeated API calls. + The connection details are fetched eagerly on every invocation to avoid retaining + per-task API client instances in global caches. """ from airflow.sdk.log import load_remote_conn_id, load_remote_log_handler @@ -887,8 +897,8 @@ def _remote_logging_conn(client: Client): yield return - # Use cached connection fetcher - conn = _get_remote_logging_conn(conn_id, client) + # Fetch connection details on-demand without caching the entire API client instance + conn = _fetch_remote_logging_conn(conn_id, client) if conn: key = f"AIRFLOW_CONN_{conn_id.upper()}" @@ -1899,9 +1909,11 @@ def supervise( if not dag_rel_path: raise ValueError("dag_path is required") + close_client = False if not client: limits = httpx.Limits(max_keepalive_connections=1, max_connections=10) client = Client(base_url=server or "", limits=limits, dry_run=dry_run, token=token) + close_client = True start = time.monotonic() @@ -1920,24 +1932,29 @@ def supervise( reset_secrets_masker() - process = ActivitySubprocess.start( - dag_rel_path=dag_rel_path, - what=ti, - client=client, - logger=logger, - bundle_info=bundle_info, - subprocess_logs_to_stdout=subprocess_logs_to_stdout, - ) + try: + process = ActivitySubprocess.start( + dag_rel_path=dag_rel_path, + what=ti, + client=client, + logger=logger, + bundle_info=bundle_info, + subprocess_logs_to_stdout=subprocess_logs_to_stdout, + ) - exit_code = process.wait() - end = time.monotonic() - log.info( - "Task finished", - task_instance_id=str(ti.id), - exit_code=exit_code, - duration=end - start, - final_state=process.final_state, - ) - if log_path and log_file_descriptor: - log_file_descriptor.close() - return exit_code + exit_code = process.wait() + end = time.monotonic() + log.info( + "Task finished", + task_instance_id=str(ti.id), + exit_code=exit_code, + duration=end - start, + final_state=process.final_state, + ) + return exit_code + finally: + if log_path and log_file_descriptor: + log_file_descriptor.close() + if close_client and client: + with suppress(Exception): + client.close() diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py index 49eaac96fa204..b0d0054cbfc52 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py +++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py @@ -2432,6 +2432,47 @@ def mock_upload_to_remote(process_log, ti): assert connection_available["conn_uri"] is not None, "Connection URI was None during upload" +def test_remote_logging_conn_caches_connection_not_client(monkeypatch): + """Test that connection caching doesn't retain API client references.""" + import gc + import weakref + + from airflow.sdk import log as sdk_log + from airflow.sdk.execution_time import supervisor + + class ExampleBackend: + def __init__(self): + self.calls = 0 + + def get_connection(self, conn_id: str): + self.calls += 1 + from airflow.sdk.definitions.connection import Connection + + return Connection(conn_id=conn_id, conn_type="example") + + backend = ExampleBackend() + monkeypatch.setattr(supervisor, "ensure_secrets_backend_loaded", lambda: [backend]) + monkeypatch.setattr(sdk_log, "load_remote_log_handler", lambda: object()) + monkeypatch.setattr(sdk_log, "load_remote_conn_id", lambda: "test_conn") + monkeypatch.delenv("AIRFLOW_CONN_TEST_CONN", raising=False) + + def noop_request(request: httpx.Request) -> httpx.Response: + return httpx.Response(200) + + clients = [] + for _ in range(3): + client = make_client(transport=httpx.MockTransport(noop_request)) + clients.append(weakref.ref(client)) + with _remote_logging_conn(client): + pass + client.close() + del client + + gc.collect() + assert backend.calls == 1, "Connection should be cached, not fetched multiple times" + assert all(ref() is None for ref in clients), "Client instances should be garbage collected" + + def test_process_log_messages_from_subprocess(monkeypatch, caplog): from airflow.sdk._shared.logging.structlog import PER_LOGGER_LEVELS