diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index fa8ffb4a2c3c6..debbd43c896e4 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -1619,6 +1619,7 @@ symlinking symlinks sync'ed sys +sysinfo syspath Systemd systemd diff --git a/providers/src/airflow/providers/edge/CHANGELOG.rst b/providers/src/airflow/providers/edge/CHANGELOG.rst index 301ca1d4d8874..8309f111f6a3f 100644 --- a/providers/src/airflow/providers/edge/CHANGELOG.rst +++ b/providers/src/airflow/providers/edge/CHANGELOG.rst @@ -26,6 +26,15 @@ Changelog --------- + +0.8.0pre0 +......... + +Misc +~~~~ + +* ``Migrate worker registration and heartbeat to FastAPI.`` + 0.7.1pre0 ......... diff --git a/providers/src/airflow/providers/edge/__init__.py b/providers/src/airflow/providers/edge/__init__.py index 9b22a264d4413..fd23acee829cc 100644 --- a/providers/src/airflow/providers/edge/__init__.py +++ b/providers/src/airflow/providers/edge/__init__.py @@ -29,7 +29,7 @@ __all__ = ["__version__"] -__version__ = "0.7.1pre0" +__version__ = "0.8.0pre0" if packaging.version.parse(packaging.version.parse(airflow_version).base_version) < packaging.version.parse( "2.10.0" diff --git a/providers/src/airflow/providers/edge/cli/api_client.py b/providers/src/airflow/providers/edge/cli/api_client.py new file mode 100644 index 0000000000000..9174191fd8c35 --- /dev/null +++ b/providers/src/airflow/providers/edge/cli/api_client.py @@ -0,0 +1,114 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import json +import logging +from datetime import datetime +from http import HTTPStatus +from pathlib import Path +from typing import TYPE_CHECKING, Any +from urllib.parse import quote, urljoin, urlparse + +import requests +import tenacity +from requests.exceptions import ConnectionError +from urllib3.exceptions import NewConnectionError + +from airflow.configuration import conf +from airflow.exceptions import AirflowException +from airflow.providers.edge.worker_api.auth import jwt_signer +from airflow.providers.edge.worker_api.datamodels import WorkerStateBody + +if TYPE_CHECKING: + from airflow.providers.edge.models.edge_worker import EdgeWorkerState + +logger = logging.getLogger(__name__) + + +def _is_retryable_exception(exception: BaseException) -> bool: + """ + Evaluate which exception types to retry. + + This is especially demanded for cases where an application gateway or Kubernetes ingress can + not find a running instance of a webserver hosting the API (HTTP 502+504) or when the + HTTP request fails in general on network level. + + Note that we want to fail on other general errors on the webserver not to send bad requests in an endless loop. + """ + retryable_status_codes = (HTTPStatus.BAD_GATEWAY, HTTPStatus.GATEWAY_TIMEOUT) + return ( + isinstance(exception, AirflowException) + and exception.status_code in retryable_status_codes + or isinstance(exception, (ConnectionError, NewConnectionError)) + ) + + +@tenacity.retry( + stop=tenacity.stop_after_attempt(10), # TODO: Make this configurable + wait=tenacity.wait_exponential(min=1), # TODO: Make this configurable + retry=tenacity.retry_if_exception(_is_retryable_exception), + before_sleep=tenacity.before_log(logger, logging.WARNING), +) +def _make_generic_request(method: str, rest_path: str, data: str) -> Any: + signer = jwt_signer() + api_url = conf.get("edge", "api_url") + path = urlparse(api_url).path.replace("/rpcapi", "") + headers = { + "Content-Type": "application/json", + "Accept": "application/json", + "Authorization": signer.generate_signed_token({"method": str(Path(path, rest_path))}), + } + api_endpoint = urljoin(api_url, rest_path) + response = requests.request(method, url=api_endpoint, data=data, headers=headers) + if response.status_code == HTTPStatus.NO_CONTENT: + return None + if response.status_code != HTTPStatus.OK: + raise AirflowException( + f"Got {response.status_code}:{response.reason} when sending " + f"the internal api request: {response.text}", + HTTPStatus(response.status_code), + ) + return json.loads(response.content) + + +def worker_register( + hostname: str, state: EdgeWorkerState, queues: list[str] | None, sysinfo: dict +) -> datetime: + """Register worker with the Edge API.""" + result = _make_generic_request( + "POST", + f"worker/{quote(hostname)}", + WorkerStateBody(state=state, jobs_active=0, queues=queues, sysinfo=sysinfo).model_dump_json( + exclude_unset=True + ), + ) + return datetime.fromisoformat(result) + + +def worker_set_state( + hostname: str, state: EdgeWorkerState, jobs_active: int, queues: list[str] | None, sysinfo: dict +) -> list[str] | None: + """Register worker with the Edge API.""" + result = _make_generic_request( + "PATCH", + f"worker/{quote(hostname)}", + WorkerStateBody(state=state, jobs_active=jobs_active, queues=queues, sysinfo=sysinfo).model_dump_json( + exclude_unset=True + ), + ) + return result diff --git a/providers/src/airflow/providers/edge/cli/edge_command.py b/providers/src/airflow/providers/edge/cli/edge_command.py index 3712049b20776..9d172bffdd5b9 100644 --- a/providers/src/airflow/providers/edge/cli/edge_command.py +++ b/providers/src/airflow/providers/edge/cli/edge_command.py @@ -14,13 +14,13 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - from __future__ import annotations import logging import os import platform import signal +import sys from dataclasses import dataclass from datetime import datetime from pathlib import Path @@ -29,15 +29,17 @@ import psutil from lockfile.pidlockfile import read_pid_from_pidfile, remove_existing_pidfile, write_pid_to_pidfile +from packaging.version import Version from airflow import __version__ as airflow_version, settings from airflow.cli.cli_config import ARG_PID, ARG_VERBOSE, ActionCommand, Arg from airflow.configuration import conf from airflow.exceptions import AirflowException from airflow.providers.edge import __version__ as edge_provider_version +from airflow.providers.edge.cli.api_client import worker_register, worker_set_state from airflow.providers.edge.models.edge_job import EdgeJob from airflow.providers.edge.models.edge_logs import EdgeLogs -from airflow.providers.edge.models.edge_worker import EdgeWorker, EdgeWorkerState, EdgeWorkerVersionException +from airflow.providers.edge.models.edge_worker import EdgeWorkerState, EdgeWorkerVersionException from airflow.utils import cli as cli_utils from airflow.utils.platform import IS_WINDOWS from airflow.utils.providers_configuration_loader import providers_configuration_loaded @@ -57,6 +59,45 @@ ) +@providers_configuration_loaded +def force_use_internal_api_on_edge_worker(): + """ + Ensure that the environment is configured for the internal API without needing to declare it outside. + + This is only required for an Edge worker and must to be done before the Click CLI wrapper is initiated. + That is because the CLI wrapper will attempt to establish a DB connection, which will fail before the + function call can take effect. In an Edge worker, we need to "patch" the environment before starting. + """ + # export Edge API to be used for internal API + os.environ["_AIRFLOW__SKIP_DATABASE_EXECUTOR_COMPATIBILITY_CHECK"] = "1" + os.environ["AIRFLOW_ENABLE_AIP_44"] = "True" + if "airflow" in sys.argv[0] and sys.argv[1:3] == ["edge", "worker"]: + AIRFLOW_VERSION = Version(airflow_version) + AIRFLOW_V_3_0_PLUS = Version(AIRFLOW_VERSION.base_version) >= Version("3.0.0") + if AIRFLOW_V_3_0_PLUS: + # Obvious TODO Make EdgeWorker compatible with Airflow 3 (again) + raise SystemExit( + "Error: EdgeWorker is currently broken on AIrflow 3/main due to removal of AIP-44, rework for AIP-72." + ) + + api_url = conf.get("edge", "api_url") + if not api_url: + raise SystemExit("Error: API URL is not configured, please correct configuration.") + logger.info("Starting worker with API endpoint %s", api_url) + os.environ["AIRFLOW__CORE__INTERNAL_API_URL"] = api_url + + from airflow.api_internal import internal_api_call + from airflow.serialization import serialized_objects + + # Note: Need to patch internal settings as statically initialized before we get here + serialized_objects._ENABLE_AIP_44 = True + internal_api_call._ENABLE_AIP_44 = True + internal_api_call.InternalApiConfig.set_use_internal_api("edge-worker") + + +force_use_internal_api_on_edge_worker() + + def _hostname() -> str: if IS_WINDOWS: return platform.uname().node @@ -153,9 +194,9 @@ def _get_sysinfo(self) -> dict: def start(self): """Start the execution in a loop until terminated.""" try: - self.last_hb = EdgeWorker.register_worker( + self.last_hb = worker_register( self.hostname, EdgeWorkerState.STARTING, self.queues, self._get_sysinfo() - ).last_update + ) except EdgeWorkerVersionException as e: logger.info("Version mismatch of Edge worker and Core. Shutting down worker.") raise SystemExit(str(e)) @@ -172,7 +213,7 @@ def start(self): logger.info("Quitting worker, signal being offline.") try: - EdgeWorker.set_state(self.hostname, EdgeWorkerState.OFFLINE, 0, self._get_sysinfo()) + worker_set_state(self.hostname, EdgeWorkerState.OFFLINE, 0, self.queues, self._get_sysinfo()) except EdgeWorkerVersionException: logger.info("Version mismatch of Edge worker and Core. Quitting worker anyway.") finally: @@ -261,7 +302,7 @@ def heartbeat(self) -> None: ) sysinfo = self._get_sysinfo() try: - self.queues = EdgeWorker.set_state(self.hostname, state, len(self.jobs), sysinfo) + self.queues = worker_set_state(self.hostname, state, len(self.jobs), self.queues, sysinfo) except EdgeWorkerVersionException: logger.info("Version mismatch of Edge worker and Core. Shutting down worker.") _EdgeWorkerCli.drain = True diff --git a/providers/src/airflow/providers/edge/executors/edge_executor.py b/providers/src/airflow/providers/edge/executors/edge_executor.py index 48ae5e872e056..4184a8ffe5bf6 100644 --- a/providers/src/airflow/providers/edge/executors/edge_executor.py +++ b/providers/src/airflow/providers/edge/executors/edge_executor.py @@ -33,7 +33,7 @@ from airflow.providers.edge.cli.edge_command import EDGE_COMMANDS from airflow.providers.edge.models.edge_job import EdgeJobModel from airflow.providers.edge.models.edge_logs import EdgeLogsModel -from airflow.providers.edge.models.edge_worker import EdgeWorker, EdgeWorkerModel, EdgeWorkerState +from airflow.providers.edge.models.edge_worker import EdgeWorkerModel, EdgeWorkerState, reset_metrics from airflow.stats import Stats from airflow.utils import timezone from airflow.utils.db import DBLocks, create_global_lock @@ -145,7 +145,7 @@ def _check_worker_liveness(self, session: Session) -> bool: for worker in lifeless_workers: changed = True worker.state = EdgeWorkerState.UNKNOWN - EdgeWorker.reset_metrics(worker.worker_name) + reset_metrics(worker.worker_name) return changed diff --git a/providers/src/airflow/providers/edge/models/edge_worker.py b/providers/src/airflow/providers/edge/models/edge_worker.py index a1287fdb96c8f..656d7539d0722 100644 --- a/providers/src/airflow/providers/edge/models/edge_worker.py +++ b/providers/src/airflow/providers/edge/models/edge_worker.py @@ -23,12 +23,7 @@ from typing import TYPE_CHECKING, Optional from pydantic import BaseModel, ConfigDict -from sqlalchemy import ( - Column, - Integer, - String, - select, -) +from sqlalchemy import Column, Integer, String, select from airflow.api_internal.internal_api_call import internal_api_call from airflow.exceptions import AirflowException @@ -129,8 +124,56 @@ def remove_queues(self, remove_queues: list[str]) -> None: self.queues = queues +def set_metrics( + worker_name: str, + state: EdgeWorkerState, + jobs_active: int, + concurrency: int, + free_concurrency: int, + queues: list[str] | None, +) -> None: + """Set metric of edge worker.""" + queues = queues if queues else [] + connected = state not in (EdgeWorkerState.UNKNOWN, EdgeWorkerState.OFFLINE) + + Stats.gauge(f"edge_worker.state.{worker_name}", int(connected)) + Stats.gauge( + "edge_worker.state", + int(connected), + tags={"name": worker_name, "state": state}, + ) + + Stats.gauge(f"edge_worker.jobs_active.{worker_name}", jobs_active) + Stats.gauge("edge_worker.jobs_active", jobs_active, tags={"worker_name": worker_name}) + + Stats.gauge(f"edge_worker.concurrency.{worker_name}", concurrency) + Stats.gauge("edge_worker.concurrency", concurrency, tags={"worker_name": worker_name}) + + Stats.gauge(f"edge_worker.free_concurrency.{worker_name}", free_concurrency) + Stats.gauge("edge_worker.free_concurrency", free_concurrency, tags={"worker_name": worker_name}) + + Stats.gauge(f"edge_worker.num_queues.{worker_name}", len(queues)) + Stats.gauge( + "edge_worker.num_queues", + len(queues), + tags={"worker_name": worker_name, "queues": ",".join(queues)}, + ) + + +def reset_metrics(worker_name: str) -> None: + """Reset metrics of worker.""" + set_metrics( + worker_name=worker_name, + state=EdgeWorkerState.UNKNOWN, + jobs_active=0, + concurrency=0, + free_concurrency=-1, + queues=None, + ) + + class EdgeWorker(BaseModel, LoggingMixin): - """Accessor for Edge Worker instances as logical model.""" + """Deprecated Edge Worker internal API, keeping for one minor for graceful migration.""" worker_name: str state: EdgeWorkerState @@ -144,119 +187,6 @@ class EdgeWorker(BaseModel, LoggingMixin): sysinfo: str model_config = ConfigDict(from_attributes=True, arbitrary_types_allowed=True) - @staticmethod - def set_metrics( - worker_name: str, - state: EdgeWorkerState, - jobs_active: int, - concurrency: int, - free_concurrency: int, - queues: list[str] | None, - ) -> None: - """Set metric of edge worker.""" - queues = queues if queues else [] - connected = state not in (EdgeWorkerState.UNKNOWN, EdgeWorkerState.OFFLINE) - Stats.gauge(f"edge_worker.state.{worker_name}", int(connected)) - Stats.gauge( - "edge_worker.state", - int(connected), - tags={"name": worker_name, "state": state}, - ) - - Stats.gauge(f"edge_worker.jobs_active.{worker_name}", jobs_active) - Stats.gauge("edge_worker.jobs_active", jobs_active, tags={"worker_name": worker_name}) - - Stats.gauge(f"edge_worker.concurrency.{worker_name}", concurrency) - Stats.gauge("edge_worker.concurrency", concurrency, tags={"worker_name": worker_name}) - - Stats.gauge(f"edge_worker.free_concurrency.{worker_name}", free_concurrency) - Stats.gauge("edge_worker.free_concurrency", free_concurrency, tags={"worker_name": worker_name}) - - Stats.gauge( - f"edge_worker.num_queues.{worker_name}", - len(queues), - ) - Stats.gauge( - "edge_worker.num_queues", - len(queues), - tags={"worker_name": worker_name, "queues": ",".join(queues)}, - ) - - @staticmethod - def reset_metrics(worker_name: str) -> None: - """Reset metrics of worker.""" - EdgeWorker.set_metrics( - worker_name=worker_name, - state=EdgeWorkerState.UNKNOWN, - jobs_active=0, - concurrency=0, - free_concurrency=-1, - queues=None, - ) - - @staticmethod - def assert_version(sysinfo: dict[str, str]) -> None: - """Check if the Edge Worker version matches the central API site.""" - from airflow import __version__ as airflow_version - from airflow.providers.edge import __version__ as edge_provider_version - - # Note: In future, more stable versions we might be more liberate, for the - # moment we require exact version match for Edge Worker and core version - if "airflow_version" in sysinfo: - airflow_on_worker = sysinfo["airflow_version"] - if airflow_on_worker != airflow_version: - raise EdgeWorkerVersionException( - f"Edge Worker runs on Airflow {airflow_on_worker} " - f"and the core runs on {airflow_version}. Rejecting access due to difference." - ) - else: - raise EdgeWorkerVersionException("Edge Worker does not specify the version it is running on.") - - if "edge_provider_version" in sysinfo: - provider_on_worker = sysinfo["edge_provider_version"] - if provider_on_worker != edge_provider_version: - raise EdgeWorkerVersionException( - f"Edge Worker runs on Edge Provider {provider_on_worker} " - f"and the core runs on {edge_provider_version}. Rejecting access due to difference." - ) - else: - raise EdgeWorkerVersionException( - "Edge Worker does not specify the provider version it is running on." - ) - - @staticmethod - @internal_api_call - @provide_session - def register_worker( - worker_name: str, - state: EdgeWorkerState, - queues: list[str] | None, - sysinfo: dict[str, str], - session: Session = NEW_SESSION, - ) -> EdgeWorker: - EdgeWorker.assert_version(sysinfo) - query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name == worker_name) - worker: EdgeWorkerModel = session.scalar(query) - if not worker: - worker = EdgeWorkerModel(worker_name=worker_name, state=state, queues=queues) - worker.state = state - worker.queues = queues - worker.sysinfo = json.dumps(sysinfo) - worker.last_update = timezone.utcnow() - session.add(worker) - return EdgeWorker( - worker_name=worker_name, - state=state, - queues=queues, - first_online=worker.first_online, - last_update=worker.last_update, - jobs_active=worker.jobs_active or 0, - jobs_taken=worker.jobs_taken or 0, - jobs_success=worker.jobs_success or 0, - jobs_failed=worker.jobs_failed or 0, - sysinfo=worker.sysinfo or "{}", - ) - @staticmethod @internal_api_call @provide_session @@ -277,7 +207,7 @@ def set_state( session.commit() Stats.incr(f"edge_worker.heartbeat_count.{worker_name}", 1, 1) Stats.incr("edge_worker.heartbeat_count", 1, 1, tags={"worker_name": worker_name}) - EdgeWorker.set_metrics( + set_metrics( worker_name=worker_name, state=state, jobs_active=jobs_active, @@ -285,25 +215,21 @@ def set_state( free_concurrency=int(sysinfo["free_concurrency"]), queues=worker.queues, ) - EdgeWorker.assert_version(sysinfo) # Exception only after worker state is in the DB - return worker.queues + raise EdgeWorkerVersionException( + "Edge Worker runs on an old version. Rejecting access due to difference." + ) @staticmethod - @provide_session - def add_and_remove_queues( + @internal_api_call + def register_worker( worker_name: str, - new_queues: list[str] | None = None, - remove_queues: list[str] | None = None, - session: Session = NEW_SESSION, - ) -> None: - query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name == worker_name) - worker: EdgeWorkerModel = session.scalar(query) - if new_queues: - worker.add_queues(new_queues) - if remove_queues: - worker.remove_queues(remove_queues) - session.add(worker) - session.commit() + state: EdgeWorkerState, + queues: list[str] | None, + sysinfo: dict[str, str], + ) -> EdgeWorker: + raise EdgeWorkerVersionException( + "Edge Worker runs on an old version. Rejecting access due to difference." + ) EdgeWorker.model_rebuild() diff --git a/providers/src/airflow/providers/edge/openapi/edge_worker_api_v1.yaml b/providers/src/airflow/providers/edge/openapi/edge_worker_api_v1.yaml index f1ab5f4c05e1c..8be23c0d07cc3 100644 --- a/providers/src/airflow/providers/edge/openapi/edge_worker_api_v1.yaml +++ b/providers/src/airflow/providers/edge/openapi/edge_worker_api_v1.yaml @@ -35,11 +35,154 @@ servers: - url: /edge_worker/v1 description: Airflow Edge Worker API paths: - "/rpcapi": + /worker/{worker_name}: + patch: + description: Set state of worker and returns the current assigned queues. + x-openapi-router-controller: airflow.providers.edge.worker_api.routes._v2_routes + operationId: set_state_v2 + parameters: + - description: Hostname or instance name of the worker + in: path + name: worker_name + required: true + schema: + description: Hostname or instance name of the worker + title: Worker Name + type: string + - description: JWT Authorization Token + in: header + name: authorization + required: true + schema: + description: JWT Authorization Token + title: Authorization + type: string + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/WorkerStateBody' + description: State of the worker with details + examples: + - jobs_active: 3 + queues: + - large_node + - wisconsin_site + state: running + sysinfo: + airflow_version: 2.10.0 + concurrency: 4 + edge_provider_version: 1.0.0 + title: Worker State + required: true + responses: + '200': + content: + application/json: + schema: + anyOf: + - items: + type: string + type: array + - type: object + nullable: true + title: Response Set State + description: Successful Response + '400': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Bad Request + '403': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Forbidden + '422': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPValidationError' + description: Validation Error + summary: Set State + tags: + - Worker + post: + description: Register a new worker to the backend. + x-openapi-router-controller: airflow.providers.edge.worker_api.routes._v2_routes + operationId: register_v2 + parameters: + - description: Hostname or instance name of the worker + in: path + name: worker_name + required: true + schema: + description: Hostname or instance name of the worker + title: Worker Name + type: string + - description: JWT Authorization Token + in: header + name: authorization + required: true + schema: + description: JWT Authorization Token + title: Authorization + type: string + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/WorkerStateBody' + description: State of the worker with details + examples: + - jobs_active: 3 + queues: + - large_node + - wisconsin_site + state: running + sysinfo: + airflow_version: 2.10.0 + concurrency: 4 + edge_provider_version: 1.0.0 + title: Worker State + required: true + responses: + '200': + content: + application/json: + schema: + format: date-time + title: Response Register + type: string + description: Successful Response + '400': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Bad Request + '403': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPExceptionResponse' + description: Forbidden + '422': + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPValidationError' + description: Validation Error + summary: Register + tags: + - Worker + /rpcapi: post: deprecated: false - x-openapi-router-controller: airflow.providers.edge.worker_api.routes.rpc_api - operationId: edge_worker_api_v2 + x-openapi-router-controller: airflow.providers.edge.worker_api.routes._v2_routes + operationId: rpcapi_v2 tags: - JSONRPC parameters: [] @@ -68,7 +211,7 @@ paths: params: title: Parameters type: object - "/health": + /health: get: operationId: health deprecated: false @@ -99,4 +242,89 @@ components: description: JSON-RPC Version (2.0) discriminator: propertyName: method_name + EdgeWorkerState: + description: Status of a Edge Worker instance. + enum: + - starting + - running + - idle + - terminating + - offline + - unknown + title: EdgeWorkerState + type: string + WorkerStateBody: + description: Details of the worker state sent to the scheduler. + type: object + required: + - state + - queues + - sysinfo + properties: + jobs_active: + default: 0 + description: Number of active jobs the worker is running. + title: Jobs Active + type: integer + queues: + anyOf: + - items: + type: string + type: array + - type: object + nullable: true + description: List of queues the worker is pulling jobs from. If not provided, + worker pulls from all queues. + title: Queues + state: + $ref: '#/components/schemas/EdgeWorkerState' + description: State of the worker from the view of the worker. + sysinfo: + description: System information of the worker. + title: Sysinfo + type: object + title: WorkerStateBody + HTTPExceptionResponse: + description: HTTPException Model used for error response. + properties: + detail: + anyOf: + - type: string + - type: object + title: Detail + required: + - detail + title: HTTPExceptionResponse + type: object + HTTPValidationError: + properties: + detail: + items: + $ref: '#/components/schemas/ValidationError' + title: Detail + type: array + title: HTTPValidationError + type: object + ValidationError: + properties: + loc: + items: + anyOf: + - type: string + - type: integer + title: Location + type: array + msg: + title: Message + type: string + type: + title: Error Type + type: string + required: + - loc + - msg + - type + title: ValidationError + type: object + tags: [] diff --git a/providers/src/airflow/providers/edge/provider.yaml b/providers/src/airflow/providers/edge/provider.yaml index 95827a44e9b16..25dd75a2624c5 100644 --- a/providers/src/airflow/providers/edge/provider.yaml +++ b/providers/src/airflow/providers/edge/provider.yaml @@ -27,7 +27,7 @@ source-date-epoch: 1729683247 # note that those versions are maintained by release manager - do not update them manually versions: - - 0.7.1pre0 + - 0.8.0pre0 dependencies: - apache-airflow>=2.10.0 diff --git a/providers/src/airflow/providers/edge/worker_api/app.py b/providers/src/airflow/providers/edge/worker_api/app.py index bfe9ef4c5bc53..69a43edb116bf 100644 --- a/providers/src/airflow/providers/edge/worker_api/app.py +++ b/providers/src/airflow/providers/edge/worker_api/app.py @@ -19,7 +19,7 @@ from fastapi import FastAPI from airflow.providers.edge.worker_api.routes.health import health_router -from airflow.providers.edge.worker_api.routes.rpc_api import rpc_api_router +from airflow.providers.edge.worker_api.routes.worker import worker_router def create_edge_worker_api_app() -> FastAPI: @@ -35,5 +35,5 @@ def create_edge_worker_api_app() -> FastAPI: ) edge_worker_api_app.include_router(health_router) - edge_worker_api_app.include_router(rpc_api_router) + edge_worker_api_app.include_router(worker_router) return edge_worker_api_app diff --git a/providers/src/airflow/providers/edge/worker_api/auth.py b/providers/src/airflow/providers/edge/worker_api/auth.py new file mode 100644 index 0000000000000..5829e94732b52 --- /dev/null +++ b/providers/src/airflow/providers/edge/worker_api/auth.py @@ -0,0 +1,110 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import logging +from functools import cache +from uuid import uuid4 + +from itsdangerous import BadSignature +from jwt import ( + ExpiredSignatureError, + ImmatureSignatureError, + InvalidAudienceError, + InvalidIssuedAtError, + InvalidSignatureError, +) + +from airflow.configuration import conf +from airflow.providers.edge.worker_api.datamodels import JsonRpcRequestBase # noqa: TCH001 +from airflow.providers.edge.worker_api.routes._v2_compat import ( + Header, + HTTPException, + Request, + status, +) +from airflow.utils.jwt_signer import JWTSigner + +log = logging.getLogger(__name__) + + +@cache +def jwt_signer() -> JWTSigner: + clock_grace = conf.getint("core", "internal_api_clock_grace", fallback=30) + return JWTSigner( + secret_key=conf.get("core", "internal_api_secret_key"), + expiration_time_in_seconds=clock_grace, + leeway_in_seconds=clock_grace, + audience="api", + ) + + +def _forbidden_response(message: str): + """Log the error and return the response anonymized.""" + error_id = uuid4() + log.exception("%s error_id=%s", message, error_id) + raise HTTPException( + status.HTTP_403_FORBIDDEN, + f"Forbidden. The server side traceback may be identified with error_id={error_id}", + ) + + +def jwt_token_authorization(method: str, authorization: str): + """Check if the JWT token is correct.""" + try: + payload = jwt_signer().verify_token(authorization) + signed_method = payload.get("method") + if not signed_method or signed_method != method: + _forbidden_response( + "Invalid method in token authorization. " + f"signed method='{signed_method}' " + f"called method='{method}'", + ) + except BadSignature: + _forbidden_response("Bad Signature. Please use only the tokens provided by the API.") + except InvalidAudienceError: + _forbidden_response("Invalid audience for the request") + except InvalidSignatureError: + _forbidden_response("The signature of the request was wrong") + except ImmatureSignatureError: + _forbidden_response("The signature of the request was sent from the future") + except ExpiredSignatureError: + _forbidden_response( + "The signature of the request has expired. Make sure that all components " + "in your system have synchronized clocks.", + ) + except InvalidIssuedAtError: + _forbidden_response( + "The request was issues in the future. Make sure that all components " + "in your system have synchronized clocks.", + ) + except Exception: + _forbidden_response("Unable to authenticate API via token.") + + +def jwt_token_authorization_rpc( + body: JsonRpcRequestBase, authorization: str = Header(description="JWT Authorization Token") +): + """Check if the JWT token is correct for JSON PRC requests.""" + jwt_token_authorization(body.method, authorization) + + +def jwt_token_authorization_rest( + request: Request, authorization: str = Header(description="JWT Authorization Token") +): + """Check if the JWT token is correct for REST API requests.""" + jwt_token_authorization(request.url.path, authorization) diff --git a/providers/src/airflow/providers/edge/worker_api/datamodels.py b/providers/src/airflow/providers/edge/worker_api/datamodels.py index 9ce181bc72688..170d8c449ffc3 100644 --- a/providers/src/airflow/providers/edge/worker_api/datamodels.py +++ b/providers/src/airflow/providers/edge/worker_api/datamodels.py @@ -16,17 +16,73 @@ # under the License. from __future__ import annotations -from typing import Any, Optional +from typing import ( # noqa: UP035 - prevent pytest failing in back-compat + Annotated, + Any, + Dict, + List, + Optional, + Union, +) -from pydantic import BaseModel +from pydantic import BaseModel, Field +from airflow.providers.edge.models.edge_worker import EdgeWorkerState # noqa: TCH001 -class JsonRpcRequest(BaseModel): + +class JsonRpcRequestBase(BaseModel): + """Base JSON RPC request model to define just the method.""" + + method: Annotated[ + str, + Field(description="Fully qualified python module method name that is called via JSON RPC."), + ] + + +class JsonRpcRequest(JsonRpcRequestBase): """JSON RPC request model.""" - method: str - """Fully qualified python module method name that is called via JSON RPC.""" - jsonrpc: str - """JSON RPC version.""" - params: Optional[dict[str, Any]] = None # noqa: UP007 - prevent pytest failing in back-compat - """Parameters passed to the method.""" + jsonrpc: Annotated[str, Field(description="JSON RPC Version", examples=["2.0"])] + params: Annotated[ + Optional[Dict[str, Any]], # noqa: UP006, UP007 - prevent pytest failing in back-compat + Field(description="Dictionary of parameters passed to the method."), + ] + + +class WorkerStateBody(BaseModel): + """Details of the worker state sent to the scheduler.""" + + state: Annotated[EdgeWorkerState, Field(description="State of the worker from the view of the worker.")] + jobs_active: Annotated[int, Field(description="Number of active jobs the worker is running.")] = 0 + queues: Annotated[ + Optional[List[str]], # noqa: UP006, UP007 - prevent pytest failing in back-compat + Field( + description="List of queues the worker is pulling jobs from. If not provided, worker pulls from all queues." + ), + ] = None + sysinfo: Annotated[ + Dict[str, Union[str, int]], # noqa: UP006, UP007 - prevent pytest failing in back-compat + Field( + description="System information of the worker.", + examples=[ + { + "concurrency": 4, + "airflow_version": "2.0.0", + "edge_provider_version": "1.0.0", + } + ], + ), + ] + + +class WorkerQueueUpdateBody(BaseModel): + """Changed queues for the worker.""" + + new_queues: Annotated[ + Optional[List[str]], # noqa: UP006, UP007 - prevent pytest failing in back-compat + Field(description="Additional queues to be added to worker."), + ] + remove_queues: Annotated[ + Optional[List[str]], # noqa: UP006, UP007 - prevent pytest failing in back-compat + Field(description="Queues to remove from worker."), + ] diff --git a/providers/src/airflow/providers/edge/worker_api/routes/_v2_compat.py b/providers/src/airflow/providers/edge/worker_api/routes/_v2_compat.py index 9774b3e6696f9..553456d410879 100644 --- a/providers/src/airflow/providers/edge/worker_api/routes/_v2_compat.py +++ b/providers/src/airflow/providers/edge/worker_api/routes/_v2_compat.py @@ -27,8 +27,9 @@ if AIRFLOW_V_3_0_PLUS: # Just re-import the types from FastAPI and Airflow Core - from fastapi import Depends, Header, HTTPException, status + from fastapi import Body, Depends, Header, HTTPException, Path, Request, status + from airflow.api_fastapi.common.db.common import SessionDep from airflow.api_fastapi.common.router import AirflowRouter from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc else: @@ -37,17 +38,33 @@ from connexion import ProblemException + class Body: # type: ignore[no-redef] + def __init__(self, *_, **__): + pass + class Depends: # type: ignore[no-redef] def __init__(self, *_, **__): pass class Header: # type: ignore[no-redef] + def __init__(self, *_, **__): + pass + + class Path: # type: ignore[no-redef] + def __init__(self, *_, **__): + pass + + class Request: # type: ignore[no-redef] + pass + + class SessionDep: # type: ignore[no-redef] pass def create_openapi_http_exception_doc(responses_status_code: list[int]) -> dict: return {} class status: # type: ignore[no-redef] + HTTP_204_NO_CONTENT = 204 HTTP_400_BAD_REQUEST = 400 HTTP_403_FORBIDDEN = 403 HTTP_500_INTERNAL_SERVER_ERROR = 500 @@ -100,3 +117,9 @@ def decorator(func: Callable) -> Callable: return func return decorator + + def patch(self, *_, **__): + def decorator(func: Callable) -> Callable: + return func + + return decorator diff --git a/providers/src/airflow/providers/edge/worker_api/routes/rpc_api.py b/providers/src/airflow/providers/edge/worker_api/routes/_v2_routes.py similarity index 52% rename from providers/src/airflow/providers/edge/worker_api/routes/rpc_api.py rename to providers/src/airflow/providers/edge/worker_api/routes/_v2_routes.py index aa5b30f5ab7e1..6f2e81caa0026 100644 --- a/providers/src/airflow/providers/edge/worker_api/routes/rpc_api.py +++ b/providers/src/airflow/providers/edge/worker_api/routes/_v2_routes.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +"""Compatibility layer for Connexion API to Airflow v2.10 API routes.""" from __future__ import annotations @@ -23,45 +24,39 @@ from typing import TYPE_CHECKING, Any, Callable from uuid import uuid4 -from itsdangerous import BadSignature -from jwt import ( - ExpiredSignatureError, - ImmatureSignatureError, - InvalidAudienceError, - InvalidIssuedAtError, - InvalidSignatureError, -) - -from airflow.configuration import conf from airflow.exceptions import AirflowException -from airflow.providers.edge.worker_api.datamodels import JsonRpcRequest -from airflow.providers.edge.worker_api.routes._v2_compat import ( - AirflowRouter, - Depends, - Header, - HTTPException, - create_openapi_http_exception_doc, - status, -) +from airflow.providers.edge.worker_api.auth import jwt_token_authorization, jwt_token_authorization_rpc +from airflow.providers.edge.worker_api.datamodels import JsonRpcRequest, WorkerStateBody +from airflow.providers.edge.worker_api.routes._v2_compat import HTTPException, status +from airflow.providers.edge.worker_api.routes.worker import register, set_state from airflow.serialization.serialized_objects import BaseSerialization -from airflow.utils.jwt_signer import JWTSigner -from airflow.utils.session import create_session +from airflow.utils.session import NEW_SESSION, create_session, provide_session if TYPE_CHECKING: from airflow.api_connexion.types import APIResponse + log = logging.getLogger(__name__) -rpc_api_router = AirflowRouter(tags=["JSONRPC"]) @cache def _initialize_method_map() -> dict[str, Callable]: + # Note: This is a copy of the (removed) AIP-44 implementation from + # airflow/api_internal/endpoints/rpc_api_endpoint.py + # for compatibility with Airflow 2.10-line. + # Methods are potentially not existing more on main branch for Airflow 3. + from airflow.api.common.trigger_dag import trigger_dag from airflow.cli.commands.task_command import _get_ti_db_access from airflow.dag_processing.manager import DagFileProcessorManager from airflow.dag_processing.processor import DagFileProcessor + + # Airflow 2.10 compatibility + from airflow.datasets import expand_alias_to_datasets # type: ignore[attr-defined] + from airflow.datasets.manager import DatasetManager # type: ignore[attr-defined] from airflow.jobs.job import Job, most_recent_job from airflow.models import Trigger, Variable, XCom from airflow.models.dag import DAG, DagModel + from airflow.models.dagcode import DagCode from airflow.models.dagrun import DagRun from airflow.models.dagwarning import DagWarning from airflow.models.renderedtifields import RenderedTaskInstanceFields @@ -76,14 +71,13 @@ def _initialize_method_map() -> dict[str, Callable]: _handle_reschedule, _record_task_map_for_downstreams, _update_rtif, - _update_ti_heartbeat, _xcom_pull, ) from airflow.models.xcom_arg import _get_task_map_length from airflow.providers.edge.models.edge_job import EdgeJob from airflow.providers.edge.models.edge_logs import EdgeLogs from airflow.providers.edge.models.edge_worker import EdgeWorker - from airflow.sdk.definitions.asset import expand_alias_to_assets + from airflow.secrets.metastore import MetastoreBackend from airflow.sensors.base import _orig_start_date from airflow.utils.cli_action_loggers import _default_action_log_internal from airflow.utils.log.file_task_handler import FileTaskHandler @@ -95,19 +89,22 @@ def _initialize_method_map() -> dict[str, Callable]: _get_ti_db_access, _get_task_map_length, _update_rtif, - _update_ti_heartbeat, _orig_start_date, _handle_failure, _handle_reschedule, _add_log, _xcom_pull, _record_task_map_for_downstreams, + trigger_dag, + DagCode.remove_deleted_code, DagModel.deactivate_deleted_dags, DagModel.get_paused_dag_ids, DagModel.get_current, DagFileProcessor._execute_task_callbacks, DagFileProcessor.execute_callbacks, DagFileProcessor.execute_callbacks_without_dag, + # Airflow 2.10 compatibility + DagFileProcessor.manage_slas, # type: ignore[attr-defined] DagFileProcessor.save_dag_to_db, DagFileProcessor.update_import_errors, DagFileProcessor._validate_task_pools_and_update_dag_warnings, @@ -116,13 +113,18 @@ def _initialize_method_map() -> dict[str, Callable]: DagFileProcessorManager.clear_nonexistent_import_errors, DagFileProcessorManager.deactivate_stale_dags, DagWarning.purge_inactive_dag_warnings, - expand_alias_to_assets, + expand_alias_to_datasets, + DatasetManager.register_dataset_change, FileTaskHandler._render_filename_db_access, Job._add_to_db, + Job._fetch_from_db, Job._kill, Job._update_heartbeat, Job._update_in_db, most_recent_job, + # Airflow 2.10 compatibility + MetastoreBackend._fetch_connection, # type: ignore[attr-defined] + MetastoreBackend._fetch_variable, # type: ignore[attr-defined] XCom.get_value, XCom.get_one, # XCom.get_many, # Not supported because it returns query @@ -141,6 +143,7 @@ def _initialize_method_map() -> dict[str, Callable]: DagRun._get_log_template, RenderedTaskInstanceFields._update_runtime_evaluated_template_fields, SerializedDagModel.get_serialized_dag, + SerializedDagModel.remove_deleted_dags, SkipMixin._skip, SkipMixin._skip_all_except, TaskInstance._check_and_change_state_before_execution, @@ -149,7 +152,6 @@ def _initialize_method_map() -> dict[str, Callable]: TaskInstance._set_state, TaskInstance.save_to_db, TaskInstance._clear_xcom_data, - TaskInstance._register_asset_changes_int, Trigger.from_object, Trigger.bulk_fetch, Trigger.clean_unused, @@ -158,6 +160,7 @@ def _initialize_method_map() -> dict[str, Callable]: Trigger.ids_for_triggerer, Trigger.assign_unassigned, # Additional things from EdgeExecutor + # These are removed in follow-up PRs as being in transition to FastAPI EdgeJob.reserve_task, EdgeJob.set_state, EdgeLogs.push_logs, @@ -167,17 +170,6 @@ def _initialize_method_map() -> dict[str, Callable]: return {f"{func.__module__}.{func.__qualname__}": func for func in functions} -@cache -def _jwt_signer() -> JWTSigner: - clock_grace = conf.getint("core", "internal_api_clock_grace", fallback=30) - return JWTSigner( - secret_key=conf.get("core", "internal_api_secret_key"), - expiration_time_in_seconds=clock_grace, - leeway_in_seconds=clock_grace, - audience="api", - ) - - def error_response(message: str, status: int): """Log the error and return the response as JSON object.""" error_id = uuid4() @@ -187,124 +179,92 @@ def error_response(message: str, status: int): return HTTPException(status, client_message) -def json_request_headers(content_type: str = Header(), accept: str = Header()): - """Check if the request headers are correct.""" - if content_type != "application/json": - raise HTTPException(status.HTTP_403_FORBIDDEN, "Expected Content-Type: application/json") - if accept != "application/json": - raise HTTPException(status.HTTP_403_FORBIDDEN, "Expected Accept: application/json") - +def rpcapi_v2(body: dict[str, Any]) -> APIResponse: + """Handle Edge Worker API `/edge_worker/v1/rpcapi` endpoint for Airflow 2.10.""" + # Note: Except the method map this _was_ a 100% copy of internal API module + # airflow.api_internal.endpoints.rpc_api_endpoint.internal_airflow_api() + # As of rework for FastAPI in Airflow 3.0, this is updated and to be removed in the future. + from flask import Response, request -def jwt_token_authorization(body: JsonRpcRequest, authorization: str = Header()): - """Check if the JWT token is correct.""" try: - payload = _jwt_signer().verify_token(authorization) - signed_method = payload.get("method") - if not signed_method or signed_method != body.method: - raise BadSignature("Invalid method in token authorization.") - except BadSignature: - raise HTTPException( - status.HTTP_403_FORBIDDEN, "Bad Signature. Please use only the tokens provided by the API." - ) - except InvalidAudienceError: - raise HTTPException(status.HTTP_403_FORBIDDEN, "Invalid audience for the request") - except InvalidSignatureError: - raise HTTPException(status.HTTP_403_FORBIDDEN, "The signature of the request was wrong") - except ImmatureSignatureError: - raise HTTPException( - status.HTTP_403_FORBIDDEN, "The signature of the request was sent from the future" - ) - except ExpiredSignatureError: - raise HTTPException( - status.HTTP_403_FORBIDDEN, - "The signature of the request has expired. Make sure that all components " - "in your system have synchronized clocks.", - ) - except InvalidIssuedAtError: - raise HTTPException( - status.HTTP_403_FORBIDDEN, - "The request was issues in the future. Make sure that all components " - "in your system have synchronized clocks.", - ) - except Exception: - raise HTTPException(status.HTTP_403_FORBIDDEN, "Unable to authenticate API via token.") + if request.headers.get("Content-Type", "") != "application/json": + raise HTTPException(status.HTTP_403_FORBIDDEN, "Expected Content-Type: application/json") + if request.headers.get("Accept", "") != "application/json": + raise HTTPException(status.HTTP_403_FORBIDDEN, "Expected Accept: application/json") + auth = request.headers.get("Authorization", "") + request_obj = JsonRpcRequest(method=body["method"], jsonrpc=body["jsonrpc"], params=body["params"]) + jwt_token_authorization_rpc(request_obj, auth) + if request_obj.jsonrpc != "2.0": + raise error_response("Expected jsonrpc 2.0 request.", status.HTTP_400_BAD_REQUEST) + log.debug("Got request for %s", request_obj.method) + methods_map = _initialize_method_map() + if request_obj.method not in methods_map: + raise error_response(f"Unrecognized method: {request_obj.method}.", status.HTTP_400_BAD_REQUEST) -def json_rpc_version(body: JsonRpcRequest): - """Check if the JSON RPC Request version is correct.""" - if body.jsonrpc != "2.0": - raise error_response("Expected jsonrpc 2.0 request.", status.HTTP_400_BAD_REQUEST) + handler = methods_map[request_obj.method] + params = {} + try: + if request_obj.params: + params = BaseSerialization.deserialize(request_obj.params, use_pydantic_models=True) + except Exception: + raise error_response("Error deserializing parameters.", status.HTTP_400_BAD_REQUEST) + log.debug("Calling method %s\nparams: %s", request_obj.method, params) + try: + # Session must be created there as it may be needed by serializer for lazy-loaded fields. + with create_session() as session: + output = handler(**params, session=session) + output_json = BaseSerialization.serialize(output, use_pydantic_models=True) + log.debug( + "Sending response: %s", json.dumps(output_json) if output_json is not None else None + ) + # In case of AirflowException or other selective known types, transport the exception class back to caller + except (KeyError, AttributeError, AirflowException) as e: + output_json = BaseSerialization.serialize(e, use_pydantic_models=True) + log.debug( + "Sending exception response: %s", json.dumps(output_json) if output_json is not None else None + ) + except Exception: + raise error_response( + f"Error executing method '{request_obj.method}'.", status.HTTP_500_INTERNAL_SERVER_ERROR + ) + response = json.dumps(output_json) if output_json is not None else None + return Response(response=response, headers={"Content-Type": "application/json"}) + except HTTPException as e: + return e.to_response() # type: ignore[attr-defined] -@rpc_api_router.post( - "/rpcapi", - dependencies=[Depends(json_request_headers), Depends(jwt_token_authorization), Depends(json_rpc_version)], - responses=create_openapi_http_exception_doc( - [ - status.HTTP_400_BAD_REQUEST, - status.HTTP_403_FORBIDDEN, - status.HTTP_500_INTERNAL_SERVER_ERROR, - ] - ), -) -def rpcapi(body: JsonRpcRequest) -> Any | None: - """Handle Edge Worker API calls as JSON-RPC.""" - log.debug("Got request for %s", body.method) - methods_map = _initialize_method_map() - if body.method not in methods_map: - raise error_response(f"Unrecognized method: {body.method}.", status.HTTP_400_BAD_REQUEST) - handler = methods_map[body.method] - params = {} - try: - if body.params: - params = BaseSerialization.deserialize(body.params, use_pydantic_models=True) - except Exception: - raise error_response("Error deserializing parameters.", status.HTTP_400_BAD_REQUEST) +@provide_session +def register_v2(worker_name: str, body: dict[str, Any], session=NEW_SESSION) -> Any: + """Handle Edge Worker API `/edge_worker/v1/worker/{worker_name}` endpoint for Airflow 2.10.""" + from flask import request - log.debug("Calling method %s\nparams: %s", body.method, params) try: - # Session must be created there as it may be needed by serializer for lazy-loaded fields. - with create_session() as session: - output = handler(**params, session=session) - output_json = BaseSerialization.serialize(output, use_pydantic_models=True) - log.debug("Sending response: %s", json.dumps(output_json) if output_json is not None else None) - return output_json - # In case of AirflowException or other selective known types, transport the exception class back to caller - except (KeyError, AttributeError, AirflowException) as e: - exception_json = BaseSerialization.serialize(e, use_pydantic_models=True) - log.debug( - "Sending exception response: %s", json.dumps(output_json) if output_json is not None else None - ) - return exception_json - except Exception: - raise error_response( - f"Error executing method '{body.method}'.", status.HTTP_500_INTERNAL_SERVER_ERROR + auth = request.headers.get("Authorization", "") + jwt_token_authorization(request.path, auth) + request_obj = WorkerStateBody( + state=body["state"], jobs_active=0, queues=body["queues"], sysinfo=body["sysinfo"] ) + return register(worker_name, request_obj, session) + except HTTPException as e: + return e.to_response() # type: ignore[attr-defined] -def edge_worker_api_v2(body: dict[str, Any]) -> APIResponse: - """Handle Edge Worker API `/edge_worker/v1/rpcapi` endpoint for Airflow 2.10.""" - # Note: Except the method map this _was_ a 100% copy of internal API module - # airflow.api_internal.endpoints.rpc_api_endpoint.internal_airflow_api() - # As of rework for FastAPI in Airflow 3.0, this is updated and to be removed in future. - from flask import Response, request +@provide_session +def set_state_v2(worker_name: str, body: dict[str, Any], session=NEW_SESSION) -> Any: + """Handle Edge Worker API `/edge_worker/v1/worker/{worker_name}` endpoint for Airflow 2.10.""" + from flask import request try: - json_request_headers( - content_type=request.headers.get("Content-Type", ""), accept=request.headers.get("Accept", "") - ) - auth = request.headers.get("Authorization", "") - json_rpc = body.get("jsonrpc", "") - method_name = body.get("method", "") - request_obj = JsonRpcRequest(method=method_name, jsonrpc=json_rpc, params=body.get("params")) - jwt_token_authorization(request_obj, auth) - - json_rpc_version(request_obj) - - output_json = rpcapi(request_obj) - response = json.dumps(output_json) if output_json is not None else None - return Response(response=response, headers={"Content-Type": "application/json"}) + jwt_token_authorization(request.path, auth) + request_obj = WorkerStateBody( + state=body["state"], + jobs_active=body["jobs_active"], + queues=body["queues"], + sysinfo=body["sysinfo"], + ) + return set_state(worker_name, request_obj, session) except HTTPException as e: return e.to_response() # type: ignore[attr-defined] diff --git a/providers/src/airflow/providers/edge/worker_api/routes/worker.py b/providers/src/airflow/providers/edge/worker_api/routes/worker.py new file mode 100644 index 0000000000000..369ace0d2df25 --- /dev/null +++ b/providers/src/airflow/providers/edge/worker_api/routes/worker.py @@ -0,0 +1,178 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import json +from datetime import datetime +from typing import Annotated + +from sqlalchemy import select + +from airflow.providers.edge.models.edge_worker import EdgeWorkerModel, set_metrics +from airflow.providers.edge.worker_api.auth import jwt_token_authorization_rest +from airflow.providers.edge.worker_api.datamodels import ( + WorkerQueueUpdateBody, # noqa: TC001 + WorkerStateBody, # noqa: TC001 +) +from airflow.providers.edge.worker_api.routes._v2_compat import ( + AirflowRouter, + Body, + Depends, + HTTPException, + Path, + SessionDep, + create_openapi_http_exception_doc, + status, +) +from airflow.stats import Stats +from airflow.utils import timezone + +worker_router = AirflowRouter( + tags=["Worker"], + prefix="/worker", + responses=create_openapi_http_exception_doc( + [ + status.HTTP_400_BAD_REQUEST, + status.HTTP_403_FORBIDDEN, + ] + ), +) + + +def _assert_version(sysinfo: dict[str, str | int]) -> None: + """Check if the Edge Worker version matches the central API site.""" + from airflow import __version__ as airflow_version + from airflow.providers.edge import __version__ as edge_provider_version + + # Note: In future, more stable versions we might be more liberate, for the + # moment we require exact version match for Edge Worker and core version + if "airflow_version" in sysinfo: + airflow_on_worker = sysinfo["airflow_version"] + if airflow_on_worker != airflow_version: + raise HTTPException( + status.HTTP_400_BAD_REQUEST, + f"Edge Worker runs on Airflow {airflow_on_worker} " + f"and the core runs on {airflow_version}. Rejecting access due to difference.", + ) + else: + raise HTTPException( + status.HTTP_400_BAD_REQUEST, "Edge Worker does not specify the version it is running on." + ) + + if "edge_provider_version" in sysinfo: + provider_on_worker = sysinfo["edge_provider_version"] + if provider_on_worker != edge_provider_version: + raise HTTPException( + status.HTTP_400_BAD_REQUEST, + f"Edge Worker runs on Edge Provider {provider_on_worker} " + f"and the core runs on {edge_provider_version}. Rejecting access due to difference.", + ) + else: + raise HTTPException( + status.HTTP_400_BAD_REQUEST, "Edge Worker does not specify the provider version it is running on." + ) + + +_worker_name_doc = Path(title="Worker Name", description="Hostname or instance name of the worker") +_worker_state_doc = Body( + title="Worker State", + description="State of the worker with details", + examples=[ + { + "state": "running", + "jobs_active": 3, + "queues": ["large_node", "wisconsin_site"], + "sysinfo": { + "concurrency": 4, + "airflow_version": "2.10.0", + "edge_provider_version": "1.0.0", + }, + } + ], +) +_worker_queue_doc = Body( + title="Changes in worker queues", + description="Changes to be applied to current queues of worker", + examples=[{"new_queues": ["new_queue"], "remove_queues": ["old_queue"]}], +) + + +@worker_router.post("/{worker_name}", dependencies=[Depends(jwt_token_authorization_rest)]) +def register( + worker_name: Annotated[str, _worker_name_doc], + body: Annotated[WorkerStateBody, _worker_state_doc], + session: SessionDep, +) -> datetime: + """Register a new worker to the backend.""" + _assert_version(body.sysinfo) + query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name == worker_name) + worker: EdgeWorkerModel = session.scalar(query) + if not worker: + worker = EdgeWorkerModel(worker_name=worker_name, state=body.state, queues=body.queues) + worker.state = body.state + worker.queues = body.queues + worker.sysinfo = json.dumps(body.sysinfo) + worker.last_update = timezone.utcnow() + session.add(worker) + return worker.last_update + + +@worker_router.patch("/{worker_name}", dependencies=[Depends(jwt_token_authorization_rest)]) +def set_state( + worker_name: Annotated[str, _worker_name_doc], + body: Annotated[WorkerStateBody, _worker_state_doc], + session: SessionDep, +) -> list[str] | None: + """Set state of worker and returns the current assigned queues.""" + query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name == worker_name) + worker: EdgeWorkerModel = session.scalar(query) + worker.state = body.state + worker.jobs_active = body.jobs_active + worker.sysinfo = json.dumps(body.sysinfo) + worker.last_update = timezone.utcnow() + session.commit() + Stats.incr(f"edge_worker.heartbeat_count.{worker_name}", 1, 1) + Stats.incr("edge_worker.heartbeat_count", 1, 1, tags={"worker_name": worker_name}) + set_metrics( + worker_name=worker_name, + state=body.state, + jobs_active=body.jobs_active, + concurrency=int(body.sysinfo.get("concurrency", -1)), + free_concurrency=int(body.sysinfo["free_concurrency"]), + queues=worker.queues, + ) + _assert_version(body.sysinfo) # Exception only after worker state is in the DB + return worker.queues + + +@worker_router.patch( + "/queues/{worker_name}", + dependencies=[Depends(jwt_token_authorization_rest)], +) +def update_queues( + worker_name: Annotated[str, _worker_name_doc], + body: Annotated[WorkerQueueUpdateBody, _worker_queue_doc], + session: SessionDep, +) -> None: + query = select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name == worker_name) + worker: EdgeWorkerModel = session.scalar(query) + if body.new_queues: + worker.add_queues(body.new_queues) + if body.remove_queues: + worker.remove_queues(body.remove_queues) + session.add(worker) diff --git a/providers/tests/edge/cli/test_edge_command.py b/providers/tests/edge/cli/test_edge_command.py index f6612b1a99a51..3304831064abd 100644 --- a/providers/tests/edge/cli/test_edge_command.py +++ b/providers/tests/edge/cli/test_edge_command.py @@ -29,7 +29,7 @@ from airflow.exceptions import AirflowException from airflow.providers.edge.cli.edge_command import _EdgeWorkerCli, _Job, _write_pid_to_pidfile from airflow.providers.edge.models.edge_job import EdgeJob -from airflow.providers.edge.models.edge_worker import EdgeWorker, EdgeWorkerState, EdgeWorkerVersionException +from airflow.providers.edge.models.edge_worker import EdgeWorkerState, EdgeWorkerVersionException from airflow.utils.state import TaskInstanceState from tests_common.test_utils.config import conf_vars @@ -74,8 +74,16 @@ def test_write_pid_to_pidfile_created_by_crashed_instance(tmp_path): assert str(os.getpid()) == pid_file_path.read_text().strip() -# Ignore the following error for mocking -# mypy: disable-error-code="attr-defined" +class _MockPopen(Popen): + def __init__(self, returncode=None): + self.generated_returncode = None + + def poll(self): + pass + + @property + def returncode(self): + return self.generated_returncode class TestEdgeWorkerCli: @@ -84,19 +92,6 @@ def dummy_joblist(self, tmp_path: Path) -> list[_Job]: logfile = tmp_path / "file.log" logfile.touch() - class MockPopen(Popen): - generated_returncode = None - - def __init__(self): - pass - - def poll(self): - pass - - @property - def returncode(self): - return self.generated_returncode - return [ _Job( edge_job=EdgeJob( @@ -113,7 +108,7 @@ def returncode(self): edge_worker=None, last_update=None, ), - process=MockPopen(), + process=_MockPopen(), logfile=logfile, logsize=0, ), @@ -168,7 +163,7 @@ def test_fetch_job( logfile_path_call_count, set_state_call_count = expected_calls mock_reserve_task.side_effect = [reserve_result] mock_popen.side_effect = ["dummy"] - with conf_vars({("edge", "api_url"): "https://mock.server"}): + with conf_vars({("edge", "api_url"): "https://invalid-api-test-endpoint"}): got_job = worker_with_job.fetch_job() mock_reserve_task.assert_called_once() assert got_job == fetch_result @@ -176,9 +171,8 @@ def test_fetch_job( assert mock_set_state.call_count == set_state_call_count def test_check_running_jobs_running(self, worker_with_job: _EdgeWorkerCli): - worker_with_job.jobs[0].process.generated_returncode = None assert worker_with_job.free_concurrency == worker_with_job.concurrency - with conf_vars({("edge", "api_url"): "https://mock.server"}): + with conf_vars({("edge", "api_url"): "https://invalid-api-test-endpoint"}): worker_with_job.check_running_jobs() assert len(worker_with_job.jobs) == 1 assert ( @@ -189,8 +183,8 @@ def test_check_running_jobs_running(self, worker_with_job: _EdgeWorkerCli): @patch("airflow.providers.edge.models.edge_job.EdgeJob.set_state") def test_check_running_jobs_success(self, mock_set_state, worker_with_job: _EdgeWorkerCli): job = worker_with_job.jobs[0] - job.process.generated_returncode = 0 - with conf_vars({("edge", "api_url"): "https://mock.server"}): + job.process.generated_returncode = 0 # type: ignore[attr-defined] + with conf_vars({("edge", "api_url"): "https://invalid-api-test-endpoint"}): worker_with_job.check_running_jobs() assert len(worker_with_job.jobs) == 0 mock_set_state.assert_called_once_with(job.edge_job.key, TaskInstanceState.SUCCESS) @@ -199,8 +193,8 @@ def test_check_running_jobs_success(self, mock_set_state, worker_with_job: _Edge @patch("airflow.providers.edge.models.edge_job.EdgeJob.set_state") def test_check_running_jobs_failed(self, mock_set_state, worker_with_job: _EdgeWorkerCli): job = worker_with_job.jobs[0] - job.process.generated_returncode = 42 - with conf_vars({("edge", "api_url"): "https://mock.server"}): + job.process.generated_returncode = 42 # type: ignore[attr-defined] + with conf_vars({("edge", "api_url"): "https://invalid-api-test-endpoint"}): worker_with_job.check_running_jobs() assert len(worker_with_job.jobs) == 0 mock_set_state.assert_called_once_with(job.edge_job.key, TaskInstanceState.FAILED) @@ -210,10 +204,12 @@ def test_check_running_jobs_failed(self, mock_set_state, worker_with_job: _EdgeW @patch("airflow.providers.edge.models.edge_logs.EdgeLogs.push_logs") def test_check_running_jobs_log_push(self, mock_push_logs, worker_with_job: _EdgeWorkerCli): job = worker_with_job.jobs[0] - job.process.generated_returncode = None job.logfile.write_text("some log content") with conf_vars( - {("edge", "api_url"): "https://mock.server", ("edge", "push_log_chunk_size"): "524288"} + { + ("edge", "api_url"): "https://invalid-api-test-endpoint", + ("edge", "push_log_chunk_size"): "524288", + } ): worker_with_job.check_running_jobs() assert len(worker_with_job.jobs) == 1 @@ -225,12 +221,14 @@ def test_check_running_jobs_log_push(self, mock_push_logs, worker_with_job: _Edg @patch("airflow.providers.edge.models.edge_logs.EdgeLogs.push_logs") def test_check_running_jobs_log_push_increment(self, mock_push_logs, worker_with_job: _EdgeWorkerCli): job = worker_with_job.jobs[0] - job.process.generated_returncode = None job.logfile.write_text("hello ") job.logsize = job.logfile.stat().st_size job.logfile.write_text("hello world") with conf_vars( - {("edge", "api_url"): "https://mock.server", ("edge", "push_log_chunk_size"): "524288"} + { + ("edge", "api_url"): "https://invalid-api-test-endpoint", + ("edge", "push_log_chunk_size"): "524288", + } ): worker_with_job.check_running_jobs() assert len(worker_with_job.jobs) == 1 @@ -242,9 +240,10 @@ def test_check_running_jobs_log_push_increment(self, mock_push_logs, worker_with @patch("airflow.providers.edge.models.edge_logs.EdgeLogs.push_logs") def test_check_running_jobs_log_push_chunks(self, mock_push_logs, worker_with_job: _EdgeWorkerCli): job = worker_with_job.jobs[0] - job.process.generated_returncode = None job.logfile.write_bytes("log1log2ülog3".encode("latin-1")) - with conf_vars({("edge", "api_url"): "https://mock.server", ("edge", "push_log_chunk_size"): "4"}): + with conf_vars( + {("edge", "api_url"): "https://invalid-api-test-endpoint", ("edge", "push_log_chunk_size"): "4"} + ): worker_with_job.check_running_jobs() assert len(worker_with_job.jobs) == 1 calls = mock_push_logs.call_args_list @@ -262,13 +261,13 @@ def test_check_running_jobs_log_push_chunks(self, mock_push_logs, worker_with_jo pytest.param(False, False, EdgeWorkerState.IDLE, id="idle"), ], ) - @patch("airflow.providers.edge.models.edge_worker.EdgeWorker.set_state") + @patch("airflow.providers.edge.cli.edge_command.worker_set_state") def test_heartbeat(self, mock_set_state, drain, jobs, expected_state, worker_with_job: _EdgeWorkerCli): if not jobs: worker_with_job.jobs = [] _EdgeWorkerCli.drain = drain mock_set_state.return_value = ["queue1", "queue2"] - with conf_vars({("edge", "api_url"): "https://mock.server"}): + with conf_vars({("edge", "api_url"): "https://invalid-api-test-endpoint"}): worker_with_job.heartbeat() assert mock_set_state.call_args.args[1] == expected_state queue_list = worker_with_job.queues or [] @@ -276,13 +275,13 @@ def test_heartbeat(self, mock_set_state, drain, jobs, expected_state, worker_wit assert "queue1" in (queue_list) assert "queue2" in (queue_list) - @patch("airflow.providers.edge.models.edge_worker.EdgeWorker.set_state") + @patch("airflow.providers.edge.cli.edge_command.worker_set_state") def test_version_mismatch(self, mock_set_state, worker_with_job): mock_set_state.side_effect = EdgeWorkerVersionException("") worker_with_job.heartbeat() assert worker_with_job.drain - @patch("airflow.providers.edge.models.edge_worker.EdgeWorker.register_worker") + @patch("airflow.providers.edge.cli.edge_command.worker_register") def test_start_missing_apiserver(self, mock_register_worker, worker_with_job: _EdgeWorkerCli): mock_register_worker.side_effect = AirflowException( "Something with 404:NOT FOUND means API is not active" @@ -290,42 +289,28 @@ def test_start_missing_apiserver(self, mock_register_worker, worker_with_job: _E with pytest.raises(SystemExit, match=r"API endpoint is not ready"): worker_with_job.start() - @patch("airflow.providers.edge.models.edge_worker.EdgeWorker.register_worker") + @patch("airflow.providers.edge.cli.edge_command.worker_register") def test_start_server_error(self, mock_register_worker, worker_with_job: _EdgeWorkerCli): mock_register_worker.side_effect = AirflowException("Something other error not FourhundretFour") with pytest.raises(SystemExit, match=r"Something other"): worker_with_job.start() - @patch("airflow.providers.edge.models.edge_worker.EdgeWorker.register_worker") + @patch("airflow.providers.edge.cli.edge_command.worker_register") @patch("airflow.providers.edge.cli.edge_command._EdgeWorkerCli.loop") - @patch("airflow.providers.edge.models.edge_worker.EdgeWorker.set_state") + @patch("airflow.providers.edge.cli.edge_command.worker_set_state") def test_start_and_run_one( - self, mock_set_state, mock_loop, mock_register_worker, worker_with_job: _EdgeWorkerCli + self, mock_set_state, mock_loop, mock_register, worker_with_job: _EdgeWorkerCli ): - mock_register_worker.side_effect = [ - EdgeWorker( - worker_name="test", - state=EdgeWorkerState.STARTING, - queues=None, - first_online=datetime.now(), - last_update=datetime.now(), - jobs_active=0, - jobs_taken=0, - jobs_success=0, - jobs_failed=0, - sysinfo="", - ) - ] - def stop_running(): _EdgeWorkerCli.drain = True worker_with_job.jobs = [] mock_loop.side_effect = stop_running + mock_register.side_effect = [datetime.now()] worker_with_job.start() - mock_register_worker.assert_called_once() + mock_register.assert_called_once() mock_loop.assert_called_once() mock_set_state.assert_called_once() diff --git a/providers/tests/edge/models/test_edge_worker.py b/providers/tests/edge/worker_api/routes/test_worker.py similarity index 67% rename from providers/tests/edge/models/test_edge_worker.py rename to providers/tests/edge/worker_api/routes/test_worker.py index 20e394ffd5767..e05a94c5f8719 100644 --- a/providers/tests/edge/models/test_edge_worker.py +++ b/providers/tests/edge/worker_api/routes/test_worker.py @@ -22,11 +22,14 @@ import pytest from airflow.providers.edge.cli.edge_command import _EdgeWorkerCli -from airflow.providers.edge.models.edge_worker import ( - EdgeWorker, - EdgeWorkerModel, - EdgeWorkerState, - EdgeWorkerVersionException, +from airflow.providers.edge.models.edge_worker import EdgeWorkerModel, EdgeWorkerState +from airflow.providers.edge.worker_api.datamodels import WorkerQueueUpdateBody, WorkerStateBody +from airflow.providers.edge.worker_api.routes._v2_compat import HTTPException +from airflow.providers.edge.worker_api.routes.worker import ( + _assert_version, + register, + set_state, + update_queues, ) from airflow.utils import timezone @@ -36,7 +39,7 @@ pytestmark = pytest.mark.db_test -class TestEdgeWorker: +class TestWorkerApiRoutes: @pytest.fixture def cli_worker(self, tmp_path: Path) -> _EdgeWorkerCli: test_worker = _EdgeWorkerCli(str(tmp_path / "dummy.pid"), "dummy", None, 8, 5, 5) @@ -50,28 +53,22 @@ def test_assert_version(self): from airflow import __version__ as airflow_version from airflow.providers.edge import __version__ as edge_provider_version - with pytest.raises(EdgeWorkerVersionException): - EdgeWorker.assert_version({}) + with pytest.raises(HTTPException): + _assert_version({}) - with pytest.raises(EdgeWorkerVersionException): - EdgeWorker.assert_version({"airflow_version": airflow_version}) + with pytest.raises(HTTPException): + _assert_version({"airflow_version": airflow_version}) - with pytest.raises(EdgeWorkerVersionException): - EdgeWorker.assert_version({"edge_provider_version": edge_provider_version}) + with pytest.raises(HTTPException): + _assert_version({"edge_provider_version": edge_provider_version}) - with pytest.raises(EdgeWorkerVersionException): - EdgeWorker.assert_version( - {"airflow_version": "1.2.3", "edge_provider_version": edge_provider_version} - ) + with pytest.raises(HTTPException): + _assert_version({"airflow_version": "1.2.3", "edge_provider_version": edge_provider_version}) - with pytest.raises(EdgeWorkerVersionException): - EdgeWorker.assert_version( - {"airflow_version": airflow_version, "edge_provider_version": "2023.10.07"} - ) + with pytest.raises(HTTPException): + _assert_version({"airflow_version": airflow_version, "edge_provider_version": "2023.10.07"}) - EdgeWorker.assert_version( - {"airflow_version": airflow_version, "edge_provider_version": edge_provider_version} - ) + _assert_version({"airflow_version": airflow_version, "edge_provider_version": edge_provider_version}) @pytest.mark.parametrize( "input_queues", @@ -80,12 +77,15 @@ def test_assert_version(self): pytest.param(["default", "default2"], id="with-queues"), ], ) - def test_register_worker( - self, session: Session, input_queues: list[str] | None, cli_worker: _EdgeWorkerCli - ): - EdgeWorker.register_worker( - "test_worker", EdgeWorkerState.STARTING, queues=input_queues, sysinfo=cli_worker._get_sysinfo() + def test_register(self, session: Session, input_queues: list[str] | None, cli_worker: _EdgeWorkerCli): + body = WorkerStateBody( + state=EdgeWorkerState.STARTING, + jobs_active=0, + queues=input_queues, + sysinfo=cli_worker._get_sysinfo(), ) + register("test_worker", body, session) + session.commit() worker: list[EdgeWorkerModel] = session.query(EdgeWorkerModel).all() assert len(worker) == 1 @@ -106,9 +106,13 @@ def test_set_state(self, session: Session, cli_worker: _EdgeWorkerCli): session.add(rwm) session.commit() - return_queues = EdgeWorker.set_state( - "test2_worker", EdgeWorkerState.RUNNING, 1, cli_worker._get_sysinfo() + body = WorkerStateBody( + state=EdgeWorkerState.RUNNING, + jobs_active=1, + queues=["default2"], + sysinfo=cli_worker._get_sysinfo(), ) + return_queues = set_state("test2_worker", body, session) worker: list[EdgeWorkerModel] = session.query(EdgeWorkerModel).all() assert len(worker) == 1 @@ -127,13 +131,12 @@ def test_set_state(self, session: Session, cli_worker: _EdgeWorkerCli): pytest.param(["init"], None, ["init"], id="check-duplicated"), ], ) - def test_add_and_remove_queues( + def test_update_queues( self, session: Session, add_queues: list[str] | None, remove_queues: list[str] | None, expected_queues: list[str], - cli_worker: _EdgeWorkerCli, ): rwm = EdgeWorkerModel( worker_name="test2_worker", @@ -143,7 +146,8 @@ def test_add_and_remove_queues( ) session.add(rwm) session.commit() - EdgeWorker.add_and_remove_queues("test2_worker", add_queues, remove_queues, session) + body = WorkerQueueUpdateBody(new_queues=add_queues, remove_queues=remove_queues) + update_queues("test2_worker", body, session) worker: list[EdgeWorkerModel] = session.query(EdgeWorkerModel).all() assert len(worker) == 1 assert worker[0].worker_name == "test2_worker"