diff --git a/airflow-core/src/airflow/config_templates/airflow_local_settings.py b/airflow-core/src/airflow/config_templates/airflow_local_settings.py index cb1cc364e2ca6..7c84ba3f3b133 100644 --- a/airflow-core/src/airflow/config_templates/airflow_local_settings.py +++ b/airflow-core/src/airflow/config_templates/airflow_local_settings.py @@ -128,6 +128,27 @@ REMOTE_LOGGING: bool = conf.getboolean("logging", "remote_logging") REMOTE_TASK_LOG: RemoteLogIO | None = None +DEFAULT_REMOTE_CONN_ID: str | None = None + + +def _default_conn_name_from(mod_path, hook_name): + # Try to set the default conn name from a hook, but don't error if something goes wrong at runtime + from importlib import import_module + + global DEFAULT_REMOTE_CONN_ID + + try: + mod = import_module(mod_path) + + hook = getattr(mod, hook_name) + + DEFAULT_REMOTE_CONN_ID = getattr(hook, "default_conn_name") + except Exception: + # Lets error in tests though! + if "PYTEST_CURRENT_TEST" in os.environ: + raise + return None + if REMOTE_LOGGING: ELASTICSEARCH_HOST: str | None = conf.get("elasticsearch", "HOST") @@ -151,6 +172,7 @@ if remote_base_log_folder.startswith("s3://"): from airflow.providers.amazon.aws.log.s3_task_handler import S3RemoteLogIO + _default_conn_name_from("airflow.providers.amazon.aws.hooks.s3", "S3Hook") REMOTE_TASK_LOG = S3RemoteLogIO( **( { @@ -166,6 +188,7 @@ elif remote_base_log_folder.startswith("cloudwatch://"): from airflow.providers.amazon.aws.log.cloudwatch_task_handler import CloudWatchRemoteLogIO + _default_conn_name_from("airflow.providers.amazon.aws.hooks.logs", "AwsLogsHook") url_parts = urlsplit(remote_base_log_folder) REMOTE_TASK_LOG = CloudWatchRemoteLogIO( **( @@ -182,6 +205,7 @@ elif remote_base_log_folder.startswith("gs://"): from airflow.providers.google.cloud.log.gcs_task_handler import GCSRemoteLogIO + _default_conn_name_from("airflow.providers.google.cloud.hooks.gcs", "GCSHook") key_path = conf.get_mandatory_value("logging", "google_key_path", fallback=None) REMOTE_TASK_LOG = GCSRemoteLogIO( @@ -199,6 +223,7 @@ elif remote_base_log_folder.startswith("wasb"): from airflow.providers.microsoft.azure.log.wasb_task_handler import WasbRemoteLogIO + _default_conn_name_from("airflow.providers.microsoft.azure.hooks.wasb", "WasbHook") wasb_log_container = conf.get_mandatory_value( "azure_remote_logging", "remote_wasb_log_container", fallback="airflow-logs" ) @@ -232,6 +257,8 @@ elif remote_base_log_folder.startswith("oss://"): from airflow.providers.alibaba.cloud.log.oss_task_handler import OSSRemoteLogIO + _default_conn_name_from("airflow.providers.alibaba.cloud.hooks.oss", "OSSHook") + REMOTE_TASK_LOG = OSSRemoteLogIO( **( { @@ -246,6 +273,8 @@ elif remote_base_log_folder.startswith("hdfs://"): from airflow.providers.apache.hdfs.log.hdfs_task_handler import HdfsRemoteLogIO + _default_conn_name_from("airflow.providers.apache.hdfs.hooks.webhdfs", "WebHDFSHook") + REMOTE_TASK_LOG = HdfsRemoteLogIO( **( { diff --git a/airflow-core/src/airflow/logging_config.py b/airflow-core/src/airflow/logging_config.py index b0c0b35515599..e6d837bc22077 100644 --- a/airflow-core/src/airflow/logging_config.py +++ b/airflow-core/src/airflow/logging_config.py @@ -33,6 +33,7 @@ REMOTE_TASK_LOG: RemoteLogIO | None +DEFAULT_REMOTE_CONN_ID: str | None = None def __getattr__(name: str): @@ -44,7 +45,7 @@ def __getattr__(name: str): def load_logging_config() -> tuple[dict[str, Any], str]: """Configure & Validate Airflow Logging.""" - global REMOTE_TASK_LOG + global REMOTE_TASK_LOG, DEFAULT_REMOTE_CONN_ID fallback = "airflow.config_templates.airflow_local_settings.DEFAULT_LOGGING_CONFIG" logging_class_path = conf.get("logging", "logging_config_class", fallback=fallback) @@ -70,10 +71,11 @@ def load_logging_config() -> tuple[dict[str, Any], str]: f"to: {type(err).__name__}:{err}" ) else: - mod = logging_class_path.rsplit(".", 1)[0] + modpath = logging_class_path.rsplit(".", 1)[0] try: - remote_task_log = import_string(f"{mod}.REMOTE_TASK_LOG") - REMOTE_TASK_LOG = remote_task_log + mod = import_string(modpath) + REMOTE_TASK_LOG = getattr(mod, "REMOTE_TASK_LOG") + DEFAULT_REMOTE_CONN_ID = getattr(mod, "DEFAULT_REMOTE_CONN_ID", None) except Exception as err: log.info("Remote task logs will not be available due to an error: %s", err) diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index d5498b9b0789c..0f85176de5a1c 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -20,6 +20,7 @@ from __future__ import annotations import atexit +import contextlib import io import logging import os @@ -128,6 +129,7 @@ from structlog.typing import FilteringBoundLogger, WrappedLogger from airflow.executors.workloads import BundleInfo + from airflow.sdk.definitions.connection import Connection from airflow.sdk.types import RuntimeTaskInstanceProtocol as RuntimeTI from airflow.secrets import BaseSecretsBackend from airflow.typing_compat import Self @@ -1630,6 +1632,93 @@ def ensure_secrets_backend_loaded() -> list[BaseSecretsBackend]: return backends +@contextlib.contextmanager +def _remote_logging_conn(client: Client): + """ + Pre-fetch the needed remote logging connection. + + If a remote logger is in use, and has the logging/remote_logging option set, we try to fetch the + connection it needs, now, directly from the API client, and store it in an env var, so that when the logging + hook tries to get the connection it + can find it easily from the env vars + + This is needed as the BaseHook.get_connection looks for SUPERVISOR_COMMS, but we are still in the + supervisor process when this is needed, so that doesn't exist yet. + """ + from airflow.sdk.log import load_remote_conn_id, load_remote_log_handler + + if load_remote_log_handler() is None or not (conn_id := load_remote_conn_id()): + # Nothing to do + yield + return + + # Since we need to use the API Client directly, we can't use Connection.get as that would try to use + # SUPERVISOR_COMMS + + # TODO: Store in the SecretsCache if its enabled - see #48858 + + def _get_conn() -> Connection | None: + backends = ensure_secrets_backend_loaded() + for secrets_backend in backends: + try: + conn = secrets_backend.get_connection(conn_id=conn_id) + if conn: + return conn + except Exception: + log.exception( + "Unable to retrieve connection from secrets backend (%s). " + "Checking subsequent secrets backend.", + type(secrets_backend).__name__, + ) + + conn = client.connections.get(conn_id) + if isinstance(conn, ConnectionResponse): + conn_result = ConnectionResult.from_conn_response(conn) + from airflow.sdk.definitions.connection import Connection + + return Connection(**conn_result.model_dump(exclude={"type"}, by_alias=True)) + return None + + if conn := _get_conn(): + key = f"AIRFLOW_CONN_{conn_id.upper()}" + old = os.getenv(key) + + os.environ[key] = conn.get_uri() + + try: + yield + finally: + if old is None: + del os.environ[key] + else: + os.environ[key] = old + + +def _configure_logging(log_path: str, client: Client) -> tuple[FilteringBoundLogger, BinaryIO | TextIO]: + # If we are told to write logs to a file, redirect the task logger to it. Make sure we append to the + # file though, otherwise when we resume we would lose the logs from the start->deferral segment if it + # lands on the same node as before. + from airflow.sdk.log import init_log_file, logging_processors + + log_file_descriptor: BinaryIO | TextIO | None = None + + log_file = init_log_file(log_path) + + pretty_logs = False + if pretty_logs: + log_file_descriptor = log_file.open("a", buffering=1) + underlying_logger: WrappedLogger = structlog.WriteLogger(cast("TextIO", log_file_descriptor)) + else: + log_file_descriptor = log_file.open("ab") + underlying_logger = structlog.BytesLogger(cast("BinaryIO", log_file_descriptor)) + + with _remote_logging_conn(client): + processors = logging_processors(enable_pretty_log=pretty_logs)[0] + logger = structlog.wrap_logger(underlying_logger, processors=processors, logger_name="task").bind() + + return logger, log_file_descriptor + + def supervise( *, ti: TaskInstance, @@ -1705,22 +1794,7 @@ def supervise( logger: FilteringBoundLogger | None = None log_file_descriptor: BinaryIO | TextIO | None = None if log_path: - # If we are told to write logs to a file, redirect the task logger to it. Make sure we append to the - # file though, otherwise when we resume we would lose the logs from the start->deferral segment if it - # lands on the same node as before. - from airflow.sdk.log import init_log_file, logging_processors - - log_file = init_log_file(log_path) - - pretty_logs = False - if pretty_logs: - log_file_descriptor = log_file.open("a", buffering=1) - underlying_logger: WrappedLogger = structlog.WriteLogger(cast("TextIO", log_file_descriptor)) - else: - log_file_descriptor = log_file.open("ab") - underlying_logger = structlog.BytesLogger(cast("BinaryIO", log_file_descriptor)) - processors = logging_processors(enable_pretty_log=pretty_logs)[0] - logger = structlog.wrap_logger(underlying_logger, processors=processors, logger_name="task").bind() + logger, log_file_descriptor = _configure_logging(log_path, client) backends = ensure_secrets_backend_loaded() log.info( diff --git a/task-sdk/src/airflow/sdk/log.py b/task-sdk/src/airflow/sdk/log.py index 46efd6bf448ed..ec0c75fa5307d 100644 --- a/task-sdk/src/airflow/sdk/log.py +++ b/task-sdk/src/airflow/sdk/log.py @@ -524,6 +524,16 @@ def load_remote_log_handler() -> RemoteLogIO | None: return airflow.logging_config.REMOTE_TASK_LOG +def load_remote_conn_id() -> str | None: + import airflow.logging_config + from airflow.configuration import conf + + if conn_id := conf.get("logging", "remote_log_conn_id", fallback=None): + return conn_id + + return airflow.logging_config.DEFAULT_REMOTE_CONN_ID + + def relative_path_from_logger(logger) -> Path | None: if not logger: return None diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py index 271f2b98c15d0..0331a7a2b9ca6 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py +++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py @@ -117,10 +117,13 @@ ActivitySubprocess, InProcessSupervisorComms, InProcessTestSupervisor, + _remote_logging_conn, set_supervisor_comms, supervise, ) +from tests_common.test_utils.config import conf_vars + if TYPE_CHECKING: import kgb @@ -2143,3 +2146,49 @@ def _handle_request(self, msg, log, req_id): # Ensure we got back what we expect assert isinstance(response, VariableResult) assert response.value == "value" + + +@pytest.mark.parametrize( + ("remote_logging", "remote_conn", "expected_env"), + ( + pytest.param(True, "", "AIRFLOW_CONN_AWS_DEFAULT", id="no-conn-id"), + pytest.param(True, "aws_default", "AIRFLOW_CONN_AWS_DEFAULT", id="explicit-default"), + pytest.param(True, "my_aws", "AIRFLOW_CONN_MY_AWS", id="other"), + pytest.param(False, "", "", id="no-remote-logging"), + ), +) +def test_remote_logging_conn(remote_logging, remote_conn, expected_env, monkeypatch): + # This doesn't strictly need the AWS provider, but it does need something that + # airflow.config_templates.airflow_local_settings.DEFAULT_LOGGING_CONFIG knows about + pytest.importorskip("airflow.providers.amazon", reason="'amazon' provider not installed") + + # This test is a little bit overly specific to how the logging is currently configured :/ + monkeypatch.delitem(sys.modules, "airflow.logging_config") + monkeypatch.delitem(sys.modules, "airflow.config_templates.airflow_local_settings", raising=False) + + def handle_request(request: httpx.Request) -> httpx.Response: + return httpx.Response( + status_code=200, + json={ + # Minimal enough to pass validation, we don't care what fields are in here for the tests + "conn_id": remote_conn, + "conn_type": "aws", + }, + ) + + with conf_vars( + { + ("logging", "remote_logging"): str(remote_logging), + ("logging", "remote_base_log_folder"): "cloudwatch://arn:aws:logs:::log-group:test", + ("logging", "remote_log_conn_id"): remote_conn, + } + ): + env = os.environ.copy() + client = make_client(transport=httpx.MockTransport(handle_request)) + + with _remote_logging_conn(client): + new_keys = os.environ.keys() - env.keys() + if remote_logging: + assert new_keys == {expected_env} + else: + assert not new_keys