Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 79 additions & 63 deletions task-sdk/src/airflow/sdk/execution_time/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from collections.abc import Generator
from contextlib import contextmanager, suppress
from datetime import datetime, timezone
from functools import lru_cache
from http import HTTPStatus
from socket import socket, socketpair
from typing import (
Expand Down Expand Up @@ -815,6 +816,82 @@ def _check_subprocess_exit(
return self._exit_code


@lru_cache
def _get_remote_logging_conn(conn_id: str, client: Client) -> Connection | None:
"""
Fetch and cache connection for remote logging.

Args:
conn_id: Connection ID to fetch
client: API client for making requests

Returns:
Connection object or None if not found
"""
# Since we need to use the API Client directly, we can't use Connection.get as that would try to use
# SUPERVISOR_COMMS

# TODO: Store in the SecretsCache if its enabled - see #48858

backends = ensure_secrets_backend_loaded()
for secrets_backend in backends:
try:
conn = secrets_backend.get_connection(conn_id=conn_id)
if conn:
return conn
except Exception:
log.exception(
"Unable to retrieve connection from secrets backend (%s). "
"Checking subsequent secrets backend.",
type(secrets_backend).__name__,
)

conn = client.connections.get(conn_id)
if isinstance(conn, ConnectionResponse):
conn_result = ConnectionResult.from_conn_response(conn)
from airflow.sdk.definitions.connection import Connection

return Connection(**conn_result.model_dump(exclude={"type"}, by_alias=True))
return None


@contextlib.contextmanager
def _remote_logging_conn(client: Client):
"""
Pre-fetch the needed remote logging connection with caching.

If a remote logger is in use, and has the logging/remote_logging option set, we try to fetch the
connection it needs, now, directly from the API client, and store it in an env var, so that when the logging
hook tries to get the connection it can find it easily from the env vars.

This is needed as the BaseHook.get_connection looks for SUPERVISOR_COMMS, but we are still in the
supervisor process when this is needed, so that doesn't exist yet.

This function uses @lru_cache for connection caching to avoid repeated API calls.
"""
from airflow.sdk.log import load_remote_conn_id, load_remote_log_handler

if load_remote_log_handler() is None or not (conn_id := load_remote_conn_id()):
# Nothing to do
yield
return

# Use cached connection fetcher
conn = _get_remote_logging_conn(conn_id, client)

if conn:
key = f"AIRFLOW_CONN_{conn_id.upper()}"
old = os.getenv(key)
os.environ[key] = conn.get_uri()
try:
yield
finally:
if old is None:
del os.environ[key]
else:
os.environ[key] = old


@attrs.define(kw_only=True)
class ActivitySubprocess(WatchedSubprocess):
client: Client
Expand Down Expand Up @@ -931,7 +1008,8 @@ def _upload_logs(self):
"""
from airflow.sdk.log import upload_to_remote

upload_to_remote(self.process_log, self.ti)
with _remote_logging_conn(self.client):
upload_to_remote(self.process_log, self.ti)

def _monitor_subprocess(self):
"""
Expand Down Expand Up @@ -1637,68 +1715,6 @@ def ensure_secrets_backend_loaded() -> list[BaseSecretsBackend]:
return backends


@contextlib.contextmanager
def _remote_logging_conn(client: Client):
"""
Pre-fetch the needed remote logging connection.

If a remote logger is in use, and has the logging/remote_logging option set, we try to fetch the
connection it needs, now, directly from the API client, and store it in an env var, so that when the logging
hook tries to get the connection it
can find it easily from the env vars

This is needed as the BaseHook.get_connection looks for SUPERVISOR_COMMS, but we are still in the
supervisor process when this is needed, so that doesn't exist yet.
"""
from airflow.sdk.log import load_remote_conn_id, load_remote_log_handler

if load_remote_log_handler() is None or not (conn_id := load_remote_conn_id()):
# Nothing to do
yield
return

# Since we need to use the API Client directly, we can't use Connection.get as that would try to use
# SUPERVISOR_COMMS

# TODO: Store in the SecretsCache if its enabled - see #48858

def _get_conn() -> Connection | None:
backends = ensure_secrets_backend_loaded()
for secrets_backend in backends:
try:
conn = secrets_backend.get_connection(conn_id=conn_id)
if conn:
return conn
except Exception:
log.exception(
"Unable to retrieve connection from secrets backend (%s). "
"Checking subsequent secrets backend.",
type(secrets_backend).__name__,
)

conn = client.connections.get(conn_id)
if isinstance(conn, ConnectionResponse):
conn_result = ConnectionResult.from_conn_response(conn)
from airflow.sdk.definitions.connection import Connection

return Connection(**conn_result.model_dump(exclude={"type"}, by_alias=True))
return None

if conn := _get_conn():
key = f"AIRFLOW_CONN_{conn_id.upper()}"
old = os.getenv(key)

os.environ[key] = conn.get_uri()

try:
yield
finally:
if old is None:
del os.environ[key]
else:
os.environ[key] = old


def _configure_logging(log_path: str, client: Client) -> tuple[FilteringBoundLogger, BinaryIO | TextIO]:
# If we are told to write logs to a file, redirect the task logger to it. Make sure we append to the
# file though, otherwise when we resume we would lose the logs from the start->deferral segment if it
Expand Down
31 changes: 30 additions & 1 deletion task-sdk/tests/task_sdk/execution_time/test_supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2176,7 +2176,7 @@ def _handle_request(self, msg, log, req_id):
pytest.param(False, "", "", id="no-remote-logging"),
),
)
def test_remote_logging_conn(remote_logging, remote_conn, expected_env, monkeypatch):
def test_remote_logging_conn(remote_logging, remote_conn, expected_env, monkeypatch, mocker):
# This doesn't strictly need the AWS provider, but it does need something that
# airflow.config_templates.airflow_local_settings.DEFAULT_LOGGING_CONFIG knows about
pytest.importorskip("airflow.providers.amazon", reason="'amazon' provider not installed")
Expand All @@ -2195,6 +2195,9 @@ def handle_request(request: httpx.Request) -> httpx.Response:
},
)

mock_masker = mocker.Mock()
mocker.patch("airflow.sdk.execution_time.secrets_masker._secrets_masker", return_value=mock_masker)

with conf_vars(
{
("logging", "remote_logging"): str(remote_logging),
Expand All @@ -2211,3 +2214,29 @@ def handle_request(request: httpx.Request) -> httpx.Response:
assert new_keys == {expected_env}
else:
assert not new_keys

if remote_logging and expected_env:
connection_available = {"available": False, "conn_uri": None}

def mock_upload_to_remote(process_log, ti):
connection_available["available"] = expected_env in os.environ
connection_available["conn_uri"] = os.environ.get(expected_env)

mocker.patch("airflow.sdk.log.upload_to_remote", side_effect=mock_upload_to_remote)

activity_subprocess = ActivitySubprocess(
process_log=mocker.MagicMock(),
id=TI_ID,
pid=12345,
stdin=mocker.MagicMock(),
client=client,
process=mocker.MagicMock(),
)
activity_subprocess.ti = mocker.MagicMock()

activity_subprocess._upload_logs()

assert connection_available["available"], (
f"Connection {expected_env} was not available during upload_to_remote call"
)
assert connection_available["conn_uri"] is not None, "Connection URI was None during upload"
Loading