diff --git a/airflow-core/src/airflow/cli/cli_config.py b/airflow-core/src/airflow/cli/cli_config.py index 31bfc6b90f0d9..0e0bb1f44088d 100644 --- a/airflow-core/src/airflow/cli/cli_config.py +++ b/airflow-core/src/airflow/cli/cli_config.py @@ -669,7 +669,7 @@ def string_lower_type(val): default=conf.get("api", "ssl_key"), help="Path to the key to use with the SSL certificate", ) -ARG_DEV = Arg(("-d", "--dev"), help="Start FastAPI in development mode", action="store_true") +ARG_DEV = Arg(("-d", "--dev"), help="Start in development mode with hot-reload enabled", action="store_true") # scheduler ARG_NUM_RUNS = Arg( @@ -1923,6 +1923,7 @@ class GroupCommand(NamedTuple): ARG_LOG_FILE, ARG_SKIP_SERVE_LOGS, ARG_VERBOSE, + ARG_DEV, ), epilog=( "Signals:\n" @@ -1946,6 +1947,7 @@ class GroupCommand(NamedTuple): ARG_CAPACITY, ARG_VERBOSE, ARG_SKIP_SERVE_LOGS, + ARG_DEV, ), ), ActionCommand( @@ -1961,6 +1963,7 @@ class GroupCommand(NamedTuple): ARG_STDERR, ARG_LOG_FILE, ARG_VERBOSE, + ARG_DEV, ), ), ActionCommand( diff --git a/airflow-core/src/airflow/cli/commands/api_server_command.py b/airflow-core/src/airflow/cli/commands/api_server_command.py index 01844398f79b2..cdecbea9b8ea0 100644 --- a/airflow-core/src/airflow/cli/commands/api_server_command.py +++ b/airflow-core/src/airflow/cli/commands/api_server_command.py @@ -139,7 +139,7 @@ def api_server(args: Namespace): get_signing_args() - if args.dev: + if cli_utils.should_enable_hot_reload(args): print(f"Starting the API server on port {args.port} and host {args.host} in development mode.") log.warning("Running in dev mode, ignoring uvicorn args") from fastapi_cli.cli import _run diff --git a/airflow-core/src/airflow/cli/commands/dag_processor_command.py b/airflow-core/src/airflow/cli/commands/dag_processor_command.py index 64e36a864fc3f..f4c303c278dbb 100644 --- a/airflow-core/src/airflow/cli/commands/dag_processor_command.py +++ b/airflow-core/src/airflow/cli/commands/dag_processor_command.py @@ -52,6 +52,15 @@ def dag_processor(args): """Start Airflow Dag Processor Job.""" job_runner = _create_dag_processor_job_runner(args) + if cli_utils.should_enable_hot_reload(args): + from airflow.cli.hot_reload import run_with_reloader + + run_with_reloader( + lambda: run_job(job=job_runner.job, execute_callable=job_runner._execute), + process_name="dag-processor", + ) + return + run_command_with_daemon_option( args=args, process_name="dag-processor", diff --git a/airflow-core/src/airflow/cli/commands/scheduler_command.py b/airflow-core/src/airflow/cli/commands/scheduler_command.py index fcdd9b2b78b4a..089d7dd51967d 100644 --- a/airflow-core/src/airflow/cli/commands/scheduler_command.py +++ b/airflow-core/src/airflow/cli/commands/scheduler_command.py @@ -51,6 +51,12 @@ def scheduler(args: Namespace): """Start Airflow Scheduler.""" print(settings.HEADER) + if cli_utils.should_enable_hot_reload(args): + from airflow.cli.hot_reload import run_with_reloader + + run_with_reloader(lambda: _run_scheduler_job(args), process_name="scheduler") + return + run_command_with_daemon_option( args=args, process_name="scheduler", diff --git a/airflow-core/src/airflow/cli/commands/triggerer_command.py b/airflow-core/src/airflow/cli/commands/triggerer_command.py index eedc4a49e50c2..a293d44603051 100644 --- a/airflow-core/src/airflow/cli/commands/triggerer_command.py +++ b/airflow-core/src/airflow/cli/commands/triggerer_command.py @@ -66,6 +66,15 @@ def triggerer(args): print(settings.HEADER) triggerer_heartrate = conf.getfloat("triggerer", "JOB_HEARTBEAT_SEC") + if cli_utils.should_enable_hot_reload(args): + from airflow.cli.hot_reload import run_with_reloader + + run_with_reloader( + lambda: triggerer_run(args.skip_serve_logs, args.capacity, triggerer_heartrate), + process_name="triggerer", + ) + return + run_command_with_daemon_option( args=args, process_name="triggerer", diff --git a/airflow-core/src/airflow/cli/hot_reload.py b/airflow-core/src/airflow/cli/hot_reload.py new file mode 100644 index 0000000000000..ecf64c2781dc2 --- /dev/null +++ b/airflow-core/src/airflow/cli/hot_reload.py @@ -0,0 +1,197 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Hot reload utilities for development mode.""" + +from __future__ import annotations + +import os +import signal +import sys +from collections.abc import Callable, Sequence +from pathlib import Path +from typing import TYPE_CHECKING + +import structlog + +if TYPE_CHECKING: + import subprocess + +log = structlog.getLogger(__name__) + + +def run_with_reloader( + callback: Callable, + process_name: str = "process", +) -> None: + """ + Run a callback function with automatic reloading on file changes. + + This function monitors specified paths for changes and restarts the process + when changes are detected. Useful for development mode hot-reloading. + + :param callback: The function to run. This should be the main entry point + of the command that needs hot-reload support. + :param process_name: Name of the process being run (for logging purposes) + """ + # Default watch paths - watch the airflow source directory + import airflow + + airflow_root = Path(airflow.__file__).parent + watch_paths = [airflow_root] + + log.info("Starting %s in development mode with hot-reload enabled", process_name) + log.info("Watching paths: %s", watch_paths) + + # Check if we're the main process or a reloaded child + reloader_pid = os.environ.get("AIRFLOW_DEV_RELOADER_PID") + if reloader_pid is None: + # We're the main process - set up the reloader + os.environ["AIRFLOW_DEV_RELOADER_PID"] = str(os.getpid()) + _run_reloader(watch_paths) + else: + # We're a child process - just run the callback + callback() + + +def _terminate_process_tree( + process: subprocess.Popen[bytes], + timeout: int = 5, + force_kill_remaining: bool = True, +) -> None: + """ + Terminate a process and all its children recursively. + + Uses psutil to ensure all child processes are properly terminated, + which is important for cleaning up subprocesses like serve-log servers. + + :param process: The subprocess.Popen process to terminate + :param timeout: Timeout in seconds to wait for graceful termination + :param force_kill_remaining: If True, force kill processes that don't terminate gracefully + """ + import subprocess + + import psutil + + try: + parent = psutil.Process(process.pid) + # Get all child processes recursively + children = parent.children(recursive=True) + + # Terminate all children first + for child in children: + try: + child.terminate() + except (psutil.NoSuchProcess, psutil.AccessDenied): + pass + + # Terminate the parent + parent.terminate() + + # Wait for all processes to terminate + gone, alive = psutil.wait_procs(children + [parent], timeout=timeout) + + # Force kill any remaining processes if requested + if force_kill_remaining: + for proc in alive: + try: + log.warning("Force killing process %s", proc.pid) + proc.kill() + except (psutil.NoSuchProcess, psutil.AccessDenied): + pass + + except (psutil.NoSuchProcess, psutil.AccessDenied): + # Process already terminated + pass + except Exception as e: + log.warning("Error terminating process tree: %s", e) + # Fallback to simple termination + try: + process.terminate() + process.wait(timeout=timeout) + except subprocess.TimeoutExpired: + if force_kill_remaining: + log.warning("Process did not terminate gracefully, killing...") + process.kill() + process.wait() + + +def _run_reloader(watch_paths: Sequence[str | Path]) -> None: + """ + Watch for changes and restart the process. + + Watches the provided paths and restarts the process by re-executing the + Python interpreter with the same arguments. + + :param watch_paths: List of paths to watch for changes. + """ + import subprocess + + from watchfiles import watch + + process = None + should_exit = False + + def start_process(): + """Start or restart the subprocess.""" + nonlocal process + if process is not None: + log.info("Stopping process and all its children...") + _terminate_process_tree(process, timeout=5, force_kill_remaining=True) + + log.info("Starting process...") + # Restart the process by re-executing Python with the same arguments + # Note: sys.argv is safe here as it comes from the original CLI invocation + # and is only used in development mode for hot-reloading the same process + process = subprocess.Popen([sys.executable] + sys.argv) + return process + + def signal_handler(signum, frame): + """Handle termination signals.""" + nonlocal should_exit, process + should_exit = True + log.info("Received signal %s, shutting down...", signum) + if process: + _terminate_process_tree(process, timeout=5, force_kill_remaining=False) + sys.exit(0) + + # Set up signal handlers + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTERM, signal_handler) + + # Start the initial process + process = start_process() + + log.info("Hot-reload enabled. Watching for file changes...") + log.info("Press Ctrl+C to stop") + + try: + for changes in watch(*watch_paths): + if should_exit: + break + + log.info("Detected changes: %s", changes) + log.info("Reloading...") + + # Restart the process + process = start_process() + + except KeyboardInterrupt: + log.info("Shutting down...") + if process: + process.terminate() + process.wait() diff --git a/airflow-core/src/airflow/utils/cli.py b/airflow-core/src/airflow/utils/cli.py index 17cad26771c1d..ed9eb7e84f633 100644 --- a/airflow-core/src/airflow/utils/cli.py +++ b/airflow-core/src/airflow/utils/cli.py @@ -472,3 +472,10 @@ def validate_dag_bundle_arg(bundle_names: list[str]) -> None: unknown_bundles: set[str] = set(bundle_names) - known_bundles if unknown_bundles: raise SystemExit(f"Bundles not found: {', '.join(unknown_bundles)}") + + +def should_enable_hot_reload(args) -> bool: + """Check whether hot-reload should be enabled based on --dev flag or DEV_MODE env var.""" + if getattr(args, "dev", False): + return True + return os.getenv("DEV_MODE", "false").lower() == "true" diff --git a/airflow-core/tests/unit/cli/commands/test_dag_processor_command.py b/airflow-core/tests/unit/cli/commands/test_dag_processor_command.py index 3e0e931587cd3..6224be2c5d59a 100644 --- a/airflow-core/tests/unit/cli/commands/test_dag_processor_command.py +++ b/airflow-core/tests/unit/cli/commands/test_dag_processor_command.py @@ -56,3 +56,14 @@ def test_bundle_names_passed(self, mock_runner, configure_testing_dag_bundle): with configure_testing_dag_bundle(os.devnull): dag_processor_command.dag_processor(args) assert mock_runner.call_args.kwargs["processor"].bundle_names_to_parse == ["testing"] + + @mock.patch("airflow.cli.hot_reload.run_with_reloader") + def test_dag_processor_with_dev_flag(self, mock_reloader): + """Ensure that dag-processor with --dev flag uses hot-reload""" + args = self.parser.parse_args(["dag-processor", "--dev"]) + dag_processor_command.dag_processor(args) + + # Verify that run_with_reloader was called + mock_reloader.assert_called_once() + # The callback function should be callable + assert callable(mock_reloader.call_args[0][0]) diff --git a/airflow-core/tests/unit/cli/commands/test_scheduler_command.py b/airflow-core/tests/unit/cli/commands/test_scheduler_command.py index 1a9e7cdc24144..e381f9a754f0e 100644 --- a/airflow-core/tests/unit/cli/commands/test_scheduler_command.py +++ b/airflow-core/tests/unit/cli/commands/test_scheduler_command.py @@ -163,3 +163,13 @@ def test_run_job_exception_handling(self, mock_run_job, mock_process, mock_sched ) mock_process.assert_called_once_with(target=serve_logs) mock_process().terminate.assert_called_once_with() + + @mock.patch("airflow.cli.hot_reload.run_with_reloader") + def test_scheduler_with_dev_flag(self, mock_reloader): + args = self.parser.parse_args(["scheduler", "--dev"]) + scheduler_command.scheduler(args) + + # Verify that run_with_reloader was called + mock_reloader.assert_called_once() + # The callback function should be callable + assert callable(mock_reloader.call_args[0][0]) diff --git a/airflow-core/tests/unit/cli/commands/test_triggerer_command.py b/airflow-core/tests/unit/cli/commands/test_triggerer_command.py index b5222038f2de5..44120f27fc981 100644 --- a/airflow-core/tests/unit/cli/commands/test_triggerer_command.py +++ b/airflow-core/tests/unit/cli/commands/test_triggerer_command.py @@ -63,3 +63,14 @@ def test_trigger_run_serve_logs(self, mock_process, mock_run_job, mock_trigger_j job=mock_trigger_job_runner.return_value.job, execute_callable=mock_trigger_job_runner.return_value._execute, ) + + @mock.patch("airflow.cli.hot_reload.run_with_reloader") + def test_triggerer_with_dev_flag(self, mock_reloader): + """Ensure that triggerer with --dev flag uses hot-reload""" + args = self.parser.parse_args(["triggerer", "--dev"]) + triggerer_command.triggerer(args) + + # Verify that run_with_reloader was called + mock_reloader.assert_called_once() + # The callback function should be callable + assert callable(mock_reloader.call_args[0][0]) diff --git a/airflow-core/tests/unit/cli/test_hot_reload.py b/airflow-core/tests/unit/cli/test_hot_reload.py new file mode 100644 index 0000000000000..952671ab2367e --- /dev/null +++ b/airflow-core/tests/unit/cli/test_hot_reload.py @@ -0,0 +1,100 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import os +import sys +from unittest import mock + +import pytest + +from airflow.cli import hot_reload + + +class TestHotReload: + """Tests for hot reload utilities.""" + + @mock.patch("airflow.cli.hot_reload._run_reloader") + def test_run_with_reloader_missing_watchfiles(self, mock_run_reloader): + """Test that run_with_reloader handles missing watchfiles by raising ImportError.""" + # Simulate watchfiles not being available when _run_reloader tries to import it + mock_run_reloader.side_effect = ImportError("No module named 'watchfiles'") + + # Clear the reloader PID env var to simulate being the main process + with mock.patch.dict(os.environ, {}, clear=True): + with pytest.raises(ImportError): + hot_reload.run_with_reloader(lambda: None) + + @mock.patch("airflow.cli.hot_reload._run_reloader") + def test_run_with_reloader_main_process(self, mock_run_reloader): + """Test run_with_reloader as the main process.""" + # Clear the reloader PID env var to simulate being the main process + with mock.patch.dict(os.environ, {}, clear=True): + callback = mock.Mock() + + hot_reload.run_with_reloader(callback) + + # Should set the env var and call _run_reloader + assert "AIRFLOW_DEV_RELOADER_PID" in os.environ + mock_run_reloader.assert_called_once() + + def test_run_with_reloader_child_process(self): + """Test run_with_reloader as a child process.""" + # Set the reloader PID env var to simulate being a child process + with mock.patch.dict(os.environ, {"AIRFLOW_DEV_RELOADER_PID": "12345"}): + callback = mock.Mock() + hot_reload.run_with_reloader(callback) + + # Should just call the callback directly + callback.assert_called_once() + + @mock.patch("subprocess.Popen") + @mock.patch("watchfiles.watch") + def test_run_reloader_starts_process(self, mock_watch, mock_popen): + """Test that _run_reloader starts a subprocess.""" + mock_process = mock.Mock() + mock_popen.return_value = mock_process + mock_watch.return_value = [] # Empty iterator, will exit immediately + + watch_paths = ["/tmp/test"] + + hot_reload._run_reloader(watch_paths) + + # Should have started a process + mock_popen.assert_called_once() + assert mock_popen.call_args[0][0] == [sys.executable] + sys.argv + + @mock.patch("airflow.cli.hot_reload._terminate_process_tree") + @mock.patch("subprocess.Popen") + @mock.patch("watchfiles.watch") + def test_run_reloader_restarts_on_changes(self, mock_watch, mock_popen, mock_terminate): + """Test that _run_reloader restarts the process on file changes.""" + mock_process = mock.Mock() + mock_popen.return_value = mock_process + + # Simulate one file change and then exit + mock_watch.return_value = iter([[("change", "/tmp/test/file.py")]]) + + watch_paths = ["/tmp/test"] + + hot_reload._run_reloader(watch_paths) + + # Should have started process twice (initial + restart) + assert mock_popen.call_count == 2 + # Should have terminated the first process + mock_terminate.assert_called() diff --git a/airflow-core/tests/unit/utils/test_cli_util.py b/airflow-core/tests/unit/utils/test_cli_util.py index 40ab3080c52b6..49200510f6844 100644 --- a/airflow-core/tests/unit/utils/test_cli_util.py +++ b/airflow-core/tests/unit/utils/test_cli_util.py @@ -289,3 +289,31 @@ def test_validate_dag_bundle_arg(): # doesn't raise cli.validate_dag_bundle_arg(["dags-folder"]) + + +@pytest.mark.parametrize( + ["dev_flag", "env_var", "expected"], + [ + # --dev flag tests + (True, None, True), + (False, None, False), + (None, None, False), # no dev flag attribute + # DEV_MODE env var tests + (False, "true", True), + (False, "false", False), + (False, "TRUE", True), + (False, "True", True), + # --dev flag takes precedence + (True, "false", True), + # Invalid env var values + (False, "yes", False), + (False, "1", False), + ], +) +def test_should_enable_hot_reload(dev_flag, env_var, expected): + """Test should_enable_hot_reload with various --dev flag and DEV_MODE env var combinations.""" + args = Namespace() if dev_flag is None else Namespace(dev=dev_flag) + env = {} if env_var is None else {"DEV_MODE": env_var} + + with mock.patch.dict(os.environ, env, clear=True): + assert cli.should_enable_hot_reload(args) is expected