diff --git a/providers/edge/README.rst b/providers/edge/README.rst
index d5bbae1780659..48031b6d2ffd6 100644
--- a/providers/edge/README.rst
+++ b/providers/edge/README.rst
@@ -23,7 +23,7 @@
Package ``apache-airflow-providers-edge``
-Release: ``0.18.1pre0``
+Release: ``0.19.0pre0``
Handle edge workers on remote sites via HTTP(s) connection and orchestrates work over distributed sites
@@ -36,7 +36,7 @@ This is a provider package for ``edge`` provider. All classes for this provider
are in ``airflow.providers.edge`` python package.
You can find package information and changelog for the provider
-in the `documentation `_.
+in the `documentation `_.
Installation
------------
@@ -59,4 +59,4 @@ PIP package Version required
================== ===================
The changelog for the provider package can be found in the
-`changelog `_.
+`changelog `_.
diff --git a/providers/edge/docs/changelog.rst b/providers/edge/docs/changelog.rst
index 4bfcfc7009f67..5b8bef99440ce 100644
--- a/providers/edge/docs/changelog.rst
+++ b/providers/edge/docs/changelog.rst
@@ -27,6 +27,15 @@
Changelog
---------
+0.19.0pre0
+..........
+
+Misc
+~~~~
+
+* ``Edge worker can be set to maintenance via CLI and also return to normal operation.``
+
+
0.18.1pre0
..........
diff --git a/providers/edge/provider.yaml b/providers/edge/provider.yaml
index 18f7d1fca6cc0..562583a9a3687 100644
--- a/providers/edge/provider.yaml
+++ b/providers/edge/provider.yaml
@@ -25,7 +25,7 @@ source-date-epoch: 1737371680
# note that those versions are maintained by release manager - do not update them manually
versions:
- - 0.18.1pre0
+ - 0.19.0pre0
plugins:
- name: edge_executor
diff --git a/providers/edge/pyproject.toml b/providers/edge/pyproject.toml
index 2619f89bec61e..ce404ea9543fa 100644
--- a/providers/edge/pyproject.toml
+++ b/providers/edge/pyproject.toml
@@ -25,7 +25,7 @@ build-backend = "flit_core.buildapi"
[project]
name = "apache-airflow-providers-edge"
-version = "0.18.1pre0"
+version = "0.19.0pre0"
description = "Provider package apache-airflow-providers-edge for Apache Airflow"
readme = "README.rst"
authors = [
@@ -61,8 +61,8 @@ dependencies = [
]
[project.urls]
-"Documentation" = "https://airflow.apache.org/docs/apache-airflow-providers-edge/0.18.1pre0"
-"Changelog" = "https://airflow.apache.org/docs/apache-airflow-providers-edge/0.18.1pre0/changelog.html"
+"Documentation" = "https://airflow.apache.org/docs/apache-airflow-providers-edge/0.19.0pre0"
+"Changelog" = "https://airflow.apache.org/docs/apache-airflow-providers-edge/0.19.0pre0/changelog.html"
"Bug Tracker" = "https://github.com/apache/airflow/issues"
"Source Code" = "https://github.com/apache/airflow"
"Slack Chat" = "https://s.apache.org/airflow-slack"
diff --git a/providers/edge/src/airflow/providers/edge/__init__.py b/providers/edge/src/airflow/providers/edge/__init__.py
index 617fe191824cd..5b8039cec8dc1 100644
--- a/providers/edge/src/airflow/providers/edge/__init__.py
+++ b/providers/edge/src/airflow/providers/edge/__init__.py
@@ -29,7 +29,7 @@
__all__ = ["__version__"]
-__version__ = "0.18.1pre0"
+__version__ = "0.19.0pre0"
if packaging.version.parse(packaging.version.parse(airflow_version).base_version) < packaging.version.parse(
"2.10.0"
diff --git a/providers/edge/src/airflow/providers/edge/cli/api_client.py b/providers/edge/src/airflow/providers/edge/cli/api_client.py
index c19504787d75c..6c9266890251c 100644
--- a/providers/edge/src/airflow/providers/edge/cli/api_client.py
+++ b/providers/edge/src/airflow/providers/edge/cli/api_client.py
@@ -109,7 +109,12 @@ def worker_register(
def worker_set_state(
- hostname: str, state: EdgeWorkerState, jobs_active: int, queues: list[str] | None, sysinfo: dict
+ hostname: str,
+ state: EdgeWorkerState,
+ jobs_active: int,
+ queues: list[str] | None,
+ sysinfo: dict,
+ maintenance_comments: str | None = None,
) -> WorkerSetStateReturn:
"""Update the state of the worker in the central site and thereby implicitly heartbeat."""
try:
@@ -117,7 +122,11 @@ def worker_set_state(
"PATCH",
f"worker/{quote(hostname)}",
WorkerStateBody(
- state=state, jobs_active=jobs_active, queues=queues, sysinfo=sysinfo
+ state=state,
+ jobs_active=jobs_active,
+ queues=queues,
+ sysinfo=sysinfo,
+ maintenance_comments=maintenance_comments,
).model_dump_json(exclude_unset=True),
)
except requests.HTTPError as e:
diff --git a/providers/edge/src/airflow/providers/edge/cli/dataclasses.py b/providers/edge/src/airflow/providers/edge/cli/dataclasses.py
new file mode 100644
index 0000000000000..3572c550de84f
--- /dev/null
+++ b/providers/edge/src/airflow/providers/edge/cli/dataclasses.py
@@ -0,0 +1,95 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import json
+from dataclasses import asdict, dataclass
+from multiprocessing import Process
+from pathlib import Path
+from typing import TYPE_CHECKING
+
+if TYPE_CHECKING:
+ from psutil import Popen
+
+ from airflow.providers.edge.models.edge_worker import EdgeWorkerState
+ from airflow.providers.edge.worker_api.datamodels import EdgeJobFetched
+
+
+@dataclass
+class MaintenanceMarker:
+ """Maintenance mode status."""
+
+ maintenance: str
+ comments: str | None
+
+ @property
+ def json(self) -> str:
+ """Get the maintenance status as JSON."""
+ return json.dumps(asdict(self))
+
+ @staticmethod
+ def from_json(json_str: str) -> MaintenanceMarker:
+ """Create a Maintenance object from JSON."""
+ return MaintenanceMarker(**json.loads(json_str))
+
+
+@dataclass
+class WorkerStatus:
+ """Status of the worker."""
+
+ job_count: int
+ jobs: list
+ state: EdgeWorkerState
+ maintenance: bool
+ maintenance_comments: str | None
+ drain: bool
+
+ @property
+ def json(self) -> str:
+ """Get the status as JSON."""
+ return json.dumps(asdict(self))
+
+ @staticmethod
+ def from_json(json_str: str) -> WorkerStatus:
+ """Create a WorkerStatus object from JSON."""
+ return WorkerStatus(**json.loads(json_str))
+
+
+@dataclass
+class Job:
+ """Holds all information for a task/job to be executed as bundle."""
+
+ edge_job: EdgeJobFetched
+ process: Popen | Process
+ logfile: Path
+ logsize: int
+ """Last size of log file, point of last chunk push."""
+
+ @property
+ def is_running(self) -> bool:
+ """Check if the job is still running."""
+ if hasattr(self.process, "returncode") and hasattr(self.process, "poll"):
+ self.process.poll()
+ return self.process.returncode is None
+ return self.process.exitcode is None
+
+ @property
+ def is_success(self) -> bool:
+ """Check if the job was successful."""
+ if hasattr(self.process, "returncode"):
+ return self.process.returncode == 0
+ return self.process.exitcode == 0
diff --git a/providers/edge/src/airflow/providers/edge/cli/edge_command.py b/providers/edge/src/airflow/providers/edge/cli/edge_command.py
index 9ef6ed755b3f7..f10f7daf8d678 100644
--- a/providers/edge/src/airflow/providers/edge/cli/edge_command.py
+++ b/providers/edge/src/airflow/providers/edge/cli/edge_command.py
@@ -22,7 +22,7 @@
import platform
import signal
import sys
-from dataclasses import dataclass
+from dataclasses import asdict
from datetime import datetime
from http import HTTPStatus
from multiprocessing import Process
@@ -47,6 +47,7 @@
worker_register,
worker_set_state,
)
+from airflow.providers.edge.cli.dataclasses import Job, MaintenanceMarker, WorkerStatus
from airflow.providers.edge.models.edge_worker import EdgeWorkerState, EdgeWorkerVersionException
from airflow.providers.edge.version_compat import AIRFLOW_V_3_0_PLUS
from airflow.utils import cli as cli_utils, timezone
@@ -115,10 +116,22 @@ def _pid_file_path(pid_file: str | None) -> str:
return cli_utils.setup_locations(process=EDGE_WORKER_PROCESS_NAME, pid=pid_file)[0]
+def _get_pid(pid_file: str | None) -> int:
+ pid = read_pid_from_pidfile(_pid_file_path(pid_file))
+ if not pid:
+ logger.warning("Could not find PID of worker.")
+ sys.exit(1)
+ return pid
+
+
def _status_file_path(pid_file: str | None) -> str:
return cli_utils.setup_locations(process=EDGE_WORKER_PROCESS_NAME, pid=pid_file)[1]
+def _maintenance_marker_file_path(pid_file: str | None) -> str:
+ return cli_utils.setup_locations(process=EDGE_WORKER_PROCESS_NAME, pid=pid_file)[1][:-4] + ".in"
+
+
def _write_pid_to_pidfile(pid_file_path: str):
"""Write PIDs for Edge Workers to disk, handling existing PID files."""
if Path(pid_file_path).exists():
@@ -143,36 +156,10 @@ def _write_pid_to_pidfile(pid_file_path: str):
write_pid_to_pidfile(pid_file_path)
-@dataclass
-class _Job:
- """Holds all information for a task/job to be executed as bundle."""
-
- edge_job: EdgeJobFetched
- process: Popen | Process
- logfile: Path
- logsize: int
- """Last size of log file, point of last chunk push."""
-
- @property
- def is_running(self) -> bool:
- """Check if the job is still running."""
- if isinstance(self.process, Popen):
- self.process.poll()
- return self.process.returncode is None
- return self.process.exitcode is None
-
- @property
- def is_success(self) -> bool:
- """Check if the job was successful."""
- if isinstance(self.process, Popen):
- return self.process.returncode == 0
- return self.process.exitcode == 0
-
-
class _EdgeWorkerCli:
"""Runner instance which executes the Edge Worker."""
- jobs: list[_Job] = []
+ jobs: list[Job] = []
"""List of jobs that the worker is running currently."""
last_hb: datetime | None = None
"""Timestamp of last heart beat sent to server."""
@@ -180,6 +167,11 @@ class _EdgeWorkerCli:
"""Flag if job processing should be completed and no new jobs fetched for a graceful stop/shutdown."""
maintenance_mode: bool = False
"""Flag if job processing should be completed and no new jobs fetched for maintenance mode. """
+ maintenance_comments: str | None = None
+ """Comments for maintenance mode."""
+
+ edge_instance: _EdgeWorkerCli | None = None
+ """Singleton instance of the worker."""
def __init__(
self,
@@ -198,21 +190,35 @@ def __init__(
self.concurrency = concurrency
self.free_concurrency = concurrency
+ _EdgeWorkerCli.edge_instance = self
+
@staticmethod
def signal_handler(sig: signal.Signals, frame):
if sig == SIG_STATUS:
- logger.info("Request to get status of Edge Worker received.")
+ marker_path = Path(_maintenance_marker_file_path(None))
+ if marker_path.exists():
+ request = MaintenanceMarker.from_json(marker_path.read_text())
+ logger.info("Requested to set maintenance mode to %s", request.maintenance)
+ _EdgeWorkerCli.maintenance_mode = request.maintenance == "on"
+ if _EdgeWorkerCli.maintenance_mode and request.comments:
+ logger.info("Comments: %s", request.comments)
+ _EdgeWorkerCli.maintenance_comments = request.comments
+ marker_path.unlink()
+ # send heartbeat immediately to update state
+ if _EdgeWorkerCli.edge_instance:
+ _EdgeWorkerCli.edge_instance.heartbeat(_EdgeWorkerCli.maintenance_comments)
+ else:
+ logger.info("Request to get status of Edge Worker received.")
status_path = Path(_status_file_path(None))
status_path.write_text(
- json.dumps(
- {
- "job_count": len(_EdgeWorkerCli.jobs),
- "jobs": [job.edge_job.key for job in _EdgeWorkerCli.jobs],
- "state": _EdgeWorkerCli._get_state(),
- "maintenance": _EdgeWorkerCli.maintenance_mode,
- "drain": _EdgeWorkerCli.drain,
- }
- )
+ WorkerStatus(
+ job_count=len(_EdgeWorkerCli.jobs),
+ jobs=[job.edge_job.key for job in _EdgeWorkerCli.jobs],
+ state=_EdgeWorkerCli._get_state(),
+ maintenance=_EdgeWorkerCli.maintenance_mode,
+ maintenance_comments=_EdgeWorkerCli.maintenance_comments,
+ drain=_EdgeWorkerCli.drain,
+ ).json
)
else:
logger.info("Request to shut down Edge Worker received, waiting for jobs to complete.")
@@ -316,7 +322,7 @@ def _launch_job(self, edge_job: EdgeJobFetched):
else:
# Airflow 2.10
process, logfile = self._launch_job_af2_10(edge_job)
- _EdgeWorkerCli.jobs.append(_Job(edge_job, process, logfile, 0))
+ _EdgeWorkerCli.jobs.append(Job(edge_job, process, logfile, 0))
def start(self):
"""Start the execution in a loop until terminated."""
@@ -432,14 +438,19 @@ def check_running_jobs(self) -> None:
self.free_concurrency = self.concurrency - used_concurrency
- def heartbeat(self) -> bool:
+ def heartbeat(self, new_maintenance_comments: str | None = None) -> bool:
"""Report liveness state of worker to central site with stats."""
state = _EdgeWorkerCli._get_state()
sysinfo = self._get_sysinfo()
worker_state_changed: bool = False
try:
worker_info = worker_set_state(
- self.hostname, state, len(_EdgeWorkerCli.jobs), self.queues, sysinfo
+ self.hostname,
+ state,
+ len(_EdgeWorkerCli.jobs),
+ self.queues,
+ sysinfo,
+ new_maintenance_comments,
)
self.queues = worker_info.queues
if worker_info.state == EdgeWorkerState.MAINTENANCE_REQUEST:
@@ -451,6 +462,11 @@ def heartbeat(self) -> bool:
):
logger.info("Exit Maintenance mode requested!")
_EdgeWorkerCli.maintenance_mode = False
+ if _EdgeWorkerCli.maintenance_mode:
+ _EdgeWorkerCli.maintenance_comments = worker_info.maintenance_comments
+ else:
+ _EdgeWorkerCli.maintenance_comments = None
+
worker_state_changed = worker_info.state != state
except EdgeWorkerVersionException:
logger.info("Version mismatch of Edge worker and Core. Shutting down worker.")
@@ -488,45 +504,108 @@ def worker(args):
@providers_configuration_loaded
def status(args):
"""Check for Airflow Edge Worker status."""
- pid = read_pid_from_pidfile(_pid_file_path(args.pid))
- # Send SIGINT
- if pid:
- logger.debug("Sending SIGUSR2 to worker pid %i.", pid)
- status_min_date = time() - 1
- status_path = Path(_status_file_path(args.pid))
- worker_process = psutil.Process(pid)
- worker_process.send_signal(SIG_STATUS)
- while psutil.pid_exists(pid) and (
- not status_path.exists() or status_path.stat().st_mtime < status_min_date
- ):
- sleep(0.1)
- if not psutil.pid_exists(pid):
- logger.warning("PID of worker dis-appeared while checking for status.")
- sys.exit(2)
- if not status_path.exists() or status_path.stat().st_mtime < status_min_date:
- logger.warning("Could not read status of worker.")
- sys.exit(3)
- status = json.loads(status_path.read_text())
- print(json.dumps(status, indent=4))
+ pid = _get_pid(args.pid)
+
+ # Send Signal as notification to drop status JSON
+ logger.debug("Sending SIGUSR2 to worker pid %i.", pid)
+ status_min_date = time() - 1
+ status_path = Path(_status_file_path(args.pid))
+ worker_process = psutil.Process(pid)
+ worker_process.send_signal(SIG_STATUS)
+ while psutil.pid_exists(pid) and (
+ not status_path.exists() or status_path.stat().st_mtime < status_min_date
+ ):
+ sleep(0.1)
+ if not psutil.pid_exists(pid):
+ logger.warning("PID of worker dis-appeared while checking for status.")
+ sys.exit(2)
+ if not status_path.exists() or status_path.stat().st_mtime < status_min_date:
+ logger.warning("Could not read status of worker.")
+ sys.exit(3)
+ status = WorkerStatus.from_json(status_path.read_text())
+ print(json.dumps(asdict(status), indent=4))
- else:
- logger.warning("Could not find PID of worker.")
- sys.exit(1)
+
+@cli_utils.action_cli(check_db=False)
+@providers_configuration_loaded
+def maintenance(args):
+ """Set or Unset maintenance mode of worker."""
+ if args.maintenance == "on" and not args.comments:
+ logger.error("Comments are required when setting maintenance mode.")
+ sys.exit(4)
+
+ pid = _get_pid(args.pid)
+
+ # Write marker JSON file
+ from getpass import getuser
+
+ marker_path = Path(_maintenance_marker_file_path(args.pid))
+ logger.debug("Writing maintenance marker file to %s.", marker_path)
+ marker_path.write_text(
+ MaintenanceMarker(
+ maintenance=args.maintenance,
+ comments=f'[{datetime.now().strftime("%Y-%m-%d %H:%M")}] - {getuser()} put '
+ f'node into maintenance mode via cli\nComment: {args.comments}'
+ if args.maintenance == "on"
+ else None,
+ ).json
+ )
+
+ # Send Signal as notification to fetch maintenance marker
+ logger.debug("Sending SIGUSR2 to worker pid %i.", pid)
+ status_min_date = time() - 1
+ status_path = Path(_status_file_path(args.pid))
+ worker_process = psutil.Process(pid)
+ worker_process.send_signal(SIG_STATUS)
+ while psutil.pid_exists(pid) and (
+ not status_path.exists() or status_path.stat().st_mtime < status_min_date
+ ):
+ sleep(0.1)
+ if not psutil.pid_exists(pid):
+ logger.warning("PID of worker dis-appeared while checking for status.")
+ sys.exit(2)
+ if not status_path.exists() or status_path.stat().st_mtime < status_min_date:
+ logger.warning("Could not read status of worker.")
+ sys.exit(3)
+ status = WorkerStatus.from_json(status_path.read_text())
+
+ if args.wait:
+ if args.maintenance == "on" and status.state != EdgeWorkerState.MAINTENANCE_MODE:
+ logger.info("Waiting for worker to be drained...")
+ while True:
+ sleep(4.5)
+ worker_process.send_signal(SIG_STATUS)
+ sleep(0.5)
+ status = WorkerStatus.from_json(status_path.read_text())
+ if status.state == EdgeWorkerState.MAINTENANCE_MODE:
+ logger.info("Worker was drained successfully!")
+ break
+ if status.state not in [
+ EdgeWorkerState.MAINTENANCE_REQUEST,
+ EdgeWorkerState.MAINTENANCE_PENDING,
+ ]:
+ logger.info("Worker maintenance was exited by someone else!")
+ break
+ if args.maintenance == "off" and status.state == EdgeWorkerState.MAINTENANCE_MODE:
+ logger.info("Waiting for worker to exit maintenance...")
+ while status.state in [EdgeWorkerState.MAINTENANCE_MODE, EdgeWorkerState.MAINTENANCE_EXIT]:
+ sleep(4.5)
+ worker_process.send_signal(SIG_STATUS)
+ sleep(0.5)
+ status = WorkerStatus.from_json(status_path.read_text())
+
+ print(json.dumps(asdict(status), indent=4))
@cli_utils.action_cli(check_db=False)
@providers_configuration_loaded
def stop(args):
"""Stop a running Airflow Edge Worker."""
- pid = read_pid_from_pidfile(_pid_file_path(args.pid))
+ pid = _get_pid(args.pid)
# Send SIGINT
- if pid:
- logger.info("Sending SIGINT to worker pid %i.", pid)
- worker_process = psutil.Process(pid)
- worker_process.send_signal(signal.SIGINT)
- else:
- logger.warning("Could not find PID of worker.")
- sys.exit(1)
+ logger.info("Sending SIGINT to worker pid %i.", pid)
+ worker_process = psutil.Process(pid)
+ worker_process.send_signal(signal.SIGINT)
if args.wait:
logger.info("Waiting for worker to stop...")
@@ -549,7 +628,18 @@ def stop(args):
("-H", "--edge-hostname"),
help="Set the hostname of worker if you have multiple workers on a single machine",
)
-ARG_WAIT = Arg(
+ARG_MAINTENANCE = Arg(("maintenance",), help="Desired maintenance state", choices=("on", "off"))
+ARG_MAINTENANCE_COMMENT = Arg(
+ ("-c", "--comments"),
+ help="Maintenance comments to report reason. Required if maintenance is turned on.",
+)
+ARG_WAIT_MAINT = Arg(
+ ("-w", "--wait"),
+ default=False,
+ help="Wait until edge worker has reached desired state.",
+ action="store_true",
+)
+ARG_WAIT_STOP = Arg(
("-w", "--wait"),
default=False,
help="Wait until edge worker is shut down.",
@@ -577,12 +667,24 @@ def stop(args):
ARG_VERBOSE,
),
),
+ ActionCommand(
+ name=maintenance.__name__,
+ help=maintenance.__doc__,
+ func=maintenance,
+ args=(
+ ARG_MAINTENANCE,
+ ARG_MAINTENANCE_COMMENT,
+ ARG_WAIT_MAINT,
+ ARG_PID,
+ ARG_VERBOSE,
+ ),
+ ),
ActionCommand(
name=stop.__name__,
help=stop.__doc__,
func=stop,
args=(
- ARG_WAIT,
+ ARG_WAIT_STOP,
ARG_PID,
ARG_VERBOSE,
),
diff --git a/providers/edge/src/airflow/providers/edge/get_provider_info.py b/providers/edge/src/airflow/providers/edge/get_provider_info.py
index 00f7247d6b55b..c646666d30186 100644
--- a/providers/edge/src/airflow/providers/edge/get_provider_info.py
+++ b/providers/edge/src/airflow/providers/edge/get_provider_info.py
@@ -28,7 +28,7 @@ def get_provider_info():
"description": "Handle edge workers on remote sites via HTTP(s) connection and orchestrates work over distributed sites\n",
"state": "not-ready",
"source-date-epoch": 1737371680,
- "versions": ["0.18.1pre0"],
+ "versions": ["0.19.0pre0"],
"plugins": [
{
"name": "edge_executor",
diff --git a/providers/edge/src/airflow/providers/edge/openapi/edge_worker_api_v1.yaml b/providers/edge/src/airflow/providers/edge/openapi/edge_worker_api_v1.yaml
index 7141f3b9b0a32..1db7fc557deec 100644
--- a/providers/edge/src/airflow/providers/edge/openapi/edge_worker_api_v1.yaml
+++ b/providers/edge/src/airflow/providers/edge/openapi/edge_worker_api_v1.yaml
@@ -617,6 +617,13 @@ components:
description: System information of the worker.
title: Sysinfo
type: object
+ maintenance_comments:
+ description: Comments about the maintenance state of the worker.
+ title: Maintenance Comments
+ anyOf:
+ - type: string
+ - type: object
+ nullable: true
title: WorkerStateBody
WorkerQueuesBody:
description: Queues that a worker supports to run jobs on.
@@ -656,6 +663,13 @@ components:
state:
$ref: '#/components/schemas/EdgeWorkerState'
description: State of the worker from the view of the worker.
+ maintenance_comments:
+ description: Comments about the maintenance state of the worker.
+ title: Maintenance Comments
+ anyOf:
+ - type: string
+ - type: object
+ nullable: true
title: WorkerSetStateReturn
EdgeJobFetched:
description: Job that is to be executed on the edge worker.
diff --git a/providers/edge/src/airflow/providers/edge/worker_api/datamodels.py b/providers/edge/src/airflow/providers/edge/worker_api/datamodels.py
index 888367a680d82..6edef1f3e48b1 100644
--- a/providers/edge/src/airflow/providers/edge/worker_api/datamodels.py
+++ b/providers/edge/src/airflow/providers/edge/worker_api/datamodels.py
@@ -142,6 +142,10 @@ class WorkerStateBody(WorkerQueuesBase):
],
),
]
+ maintenance_comments: Annotated[
+ str | None,
+ Field(description="Comments about the maintenance state of the worker."),
+ ] = None
class WorkerQueueUpdateBody(BaseModel):
@@ -174,3 +178,7 @@ class WorkerSetStateReturn(BaseModel):
description="List of queues the worker is pulling jobs from. If not provided, worker pulls from all queues."
),
]
+ maintenance_comments: Annotated[
+ str | None,
+ Field(description="Comments about the maintenance state of the worker."),
+ ] = None
diff --git a/providers/edge/src/airflow/providers/edge/worker_api/routes/_v2_routes.py b/providers/edge/src/airflow/providers/edge/worker_api/routes/_v2_routes.py
index ae50438a2345b..32423ddfea765 100644
--- a/providers/edge/src/airflow/providers/edge/worker_api/routes/_v2_routes.py
+++ b/providers/edge/src/airflow/providers/edge/worker_api/routes/_v2_routes.py
@@ -152,6 +152,7 @@ def set_state_v2(worker_name: str, body: dict[str, Any], session=NEW_SESSION) ->
jobs_active=body["jobs_active"],
queues=body["queues"],
sysinfo=body["sysinfo"],
+ maintenance_comments=body.get("maintenance_comments"),
)
return set_state(worker_name, request_obj, session).model_dump()
except HTTPException as e:
diff --git a/providers/edge/src/airflow/providers/edge/worker_api/routes/worker.py b/providers/edge/src/airflow/providers/edge/worker_api/routes/worker.py
index 602af8b0b2977..6c68f9b29aea1 100644
--- a/providers/edge/src/airflow/providers/edge/worker_api/routes/worker.py
+++ b/providers/edge/src/airflow/providers/edge/worker_api/routes/worker.py
@@ -113,9 +113,7 @@ def _assert_version(sysinfo: dict[str, str | int]) -> None:
)
-def redefine_state_if_maintenance(
- worker_state: EdgeWorkerState, body_state: EdgeWorkerState
-) -> EdgeWorkerState:
+def redefine_state(worker_state: EdgeWorkerState, body_state: EdgeWorkerState) -> EdgeWorkerState:
"""Redefine the state of the worker based on maintenance request."""
if (
worker_state == EdgeWorkerState.MAINTENANCE_REQUEST
@@ -138,6 +136,20 @@ def redefine_state_if_maintenance(
return body_state
+def redefine_maintenance_comments(
+ worker_maintenance_comment: str | None, body_maintenance_comments: str | None
+) -> str | None:
+ """Add new maintenance comments or overwrite the old ones if it is too long."""
+ if body_maintenance_comments:
+ if (
+ worker_maintenance_comment
+ and len(body_maintenance_comments) + len(worker_maintenance_comment) < 1020
+ ):
+ return f"{worker_maintenance_comment}\n\n{body_maintenance_comments}"
+ return body_maintenance_comments
+ return worker_maintenance_comment
+
+
@worker_router.post("/{worker_name}", dependencies=[Depends(jwt_token_authorization_rest)])
def register(
worker_name: Annotated[str, _worker_name_doc],
@@ -150,7 +162,10 @@ def register(
worker: EdgeWorkerModel = session.scalar(query)
if not worker:
worker = EdgeWorkerModel(worker_name=worker_name, state=body.state, queues=body.queues)
- worker.state = redefine_state_if_maintenance(worker.state, body.state)
+ worker.state = redefine_state(worker.state, body.state)
+ worker.maintenance_comment = redefine_maintenance_comments(
+ worker.maintenance_comment, body.maintenance_comments
+ )
worker.queues = body.queues
worker.sysinfo = json.dumps(body.sysinfo)
worker.last_update = timezone.utcnow()
@@ -167,7 +182,10 @@ def set_state(
"""Set state of worker and returns the current assigned queues."""
query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name == worker_name)
worker: EdgeWorkerModel = session.scalar(query)
- worker.state = redefine_state_if_maintenance(worker.state, body.state)
+ worker.state = redefine_state(worker.state, body.state)
+ worker.maintenance_comment = redefine_maintenance_comments(
+ worker.maintenance_comment, body.maintenance_comments
+ )
worker.jobs_active = body.jobs_active
worker.sysinfo = json.dumps(body.sysinfo)
worker.last_update = timezone.utcnow()
@@ -183,7 +201,9 @@ def set_state(
queues=worker.queues,
)
_assert_version(body.sysinfo) # Exception only after worker state is in the DB
- return WorkerSetStateReturn(state=worker.state, queues=worker.queues)
+ return WorkerSetStateReturn(
+ state=worker.state, queues=worker.queues, maintenance_comments=worker.maintenance_comment
+ )
@worker_router.patch(
diff --git a/providers/edge/tests/unit/edge/cli/test_dataclasses.py b/providers/edge/tests/unit/edge/cli/test_dataclasses.py
new file mode 100644
index 0000000000000..3bc68d26b5ed8
--- /dev/null
+++ b/providers/edge/tests/unit/edge/cli/test_dataclasses.py
@@ -0,0 +1,41 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from airflow.providers.edge.cli.dataclasses import MaintenanceMarker, WorkerStatus
+from airflow.providers.edge.models.edge_worker import EdgeWorkerState
+
+MOCK_ENDPOINT = "https://invalid-api-test-endpoint"
+
+
+class TestMaintenanceMarker:
+ def test_maintenance_marker_json(self):
+ marker = MaintenanceMarker(maintenance="maintenance", comments="comments")
+ assert marker == MaintenanceMarker.from_json(marker.json)
+
+
+class TestWorkerStatus:
+ def test_worker_status_json(self):
+ status = WorkerStatus(
+ job_count=1,
+ jobs=[],
+ state=EdgeWorkerState.RUNNING,
+ maintenance=False,
+ maintenance_comments=None,
+ drain=False,
+ )
+ assert status == WorkerStatus.from_json(status.json)
diff --git a/providers/edge/tests/unit/edge/cli/test_edge_command.py b/providers/edge/tests/unit/edge/cli/test_edge_command.py
index f12756c5b7a1a..55c63a4f474aa 100644
--- a/providers/edge/tests/unit/edge/cli/test_edge_command.py
+++ b/providers/edge/tests/unit/edge/cli/test_edge_command.py
@@ -27,7 +27,8 @@
import time_machine
from requests import HTTPError, Response
-from airflow.providers.edge.cli.edge_command import _EdgeWorkerCli, _Job, _write_pid_to_pidfile
+from airflow.providers.edge.cli.dataclasses import Job
+from airflow.providers.edge.cli.edge_command import _EdgeWorkerCli, _write_pid_to_pidfile
from airflow.providers.edge.models.edge_worker import EdgeWorkerState, EdgeWorkerVersionException
from airflow.providers.edge.worker_api.datamodels import EdgeJobFetched, WorkerSetStateReturn
from airflow.utils import timezone
@@ -114,12 +115,12 @@ def returncode(self):
class TestEdgeWorkerCli:
@pytest.fixture
- def mock_joblist(self, tmp_path: Path) -> list[_Job]:
+ def mock_joblist(self, tmp_path: Path) -> list[Job]:
logfile = tmp_path / "file.log"
logfile.touch()
return [
- _Job(
+ Job(
edge_job=EdgeJobFetched(
dag_id="test",
task_id="test1",
@@ -136,7 +137,7 @@ def mock_joblist(self, tmp_path: Path) -> list[_Job]:
]
@pytest.fixture
- def worker_with_job(self, tmp_path: Path, mock_joblist: list[_Job]) -> _EdgeWorkerCli:
+ def worker_with_job(self, tmp_path: Path, mock_joblist: list[Job]) -> _EdgeWorkerCli:
test_worker = _EdgeWorkerCli(str(tmp_path / "mock.pid"), "mock", None, 8, 5, 5)
_EdgeWorkerCli.jobs = mock_joblist
return test_worker
diff --git a/providers/edge/tests/unit/edge/worker_api/routes/test_worker.py b/providers/edge/tests/unit/edge/worker_api/routes/test_worker.py
index 9535c9e5192c6..3e19ff3c04024 100644
--- a/providers/edge/tests/unit/edge/worker_api/routes/test_worker.py
+++ b/providers/edge/tests/unit/edge/worker_api/routes/test_worker.py
@@ -145,12 +145,12 @@ def test_register(self, session: Session, input_queues: list[str] | None, cli_wo
),
],
)
- def test_redefine_state_if_maintenance(
+ def test_redefine_state(
self, worker_state: EdgeWorkerState, body_state: EdgeWorkerState, expected_state: EdgeWorkerState
):
- from airflow.providers.edge.worker_api.routes.worker import redefine_state_if_maintenance
+ from airflow.providers.edge.worker_api.routes.worker import redefine_state
- assert redefine_state_if_maintenance(worker_state, body_state) == expected_state
+ assert redefine_state(worker_state, body_state) == expected_state
def test_set_state(self, session: Session, cli_worker: _EdgeWorkerCli):
queues = ["default", "default2"]