diff --git a/providers/celery/src/airflow/providers/celery/cli/celery_command.py b/providers/celery/src/airflow/providers/celery/cli/celery_command.py index 3022e7b474be1..5a79d260e868b 100644 --- a/providers/celery/src/airflow/providers/celery/cli/celery_command.py +++ b/providers/celery/src/airflow/providers/celery/cli/celery_command.py @@ -396,3 +396,35 @@ def remove_queue(args): queues = args.queues.split(",") for queue in queues: celery_app.control.cancel_consumer(queue, destination=[args.celery_hostname]) + + +@cli_utils.action_cli(check_db=False) +@_providers_configuration_loaded +def remove_all_queues(args): + """Unsubscribe a Celery worker from all its active queues.""" + _check_if_active_celery_worker(hostname=args.celery_hostname) + # This needs to be imported locally to not trigger Providers Manager initialization + from airflow.providers.celery.executors.celery_executor import app as celery_app + + inspect = celery_app.control.inspect() + active_workers = inspect.active_queues() + + if not active_workers or args.celery_hostname not in active_workers: + print(f"No active queues found for worker: {args.celery_hostname}") + return + + worker_queues = active_workers[args.celery_hostname] + queue_names = [queue["name"] for queue in worker_queues if "name" in queue] + + if not queue_names: + print(f"No queues to remove for worker: {args.celery_hostname}") + return + + print( + f"Removing {len(queue_names)} queue(s) from worker {args.celery_hostname}: {', '.join(queue_names)}" + ) + + for queue_name in queue_names: + celery_app.control.cancel_consumer(queue_name, destination=[args.celery_hostname]) + + print(f"Successfully removed all queues from worker: {args.celery_hostname}") diff --git a/providers/celery/src/airflow/providers/celery/executors/celery_executor.py b/providers/celery/src/airflow/providers/celery/executors/celery_executor.py index 6091dc5c124d3..13b5ce3237f4b 100644 --- a/providers/celery/src/airflow/providers/celery/executors/celery_executor.py +++ b/providers/celery/src/airflow/providers/celery/executors/celery_executor.py @@ -268,6 +268,12 @@ def __getattr__(name): ARG_FULL_CELERY_HOSTNAME, ), ), + ActionCommand( + name="remove-all-queues", + help="Unsubscribe Celery worker from all its active queues", + func=lazy_load_command(f"{CELERY_CLI_COMMAND_PATH}.remove_all_queues"), + args=(ARG_FULL_CELERY_HOSTNAME,), + ), ) diff --git a/providers/celery/tests/unit/celery/cli/test_celery_command.py b/providers/celery/tests/unit/celery/cli/test_celery_command.py index 2bd0755c32c42..e4aedfe45e431 100644 --- a/providers/celery/tests/unit/celery/cli/test_celery_command.py +++ b/providers/celery/tests/unit/celery/cli/test_celery_command.py @@ -439,6 +439,29 @@ def test_remove_queue(self, mock_cancel_consumer): celery_command.remove_queue(args) mock_cancel_consumer.assert_called_once_with("test1", destination=["celery@host_1"]) + @pytest.mark.db_test + @mock.patch("airflow.providers.celery.executors.celery_executor.app.control.cancel_consumer") + @mock.patch("airflow.providers.celery.executors.celery_executor.app.control.inspect") + def test_remove_all_queues(self, mock_inspect, mock_cancel_consumer): + args = self.parser.parse_args(["celery", "remove-all-queues", "-H", "celery@host_1"]) + mock_instance = MagicMock() + mock_instance.active_queues.return_value = { + "celery@host_1": [{"name": "queue1"}, {"name": "queue2"}], + "celery@host_2": [{"name": "queue3"}], + } + mock_inspect.return_value = mock_instance + with patch( + "airflow.providers.celery.cli.celery_command._check_if_active_celery_worker", return_value=None + ): + celery_command.remove_all_queues(args) + # Verify cancel_consumer was called for each queue + expected_calls = [ + mock.call("queue1", destination=["celery@host_1"]), + mock.call("queue2", destination=["celery@host_1"]), + ] + mock_cancel_consumer.assert_has_calls(expected_calls, any_order=True) + assert mock_cancel_consumer.call_count == 2 + @patch("airflow.providers.celery.cli.celery_command.Process") @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Doesn't apply to pre-3.0")