From f3a3328690b37bf994f1327caca6fc62741013e5 Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Thu, 24 Jul 2025 15:02:38 +0100 Subject: [PATCH] Allow Remote logging providers to load connections from the API Server Often remote logging is down using automatic instance profiles, but not always. If you tried to configure a logger by a connection defined in the metadata DB it would have not worked (it either caused the supervise job to fail early, or to just behave as if the connection didn't exist, depending on the hook's behaviour) Unfortunately, the way of knowing what the default connection ID various hooks use is not easily discoverable, at least not easily from the outside (we can't look at `remote.hook` as for most log providers that would try to load the connection, failing in the way we are trying to fix) so I updated the log config module to keep track of what the default conn id is for the modern log providers. Once we have the connection ID we know (or at least have a good idea that we've got the right one) we then pre-emptively check the secrets backends for it, if not found there load it from the API server, and then either way. if we find a connection we put it in the env variable so that it is available. The reason we use this approach, is that are running in the supervisor process itself, so SUPERVISOR_COMMS is not and cannot be set yet. --- .../airflow_local_settings.py | 29 +++++ airflow-core/src/airflow/logging_config.py | 10 +- .../airflow/sdk/execution_time/supervisor.py | 106 +++++++++++++++--- task-sdk/src/airflow/sdk/log.py | 10 ++ .../execution_time/test_supervisor.py | 49 ++++++++ 5 files changed, 184 insertions(+), 20 deletions(-) diff --git a/airflow-core/src/airflow/config_templates/airflow_local_settings.py b/airflow-core/src/airflow/config_templates/airflow_local_settings.py index cb1cc364e2ca6..7c84ba3f3b133 100644 --- a/airflow-core/src/airflow/config_templates/airflow_local_settings.py +++ b/airflow-core/src/airflow/config_templates/airflow_local_settings.py @@ -128,6 +128,27 @@ REMOTE_LOGGING: bool = conf.getboolean("logging", "remote_logging") REMOTE_TASK_LOG: RemoteLogIO | None = None +DEFAULT_REMOTE_CONN_ID: str | None = None + + +def _default_conn_name_from(mod_path, hook_name): + # Try to set the default conn name from a hook, but don't error if something goes wrong at runtime + from importlib import import_module + + global DEFAULT_REMOTE_CONN_ID + + try: + mod = import_module(mod_path) + + hook = getattr(mod, hook_name) + + DEFAULT_REMOTE_CONN_ID = getattr(hook, "default_conn_name") + except Exception: + # Lets error in tests though! + if "PYTEST_CURRENT_TEST" in os.environ: + raise + return None + if REMOTE_LOGGING: ELASTICSEARCH_HOST: str | None = conf.get("elasticsearch", "HOST") @@ -151,6 +172,7 @@ if remote_base_log_folder.startswith("s3://"): from airflow.providers.amazon.aws.log.s3_task_handler import S3RemoteLogIO + _default_conn_name_from("airflow.providers.amazon.aws.hooks.s3", "S3Hook") REMOTE_TASK_LOG = S3RemoteLogIO( **( { @@ -166,6 +188,7 @@ elif remote_base_log_folder.startswith("cloudwatch://"): from airflow.providers.amazon.aws.log.cloudwatch_task_handler import CloudWatchRemoteLogIO + _default_conn_name_from("airflow.providers.amazon.aws.hooks.logs", "AwsLogsHook") url_parts = urlsplit(remote_base_log_folder) REMOTE_TASK_LOG = CloudWatchRemoteLogIO( **( @@ -182,6 +205,7 @@ elif remote_base_log_folder.startswith("gs://"): from airflow.providers.google.cloud.log.gcs_task_handler import GCSRemoteLogIO + _default_conn_name_from("airflow.providers.google.cloud.hooks.gcs", "GCSHook") key_path = conf.get_mandatory_value("logging", "google_key_path", fallback=None) REMOTE_TASK_LOG = GCSRemoteLogIO( @@ -199,6 +223,7 @@ elif remote_base_log_folder.startswith("wasb"): from airflow.providers.microsoft.azure.log.wasb_task_handler import WasbRemoteLogIO + _default_conn_name_from("airflow.providers.microsoft.azure.hooks.wasb", "WasbHook") wasb_log_container = conf.get_mandatory_value( "azure_remote_logging", "remote_wasb_log_container", fallback="airflow-logs" ) @@ -232,6 +257,8 @@ elif remote_base_log_folder.startswith("oss://"): from airflow.providers.alibaba.cloud.log.oss_task_handler import OSSRemoteLogIO + _default_conn_name_from("airflow.providers.alibaba.cloud.hooks.oss", "OSSHook") + REMOTE_TASK_LOG = OSSRemoteLogIO( **( { @@ -246,6 +273,8 @@ elif remote_base_log_folder.startswith("hdfs://"): from airflow.providers.apache.hdfs.log.hdfs_task_handler import HdfsRemoteLogIO + _default_conn_name_from("airflow.providers.apache.hdfs.hooks.webhdfs", "WebHDFSHook") + REMOTE_TASK_LOG = HdfsRemoteLogIO( **( { diff --git a/airflow-core/src/airflow/logging_config.py b/airflow-core/src/airflow/logging_config.py index b0c0b35515599..e6d837bc22077 100644 --- a/airflow-core/src/airflow/logging_config.py +++ b/airflow-core/src/airflow/logging_config.py @@ -33,6 +33,7 @@ REMOTE_TASK_LOG: RemoteLogIO | None +DEFAULT_REMOTE_CONN_ID: str | None = None def __getattr__(name: str): @@ -44,7 +45,7 @@ def __getattr__(name: str): def load_logging_config() -> tuple[dict[str, Any], str]: """Configure & Validate Airflow Logging.""" - global REMOTE_TASK_LOG + global REMOTE_TASK_LOG, DEFAULT_REMOTE_CONN_ID fallback = "airflow.config_templates.airflow_local_settings.DEFAULT_LOGGING_CONFIG" logging_class_path = conf.get("logging", "logging_config_class", fallback=fallback) @@ -70,10 +71,11 @@ def load_logging_config() -> tuple[dict[str, Any], str]: f"to: {type(err).__name__}:{err}" ) else: - mod = logging_class_path.rsplit(".", 1)[0] + modpath = logging_class_path.rsplit(".", 1)[0] try: - remote_task_log = import_string(f"{mod}.REMOTE_TASK_LOG") - REMOTE_TASK_LOG = remote_task_log + mod = import_string(modpath) + REMOTE_TASK_LOG = getattr(mod, "REMOTE_TASK_LOG") + DEFAULT_REMOTE_CONN_ID = getattr(mod, "DEFAULT_REMOTE_CONN_ID", None) except Exception as err: log.info("Remote task logs will not be available due to an error: %s", err) diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index d5498b9b0789c..0f85176de5a1c 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -20,6 +20,7 @@ from __future__ import annotations import atexit +import contextlib import io import logging import os @@ -128,6 +129,7 @@ from structlog.typing import FilteringBoundLogger, WrappedLogger from airflow.executors.workloads import BundleInfo + from airflow.sdk.definitions.connection import Connection from airflow.sdk.types import RuntimeTaskInstanceProtocol as RuntimeTI from airflow.secrets import BaseSecretsBackend from airflow.typing_compat import Self @@ -1630,6 +1632,93 @@ def ensure_secrets_backend_loaded() -> list[BaseSecretsBackend]: return backends +@contextlib.contextmanager +def _remote_logging_conn(client: Client): + """ + Pre-fetch the needed remote logging connection. + + If a remote logger is in use, and has the logging/remote_logging option set, we try to fetch the + connection it needs, now, directly from the API client, and store it in an env var, so that when the logging + hook tries to get the connection it + can find it easily from the env vars + + This is needed as the BaseHook.get_connection looks for SUPERVISOR_COMMS, but we are still in the + supervisor process when this is needed, so that doesn't exist yet. + """ + from airflow.sdk.log import load_remote_conn_id, load_remote_log_handler + + if load_remote_log_handler() is None or not (conn_id := load_remote_conn_id()): + # Nothing to do + yield + return + + # Since we need to use the API Client directly, we can't use Connection.get as that would try to use + # SUPERVISOR_COMMS + + # TODO: Store in the SecretsCache if its enabled - see #48858 + + def _get_conn() -> Connection | None: + backends = ensure_secrets_backend_loaded() + for secrets_backend in backends: + try: + conn = secrets_backend.get_connection(conn_id=conn_id) + if conn: + return conn + except Exception: + log.exception( + "Unable to retrieve connection from secrets backend (%s). " + "Checking subsequent secrets backend.", + type(secrets_backend).__name__, + ) + + conn = client.connections.get(conn_id) + if isinstance(conn, ConnectionResponse): + conn_result = ConnectionResult.from_conn_response(conn) + from airflow.sdk.definitions.connection import Connection + + return Connection(**conn_result.model_dump(exclude={"type"}, by_alias=True)) + return None + + if conn := _get_conn(): + key = f"AIRFLOW_CONN_{conn_id.upper()}" + old = os.getenv(key) + + os.environ[key] = conn.get_uri() + + try: + yield + finally: + if old is None: + del os.environ[key] + else: + os.environ[key] = old + + +def _configure_logging(log_path: str, client: Client) -> tuple[FilteringBoundLogger, BinaryIO | TextIO]: + # If we are told to write logs to a file, redirect the task logger to it. Make sure we append to the + # file though, otherwise when we resume we would lose the logs from the start->deferral segment if it + # lands on the same node as before. + from airflow.sdk.log import init_log_file, logging_processors + + log_file_descriptor: BinaryIO | TextIO | None = None + + log_file = init_log_file(log_path) + + pretty_logs = False + if pretty_logs: + log_file_descriptor = log_file.open("a", buffering=1) + underlying_logger: WrappedLogger = structlog.WriteLogger(cast("TextIO", log_file_descriptor)) + else: + log_file_descriptor = log_file.open("ab") + underlying_logger = structlog.BytesLogger(cast("BinaryIO", log_file_descriptor)) + + with _remote_logging_conn(client): + processors = logging_processors(enable_pretty_log=pretty_logs)[0] + logger = structlog.wrap_logger(underlying_logger, processors=processors, logger_name="task").bind() + + return logger, log_file_descriptor + + def supervise( *, ti: TaskInstance, @@ -1705,22 +1794,7 @@ def supervise( logger: FilteringBoundLogger | None = None log_file_descriptor: BinaryIO | TextIO | None = None if log_path: - # If we are told to write logs to a file, redirect the task logger to it. Make sure we append to the - # file though, otherwise when we resume we would lose the logs from the start->deferral segment if it - # lands on the same node as before. - from airflow.sdk.log import init_log_file, logging_processors - - log_file = init_log_file(log_path) - - pretty_logs = False - if pretty_logs: - log_file_descriptor = log_file.open("a", buffering=1) - underlying_logger: WrappedLogger = structlog.WriteLogger(cast("TextIO", log_file_descriptor)) - else: - log_file_descriptor = log_file.open("ab") - underlying_logger = structlog.BytesLogger(cast("BinaryIO", log_file_descriptor)) - processors = logging_processors(enable_pretty_log=pretty_logs)[0] - logger = structlog.wrap_logger(underlying_logger, processors=processors, logger_name="task").bind() + logger, log_file_descriptor = _configure_logging(log_path, client) backends = ensure_secrets_backend_loaded() log.info( diff --git a/task-sdk/src/airflow/sdk/log.py b/task-sdk/src/airflow/sdk/log.py index 46efd6bf448ed..ec0c75fa5307d 100644 --- a/task-sdk/src/airflow/sdk/log.py +++ b/task-sdk/src/airflow/sdk/log.py @@ -524,6 +524,16 @@ def load_remote_log_handler() -> RemoteLogIO | None: return airflow.logging_config.REMOTE_TASK_LOG +def load_remote_conn_id() -> str | None: + import airflow.logging_config + from airflow.configuration import conf + + if conn_id := conf.get("logging", "remote_log_conn_id", fallback=None): + return conn_id + + return airflow.logging_config.DEFAULT_REMOTE_CONN_ID + + def relative_path_from_logger(logger) -> Path | None: if not logger: return None diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py index 271f2b98c15d0..0331a7a2b9ca6 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py +++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py @@ -117,10 +117,13 @@ ActivitySubprocess, InProcessSupervisorComms, InProcessTestSupervisor, + _remote_logging_conn, set_supervisor_comms, supervise, ) +from tests_common.test_utils.config import conf_vars + if TYPE_CHECKING: import kgb @@ -2143,3 +2146,49 @@ def _handle_request(self, msg, log, req_id): # Ensure we got back what we expect assert isinstance(response, VariableResult) assert response.value == "value" + + +@pytest.mark.parametrize( + ("remote_logging", "remote_conn", "expected_env"), + ( + pytest.param(True, "", "AIRFLOW_CONN_AWS_DEFAULT", id="no-conn-id"), + pytest.param(True, "aws_default", "AIRFLOW_CONN_AWS_DEFAULT", id="explicit-default"), + pytest.param(True, "my_aws", "AIRFLOW_CONN_MY_AWS", id="other"), + pytest.param(False, "", "", id="no-remote-logging"), + ), +) +def test_remote_logging_conn(remote_logging, remote_conn, expected_env, monkeypatch): + # This doesn't strictly need the AWS provider, but it does need something that + # airflow.config_templates.airflow_local_settings.DEFAULT_LOGGING_CONFIG knows about + pytest.importorskip("airflow.providers.amazon", reason="'amazon' provider not installed") + + # This test is a little bit overly specific to how the logging is currently configured :/ + monkeypatch.delitem(sys.modules, "airflow.logging_config") + monkeypatch.delitem(sys.modules, "airflow.config_templates.airflow_local_settings", raising=False) + + def handle_request(request: httpx.Request) -> httpx.Response: + return httpx.Response( + status_code=200, + json={ + # Minimal enough to pass validation, we don't care what fields are in here for the tests + "conn_id": remote_conn, + "conn_type": "aws", + }, + ) + + with conf_vars( + { + ("logging", "remote_logging"): str(remote_logging), + ("logging", "remote_base_log_folder"): "cloudwatch://arn:aws:logs:::log-group:test", + ("logging", "remote_log_conn_id"): remote_conn, + } + ): + env = os.environ.copy() + client = make_client(transport=httpx.MockTransport(handle_request)) + + with _remote_logging_conn(client): + new_keys = os.environ.keys() - env.keys() + if remote_logging: + assert new_keys == {expected_env} + else: + assert not new_keys