diff --git a/providers/celery/src/airflow/providers/celery/cli/celery_command.py b/providers/celery/src/airflow/providers/celery/cli/celery_command.py index c3cdc793ace9e..997cbb4c8a56d 100644 --- a/providers/celery/src/airflow/providers/celery/cli/celery_command.py +++ b/providers/celery/src/airflow/providers/celery/cli/celery_command.py @@ -33,6 +33,7 @@ from lockfile.pidlockfile import read_pid_from_pidfile, remove_existing_pidfile from airflow import settings +from airflow.cli.simple_table import AirflowConsole from airflow.configuration import conf from airflow.exceptions import AirflowConfigException from airflow.providers.celery.version_compat import AIRFLOW_V_3_0_PLUS @@ -305,3 +306,93 @@ def stop_worker(args): # Remove pid file remove_existing_pidfile(pid_file_path) + + +@_providers_configuration_loaded +def _check_if_active_celery_worker(hostname: str): + """Check if celery worker is active before executing dependent cli commands.""" + # This needs to be imported locally to not trigger Providers Manager initialization + from airflow.providers.celery.executors.celery_executor import app as celery_app + + inspect = celery_app.control.inspect() + active_workers = inspect.active_queues() + if not active_workers: + raise SystemExit("Error: No active Celery workers found!") + if hostname not in active_workers: + raise SystemExit(f"Error: {hostname} is unknown!") + + +@cli_utils.action_cli +@_providers_configuration_loaded +def list_workers(args): + """List all active celery workers.""" + workers = [] + # This needs to be imported locally to not trigger Providers Manager initialization + from airflow.providers.celery.executors.celery_executor import app as celery_app + + inspect = celery_app.control.inspect() + active_workers = inspect.active_queues() + if active_workers: + workers = [ + { + "worker_name": worker, + "queues": [queue["name"] for queue in active_workers[worker] if "name" in queue], + } + for worker in active_workers + ] + AirflowConsole().print_as(data=workers, output=args.output) + + +@cli_utils.action_cli +@_providers_configuration_loaded +def shutdown_worker(args): + """Request graceful shutdown of a celery worker.""" + _check_if_active_celery_worker(hostname=args.celery_hostname) + # This needs to be imported locally to not trigger Providers Manager initialization + from airflow.providers.celery.executors.celery_executor import app as celery_app + + celery_app.control.shutdown(destination=[args.celery_hostname]) + + +@cli_utils.action_cli +@_providers_configuration_loaded +def shutdown_all_workers(args): + """Request graceful shutdown all celery workers.""" + if not ( + args.yes + or input( + "This will shutdown all active celery workers connected to the celery broker, this cannot be undone! Proceed? (y/n)" + ).upper() + == "Y" + ): + raise SystemExit("Cancelled") + # This needs to be imported locally to not trigger Providers Manager initialization + from airflow.providers.celery.executors.celery_executor import app as celery_app + + celery_app.control.broadcast("shutdown") + + +@cli_utils.action_cli +@_providers_configuration_loaded +def add_queue(args): + """Subscribe a Celery worker to specified queues.""" + _check_if_active_celery_worker(hostname=args.celery_hostname) + # This needs to be imported locally to not trigger Providers Manager initialization + from airflow.providers.celery.executors.celery_executor import app as celery_app + + queues = args.queues.split(",") + for queue in queues: + celery_app.control.add_consumer(queue, destination=[args.celery_hostname]) + + +@cli_utils.action_cli +@_providers_configuration_loaded +def remove_queue(args): + """Unsubscribe a Celery worker from specified queues.""" + _check_if_active_celery_worker(hostname=args.celery_hostname) + # This needs to be imported locally to not trigger Providers Manager initialization + from airflow.providers.celery.executors.celery_executor import app as celery_app + + queues = args.queues.split(",") + for queue in queues: + celery_app.control.cancel_consumer(queue, destination=[args.celery_hostname]) diff --git a/providers/celery/src/airflow/providers/celery/executors/celery_executor.py b/providers/celery/src/airflow/providers/celery/executors/celery_executor.py index 0a5aee9c0308b..3b97475bab2c2 100644 --- a/providers/celery/src/airflow/providers/celery/executors/celery_executor.py +++ b/providers/celery/src/airflow/providers/celery/executors/celery_executor.py @@ -157,6 +157,32 @@ def __getattr__(name): help="Don't subscribe to other workers events", action="store_true", ) +ARG_OUTPUT = Arg( + ( + "-o", + "--output", + ), + help="Output format. Allowed values: json, yaml, plain, table (default: table)", + metavar="(table, json, yaml, plain)", + choices=("table", "json", "yaml", "plain"), + default="table", +) +ARG_FULL_CELERY_HOSTNAME = Arg( + ("-H", "--celery-hostname"), + required=True, + help="Specify the full celery hostname. example: celery@hostname", +) +ARG_REQUIRED_QUEUES = Arg( + ("-q", "--queues"), + help="Comma delimited list of queues to serve", + required=True, +) +ARG_YES = Arg( + ("-y", "--yes"), + help="Do not prompt to confirm. Use with care!", + action="store_true", + default=False, +) CELERY_CLI_COMMAND_PATH = "airflow.providers.celery.cli.celery_command" @@ -207,6 +233,42 @@ def __getattr__(name): func=lazy_load_command(f"{CELERY_CLI_COMMAND_PATH}.stop_worker"), args=(ARG_PID, ARG_VERBOSE), ), + ActionCommand( + name="list-workers", + help="List active celery workers", + func=lazy_load_command(f"{CELERY_CLI_COMMAND_PATH}.list_workers"), + args=(ARG_OUTPUT,), + ), + ActionCommand( + name="shutdown-worker", + help="Request graceful shutdown of celery workers", + func=lazy_load_command(f"{CELERY_CLI_COMMAND_PATH}.shutdown_worker"), + args=(ARG_FULL_CELERY_HOSTNAME,), + ), + ActionCommand( + name="shutdown-all-workers", + help="Request graceful shutdown of all active celery workers", + func=lazy_load_command(f"{CELERY_CLI_COMMAND_PATH}.shutdown_all_workers"), + args=(ARG_YES,), + ), + ActionCommand( + name="add-queue", + help="Subscribe Celery worker to specified queues", + func=lazy_load_command(f"{CELERY_CLI_COMMAND_PATH}.add_queue"), + args=( + ARG_REQUIRED_QUEUES, + ARG_FULL_CELERY_HOSTNAME, + ), + ), + ActionCommand( + name="remove-queue", + help="Unsubscribe Celery worker from specified queues", + func=lazy_load_command(f"{CELERY_CLI_COMMAND_PATH}.remove_queue"), + args=( + ARG_REQUIRED_QUEUES, + ARG_FULL_CELERY_HOSTNAME, + ), + ), ) diff --git a/providers/celery/tests/unit/celery/cli/test_celery_command.py b/providers/celery/tests/unit/celery/cli/test_celery_command.py index cae229fd3b5f2..ae124f56fb63d 100644 --- a/providers/celery/tests/unit/celery/cli/test_celery_command.py +++ b/providers/celery/tests/unit/celery/cli/test_celery_command.py @@ -17,10 +17,13 @@ # under the License. from __future__ import annotations +import contextlib import importlib +import json import os +from io import StringIO from unittest import mock -from unittest.mock import patch +from unittest.mock import MagicMock, patch import pytest @@ -347,6 +350,73 @@ def test_run_command_daemon_v3_above( self._test_run_command_daemon(mock_celery_app, mock_daemon, mock_setup_locations, mock_pid_file) +class TestRemoteCeleryControlCommands: + @classmethod + def setup_class(cls): + with conf_vars({("core", "executor"): "CeleryExecutor"}): + importlib.reload(executor_loader) + importlib.reload(cli_parser) + cls.parser = cli_parser.get_parser() + + @pytest.mark.db_test + @mock.patch("airflow.providers.celery.executors.celery_executor.app.control.inspect") + def test_list_celery_workers(self, mock_inspect): + args = self.parser.parse_args(["celery", "list-workers", "--output", "json"]) + mock_instance = MagicMock() + mock_instance.active_queues.return_value = { + "celery@host_1": [{"name": "queue1"}, {"name": "queue2"}], + "celery@host_2": [{"name": "queue3"}], + } + mock_inspect.return_value = mock_instance + with contextlib.redirect_stdout(StringIO()) as temp_stdout: + celery_command.list_workers(args) + out = temp_stdout.getvalue() + celery_workers = json.loads(out) + for key in ["worker_name", "queues"]: + assert key in celery_workers[0] + assert any("celery@host_1" in h["worker_name"] for h in celery_workers) + + @pytest.mark.db_test + @mock.patch("airflow.providers.celery.executors.celery_executor.app.control.shutdown") + def test_shutdown_worker(self, mock_shutdown): + args = self.parser.parse_args(["celery", "shutdown-worker", "-H", "celery@host_1"]) + with patch( + "airflow.providers.celery.cli.celery_command._check_if_active_celery_worker", return_value=None + ): + celery_command.shutdown_worker(args) + mock_shutdown.assert_called_once_with(destination=["celery@host_1"]) + + @pytest.mark.db_test + @mock.patch("airflow.providers.celery.executors.celery_executor.app.control.broadcast") + def test_shutdown_all_workers(self, mock_broadcast): + args = self.parser.parse_args(["celery", "shutdown-all-workers", "-y"]) + with patch( + "airflow.providers.celery.cli.celery_command._check_if_active_celery_worker", return_value=None + ): + celery_command.shutdown_all_workers(args) + mock_broadcast.assert_called_once_with("shutdown") + + @pytest.mark.db_test + @mock.patch("airflow.providers.celery.executors.celery_executor.app.control.add_consumer") + def test_add_queue(self, mock_add_consumer): + args = self.parser.parse_args(["celery", "add-queue", "-q", "test1", "-H", "celery@host_1"]) + with patch( + "airflow.providers.celery.cli.celery_command._check_if_active_celery_worker", return_value=None + ): + celery_command.add_queue(args) + mock_add_consumer.assert_called_once_with("test1", destination=["celery@host_1"]) + + @pytest.mark.db_test + @mock.patch("airflow.providers.celery.executors.celery_executor.app.control.cancel_consumer") + def test_remove_queue(self, mock_cancel_consumer): + args = self.parser.parse_args(["celery", "remove-queue", "-q", "test1", "-H", "celery@host_1"]) + with patch( + "airflow.providers.celery.cli.celery_command._check_if_active_celery_worker", return_value=None + ): + celery_command.remove_queue(args) + mock_cancel_consumer.assert_called_once_with("test1", destination=["celery@host_1"]) + + @patch("airflow.providers.celery.cli.celery_command.Process") @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Doesn't apply to pre-3.0") def test_stale_bundle_cleanup(mock_process):