diff --git a/providers/edge3/src/airflow/providers/edge3/cli/edge_command.py b/providers/edge3/src/airflow/providers/edge3/cli/edge_command.py index 40f8a4301320d..8ff5e953dd0cb 100644 --- a/providers/edge3/src/airflow/providers/edge3/cli/edge_command.py +++ b/providers/edge3/src/airflow/providers/edge3/cli/edge_command.py @@ -23,6 +23,7 @@ import sys from dataclasses import asdict from datetime import datetime +from getpass import getuser from http import HTTPStatus from multiprocessing import Process from pathlib import Path @@ -36,6 +37,7 @@ from airflow import __version__ as airflow_version, settings from airflow.cli.cli_config import ARG_PID, ARG_VERBOSE, ActionCommand, Arg +from airflow.cli.simple_table import AirflowConsole from airflow.configuration import conf from airflow.providers.edge3 import __version__ as edge_provider_version from airflow.providers.edge3.cli.api_client import ( @@ -508,7 +510,7 @@ def worker(args): @cli_utils.action_cli(check_db=False) @providers_configuration_loaded def status(args): - """Check for Airflow Edge Worker status.""" + """Check for Airflow Local Edge Worker status.""" pid = _get_pid(args.pid) # Send Signal as notification to drop status JSON @@ -534,7 +536,7 @@ def status(args): @cli_utils.action_cli(check_db=False) @providers_configuration_loaded def maintenance(args): - """Set or Unset maintenance mode of worker.""" + """Set or Unset maintenance mode of local edge worker.""" if args.maintenance == "on" and not args.comments: logger.error("Comments are required when setting maintenance mode.") sys.exit(4) @@ -605,7 +607,7 @@ def maintenance(args): @cli_utils.action_cli(check_db=False) @providers_configuration_loaded def stop(args): - """Stop a running Airflow Edge Worker.""" + """Stop a running local Airflow Edge Worker.""" pid = _get_pid(args.pid) # Send SIGINT logger.info("Sending SIGINT to worker pid %i.", pid) @@ -619,6 +621,97 @@ def stop(args): logger.info("Worker has been shut down.") +def _check_valid_db_connection(): + """Check for a valid db connection before executing db dependent cli commands.""" + db_conn = conf.get("database", "sql_alchemy_conn") + db_default = conf.get_default_value("database", "sql_alchemy_conn") + if db_conn == db_default: + raise SystemExit( + "Error: The database connection is not set. Please set the connection in the configuration file." + ) + + +def _check_if_registered_edge_host(hostname: str): + """Check if edge worker is registered with the db before executing dependent cli commands.""" + from airflow.providers.edge3.models.edge_worker import _fetch_edge_hosts_from_db + + if not _fetch_edge_hosts_from_db(hostname=hostname): + raise SystemExit(f"Error: Edge Worker {hostname} is unknown!") + + +@cli_utils.action_cli(check_db=False) +@providers_configuration_loaded +def list_edge_workers(args) -> None: + """Query the db to list all registered edge workers.""" + _check_valid_db_connection() + from airflow.providers.edge3.models.edge_worker import get_registered_edge_hosts + + all_hosts_iter = get_registered_edge_hosts(states=args.state) + # Format and print worker info on the screen + fields = [ + "worker_name", + "state", + "queues", + "maintenance_comment", + ] + all_hosts = [{f: host.__getattribute__(f) for f in fields} for host in all_hosts_iter] + AirflowConsole().print_as(data=all_hosts, output=args.output) + + +@cli_utils.action_cli(check_db=False) +@providers_configuration_loaded +def put_remote_worker_on_maintenance(args) -> None: + """Put remote edge worker on maintenance.""" + _check_valid_db_connection() + _check_if_registered_edge_host(hostname=args.edge_hostname) + from airflow.providers.edge3.models.edge_worker import request_maintenance + + request_maintenance(args.edge_hostname, args.comments) + logger.info("%s has been put on maintenance by %s.", args.edge_hostname, getuser()) + + +@cli_utils.action_cli(check_db=False) +@providers_configuration_loaded +def remove_remote_worker_from_maintenance(args) -> None: + """Remove remote edge worker from maintenance.""" + _check_valid_db_connection() + _check_if_registered_edge_host(hostname=args.edge_hostname) + from airflow.providers.edge3.models.edge_worker import exit_maintenance + + exit_maintenance(args.edge_hostname) + logger.info("%s has been removed from maintenance by %s.", args.edge_hostname, getuser()) + + +@cli_utils.action_cli(check_db=False) +@providers_configuration_loaded +def remote_worker_update_maintenance_comment(args) -> None: + """Update maintenance comments of the remote edge worker.""" + _check_valid_db_connection() + _check_if_registered_edge_host(hostname=args.edge_hostname) + from airflow.providers.edge3.models.edge_worker import change_maintenance_comment + + try: + change_maintenance_comment(args.edge_hostname, args.comments) + logger.info("Maintenance comments updated for %s by %s.", args.edge_hostname, getuser()) + except TypeError: + raise SystemExit + + +@cli_utils.action_cli(check_db=False) +@providers_configuration_loaded +def remove_remote_worker(args) -> None: + """Remove remote edge worker entry from db.""" + _check_valid_db_connection() + _check_if_registered_edge_host(hostname=args.edge_hostname) + from airflow.providers.edge3.models.edge_worker import remove_worker + + try: + remove_worker(args.edge_hostname) + logger.info("Edge Worker host %s removed by %s.", args.edge_hostname, getuser()) + except TypeError: + raise SystemExit + + ARG_CONCURRENCY = Arg( ("-c", "--concurrency"), type=int, @@ -633,11 +726,21 @@ def stop(args): ("-H", "--edge-hostname"), help="Set the hostname of worker if you have multiple workers on a single machine", ) +ARG_REQUIRED_EDGE_HOSTNAME = Arg( + ("-H", "--edge-hostname"), + help="Set the hostname of worker if you have multiple workers on a single machine", + required=True, +) ARG_MAINTENANCE = Arg(("maintenance",), help="Desired maintenance state", choices=("on", "off")) ARG_MAINTENANCE_COMMENT = Arg( ("-c", "--comments"), help="Maintenance comments to report reason. Required if maintenance is turned on.", ) +ARG_REQUIRED_MAINTENANCE_COMMENT = Arg( + ("-c", "--comments"), + help="Maintenance comments to report reason. Required if enabling maintenance", + required=True, +) ARG_WAIT_MAINT = Arg( ("-w", "--wait"), default=False, @@ -650,6 +753,25 @@ def stop(args): help="Wait until edge worker is shut down.", action="store_true", ) +ARG_OUTPUT = Arg( + ( + "-o", + "--output", + ), + help="Output format. Allowed values: json, yaml, plain, table (default: table)", + metavar="(table, json, yaml, plain)", + choices=("table", "json", "yaml", "plain"), + default="table", +) +ARG_STATE = Arg( + ( + "-s", + "--state", + ), + nargs="+", + help="State of the edge worker", +) + EDGE_COMMANDS: list[ActionCommand] = [ ActionCommand( name=worker.__name__, @@ -694,4 +816,43 @@ def stop(args): ARG_VERBOSE, ), ), + ActionCommand( + name="list-workers", + help=list_edge_workers.__doc__, + func=list_edge_workers, + args=( + ARG_OUTPUT, + ARG_STATE, + ), + ), + ActionCommand( + name="remote-edge-worker-request-maintenance", + help=put_remote_worker_on_maintenance.__doc__, + func=put_remote_worker_on_maintenance, + args=( + ARG_REQUIRED_EDGE_HOSTNAME, + ARG_REQUIRED_MAINTENANCE_COMMENT, + ), + ), + ActionCommand( + name="remote-edge-worker-exit-maintenance", + help=remove_remote_worker_from_maintenance.__doc__, + func=remove_remote_worker_from_maintenance, + args=(ARG_REQUIRED_EDGE_HOSTNAME,), + ), + ActionCommand( + name="remote-edge-worker-update-maintenance-comment", + help=remote_worker_update_maintenance_comment.__doc__, + func=remote_worker_update_maintenance_comment, + args=( + ARG_REQUIRED_EDGE_HOSTNAME, + ARG_REQUIRED_MAINTENANCE_COMMENT, + ), + ), + ActionCommand( + name="remove-remote-edge-worker", + help=remove_remote_worker.__doc__, + func=remove_remote_worker, + args=(ARG_REQUIRED_EDGE_HOSTNAME,), + ), ] diff --git a/providers/edge3/src/airflow/providers/edge3/models/edge_worker.py b/providers/edge3/src/airflow/providers/edge3/models/edge_worker.py index f8b29a7863092..75a1eac3cb87d 100644 --- a/providers/edge3/src/airflow/providers/edge3/models/edge_worker.py +++ b/providers/edge3/src/airflow/providers/edge3/models/edge_worker.py @@ -18,6 +18,7 @@ import ast import json +import logging from datetime import datetime from enum import Enum from typing import TYPE_CHECKING @@ -29,12 +30,15 @@ from airflow.stats import Stats from airflow.utils import timezone from airflow.utils.log.logging_mixin import LoggingMixin +from airflow.utils.providers_configuration_loader import providers_configuration_loaded from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.sqlalchemy import UtcDateTime if TYPE_CHECKING: from sqlalchemy.orm import Session +logger = logging.getLogger(__name__) + class EdgeWorkerVersionException(AirflowException): """Signal a version mismatch between core and Edge Site.""" @@ -194,6 +198,26 @@ def reset_metrics(worker_name: str) -> None: ) +@providers_configuration_loaded +@provide_session +def _fetch_edge_hosts_from_db( + hostname: str | None = None, states: list | None = None, session: Session = NEW_SESSION +) -> list: + query = select(EdgeWorkerModel) + if states: + query = query.where(EdgeWorkerModel.state.in_(states)) + if hostname: + query = query.where(EdgeWorkerModel.worker_name == hostname) + query = query.order_by(EdgeWorkerModel.worker_name) + return session.scalars(query).all() + + +@providers_configuration_loaded +@provide_session +def get_registered_edge_hosts(states: list | None = None, session: Session = NEW_SESSION): + return _fetch_edge_hosts_from_db(states=states, session=session) + + @provide_session def request_maintenance( worker_name: str, maintenance_comment: str | None, session: Session = NEW_SESSION @@ -217,7 +241,14 @@ def exit_maintenance(worker_name: str, session: Session = NEW_SESSION) -> None: @provide_session def remove_worker(worker_name: str, session: Session = NEW_SESSION) -> None: """Remove a worker that is offline or just gone from DB.""" - session.execute(delete(EdgeWorkerModel).where(EdgeWorkerModel.worker_name == worker_name)) + query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name == worker_name) + worker: EdgeWorkerModel = session.scalar(query) + if worker.state == EdgeWorkerState.OFFLINE or worker.state == EdgeWorkerState.OFFLINE_MAINTENANCE: + session.execute(delete(EdgeWorkerModel).where(EdgeWorkerModel.worker_name == worker_name)) + else: + error_message = f"Cannot remove edge worker {worker_name} as it is in {worker.state} state!" + logger.error(error_message) + raise TypeError(error_message) @provide_session @@ -227,4 +258,12 @@ def change_maintenance_comment( """Write maintenance comment in the db.""" query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name == worker_name) worker: EdgeWorkerModel = session.scalar(query) - worker.maintenance_comment = maintenance_comment + if ( + worker.state == EdgeWorkerState.MAINTENANCE_MODE + or worker.state == EdgeWorkerState.OFFLINE_MAINTENANCE + ): + worker.maintenance_comment = maintenance_comment + else: + error_message = f"Cannot change maintenance comment as {worker_name} is not in maintenance!" + logger.error(error_message) + raise TypeError(error_message) diff --git a/providers/edge3/tests/unit/edge3/cli/test_edge_command.py b/providers/edge3/tests/unit/edge3/cli/test_edge_command.py index 78ff6daee0049..bcf5e5d12f271 100644 --- a/providers/edge3/tests/unit/edge3/cli/test_edge_command.py +++ b/providers/edge3/tests/unit/edge3/cli/test_edge_command.py @@ -16,9 +16,14 @@ # under the License. from __future__ import annotations +import argparse +import contextlib +import importlib +import json import logging import os from datetime import datetime +from io import StringIO from pathlib import Path from subprocess import Popen from unittest.mock import MagicMock, call, patch @@ -27,9 +32,19 @@ import time_machine from requests import HTTPError, Response +from airflow.cli import cli_parser +from airflow.executors import executor_loader +from airflow.providers.edge3.cli import edge_command from airflow.providers.edge3.cli.dataclasses import Job -from airflow.providers.edge3.cli.edge_command import _EdgeWorkerCli, _write_pid_to_pidfile -from airflow.providers.edge3.models.edge_worker import EdgeWorkerState, EdgeWorkerVersionException +from airflow.providers.edge3.cli.edge_command import ( + _EdgeWorkerCli, + _write_pid_to_pidfile, +) +from airflow.providers.edge3.models.edge_worker import ( + EdgeWorkerModel, + EdgeWorkerState, + EdgeWorkerVersionException, +) from airflow.providers.edge3.worker_api.datamodels import ( EdgeJobFetched, WorkerRegistrationReturn, @@ -118,6 +133,17 @@ def returncode(self): class TestEdgeWorkerCli: + parser: argparse.ArgumentParser + + @classmethod + def setup_class(cls): + with conf_vars( + {("core", "executor"): "airflow.providers.edge3.executors.edge_executor.EdgeExecutor"} + ): + importlib.reload(executor_loader) + importlib.reload(cli_parser) + cls.parser = cli_parser.get_parser() + @pytest.fixture def mock_joblist(self, tmp_path: Path) -> list[Job]: logfile = tmp_path / "file.log" @@ -146,6 +172,15 @@ def worker_with_job(self, tmp_path: Path, mock_joblist: list[Job]) -> _EdgeWorke _EdgeWorkerCli.jobs = mock_joblist return test_worker + @pytest.fixture + def mock_edgeworker(self) -> EdgeWorkerModel: + test_edgeworker = EdgeWorkerModel( + worker_name="test_edge_worker", + state="idle", + queues=["default"], + ) + return test_edgeworker + @patch("airflow.providers.edge3.cli.edge_command.Process") @patch("airflow.providers.edge3.cli.edge_command.logs_logfile_path") @patch("airflow.providers.edge3.cli.edge_command.Popen") @@ -395,3 +430,23 @@ def test_get_sysinfo(self, worker_with_job: _EdgeWorkerCli): assert "edge_provider_version" in sysinfo assert "concurrency" in sysinfo assert sysinfo["concurrency"] == concurrency + + @pytest.mark.db_test + def test_list_edge_workers(self, mock_edgeworker: EdgeWorkerModel): + args = self.parser.parse_args(["edge", "list-workers", "--output", "json"]) + with contextlib.redirect_stdout(StringIO()) as temp_stdout: + with patch( + "airflow.providers.edge3.models.edge_worker.get_registered_edge_hosts", + return_value=[mock_edgeworker], + ): + edge_command.list_edge_workers(args) + out = temp_stdout.getvalue() + edge_workers = json.loads(out) + for key in [ + "worker_name", + "state", + "queues", + "maintenance_comment", + ]: + assert key in edge_workers[0] + assert any("test_edge_worker" in h["worker_name"] for h in edge_workers)