Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 164 additions & 3 deletions providers/edge3/src/airflow/providers/edge3/cli/edge_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import sys
from dataclasses import asdict
from datetime import datetime
from getpass import getuser
from http import HTTPStatus
from multiprocessing import Process
from pathlib import Path
Expand All @@ -36,6 +37,7 @@

from airflow import __version__ as airflow_version, settings
from airflow.cli.cli_config import ARG_PID, ARG_VERBOSE, ActionCommand, Arg
from airflow.cli.simple_table import AirflowConsole
from airflow.configuration import conf
from airflow.providers.edge3 import __version__ as edge_provider_version
from airflow.providers.edge3.cli.api_client import (
Expand Down Expand Up @@ -508,7 +510,7 @@ def worker(args):
@cli_utils.action_cli(check_db=False)
@providers_configuration_loaded
def status(args):
"""Check for Airflow Edge Worker status."""
"""Check for Airflow Local Edge Worker status."""
pid = _get_pid(args.pid)

# Send Signal as notification to drop status JSON
Expand All @@ -534,7 +536,7 @@ def status(args):
@cli_utils.action_cli(check_db=False)
@providers_configuration_loaded
def maintenance(args):
"""Set or Unset maintenance mode of worker."""
"""Set or Unset maintenance mode of local edge worker."""
if args.maintenance == "on" and not args.comments:
logger.error("Comments are required when setting maintenance mode.")
sys.exit(4)
Expand Down Expand Up @@ -605,7 +607,7 @@ def maintenance(args):
@cli_utils.action_cli(check_db=False)
@providers_configuration_loaded
def stop(args):
"""Stop a running Airflow Edge Worker."""
"""Stop a running local Airflow Edge Worker."""
pid = _get_pid(args.pid)
# Send SIGINT
logger.info("Sending SIGINT to worker pid %i.", pid)
Expand All @@ -619,6 +621,97 @@ def stop(args):
logger.info("Worker has been shut down.")


def _check_valid_db_connection():
"""Check for a valid db connection before executing db dependent cli commands."""
db_conn = conf.get("database", "sql_alchemy_conn")
db_default = conf.get_default_value("database", "sql_alchemy_conn")
if db_conn == db_default:
raise SystemExit(
"Error: The database connection is not set. Please set the connection in the configuration file."
)


def _check_if_registered_edge_host(hostname: str):
"""Check if edge worker is registered with the db before executing dependent cli commands."""
from airflow.providers.edge3.models.edge_worker import _fetch_edge_hosts_from_db

if not _fetch_edge_hosts_from_db(hostname=hostname):
raise SystemExit(f"Error: Edge Worker {hostname} is unknown!")


@cli_utils.action_cli(check_db=False)
@providers_configuration_loaded
def list_edge_workers(args) -> None:
"""Query the db to list all registered edge workers."""
_check_valid_db_connection()
from airflow.providers.edge3.models.edge_worker import get_registered_edge_hosts

all_hosts_iter = get_registered_edge_hosts(states=args.state)
# Format and print worker info on the screen
fields = [
"worker_name",
"state",
"queues",
"maintenance_comment",
]
all_hosts = [{f: host.__getattribute__(f) for f in fields} for host in all_hosts_iter]
AirflowConsole().print_as(data=all_hosts, output=args.output)


@cli_utils.action_cli(check_db=False)
@providers_configuration_loaded
def put_remote_worker_on_maintenance(args) -> None:
"""Put remote edge worker on maintenance."""
_check_valid_db_connection()
_check_if_registered_edge_host(hostname=args.edge_hostname)
from airflow.providers.edge3.models.edge_worker import request_maintenance

request_maintenance(args.edge_hostname, args.comments)
logger.info("%s has been put on maintenance by %s.", args.edge_hostname, getuser())


@cli_utils.action_cli(check_db=False)
@providers_configuration_loaded
def remove_remote_worker_from_maintenance(args) -> None:
"""Remove remote edge worker from maintenance."""
_check_valid_db_connection()
_check_if_registered_edge_host(hostname=args.edge_hostname)
from airflow.providers.edge3.models.edge_worker import exit_maintenance

exit_maintenance(args.edge_hostname)
logger.info("%s has been removed from maintenance by %s.", args.edge_hostname, getuser())


@cli_utils.action_cli(check_db=False)
@providers_configuration_loaded
def remote_worker_update_maintenance_comment(args) -> None:
"""Update maintenance comments of the remote edge worker."""
_check_valid_db_connection()
_check_if_registered_edge_host(hostname=args.edge_hostname)
from airflow.providers.edge3.models.edge_worker import change_maintenance_comment

try:
change_maintenance_comment(args.edge_hostname, args.comments)
logger.info("Maintenance comments updated for %s by %s.", args.edge_hostname, getuser())
except TypeError:
raise SystemExit


@cli_utils.action_cli(check_db=False)
@providers_configuration_loaded
def remove_remote_worker(args) -> None:
"""Remove remote edge worker entry from db."""
_check_valid_db_connection()
_check_if_registered_edge_host(hostname=args.edge_hostname)
from airflow.providers.edge3.models.edge_worker import remove_worker

try:
remove_worker(args.edge_hostname)
logger.info("Edge Worker host %s removed by %s.", args.edge_hostname, getuser())
except TypeError:
raise SystemExit


ARG_CONCURRENCY = Arg(
("-c", "--concurrency"),
type=int,
Expand All @@ -633,11 +726,21 @@ def stop(args):
("-H", "--edge-hostname"),
help="Set the hostname of worker if you have multiple workers on a single machine",
)
ARG_REQUIRED_EDGE_HOSTNAME = Arg(
("-H", "--edge-hostname"),
help="Set the hostname of worker if you have multiple workers on a single machine",
required=True,
)
ARG_MAINTENANCE = Arg(("maintenance",), help="Desired maintenance state", choices=("on", "off"))
ARG_MAINTENANCE_COMMENT = Arg(
("-c", "--comments"),
help="Maintenance comments to report reason. Required if maintenance is turned on.",
)
ARG_REQUIRED_MAINTENANCE_COMMENT = Arg(
("-c", "--comments"),
help="Maintenance comments to report reason. Required if enabling maintenance",
required=True,
)
ARG_WAIT_MAINT = Arg(
("-w", "--wait"),
default=False,
Expand All @@ -650,6 +753,25 @@ def stop(args):
help="Wait until edge worker is shut down.",
action="store_true",
)
ARG_OUTPUT = Arg(
(
"-o",
"--output",
),
help="Output format. Allowed values: json, yaml, plain, table (default: table)",
metavar="(table, json, yaml, plain)",
choices=("table", "json", "yaml", "plain"),
default="table",
)
ARG_STATE = Arg(
(
"-s",
"--state",
),
nargs="+",
help="State of the edge worker",
)

EDGE_COMMANDS: list[ActionCommand] = [
ActionCommand(
name=worker.__name__,
Expand Down Expand Up @@ -694,4 +816,43 @@ def stop(args):
ARG_VERBOSE,
),
),
ActionCommand(
name="list-workers",
help=list_edge_workers.__doc__,
func=list_edge_workers,
args=(
ARG_OUTPUT,
ARG_STATE,
),
),
ActionCommand(
name="remote-edge-worker-request-maintenance",
help=put_remote_worker_on_maintenance.__doc__,
func=put_remote_worker_on_maintenance,
args=(
ARG_REQUIRED_EDGE_HOSTNAME,
ARG_REQUIRED_MAINTENANCE_COMMENT,
),
),
ActionCommand(
name="remote-edge-worker-exit-maintenance",
help=remove_remote_worker_from_maintenance.__doc__,
func=remove_remote_worker_from_maintenance,
args=(ARG_REQUIRED_EDGE_HOSTNAME,),
),
ActionCommand(
name="remote-edge-worker-update-maintenance-comment",
help=remote_worker_update_maintenance_comment.__doc__,
func=remote_worker_update_maintenance_comment,
args=(
ARG_REQUIRED_EDGE_HOSTNAME,
ARG_REQUIRED_MAINTENANCE_COMMENT,
),
),
ActionCommand(
name="remove-remote-edge-worker",
help=remove_remote_worker.__doc__,
func=remove_remote_worker,
args=(ARG_REQUIRED_EDGE_HOSTNAME,),
),
]
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import ast
import json
import logging
from datetime import datetime
from enum import Enum
from typing import TYPE_CHECKING
Expand All @@ -29,12 +30,15 @@
from airflow.stats import Stats
from airflow.utils import timezone
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.providers_configuration_loader import providers_configuration_loaded
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.sqlalchemy import UtcDateTime

if TYPE_CHECKING:
from sqlalchemy.orm import Session

logger = logging.getLogger(__name__)


class EdgeWorkerVersionException(AirflowException):
"""Signal a version mismatch between core and Edge Site."""
Expand Down Expand Up @@ -194,6 +198,26 @@ def reset_metrics(worker_name: str) -> None:
)


@providers_configuration_loaded
@provide_session
def _fetch_edge_hosts_from_db(
hostname: str | None = None, states: list | None = None, session: Session = NEW_SESSION
) -> list:
query = select(EdgeWorkerModel)
if states:
query = query.where(EdgeWorkerModel.state.in_(states))
if hostname:
query = query.where(EdgeWorkerModel.worker_name == hostname)
query = query.order_by(EdgeWorkerModel.worker_name)
return session.scalars(query).all()


@providers_configuration_loaded
@provide_session
def get_registered_edge_hosts(states: list | None = None, session: Session = NEW_SESSION):
return _fetch_edge_hosts_from_db(states=states, session=session)


@provide_session
def request_maintenance(
worker_name: str, maintenance_comment: str | None, session: Session = NEW_SESSION
Expand All @@ -217,7 +241,14 @@ def exit_maintenance(worker_name: str, session: Session = NEW_SESSION) -> None:
@provide_session
def remove_worker(worker_name: str, session: Session = NEW_SESSION) -> None:
"""Remove a worker that is offline or just gone from DB."""
session.execute(delete(EdgeWorkerModel).where(EdgeWorkerModel.worker_name == worker_name))
query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name == worker_name)
worker: EdgeWorkerModel = session.scalar(query)
if worker.state == EdgeWorkerState.OFFLINE or worker.state == EdgeWorkerState.OFFLINE_MAINTENANCE:
session.execute(delete(EdgeWorkerModel).where(EdgeWorkerModel.worker_name == worker_name))
else:
error_message = f"Cannot remove edge worker {worker_name} as it is in {worker.state} state!"
logger.error(error_message)
raise TypeError(error_message)


@provide_session
Expand All @@ -227,4 +258,12 @@ def change_maintenance_comment(
"""Write maintenance comment in the db."""
query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name == worker_name)
worker: EdgeWorkerModel = session.scalar(query)
worker.maintenance_comment = maintenance_comment
if (
worker.state == EdgeWorkerState.MAINTENANCE_MODE
or worker.state == EdgeWorkerState.OFFLINE_MAINTENANCE
):
worker.maintenance_comment = maintenance_comment
else:
error_message = f"Cannot change maintenance comment as {worker_name} is not in maintenance!"
logger.error(error_message)
raise TypeError(error_message)
Loading