diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index 2d083150920ab..20b1f6e0ce657 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -33,6 +33,7 @@ from collections.abc import Generator from contextlib import contextmanager, suppress from datetime import datetime, timezone +from functools import lru_cache from http import HTTPStatus from socket import socket, socketpair from typing import ( @@ -815,6 +816,82 @@ def _check_subprocess_exit( return self._exit_code +@lru_cache +def _get_remote_logging_conn(conn_id: str, client: Client) -> Connection | None: + """ + Fetch and cache connection for remote logging. + + Args: + conn_id: Connection ID to fetch + client: API client for making requests + + Returns: + Connection object or None if not found + """ + # Since we need to use the API Client directly, we can't use Connection.get as that would try to use + # SUPERVISOR_COMMS + + # TODO: Store in the SecretsCache if its enabled - see #48858 + + backends = ensure_secrets_backend_loaded() + for secrets_backend in backends: + try: + conn = secrets_backend.get_connection(conn_id=conn_id) + if conn: + return conn + except Exception: + log.exception( + "Unable to retrieve connection from secrets backend (%s). " + "Checking subsequent secrets backend.", + type(secrets_backend).__name__, + ) + + conn = client.connections.get(conn_id) + if isinstance(conn, ConnectionResponse): + conn_result = ConnectionResult.from_conn_response(conn) + from airflow.sdk.definitions.connection import Connection + + return Connection(**conn_result.model_dump(exclude={"type"}, by_alias=True)) + return None + + +@contextlib.contextmanager +def _remote_logging_conn(client: Client): + """ + Pre-fetch the needed remote logging connection with caching. + + If a remote logger is in use, and has the logging/remote_logging option set, we try to fetch the + connection it needs, now, directly from the API client, and store it in an env var, so that when the logging + hook tries to get the connection it can find it easily from the env vars. + + This is needed as the BaseHook.get_connection looks for SUPERVISOR_COMMS, but we are still in the + supervisor process when this is needed, so that doesn't exist yet. + + This function uses @lru_cache for connection caching to avoid repeated API calls. + """ + from airflow.sdk.log import load_remote_conn_id, load_remote_log_handler + + if load_remote_log_handler() is None or not (conn_id := load_remote_conn_id()): + # Nothing to do + yield + return + + # Use cached connection fetcher + conn = _get_remote_logging_conn(conn_id, client) + + if conn: + key = f"AIRFLOW_CONN_{conn_id.upper()}" + old = os.getenv(key) + os.environ[key] = conn.get_uri() + try: + yield + finally: + if old is None: + del os.environ[key] + else: + os.environ[key] = old + + @attrs.define(kw_only=True) class ActivitySubprocess(WatchedSubprocess): client: Client @@ -931,7 +1008,8 @@ def _upload_logs(self): """ from airflow.sdk.log import upload_to_remote - upload_to_remote(self.process_log, self.ti) + with _remote_logging_conn(self.client): + upload_to_remote(self.process_log, self.ti) def _monitor_subprocess(self): """ @@ -1637,68 +1715,6 @@ def ensure_secrets_backend_loaded() -> list[BaseSecretsBackend]: return backends -@contextlib.contextmanager -def _remote_logging_conn(client: Client): - """ - Pre-fetch the needed remote logging connection. - - If a remote logger is in use, and has the logging/remote_logging option set, we try to fetch the - connection it needs, now, directly from the API client, and store it in an env var, so that when the logging - hook tries to get the connection it - can find it easily from the env vars - - This is needed as the BaseHook.get_connection looks for SUPERVISOR_COMMS, but we are still in the - supervisor process when this is needed, so that doesn't exist yet. - """ - from airflow.sdk.log import load_remote_conn_id, load_remote_log_handler - - if load_remote_log_handler() is None or not (conn_id := load_remote_conn_id()): - # Nothing to do - yield - return - - # Since we need to use the API Client directly, we can't use Connection.get as that would try to use - # SUPERVISOR_COMMS - - # TODO: Store in the SecretsCache if its enabled - see #48858 - - def _get_conn() -> Connection | None: - backends = ensure_secrets_backend_loaded() - for secrets_backend in backends: - try: - conn = secrets_backend.get_connection(conn_id=conn_id) - if conn: - return conn - except Exception: - log.exception( - "Unable to retrieve connection from secrets backend (%s). " - "Checking subsequent secrets backend.", - type(secrets_backend).__name__, - ) - - conn = client.connections.get(conn_id) - if isinstance(conn, ConnectionResponse): - conn_result = ConnectionResult.from_conn_response(conn) - from airflow.sdk.definitions.connection import Connection - - return Connection(**conn_result.model_dump(exclude={"type"}, by_alias=True)) - return None - - if conn := _get_conn(): - key = f"AIRFLOW_CONN_{conn_id.upper()}" - old = os.getenv(key) - - os.environ[key] = conn.get_uri() - - try: - yield - finally: - if old is None: - del os.environ[key] - else: - os.environ[key] = old - - def _configure_logging(log_path: str, client: Client) -> tuple[FilteringBoundLogger, BinaryIO | TextIO]: # If we are told to write logs to a file, redirect the task logger to it. Make sure we append to the # file though, otherwise when we resume we would lose the logs from the start->deferral segment if it diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py index 196494c258ddf..ace3bb5c35bff 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py +++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py @@ -2176,7 +2176,7 @@ def _handle_request(self, msg, log, req_id): pytest.param(False, "", "", id="no-remote-logging"), ), ) -def test_remote_logging_conn(remote_logging, remote_conn, expected_env, monkeypatch): +def test_remote_logging_conn(remote_logging, remote_conn, expected_env, monkeypatch, mocker): # This doesn't strictly need the AWS provider, but it does need something that # airflow.config_templates.airflow_local_settings.DEFAULT_LOGGING_CONFIG knows about pytest.importorskip("airflow.providers.amazon", reason="'amazon' provider not installed") @@ -2195,6 +2195,9 @@ def handle_request(request: httpx.Request) -> httpx.Response: }, ) + mock_masker = mocker.Mock() + mocker.patch("airflow.sdk.execution_time.secrets_masker._secrets_masker", return_value=mock_masker) + with conf_vars( { ("logging", "remote_logging"): str(remote_logging), @@ -2211,3 +2214,29 @@ def handle_request(request: httpx.Request) -> httpx.Response: assert new_keys == {expected_env} else: assert not new_keys + + if remote_logging and expected_env: + connection_available = {"available": False, "conn_uri": None} + + def mock_upload_to_remote(process_log, ti): + connection_available["available"] = expected_env in os.environ + connection_available["conn_uri"] = os.environ.get(expected_env) + + mocker.patch("airflow.sdk.log.upload_to_remote", side_effect=mock_upload_to_remote) + + activity_subprocess = ActivitySubprocess( + process_log=mocker.MagicMock(), + id=TI_ID, + pid=12345, + stdin=mocker.MagicMock(), + client=client, + process=mocker.MagicMock(), + ) + activity_subprocess.ti = mocker.MagicMock() + + activity_subprocess._upload_logs() + + assert connection_available["available"], ( + f"Connection {expected_env} was not available during upload_to_remote call" + ) + assert connection_available["conn_uri"] is not None, "Connection URI was None during upload"