Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -396,3 +396,35 @@ def remove_queue(args):
queues = args.queues.split(",")
for queue in queues:
celery_app.control.cancel_consumer(queue, destination=[args.celery_hostname])


@cli_utils.action_cli(check_db=False)
@_providers_configuration_loaded
def remove_all_queues(args):
"""Unsubscribe a Celery worker from all its active queues."""
_check_if_active_celery_worker(hostname=args.celery_hostname)
# This needs to be imported locally to not trigger Providers Manager initialization
from airflow.providers.celery.executors.celery_executor import app as celery_app

inspect = celery_app.control.inspect()
active_workers = inspect.active_queues()

if not active_workers or args.celery_hostname not in active_workers:
print(f"No active queues found for worker: {args.celery_hostname}")
return

worker_queues = active_workers[args.celery_hostname]
queue_names = [queue["name"] for queue in worker_queues if "name" in queue]

if not queue_names:
print(f"No queues to remove for worker: {args.celery_hostname}")
return

print(
f"Removing {len(queue_names)} queue(s) from worker {args.celery_hostname}: {', '.join(queue_names)}"
)

for queue_name in queue_names:
celery_app.control.cancel_consumer(queue_name, destination=[args.celery_hostname])

print(f"Successfully removed all queues from worker: {args.celery_hostname}")
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,12 @@ def __getattr__(name):
ARG_FULL_CELERY_HOSTNAME,
),
),
ActionCommand(
name="remove-all-queues",
help="Unsubscribe Celery worker from all its active queues",
func=lazy_load_command(f"{CELERY_CLI_COMMAND_PATH}.remove_all_queues"),
args=(ARG_FULL_CELERY_HOSTNAME,),
),
)


Expand Down
23 changes: 23 additions & 0 deletions providers/celery/tests/unit/celery/cli/test_celery_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,29 @@ def test_remove_queue(self, mock_cancel_consumer):
celery_command.remove_queue(args)
mock_cancel_consumer.assert_called_once_with("test1", destination=["celery@host_1"])

@pytest.mark.db_test
@mock.patch("airflow.providers.celery.executors.celery_executor.app.control.cancel_consumer")
@mock.patch("airflow.providers.celery.executors.celery_executor.app.control.inspect")
def test_remove_all_queues(self, mock_inspect, mock_cancel_consumer):
args = self.parser.parse_args(["celery", "remove-all-queues", "-H", "celery@host_1"])
mock_instance = MagicMock()
mock_instance.active_queues.return_value = {
"celery@host_1": [{"name": "queue1"}, {"name": "queue2"}],
"celery@host_2": [{"name": "queue3"}],
}
mock_inspect.return_value = mock_instance
with patch(
"airflow.providers.celery.cli.celery_command._check_if_active_celery_worker", return_value=None
):
celery_command.remove_all_queues(args)
# Verify cancel_consumer was called for each queue
expected_calls = [
mock.call("queue1", destination=["celery@host_1"]),
mock.call("queue2", destination=["celery@host_1"]),
]
mock_cancel_consumer.assert_has_calls(expected_calls, any_order=True)
assert mock_cancel_consumer.call_count == 2


@patch("airflow.providers.celery.cli.celery_command.Process")
@pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Doesn't apply to pre-3.0")
Expand Down
Loading