Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion providers/edge3/docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,9 @@ PIP package Version required
``apache-airflow`` ``>=3.0.0,!=3.1.0``
``apache-airflow-providers-common-compat`` ``>=1.10.1``
``pydantic`` ``>=2.11.0``
``retryhttp`` ``>=1.2.0,!=1.3.0``
``retryhttp`` ``>=1.4.0``
``aiofiles`` ``>=23.2.0``
``aiohttp`` ``>=3.9.2``
========================================== ===================

Cross provider package dependencies
Expand Down
4 changes: 3 additions & 1 deletion providers/edge3/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ dependencies = [
"apache-airflow>=3.0.0,!=3.1.0",
"apache-airflow-providers-common-compat>=1.10.1", # use next version
"pydantic>=2.11.0",
"retryhttp>=1.2.0,!=1.3.0",
"retryhttp>=1.4.0",
"aiofiles>=23.2.0",
"aiohttp>=3.9.2",
]

[dependency-groups]
Expand Down
56 changes: 29 additions & 27 deletions providers/edge3/src/airflow/providers/edge3/cli/api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
# under the License.
from __future__ import annotations

import json
import logging
import os
from datetime import datetime
Expand All @@ -26,7 +25,7 @@
from typing import TYPE_CHECKING, Any
from urllib.parse import quote, urljoin

import requests
from aiohttp import ClientConnectionError, ClientResponseError, ServerTimeoutError, request
from retryhttp import retry, wait_retry_after
from tenacity import before_sleep_log, wait_random_exponential

Expand All @@ -44,11 +43,11 @@
WorkerSetStateReturn,
WorkerStateBody,
)
from airflow.utils.state import TaskInstanceState # noqa: TC001

if TYPE_CHECKING:
from airflow.models.taskinstancekey import TaskInstanceKey
from airflow.providers.edge3.models.edge_worker import EdgeWorkerState
from airflow.utils.state import TaskInstanceState

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -92,39 +91,42 @@ def jwt_generator() -> JWTGenerator:
wait_timeouts=_default_wait,
wait_rate_limited=wait_retry_after(fallback=_default_wait), # No infinite timeout on HTTP 429
before_sleep=before_sleep_log(logger, logging.WARNING),
network_errors=ClientConnectionError,
timeouts=ServerTimeoutError,
)
def _make_generic_request(method: str, rest_path: str, data: str | None = None) -> Any:
async def _make_generic_request(method: str, rest_path: str, data: str | None = None) -> Any:
authorization = jwt_generator().generate({"method": rest_path})
api_url = conf.get("edge", "api_url")
content_type = {"Content-Type": "application/json"} if data else {}
headers = {
"Content-Type": "application/json",
**content_type,
"Accept": "application/json",
"Authorization": authorization,
}
api_endpoint = urljoin(api_url, rest_path)
response = requests.request(method, url=api_endpoint, data=data, headers=headers)
response.raise_for_status()
if response.status_code == HTTPStatus.NO_CONTENT:
return None
return json.loads(response.content)
async with request(method, url=api_endpoint, data=data, headers=headers) as response:
response.raise_for_status()
if response.status == HTTPStatus.NO_CONTENT:
return None
return await response.json()


def worker_register(
async def worker_register(
hostname: str, state: EdgeWorkerState, queues: list[str] | None, sysinfo: dict
) -> WorkerRegistrationReturn:
"""Register worker with the Edge API."""
try:
result = _make_generic_request(
result = await _make_generic_request(
"POST",
f"worker/{quote(hostname)}",
WorkerStateBody(state=state, jobs_active=0, queues=queues, sysinfo=sysinfo).model_dump_json(
exclude_unset=True
),
)
except requests.HTTPError as e:
if e.response.status_code == 400:
except ClientResponseError as e:
if e.status == HTTPStatus.BAD_REQUEST:
raise EdgeWorkerVersionException(str(e))
if e.response.status_code == 409:
if e.status == HTTPStatus.CONFLICT:
raise EdgeWorkerDuplicateException(
f"A worker with the name '{hostname}' is already active. "
"Please ensure worker names are unique, or stop the existing worker before starting a new one."
Expand All @@ -133,7 +135,7 @@ def worker_register(
return WorkerRegistrationReturn(**result)


def worker_set_state(
async def worker_set_state(
hostname: str,
state: EdgeWorkerState,
jobs_active: int,
Expand All @@ -143,7 +145,7 @@ def worker_set_state(
) -> WorkerSetStateReturn:
"""Update the state of the worker in the central site and thereby implicitly heartbeat."""
try:
result = _make_generic_request(
result = await _make_generic_request(
"PATCH",
f"worker/{quote(hostname)}",
WorkerStateBody(
Expand All @@ -154,16 +156,16 @@ def worker_set_state(
maintenance_comments=maintenance_comments,
).model_dump_json(exclude_unset=True),
)
except requests.HTTPError as e:
if e.response.status_code == 400:
except ClientResponseError as e:
if e.status == HTTPStatus.BAD_REQUEST:
raise EdgeWorkerVersionException(str(e))
raise e
return WorkerSetStateReturn(**result)


def jobs_fetch(hostname: str, queues: list[str] | None, free_concurrency: int) -> EdgeJobFetched | None:
async def jobs_fetch(hostname: str, queues: list[str] | None, free_concurrency: int) -> EdgeJobFetched | None:
"""Fetch a job to execute on the edge worker."""
result = _make_generic_request(
result = await _make_generic_request(
"POST",
f"jobs/fetch/{quote(hostname)}",
WorkerQueuesBody(queues=queues, free_concurrency=free_concurrency).model_dump_json(
Expand All @@ -175,31 +177,31 @@ def jobs_fetch(hostname: str, queues: list[str] | None, free_concurrency: int) -
return None


def jobs_set_state(key: TaskInstanceKey, state: TaskInstanceState) -> None:
async def jobs_set_state(key: TaskInstanceKey, state: TaskInstanceState) -> None:
"""Set the state of a job."""
_make_generic_request(
await _make_generic_request(
"PATCH",
f"jobs/state/{key.dag_id}/{key.task_id}/{key.run_id}/{key.try_number}/{key.map_index}/{state}",
)


def logs_logfile_path(task: TaskInstanceKey) -> Path:
async def logs_logfile_path(task: TaskInstanceKey) -> Path:
"""Elaborate the path and filename to expect from task execution."""
result = _make_generic_request(
result = await _make_generic_request(
"GET",
f"logs/logfile_path/{task.dag_id}/{task.task_id}/{task.run_id}/{task.try_number}/{task.map_index}",
)
base_log_folder = conf.get("logging", "base_log_folder", fallback="NOT AVAILABLE")
return Path(base_log_folder, result)


def logs_push(
async def logs_push(
task: TaskInstanceKey,
log_chunk_time: datetime,
log_chunk_data: str,
) -> None:
"""Push an incremental log chunk from Edge Worker to central site."""
_make_generic_request(
await _make_generic_request(
"POST",
f"logs/push/{task.dag_id}/{task.task_id}/{task.run_id}/{task.try_number}/{task.map_index}",
PushLogsBody(log_chunk_time=log_chunk_time, log_chunk_data=log_chunk_data).model_dump_json(
Expand Down
13 changes: 3 additions & 10 deletions providers/edge3/src/airflow/providers/edge3/cli/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from psutil import Popen

from airflow.providers.edge3.models.edge_worker import EdgeWorkerState
from airflow.providers.edge3.worker_api.datamodels import EdgeJobFetched

Expand Down Expand Up @@ -74,22 +72,17 @@ class Job:
"""Holds all information for a task/job to be executed as bundle."""

edge_job: EdgeJobFetched
process: Popen | Process
process: Process
logfile: Path
logsize: int
logsize: int = 0
"""Last size of log file, point of last chunk push."""

@property
def is_running(self) -> bool:
"""Check if the job is still running."""
if hasattr(self.process, "returncode") and hasattr(self.process, "poll"):
self.process.poll()
return self.process.returncode is None
return self.process.exitcode is None
return self.process.is_alive()

@property
def is_success(self) -> bool:
"""Check if the job was successful."""
if hasattr(self.process, "returncode"):
return self.process.returncode == 0
return self.process.exitcode == 0
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations

import asyncio
import json
import logging
import os
Expand All @@ -41,7 +42,6 @@
pid_file_path,
status_file_path,
)
from airflow.providers.edge3.cli.worker import SIG_STATUS, EdgeWorker
from airflow.providers.edge3.models.edge_worker import EdgeWorkerState
from airflow.utils import cli as cli_utils
from airflow.utils.net import getfqdn
Expand Down Expand Up @@ -88,6 +88,8 @@ def _launch_worker(args):
print(settings.HEADER)
print(EDGE_WORKER_HEADER)

from airflow.providers.edge3.cli.worker import EdgeWorker

edge_worker = EdgeWorker(
pid_file_path=pid_file_path(args.pid),
hostname=args.edge_hostname or getfqdn(),
Expand All @@ -97,7 +99,7 @@ def _launch_worker(args):
heartbeat_interval=conf.getint("edge", "heartbeat_interval"),
daemon=args.daemon,
)
edge_worker.start()
asyncio.run(edge_worker.start())


@cli_utils.action_cli(check_db=False)
Expand All @@ -120,6 +122,8 @@ def worker(args):
@providers_configuration_loaded
def status(args):
"""Check for Airflow Local Edge Worker status."""
from airflow.providers.edge3.cli.worker import SIG_STATUS

pid = get_pid(args.pid)

# Send Signal as notification to drop status JSON
Expand All @@ -146,6 +150,8 @@ def status(args):
@providers_configuration_loaded
def maintenance(args):
"""Set or Unset maintenance mode of local edge worker."""
from airflow.providers.edge3.cli.worker import SIG_STATUS

if args.maintenance == "on" and not args.comments:
logger.error("Comments are required when setting maintenance mode.")
sys.exit(4)
Expand Down
Loading
Loading