diff --git a/airflow-core/src/airflow/settings.py b/airflow-core/src/airflow/settings.py index c137c80b56fba..08d7ac7af5ba7 100644 --- a/airflow-core/src/airflow/settings.py +++ b/airflow-core/src/airflow/settings.py @@ -616,7 +616,9 @@ def initialize(): # The webservers import this file from models.py with the default settings. if not os.environ.get("PYTHON_OPERATORS_VIRTUAL_ENV_MODE", None): - configure_orm() + is_worker = os.environ.get("_AIRFLOW__REEXECUTED_PROCESS") == "1" + if not is_worker: + configure_orm() configure_action_logging() # mask the sensitive_config_values diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index 5ad5d4df45feb..e4292deca2a10 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -35,8 +35,9 @@ import attrs import lazy_object_proxy import structlog -from pydantic import AwareDatetime, ConfigDict, Field, JsonValue +from pydantic import AwareDatetime, ConfigDict, Field, JsonValue, TypeAdapter +from airflow.configuration import conf from airflow.dag_processing.bundles.base import BaseDagBundle, BundleVersionLock from airflow.dag_processing.bundles.manager import DagBundlesManager from airflow.exceptions import AirflowInactiveAssetInInletOrOutletException @@ -97,6 +98,7 @@ ) from airflow.sdk.execution_time.xcom import XCom from airflow.utils.net import get_hostname +from airflow.utils.platform import getuser from airflow.utils.timezone import coerce_datetime if TYPE_CHECKING: @@ -642,6 +644,7 @@ def parse(what: StartupDetails, log: Logger) -> RuntimeTaskInstance: # accessible wherever needed during task execution without modifying every layer of the call stack. SUPERVISOR_COMMS: CommsDecoder[ToTask, ToSupervisor] + # State machine! # 1. Start up (receive details from supervisor) # 2. Execution (run task code, possibly send requests) @@ -651,13 +654,18 @@ def parse(what: StartupDetails, log: Logger) -> RuntimeTaskInstance: def startup() -> tuple[RuntimeTaskInstance, Context, Logger]: # The parent sends us a StartupDetails message un-prompted. After this, every single message is only sent # in response to us sending a request. - msg = SUPERVISOR_COMMS._get_response() + log = structlog.get_logger(logger_name="task") + + if os.environ.get("_AIRFLOW__REEXECUTED_PROCESS") == "1" and os.environ.get("_AIRFLOW__STARTUP_MSG"): + # entrypoint of re-exec process + msg = TypeAdapter(StartupDetails).validate_json(os.environ["_AIRFLOW__STARTUP_MSG"]) + log.debug("Using serialized startup message from environment", msg=msg) + else: + # normal entry point + msg = SUPERVISOR_COMMS._get_response() # type: ignore[assignment] if not isinstance(msg, StartupDetails): raise RuntimeError(f"Unhandled startup message {type(msg)} {msg}") - - log = structlog.get_logger(logger_name="task") - # setproctitle causes issue on Mac OS: https://github.com/benoitc/gunicorn/issues/3021 os_type = sys.platform if os_type == "darwin": @@ -677,6 +685,34 @@ def startup() -> tuple[RuntimeTaskInstance, Context, Logger]: ti.log_url = get_log_url_from_ti(ti) log.debug("DAG file parsed", file=msg.dag_rel_path) + run_as_user = getattr(ti.task, "run_as_user", None) or conf.get( + "core", "default_impersonation", fallback=None + ) + + if os.environ.get("_AIRFLOW__REEXECUTED_PROCESS") != "1" and run_as_user and run_as_user != getuser(): + # enters here for re-exec process + os.environ["_AIRFLOW__REEXECUTED_PROCESS"] = "1" + # store startup message in environment for re-exec process + os.environ["_AIRFLOW__STARTUP_MSG"] = msg.model_dump_json() + os.set_inheritable(SUPERVISOR_COMMS.socket.fileno(), True) + + # Import main directly from the module instead of re-executing the file. + # This ensures that when other parts modules import + # airflow.sdk.execution_time.task_runner, they get the same module instance + # with the properly initialized SUPERVISOR_COMMS global variable. + # If we re-executed the module with `python -m`, it would load as __main__ and future + # imports would get a fresh copy without the initialized globals. + rexec_python_code = "from airflow.sdk.execution_time.task_runner import main; main()" + cmd = ["sudo", "-E", "-H", "-u", run_as_user, sys.executable, "-c", rexec_python_code] + log.info( + "Running command", + command=cmd, + ) + os.execvp("sudo", cmd) + + # ideally, we should never reach here, but if we do, we should return None, None, None + return None, None, None + return ti, ti.get_template_context(), log diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index 5ce460c455bf4..9745297f8c7f0 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -644,6 +644,91 @@ def execute(self, context): mock_supervisor_comms.assert_has_calls(expected_calls) +@patch("os.execvp") +@patch("os.set_inheritable") +def test_task_run_with_user_impersonation( + mock_set_inheritable, mock_execvp, mocked_parse, make_ti_context, time_machine, mock_supervisor_comms +): + class CustomOperator(BaseOperator): + def execute(self, context): + print("Hi from CustomOperator!") + + task = CustomOperator(task_id="impersonation_task", run_as_user="airflowuser") + instant = timezone.datetime(2024, 12, 3, 10, 0) + + what = StartupDetails( + ti=TaskInstance( + id=uuid7(), + task_id="impersonation_task", + dag_id="basic_dag", + run_id="c", + try_number=1, + ), + dag_rel_path="", + bundle_info=FAKE_BUNDLE, + ti_context=make_ti_context(), + start_date=timezone.utcnow(), + ) + + mocked_parse(what, "basic_dag", task) + time_machine.move_to(instant, tick=False) + + mock_supervisor_comms._get_response.return_value = what + mock_supervisor_comms.socket.fileno.return_value = 42 + + with mock.patch.dict(os.environ, {}, clear=True): + startup() + + assert os.environ["_AIRFLOW__REEXECUTED_PROCESS"] == "1" + assert "_AIRFLOW__STARTUP_MSG" in os.environ + + mock_set_inheritable.assert_called_once_with(42, True) + actual_cmd = mock_execvp.call_args.args[1] + + assert actual_cmd[:5] == ["sudo", "-E", "-H", "-u", "airflowuser"] + assert "python -c" in actual_cmd[5] + " " + actual_cmd[6] + assert actual_cmd[7] == "from airflow.sdk.execution_time.task_runner import main; main()" + + +@patch("airflow.sdk.execution_time.task_runner.getuser") +def test_task_run_with_user_impersonation_default_user( + mock_get_user, mocked_parse, make_ti_context, time_machine, mock_supervisor_comms +): + class CustomOperator(BaseOperator): + def execute(self, context): + print("Hi from CustomOperator!") + + task = CustomOperator(task_id="impersonation_task", run_as_user="default_user") + instant = timezone.datetime(2024, 12, 3, 10, 0) + + what = StartupDetails( + ti=TaskInstance( + id=uuid7(), + task_id="impersonation_task", + dag_id="basic_dag", + run_id="c", + try_number=1, + ), + dag_rel_path="", + bundle_info=FAKE_BUNDLE, + ti_context=make_ti_context(), + start_date=timezone.utcnow(), + ) + + mocked_parse(what, "basic_dag", task) + time_machine.move_to(instant, tick=False) + + mock_supervisor_comms._get_response.return_value = what + mock_supervisor_comms.socket.fileno.return_value = 42 + mock_get_user.return_value = "default_user" + + with mock.patch.dict(os.environ, {}, clear=True): + startup() + + assert "_AIRFLOW__REEXECUTED_PROCESS" not in os.environ + assert "_AIRFLOW__STARTUP_MSG" not in os.environ + + @pytest.mark.parametrize( ["command", "rendered_command"], [