Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from lockfile.pidlockfile import read_pid_from_pidfile, remove_existing_pidfile

from airflow import settings
from airflow.cli.simple_table import AirflowConsole
from airflow.configuration import conf
from airflow.exceptions import AirflowConfigException
from airflow.providers.celery.version_compat import AIRFLOW_V_3_0_PLUS
Expand Down Expand Up @@ -305,3 +306,93 @@ def stop_worker(args):

# Remove pid file
remove_existing_pidfile(pid_file_path)


@_providers_configuration_loaded
def _check_if_active_celery_worker(hostname: str):
"""Check if celery worker is active before executing dependent cli commands."""
# This needs to be imported locally to not trigger Providers Manager initialization
from airflow.providers.celery.executors.celery_executor import app as celery_app

inspect = celery_app.control.inspect()
active_workers = inspect.active_queues()
if not active_workers:
raise SystemExit("Error: No active Celery workers found!")
if hostname not in active_workers:
raise SystemExit(f"Error: {hostname} is unknown!")


@cli_utils.action_cli
@_providers_configuration_loaded
def list_workers(args):
"""List all active celery workers."""
workers = []
# This needs to be imported locally to not trigger Providers Manager initialization
from airflow.providers.celery.executors.celery_executor import app as celery_app

inspect = celery_app.control.inspect()
active_workers = inspect.active_queues()
if active_workers:
workers = [
{
"worker_name": worker,
"queues": [queue["name"] for queue in active_workers[worker] if "name" in queue],
}
for worker in active_workers
]
AirflowConsole().print_as(data=workers, output=args.output)


@cli_utils.action_cli
@_providers_configuration_loaded
def shutdown_worker(args):
"""Request graceful shutdown of a celery worker."""
_check_if_active_celery_worker(hostname=args.celery_hostname)
# This needs to be imported locally to not trigger Providers Manager initialization
from airflow.providers.celery.executors.celery_executor import app as celery_app

celery_app.control.shutdown(destination=[args.celery_hostname])


@cli_utils.action_cli
@_providers_configuration_loaded
def shutdown_all_workers(args):
"""Request graceful shutdown all celery workers."""
if not (
args.yes
or input(
"This will shutdown all active celery workers connected to the celery broker, this cannot be undone! Proceed? (y/n)"
).upper()
== "Y"
):
raise SystemExit("Cancelled")
# This needs to be imported locally to not trigger Providers Manager initialization
from airflow.providers.celery.executors.celery_executor import app as celery_app

celery_app.control.broadcast("shutdown")


@cli_utils.action_cli
@_providers_configuration_loaded
def add_queue(args):
"""Subscribe a Celery worker to specified queues."""
_check_if_active_celery_worker(hostname=args.celery_hostname)
# This needs to be imported locally to not trigger Providers Manager initialization
from airflow.providers.celery.executors.celery_executor import app as celery_app

queues = args.queues.split(",")
for queue in queues:
celery_app.control.add_consumer(queue, destination=[args.celery_hostname])


@cli_utils.action_cli
@_providers_configuration_loaded
def remove_queue(args):
"""Unsubscribe a Celery worker from specified queues."""
_check_if_active_celery_worker(hostname=args.celery_hostname)
# This needs to be imported locally to not trigger Providers Manager initialization
from airflow.providers.celery.executors.celery_executor import app as celery_app

queues = args.queues.split(",")
for queue in queues:
celery_app.control.cancel_consumer(queue, destination=[args.celery_hostname])
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,32 @@ def __getattr__(name):
help="Don't subscribe to other workers events",
action="store_true",
)
ARG_OUTPUT = Arg(
(
"-o",
"--output",
),
help="Output format. Allowed values: json, yaml, plain, table (default: table)",
metavar="(table, json, yaml, plain)",
choices=("table", "json", "yaml", "plain"),
default="table",
)
ARG_FULL_CELERY_HOSTNAME = Arg(
("-H", "--celery-hostname"),
required=True,
help="Specify the full celery hostname. example: celery@hostname",
)
ARG_REQUIRED_QUEUES = Arg(
("-q", "--queues"),
help="Comma delimited list of queues to serve",
required=True,
)
ARG_YES = Arg(
("-y", "--yes"),
help="Do not prompt to confirm. Use with care!",
action="store_true",
default=False,
)

CELERY_CLI_COMMAND_PATH = "airflow.providers.celery.cli.celery_command"

Expand Down Expand Up @@ -207,6 +233,42 @@ def __getattr__(name):
func=lazy_load_command(f"{CELERY_CLI_COMMAND_PATH}.stop_worker"),
args=(ARG_PID, ARG_VERBOSE),
),
ActionCommand(
name="list-workers",
help="List active celery workers",
func=lazy_load_command(f"{CELERY_CLI_COMMAND_PATH}.list_workers"),
args=(ARG_OUTPUT,),
),
ActionCommand(
name="shutdown-worker",
help="Request graceful shutdown of celery workers",
func=lazy_load_command(f"{CELERY_CLI_COMMAND_PATH}.shutdown_worker"),
args=(ARG_FULL_CELERY_HOSTNAME,),
),
ActionCommand(
name="shutdown-all-workers",
help="Request graceful shutdown of all active celery workers",
func=lazy_load_command(f"{CELERY_CLI_COMMAND_PATH}.shutdown_all_workers"),
args=(ARG_YES,),
),
ActionCommand(
name="add-queue",
help="Subscribe Celery worker to specified queues",
func=lazy_load_command(f"{CELERY_CLI_COMMAND_PATH}.add_queue"),
args=(
ARG_REQUIRED_QUEUES,
ARG_FULL_CELERY_HOSTNAME,
),
),
ActionCommand(
name="remove-queue",
help="Unsubscribe Celery worker from specified queues",
func=lazy_load_command(f"{CELERY_CLI_COMMAND_PATH}.remove_queue"),
args=(
ARG_REQUIRED_QUEUES,
ARG_FULL_CELERY_HOSTNAME,
),
),
)


Expand Down
72 changes: 71 additions & 1 deletion providers/celery/tests/unit/celery/cli/test_celery_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@
# under the License.
from __future__ import annotations

import contextlib
import importlib
import json
import os
from io import StringIO
from unittest import mock
from unittest.mock import patch
from unittest.mock import MagicMock, patch

import pytest

Expand Down Expand Up @@ -347,6 +350,73 @@ def test_run_command_daemon_v3_above(
self._test_run_command_daemon(mock_celery_app, mock_daemon, mock_setup_locations, mock_pid_file)


class TestRemoteCeleryControlCommands:
@classmethod
def setup_class(cls):
with conf_vars({("core", "executor"): "CeleryExecutor"}):
importlib.reload(executor_loader)
importlib.reload(cli_parser)
cls.parser = cli_parser.get_parser()

@pytest.mark.db_test
@mock.patch("airflow.providers.celery.executors.celery_executor.app.control.inspect")
def test_list_celery_workers(self, mock_inspect):
args = self.parser.parse_args(["celery", "list-workers", "--output", "json"])
mock_instance = MagicMock()
mock_instance.active_queues.return_value = {
"celery@host_1": [{"name": "queue1"}, {"name": "queue2"}],
"celery@host_2": [{"name": "queue3"}],
}
mock_inspect.return_value = mock_instance
with contextlib.redirect_stdout(StringIO()) as temp_stdout:
celery_command.list_workers(args)
out = temp_stdout.getvalue()
celery_workers = json.loads(out)
for key in ["worker_name", "queues"]:
assert key in celery_workers[0]
assert any("celery@host_1" in h["worker_name"] for h in celery_workers)

@pytest.mark.db_test
@mock.patch("airflow.providers.celery.executors.celery_executor.app.control.shutdown")
def test_shutdown_worker(self, mock_shutdown):
args = self.parser.parse_args(["celery", "shutdown-worker", "-H", "celery@host_1"])
with patch(
"airflow.providers.celery.cli.celery_command._check_if_active_celery_worker", return_value=None
):
celery_command.shutdown_worker(args)
mock_shutdown.assert_called_once_with(destination=["celery@host_1"])

@pytest.mark.db_test
@mock.patch("airflow.providers.celery.executors.celery_executor.app.control.broadcast")
def test_shutdown_all_workers(self, mock_broadcast):
args = self.parser.parse_args(["celery", "shutdown-all-workers", "-y"])
with patch(
"airflow.providers.celery.cli.celery_command._check_if_active_celery_worker", return_value=None
):
celery_command.shutdown_all_workers(args)
mock_broadcast.assert_called_once_with("shutdown")

@pytest.mark.db_test
@mock.patch("airflow.providers.celery.executors.celery_executor.app.control.add_consumer")
def test_add_queue(self, mock_add_consumer):
args = self.parser.parse_args(["celery", "add-queue", "-q", "test1", "-H", "celery@host_1"])
with patch(
"airflow.providers.celery.cli.celery_command._check_if_active_celery_worker", return_value=None
):
celery_command.add_queue(args)
mock_add_consumer.assert_called_once_with("test1", destination=["celery@host_1"])

@pytest.mark.db_test
@mock.patch("airflow.providers.celery.executors.celery_executor.app.control.cancel_consumer")
def test_remove_queue(self, mock_cancel_consumer):
args = self.parser.parse_args(["celery", "remove-queue", "-q", "test1", "-H", "celery@host_1"])
with patch(
"airflow.providers.celery.cli.celery_command._check_if_active_celery_worker", return_value=None
):
celery_command.remove_queue(args)
mock_cancel_consumer.assert_called_once_with("test1", destination=["celery@host_1"])


@patch("airflow.providers.celery.cli.celery_command.Process")
@pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Doesn't apply to pre-3.0")
def test_stale_bundle_cleanup(mock_process):
Expand Down
Loading