Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions airflow-core/src/airflow/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,17 +700,16 @@ def initialize():
configure_adapters()
# The webservers import this file from models.py with the default settings.

if not os.environ.get("PYTHON_OPERATORS_VIRTUAL_ENV_MODE", None):
is_worker = os.environ.get("_AIRFLOW__REEXECUTED_PROCESS") == "1"
if not is_worker:
configure_orm()
configure_action_logging()

# Configure secrets masker before masking secrets
_configure_secrets_masker()

# mask the sensitive_config_values
conf.mask_secrets()
is_worker = os.environ.get("_AIRFLOW__REEXECUTED_PROCESS") == "1"
if not os.environ.get("PYTHON_OPERATORS_VIRTUAL_ENV_MODE", None) and not is_worker:
configure_orm()

# mask the sensitive_config_values
conf.mask_secrets()
configure_action_logging()

# Run any custom runtime checks that needs to be executed for providers
run_providers_custom_runtime_checks()
Expand Down
45 changes: 32 additions & 13 deletions task-sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,20 +694,15 @@ def startup() -> tuple[RuntimeTaskInstance, Context, Logger]:
# in response to us sending a request.
log = structlog.get_logger(logger_name="task")

if os.environ.get("_AIRFLOW__REEXECUTED_PROCESS") == "1" and os.environ.get("_AIRFLOW__STARTUP_MSG"):
if os.environ.get("_AIRFLOW__REEXECUTED_PROCESS") == "1" and (
msgjson := os.environ.get("_AIRFLOW__STARTUP_MSG")
):
# Clear any Kerberos replace cache if there is one, so new process can't reuse it.
os.environ.pop("KRB5CCNAME", None)
# entrypoint of re-exec process
msg = TypeAdapter(StartupDetails).validate_json(os.environ["_AIRFLOW__STARTUP_MSG"])

logs = SUPERVISOR_COMMS.send(ResendLoggingFD())
if isinstance(logs, SentFDs):
from airflow.sdk.log import configure_logging

log_io = os.fdopen(logs.fds[0], "wb", buffering=0)
configure_logging(json_output=True, output=log_io, sending_to_supervisor=True)
else:
print("Unable to re-configure logging after sudo, we didn't get an FD", file=sys.stderr)
msg: StartupDetails = TypeAdapter(StartupDetails).validate_json(msgjson)
reinit_supervisor_comms()

# We delay this message until _after_ we've got the logging re-configured, otherwise it will show up
# on stdout
Expand All @@ -716,8 +711,9 @@ def startup() -> tuple[RuntimeTaskInstance, Context, Logger]:
# normal entry point
msg = SUPERVISOR_COMMS._get_response() # type: ignore[assignment]

if not isinstance(msg, StartupDetails):
raise RuntimeError(f"Unhandled startup message {type(msg)} {msg}")
if not isinstance(msg, StartupDetails):
raise RuntimeError(f"Unhandled startup message {type(msg)} {msg}")

# setproctitle causes issue on Mac OS: https://github.com/benoitc/gunicorn/issues/3021
os_type = sys.platform
if os_type == "darwin":
Expand Down Expand Up @@ -1443,7 +1439,6 @@ def finalize(


def main():
# TODO: add an exception here, it causes an oof of a stack trace if it happens to early!
log = structlog.get_logger(logger_name="task")

global SUPERVISOR_COMMS
Expand Down Expand Up @@ -1472,5 +1467,29 @@ def main():
SUPERVISOR_COMMS.socket.close()


def reinit_supervisor_comms() -> None:
"""
Re-initialize supervisor comms and logging channel in subprocess.

This is not needed for most cases, but is used when either we re-launch the process via sudo for
run_as_user, or from inside the python code in a virtualenv (et al.) operator to re-connect so those tasks
can continue to access variables etc.
"""
if "SUPERVISOR_COMMS" not in globals():
global SUPERVISOR_COMMS
log = structlog.get_logger(logger_name="task")

SUPERVISOR_COMMS = CommsDecoder[ToTask, ToSupervisor](log=log)

logs = SUPERVISOR_COMMS.send(ResendLoggingFD())
if isinstance(logs, SentFDs):
from airflow.sdk.log import configure_logging

log_io = os.fdopen(logs.fds[0], "wb", buffering=0)
configure_logging(json_output=True, output=log_io, sending_to_supervisor=True)
else:
print("Unable to re-configure logging after sudo, we didn't get an FD", file=sys.stderr)


if __name__ == "__main__":
main()
67 changes: 67 additions & 0 deletions task-sdk/tests/task_sdk/execution_time/test_supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,14 @@
import selectors
import signal
import socket
import subprocess
import sys
import time
from contextlib import nullcontext
from dataclasses import dataclass, field
from operator import attrgetter
from random import randint
from textwrap import dedent
from time import sleep
from typing import TYPE_CHECKING, Any
from unittest import mock
Expand Down Expand Up @@ -2539,3 +2541,68 @@ def noop_request(request: httpx.Request) -> httpx.Response:
gc.collect()
assert backend.calls == 1, "Connection should be cached, not fetched multiple times"
assert all(ref() is None for ref in clients), "Client instances should be garbage collected"


def test_reinit_supervisor_comms(monkeypatch, client_with_ti_start, caplog):
def subprocess_main():
# This is run in the subprocess!

# Ensure we follow the "protocol" and get the startup message before we do anything else
c = CommsDecoder()
c._get_response()

# This mirrors what the VirtualEnvProvider puts in it's script
script = """
import os
import sys
import structlog

from airflow.sdk import Connection
from airflow.sdk.execution_time.task_runner import reinit_supervisor_comms

reinit_supervisor_comms()

Connection.get("a")
print("ok")
sys.stdout.flush()

structlog.get_logger().info("is connected")
"""
# Now we launch a new process, as VirtualEnvOperator will do
subprocess.check_call([sys.executable, "-c", dedent(script)])

client_with_ti_start.connections.get.return_value = ConnectionResult(
conn_id="test_conn", conn_type="mysql", login="a", password="password1"
)
proc = ActivitySubprocess.start(
dag_rel_path=os.devnull,
bundle_info=FAKE_BUNDLE,
what=TaskInstance(
id="4d828a62-a417-4936-a7a6-2b3fabacecab",
task_id="b",
dag_id="c",
run_id="d",
try_number=1,
dag_version_id=uuid7(),
),
client=client_with_ti_start,
target=subprocess_main,
)

rc = proc.wait()

assert rc == 0, caplog.text
# Check that the log messages are write. We should expect stdout to apper right, and crucially, we should
# expect logs from the venv process to appear without extra "wrapping"
assert {
"logger": "task.stdout",
"event": "ok",
"log_level": "info",
"timestamp": mock.ANY,
} in caplog, caplog.text
assert {
"logger_name": "task",
"log_level": "info",
"event": "is connected",
"timestamp": mock.ANY,
} in caplog, caplog.text