diff --git a/providers/edge/README.rst b/providers/edge/README.rst index d5bbae1780659..48031b6d2ffd6 100644 --- a/providers/edge/README.rst +++ b/providers/edge/README.rst @@ -23,7 +23,7 @@ Package ``apache-airflow-providers-edge`` -Release: ``0.18.1pre0`` +Release: ``0.19.0pre0`` Handle edge workers on remote sites via HTTP(s) connection and orchestrates work over distributed sites @@ -36,7 +36,7 @@ This is a provider package for ``edge`` provider. All classes for this provider are in ``airflow.providers.edge`` python package. You can find package information and changelog for the provider -in the `documentation `_. +in the `documentation `_. Installation ------------ @@ -59,4 +59,4 @@ PIP package Version required ================== =================== The changelog for the provider package can be found in the -`changelog `_. +`changelog `_. diff --git a/providers/edge/docs/changelog.rst b/providers/edge/docs/changelog.rst index 4bfcfc7009f67..5b8bef99440ce 100644 --- a/providers/edge/docs/changelog.rst +++ b/providers/edge/docs/changelog.rst @@ -27,6 +27,15 @@ Changelog --------- +0.19.0pre0 +.......... + +Misc +~~~~ + +* ``Edge worker can be set to maintenance via CLI and also return to normal operation.`` + + 0.18.1pre0 .......... diff --git a/providers/edge/provider.yaml b/providers/edge/provider.yaml index 18f7d1fca6cc0..562583a9a3687 100644 --- a/providers/edge/provider.yaml +++ b/providers/edge/provider.yaml @@ -25,7 +25,7 @@ source-date-epoch: 1737371680 # note that those versions are maintained by release manager - do not update them manually versions: - - 0.18.1pre0 + - 0.19.0pre0 plugins: - name: edge_executor diff --git a/providers/edge/pyproject.toml b/providers/edge/pyproject.toml index 2619f89bec61e..ce404ea9543fa 100644 --- a/providers/edge/pyproject.toml +++ b/providers/edge/pyproject.toml @@ -25,7 +25,7 @@ build-backend = "flit_core.buildapi" [project] name = "apache-airflow-providers-edge" -version = "0.18.1pre0" +version = "0.19.0pre0" description = "Provider package apache-airflow-providers-edge for Apache Airflow" readme = "README.rst" authors = [ @@ -61,8 +61,8 @@ dependencies = [ ] [project.urls] -"Documentation" = "https://airflow.apache.org/docs/apache-airflow-providers-edge/0.18.1pre0" -"Changelog" = "https://airflow.apache.org/docs/apache-airflow-providers-edge/0.18.1pre0/changelog.html" +"Documentation" = "https://airflow.apache.org/docs/apache-airflow-providers-edge/0.19.0pre0" +"Changelog" = "https://airflow.apache.org/docs/apache-airflow-providers-edge/0.19.0pre0/changelog.html" "Bug Tracker" = "https://github.com/apache/airflow/issues" "Source Code" = "https://github.com/apache/airflow" "Slack Chat" = "https://s.apache.org/airflow-slack" diff --git a/providers/edge/src/airflow/providers/edge/__init__.py b/providers/edge/src/airflow/providers/edge/__init__.py index 617fe191824cd..5b8039cec8dc1 100644 --- a/providers/edge/src/airflow/providers/edge/__init__.py +++ b/providers/edge/src/airflow/providers/edge/__init__.py @@ -29,7 +29,7 @@ __all__ = ["__version__"] -__version__ = "0.18.1pre0" +__version__ = "0.19.0pre0" if packaging.version.parse(packaging.version.parse(airflow_version).base_version) < packaging.version.parse( "2.10.0" diff --git a/providers/edge/src/airflow/providers/edge/cli/api_client.py b/providers/edge/src/airflow/providers/edge/cli/api_client.py index c19504787d75c..6c9266890251c 100644 --- a/providers/edge/src/airflow/providers/edge/cli/api_client.py +++ b/providers/edge/src/airflow/providers/edge/cli/api_client.py @@ -109,7 +109,12 @@ def worker_register( def worker_set_state( - hostname: str, state: EdgeWorkerState, jobs_active: int, queues: list[str] | None, sysinfo: dict + hostname: str, + state: EdgeWorkerState, + jobs_active: int, + queues: list[str] | None, + sysinfo: dict, + maintenance_comments: str | None = None, ) -> WorkerSetStateReturn: """Update the state of the worker in the central site and thereby implicitly heartbeat.""" try: @@ -117,7 +122,11 @@ def worker_set_state( "PATCH", f"worker/{quote(hostname)}", WorkerStateBody( - state=state, jobs_active=jobs_active, queues=queues, sysinfo=sysinfo + state=state, + jobs_active=jobs_active, + queues=queues, + sysinfo=sysinfo, + maintenance_comments=maintenance_comments, ).model_dump_json(exclude_unset=True), ) except requests.HTTPError as e: diff --git a/providers/edge/src/airflow/providers/edge/cli/dataclasses.py b/providers/edge/src/airflow/providers/edge/cli/dataclasses.py new file mode 100644 index 0000000000000..3572c550de84f --- /dev/null +++ b/providers/edge/src/airflow/providers/edge/cli/dataclasses.py @@ -0,0 +1,95 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import json +from dataclasses import asdict, dataclass +from multiprocessing import Process +from pathlib import Path +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from psutil import Popen + + from airflow.providers.edge.models.edge_worker import EdgeWorkerState + from airflow.providers.edge.worker_api.datamodels import EdgeJobFetched + + +@dataclass +class MaintenanceMarker: + """Maintenance mode status.""" + + maintenance: str + comments: str | None + + @property + def json(self) -> str: + """Get the maintenance status as JSON.""" + return json.dumps(asdict(self)) + + @staticmethod + def from_json(json_str: str) -> MaintenanceMarker: + """Create a Maintenance object from JSON.""" + return MaintenanceMarker(**json.loads(json_str)) + + +@dataclass +class WorkerStatus: + """Status of the worker.""" + + job_count: int + jobs: list + state: EdgeWorkerState + maintenance: bool + maintenance_comments: str | None + drain: bool + + @property + def json(self) -> str: + """Get the status as JSON.""" + return json.dumps(asdict(self)) + + @staticmethod + def from_json(json_str: str) -> WorkerStatus: + """Create a WorkerStatus object from JSON.""" + return WorkerStatus(**json.loads(json_str)) + + +@dataclass +class Job: + """Holds all information for a task/job to be executed as bundle.""" + + edge_job: EdgeJobFetched + process: Popen | Process + logfile: Path + logsize: int + """Last size of log file, point of last chunk push.""" + + @property + def is_running(self) -> bool: + """Check if the job is still running.""" + if hasattr(self.process, "returncode") and hasattr(self.process, "poll"): + self.process.poll() + return self.process.returncode is None + return self.process.exitcode is None + + @property + def is_success(self) -> bool: + """Check if the job was successful.""" + if hasattr(self.process, "returncode"): + return self.process.returncode == 0 + return self.process.exitcode == 0 diff --git a/providers/edge/src/airflow/providers/edge/cli/edge_command.py b/providers/edge/src/airflow/providers/edge/cli/edge_command.py index 9ef6ed755b3f7..f10f7daf8d678 100644 --- a/providers/edge/src/airflow/providers/edge/cli/edge_command.py +++ b/providers/edge/src/airflow/providers/edge/cli/edge_command.py @@ -22,7 +22,7 @@ import platform import signal import sys -from dataclasses import dataclass +from dataclasses import asdict from datetime import datetime from http import HTTPStatus from multiprocessing import Process @@ -47,6 +47,7 @@ worker_register, worker_set_state, ) +from airflow.providers.edge.cli.dataclasses import Job, MaintenanceMarker, WorkerStatus from airflow.providers.edge.models.edge_worker import EdgeWorkerState, EdgeWorkerVersionException from airflow.providers.edge.version_compat import AIRFLOW_V_3_0_PLUS from airflow.utils import cli as cli_utils, timezone @@ -115,10 +116,22 @@ def _pid_file_path(pid_file: str | None) -> str: return cli_utils.setup_locations(process=EDGE_WORKER_PROCESS_NAME, pid=pid_file)[0] +def _get_pid(pid_file: str | None) -> int: + pid = read_pid_from_pidfile(_pid_file_path(pid_file)) + if not pid: + logger.warning("Could not find PID of worker.") + sys.exit(1) + return pid + + def _status_file_path(pid_file: str | None) -> str: return cli_utils.setup_locations(process=EDGE_WORKER_PROCESS_NAME, pid=pid_file)[1] +def _maintenance_marker_file_path(pid_file: str | None) -> str: + return cli_utils.setup_locations(process=EDGE_WORKER_PROCESS_NAME, pid=pid_file)[1][:-4] + ".in" + + def _write_pid_to_pidfile(pid_file_path: str): """Write PIDs for Edge Workers to disk, handling existing PID files.""" if Path(pid_file_path).exists(): @@ -143,36 +156,10 @@ def _write_pid_to_pidfile(pid_file_path: str): write_pid_to_pidfile(pid_file_path) -@dataclass -class _Job: - """Holds all information for a task/job to be executed as bundle.""" - - edge_job: EdgeJobFetched - process: Popen | Process - logfile: Path - logsize: int - """Last size of log file, point of last chunk push.""" - - @property - def is_running(self) -> bool: - """Check if the job is still running.""" - if isinstance(self.process, Popen): - self.process.poll() - return self.process.returncode is None - return self.process.exitcode is None - - @property - def is_success(self) -> bool: - """Check if the job was successful.""" - if isinstance(self.process, Popen): - return self.process.returncode == 0 - return self.process.exitcode == 0 - - class _EdgeWorkerCli: """Runner instance which executes the Edge Worker.""" - jobs: list[_Job] = [] + jobs: list[Job] = [] """List of jobs that the worker is running currently.""" last_hb: datetime | None = None """Timestamp of last heart beat sent to server.""" @@ -180,6 +167,11 @@ class _EdgeWorkerCli: """Flag if job processing should be completed and no new jobs fetched for a graceful stop/shutdown.""" maintenance_mode: bool = False """Flag if job processing should be completed and no new jobs fetched for maintenance mode. """ + maintenance_comments: str | None = None + """Comments for maintenance mode.""" + + edge_instance: _EdgeWorkerCli | None = None + """Singleton instance of the worker.""" def __init__( self, @@ -198,21 +190,35 @@ def __init__( self.concurrency = concurrency self.free_concurrency = concurrency + _EdgeWorkerCli.edge_instance = self + @staticmethod def signal_handler(sig: signal.Signals, frame): if sig == SIG_STATUS: - logger.info("Request to get status of Edge Worker received.") + marker_path = Path(_maintenance_marker_file_path(None)) + if marker_path.exists(): + request = MaintenanceMarker.from_json(marker_path.read_text()) + logger.info("Requested to set maintenance mode to %s", request.maintenance) + _EdgeWorkerCli.maintenance_mode = request.maintenance == "on" + if _EdgeWorkerCli.maintenance_mode and request.comments: + logger.info("Comments: %s", request.comments) + _EdgeWorkerCli.maintenance_comments = request.comments + marker_path.unlink() + # send heartbeat immediately to update state + if _EdgeWorkerCli.edge_instance: + _EdgeWorkerCli.edge_instance.heartbeat(_EdgeWorkerCli.maintenance_comments) + else: + logger.info("Request to get status of Edge Worker received.") status_path = Path(_status_file_path(None)) status_path.write_text( - json.dumps( - { - "job_count": len(_EdgeWorkerCli.jobs), - "jobs": [job.edge_job.key for job in _EdgeWorkerCli.jobs], - "state": _EdgeWorkerCli._get_state(), - "maintenance": _EdgeWorkerCli.maintenance_mode, - "drain": _EdgeWorkerCli.drain, - } - ) + WorkerStatus( + job_count=len(_EdgeWorkerCli.jobs), + jobs=[job.edge_job.key for job in _EdgeWorkerCli.jobs], + state=_EdgeWorkerCli._get_state(), + maintenance=_EdgeWorkerCli.maintenance_mode, + maintenance_comments=_EdgeWorkerCli.maintenance_comments, + drain=_EdgeWorkerCli.drain, + ).json ) else: logger.info("Request to shut down Edge Worker received, waiting for jobs to complete.") @@ -316,7 +322,7 @@ def _launch_job(self, edge_job: EdgeJobFetched): else: # Airflow 2.10 process, logfile = self._launch_job_af2_10(edge_job) - _EdgeWorkerCli.jobs.append(_Job(edge_job, process, logfile, 0)) + _EdgeWorkerCli.jobs.append(Job(edge_job, process, logfile, 0)) def start(self): """Start the execution in a loop until terminated.""" @@ -432,14 +438,19 @@ def check_running_jobs(self) -> None: self.free_concurrency = self.concurrency - used_concurrency - def heartbeat(self) -> bool: + def heartbeat(self, new_maintenance_comments: str | None = None) -> bool: """Report liveness state of worker to central site with stats.""" state = _EdgeWorkerCli._get_state() sysinfo = self._get_sysinfo() worker_state_changed: bool = False try: worker_info = worker_set_state( - self.hostname, state, len(_EdgeWorkerCli.jobs), self.queues, sysinfo + self.hostname, + state, + len(_EdgeWorkerCli.jobs), + self.queues, + sysinfo, + new_maintenance_comments, ) self.queues = worker_info.queues if worker_info.state == EdgeWorkerState.MAINTENANCE_REQUEST: @@ -451,6 +462,11 @@ def heartbeat(self) -> bool: ): logger.info("Exit Maintenance mode requested!") _EdgeWorkerCli.maintenance_mode = False + if _EdgeWorkerCli.maintenance_mode: + _EdgeWorkerCli.maintenance_comments = worker_info.maintenance_comments + else: + _EdgeWorkerCli.maintenance_comments = None + worker_state_changed = worker_info.state != state except EdgeWorkerVersionException: logger.info("Version mismatch of Edge worker and Core. Shutting down worker.") @@ -488,45 +504,108 @@ def worker(args): @providers_configuration_loaded def status(args): """Check for Airflow Edge Worker status.""" - pid = read_pid_from_pidfile(_pid_file_path(args.pid)) - # Send SIGINT - if pid: - logger.debug("Sending SIGUSR2 to worker pid %i.", pid) - status_min_date = time() - 1 - status_path = Path(_status_file_path(args.pid)) - worker_process = psutil.Process(pid) - worker_process.send_signal(SIG_STATUS) - while psutil.pid_exists(pid) and ( - not status_path.exists() or status_path.stat().st_mtime < status_min_date - ): - sleep(0.1) - if not psutil.pid_exists(pid): - logger.warning("PID of worker dis-appeared while checking for status.") - sys.exit(2) - if not status_path.exists() or status_path.stat().st_mtime < status_min_date: - logger.warning("Could not read status of worker.") - sys.exit(3) - status = json.loads(status_path.read_text()) - print(json.dumps(status, indent=4)) + pid = _get_pid(args.pid) + + # Send Signal as notification to drop status JSON + logger.debug("Sending SIGUSR2 to worker pid %i.", pid) + status_min_date = time() - 1 + status_path = Path(_status_file_path(args.pid)) + worker_process = psutil.Process(pid) + worker_process.send_signal(SIG_STATUS) + while psutil.pid_exists(pid) and ( + not status_path.exists() or status_path.stat().st_mtime < status_min_date + ): + sleep(0.1) + if not psutil.pid_exists(pid): + logger.warning("PID of worker dis-appeared while checking for status.") + sys.exit(2) + if not status_path.exists() or status_path.stat().st_mtime < status_min_date: + logger.warning("Could not read status of worker.") + sys.exit(3) + status = WorkerStatus.from_json(status_path.read_text()) + print(json.dumps(asdict(status), indent=4)) - else: - logger.warning("Could not find PID of worker.") - sys.exit(1) + +@cli_utils.action_cli(check_db=False) +@providers_configuration_loaded +def maintenance(args): + """Set or Unset maintenance mode of worker.""" + if args.maintenance == "on" and not args.comments: + logger.error("Comments are required when setting maintenance mode.") + sys.exit(4) + + pid = _get_pid(args.pid) + + # Write marker JSON file + from getpass import getuser + + marker_path = Path(_maintenance_marker_file_path(args.pid)) + logger.debug("Writing maintenance marker file to %s.", marker_path) + marker_path.write_text( + MaintenanceMarker( + maintenance=args.maintenance, + comments=f'[{datetime.now().strftime("%Y-%m-%d %H:%M")}] - {getuser()} put ' + f'node into maintenance mode via cli\nComment: {args.comments}' + if args.maintenance == "on" + else None, + ).json + ) + + # Send Signal as notification to fetch maintenance marker + logger.debug("Sending SIGUSR2 to worker pid %i.", pid) + status_min_date = time() - 1 + status_path = Path(_status_file_path(args.pid)) + worker_process = psutil.Process(pid) + worker_process.send_signal(SIG_STATUS) + while psutil.pid_exists(pid) and ( + not status_path.exists() or status_path.stat().st_mtime < status_min_date + ): + sleep(0.1) + if not psutil.pid_exists(pid): + logger.warning("PID of worker dis-appeared while checking for status.") + sys.exit(2) + if not status_path.exists() or status_path.stat().st_mtime < status_min_date: + logger.warning("Could not read status of worker.") + sys.exit(3) + status = WorkerStatus.from_json(status_path.read_text()) + + if args.wait: + if args.maintenance == "on" and status.state != EdgeWorkerState.MAINTENANCE_MODE: + logger.info("Waiting for worker to be drained...") + while True: + sleep(4.5) + worker_process.send_signal(SIG_STATUS) + sleep(0.5) + status = WorkerStatus.from_json(status_path.read_text()) + if status.state == EdgeWorkerState.MAINTENANCE_MODE: + logger.info("Worker was drained successfully!") + break + if status.state not in [ + EdgeWorkerState.MAINTENANCE_REQUEST, + EdgeWorkerState.MAINTENANCE_PENDING, + ]: + logger.info("Worker maintenance was exited by someone else!") + break + if args.maintenance == "off" and status.state == EdgeWorkerState.MAINTENANCE_MODE: + logger.info("Waiting for worker to exit maintenance...") + while status.state in [EdgeWorkerState.MAINTENANCE_MODE, EdgeWorkerState.MAINTENANCE_EXIT]: + sleep(4.5) + worker_process.send_signal(SIG_STATUS) + sleep(0.5) + status = WorkerStatus.from_json(status_path.read_text()) + + print(json.dumps(asdict(status), indent=4)) @cli_utils.action_cli(check_db=False) @providers_configuration_loaded def stop(args): """Stop a running Airflow Edge Worker.""" - pid = read_pid_from_pidfile(_pid_file_path(args.pid)) + pid = _get_pid(args.pid) # Send SIGINT - if pid: - logger.info("Sending SIGINT to worker pid %i.", pid) - worker_process = psutil.Process(pid) - worker_process.send_signal(signal.SIGINT) - else: - logger.warning("Could not find PID of worker.") - sys.exit(1) + logger.info("Sending SIGINT to worker pid %i.", pid) + worker_process = psutil.Process(pid) + worker_process.send_signal(signal.SIGINT) if args.wait: logger.info("Waiting for worker to stop...") @@ -549,7 +628,18 @@ def stop(args): ("-H", "--edge-hostname"), help="Set the hostname of worker if you have multiple workers on a single machine", ) -ARG_WAIT = Arg( +ARG_MAINTENANCE = Arg(("maintenance",), help="Desired maintenance state", choices=("on", "off")) +ARG_MAINTENANCE_COMMENT = Arg( + ("-c", "--comments"), + help="Maintenance comments to report reason. Required if maintenance is turned on.", +) +ARG_WAIT_MAINT = Arg( + ("-w", "--wait"), + default=False, + help="Wait until edge worker has reached desired state.", + action="store_true", +) +ARG_WAIT_STOP = Arg( ("-w", "--wait"), default=False, help="Wait until edge worker is shut down.", @@ -577,12 +667,24 @@ def stop(args): ARG_VERBOSE, ), ), + ActionCommand( + name=maintenance.__name__, + help=maintenance.__doc__, + func=maintenance, + args=( + ARG_MAINTENANCE, + ARG_MAINTENANCE_COMMENT, + ARG_WAIT_MAINT, + ARG_PID, + ARG_VERBOSE, + ), + ), ActionCommand( name=stop.__name__, help=stop.__doc__, func=stop, args=( - ARG_WAIT, + ARG_WAIT_STOP, ARG_PID, ARG_VERBOSE, ), diff --git a/providers/edge/src/airflow/providers/edge/get_provider_info.py b/providers/edge/src/airflow/providers/edge/get_provider_info.py index 00f7247d6b55b..c646666d30186 100644 --- a/providers/edge/src/airflow/providers/edge/get_provider_info.py +++ b/providers/edge/src/airflow/providers/edge/get_provider_info.py @@ -28,7 +28,7 @@ def get_provider_info(): "description": "Handle edge workers on remote sites via HTTP(s) connection and orchestrates work over distributed sites\n", "state": "not-ready", "source-date-epoch": 1737371680, - "versions": ["0.18.1pre0"], + "versions": ["0.19.0pre0"], "plugins": [ { "name": "edge_executor", diff --git a/providers/edge/src/airflow/providers/edge/openapi/edge_worker_api_v1.yaml b/providers/edge/src/airflow/providers/edge/openapi/edge_worker_api_v1.yaml index 7141f3b9b0a32..1db7fc557deec 100644 --- a/providers/edge/src/airflow/providers/edge/openapi/edge_worker_api_v1.yaml +++ b/providers/edge/src/airflow/providers/edge/openapi/edge_worker_api_v1.yaml @@ -617,6 +617,13 @@ components: description: System information of the worker. title: Sysinfo type: object + maintenance_comments: + description: Comments about the maintenance state of the worker. + title: Maintenance Comments + anyOf: + - type: string + - type: object + nullable: true title: WorkerStateBody WorkerQueuesBody: description: Queues that a worker supports to run jobs on. @@ -656,6 +663,13 @@ components: state: $ref: '#/components/schemas/EdgeWorkerState' description: State of the worker from the view of the worker. + maintenance_comments: + description: Comments about the maintenance state of the worker. + title: Maintenance Comments + anyOf: + - type: string + - type: object + nullable: true title: WorkerSetStateReturn EdgeJobFetched: description: Job that is to be executed on the edge worker. diff --git a/providers/edge/src/airflow/providers/edge/worker_api/datamodels.py b/providers/edge/src/airflow/providers/edge/worker_api/datamodels.py index 888367a680d82..6edef1f3e48b1 100644 --- a/providers/edge/src/airflow/providers/edge/worker_api/datamodels.py +++ b/providers/edge/src/airflow/providers/edge/worker_api/datamodels.py @@ -142,6 +142,10 @@ class WorkerStateBody(WorkerQueuesBase): ], ), ] + maintenance_comments: Annotated[ + str | None, + Field(description="Comments about the maintenance state of the worker."), + ] = None class WorkerQueueUpdateBody(BaseModel): @@ -174,3 +178,7 @@ class WorkerSetStateReturn(BaseModel): description="List of queues the worker is pulling jobs from. If not provided, worker pulls from all queues." ), ] + maintenance_comments: Annotated[ + str | None, + Field(description="Comments about the maintenance state of the worker."), + ] = None diff --git a/providers/edge/src/airflow/providers/edge/worker_api/routes/_v2_routes.py b/providers/edge/src/airflow/providers/edge/worker_api/routes/_v2_routes.py index ae50438a2345b..32423ddfea765 100644 --- a/providers/edge/src/airflow/providers/edge/worker_api/routes/_v2_routes.py +++ b/providers/edge/src/airflow/providers/edge/worker_api/routes/_v2_routes.py @@ -152,6 +152,7 @@ def set_state_v2(worker_name: str, body: dict[str, Any], session=NEW_SESSION) -> jobs_active=body["jobs_active"], queues=body["queues"], sysinfo=body["sysinfo"], + maintenance_comments=body.get("maintenance_comments"), ) return set_state(worker_name, request_obj, session).model_dump() except HTTPException as e: diff --git a/providers/edge/src/airflow/providers/edge/worker_api/routes/worker.py b/providers/edge/src/airflow/providers/edge/worker_api/routes/worker.py index 602af8b0b2977..6c68f9b29aea1 100644 --- a/providers/edge/src/airflow/providers/edge/worker_api/routes/worker.py +++ b/providers/edge/src/airflow/providers/edge/worker_api/routes/worker.py @@ -113,9 +113,7 @@ def _assert_version(sysinfo: dict[str, str | int]) -> None: ) -def redefine_state_if_maintenance( - worker_state: EdgeWorkerState, body_state: EdgeWorkerState -) -> EdgeWorkerState: +def redefine_state(worker_state: EdgeWorkerState, body_state: EdgeWorkerState) -> EdgeWorkerState: """Redefine the state of the worker based on maintenance request.""" if ( worker_state == EdgeWorkerState.MAINTENANCE_REQUEST @@ -138,6 +136,20 @@ def redefine_state_if_maintenance( return body_state +def redefine_maintenance_comments( + worker_maintenance_comment: str | None, body_maintenance_comments: str | None +) -> str | None: + """Add new maintenance comments or overwrite the old ones if it is too long.""" + if body_maintenance_comments: + if ( + worker_maintenance_comment + and len(body_maintenance_comments) + len(worker_maintenance_comment) < 1020 + ): + return f"{worker_maintenance_comment}\n\n{body_maintenance_comments}" + return body_maintenance_comments + return worker_maintenance_comment + + @worker_router.post("/{worker_name}", dependencies=[Depends(jwt_token_authorization_rest)]) def register( worker_name: Annotated[str, _worker_name_doc], @@ -150,7 +162,10 @@ def register( worker: EdgeWorkerModel = session.scalar(query) if not worker: worker = EdgeWorkerModel(worker_name=worker_name, state=body.state, queues=body.queues) - worker.state = redefine_state_if_maintenance(worker.state, body.state) + worker.state = redefine_state(worker.state, body.state) + worker.maintenance_comment = redefine_maintenance_comments( + worker.maintenance_comment, body.maintenance_comments + ) worker.queues = body.queues worker.sysinfo = json.dumps(body.sysinfo) worker.last_update = timezone.utcnow() @@ -167,7 +182,10 @@ def set_state( """Set state of worker and returns the current assigned queues.""" query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name == worker_name) worker: EdgeWorkerModel = session.scalar(query) - worker.state = redefine_state_if_maintenance(worker.state, body.state) + worker.state = redefine_state(worker.state, body.state) + worker.maintenance_comment = redefine_maintenance_comments( + worker.maintenance_comment, body.maintenance_comments + ) worker.jobs_active = body.jobs_active worker.sysinfo = json.dumps(body.sysinfo) worker.last_update = timezone.utcnow() @@ -183,7 +201,9 @@ def set_state( queues=worker.queues, ) _assert_version(body.sysinfo) # Exception only after worker state is in the DB - return WorkerSetStateReturn(state=worker.state, queues=worker.queues) + return WorkerSetStateReturn( + state=worker.state, queues=worker.queues, maintenance_comments=worker.maintenance_comment + ) @worker_router.patch( diff --git a/providers/edge/tests/unit/edge/cli/test_dataclasses.py b/providers/edge/tests/unit/edge/cli/test_dataclasses.py new file mode 100644 index 0000000000000..3bc68d26b5ed8 --- /dev/null +++ b/providers/edge/tests/unit/edge/cli/test_dataclasses.py @@ -0,0 +1,41 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from airflow.providers.edge.cli.dataclasses import MaintenanceMarker, WorkerStatus +from airflow.providers.edge.models.edge_worker import EdgeWorkerState + +MOCK_ENDPOINT = "https://invalid-api-test-endpoint" + + +class TestMaintenanceMarker: + def test_maintenance_marker_json(self): + marker = MaintenanceMarker(maintenance="maintenance", comments="comments") + assert marker == MaintenanceMarker.from_json(marker.json) + + +class TestWorkerStatus: + def test_worker_status_json(self): + status = WorkerStatus( + job_count=1, + jobs=[], + state=EdgeWorkerState.RUNNING, + maintenance=False, + maintenance_comments=None, + drain=False, + ) + assert status == WorkerStatus.from_json(status.json) diff --git a/providers/edge/tests/unit/edge/cli/test_edge_command.py b/providers/edge/tests/unit/edge/cli/test_edge_command.py index f12756c5b7a1a..55c63a4f474aa 100644 --- a/providers/edge/tests/unit/edge/cli/test_edge_command.py +++ b/providers/edge/tests/unit/edge/cli/test_edge_command.py @@ -27,7 +27,8 @@ import time_machine from requests import HTTPError, Response -from airflow.providers.edge.cli.edge_command import _EdgeWorkerCli, _Job, _write_pid_to_pidfile +from airflow.providers.edge.cli.dataclasses import Job +from airflow.providers.edge.cli.edge_command import _EdgeWorkerCli, _write_pid_to_pidfile from airflow.providers.edge.models.edge_worker import EdgeWorkerState, EdgeWorkerVersionException from airflow.providers.edge.worker_api.datamodels import EdgeJobFetched, WorkerSetStateReturn from airflow.utils import timezone @@ -114,12 +115,12 @@ def returncode(self): class TestEdgeWorkerCli: @pytest.fixture - def mock_joblist(self, tmp_path: Path) -> list[_Job]: + def mock_joblist(self, tmp_path: Path) -> list[Job]: logfile = tmp_path / "file.log" logfile.touch() return [ - _Job( + Job( edge_job=EdgeJobFetched( dag_id="test", task_id="test1", @@ -136,7 +137,7 @@ def mock_joblist(self, tmp_path: Path) -> list[_Job]: ] @pytest.fixture - def worker_with_job(self, tmp_path: Path, mock_joblist: list[_Job]) -> _EdgeWorkerCli: + def worker_with_job(self, tmp_path: Path, mock_joblist: list[Job]) -> _EdgeWorkerCli: test_worker = _EdgeWorkerCli(str(tmp_path / "mock.pid"), "mock", None, 8, 5, 5) _EdgeWorkerCli.jobs = mock_joblist return test_worker diff --git a/providers/edge/tests/unit/edge/worker_api/routes/test_worker.py b/providers/edge/tests/unit/edge/worker_api/routes/test_worker.py index 9535c9e5192c6..3e19ff3c04024 100644 --- a/providers/edge/tests/unit/edge/worker_api/routes/test_worker.py +++ b/providers/edge/tests/unit/edge/worker_api/routes/test_worker.py @@ -145,12 +145,12 @@ def test_register(self, session: Session, input_queues: list[str] | None, cli_wo ), ], ) - def test_redefine_state_if_maintenance( + def test_redefine_state( self, worker_state: EdgeWorkerState, body_state: EdgeWorkerState, expected_state: EdgeWorkerState ): - from airflow.providers.edge.worker_api.routes.worker import redefine_state_if_maintenance + from airflow.providers.edge.worker_api.routes.worker import redefine_state - assert redefine_state_if_maintenance(worker_state, body_state) == expected_state + assert redefine_state(worker_state, body_state) == expected_state def test_set_state(self, session: Session, cli_worker: _EdgeWorkerCli): queues = ["default", "default2"]