Skip to content

Commit

Permalink
ECS Executor - add support to adopt orphaned tasks.
Browse files Browse the repository at this point in the history
  • Loading branch information
ferruzzi committed Mar 6, 2024
1 parent 35fef2b commit ee6f7ba
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 7 deletions.
9 changes: 9 additions & 0 deletions airflow/executors/base_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,15 @@ def success(self, key: TaskInstanceKey, info=None) -> None:
"""
self.change_state(key, TaskInstanceState.SUCCESS, info)

def queued(self, key: TaskInstanceKey, info=None) -> None:
"""
Set queued state for the event.
:param info: Executor information for the task instance
:param key: Unique key for the task instance
"""
self.change_state(key, TaskInstanceState.QUEUED, info)

def get_event_buffer(self, dag_ids=None) -> dict[TaskInstanceKey, EventBufferValueType]:
"""
Return and flush the event buffer.
Expand Down
43 changes: 41 additions & 2 deletions airflow/providers/amazon/aws/executors/ecs/ecs_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import time
from collections import defaultdict, deque
from copy import deepcopy
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Sequence

from botocore.exceptions import ClientError, NoCredentialsError

Expand All @@ -47,12 +47,13 @@
exponential_backoff_retry,
)
from airflow.providers.amazon.aws.hooks.ecs import EcsHook
from airflow.stats import Stats
from airflow.utils import timezone
from airflow.utils.helpers import merge_dicts
from airflow.utils.state import State

if TYPE_CHECKING:
from airflow.models.taskinstance import TaskInstanceKey
from airflow.models.taskinstance import TaskInstance, TaskInstanceKey
from airflow.providers.amazon.aws.executors.ecs.utils import (
CommandType,
ExecutorConfigType,
Expand Down Expand Up @@ -240,6 +241,7 @@ def __update_running_task(self, task):
# Get state of current task.
task_state = task.get_task_state()
task_key = self.active_workers.arn_to_key[task.task_arn]

# Mark finished tasks as either a success/failure.
if task_state == State.FAILED:
self.fail(task_key)
Expand Down Expand Up @@ -394,6 +396,7 @@ def attempt_task_runs(self):
else:
task = run_task_response["tasks"][0]
self.active_workers.add_task(task, task_key, queue, cmd, exec_config, attempt_number)
self.queued(task_key, task.task_arn)
if failure_reasons:
self.log.error(
"Pending ECS tasks failed to launch for the following reasons: %s. Retrying later.",
Expand Down Expand Up @@ -494,3 +497,39 @@ def get_container(self, container_list):
'container "name" must be provided in "containerOverrides" configuration'
)
raise KeyError(f"No such container found by container name: {self.container_name}")

def try_adopt_task_instances(self, tis: Sequence[TaskInstance]) -> Sequence[TaskInstance]:
"""
Adopt task instances which have an external_executor_id (the ECS task ARN).
Anything that is not adopted will be cleared by the scheduler and becomes eligible for re-scheduling.
"""
with Stats.timer("ecs_executor.adopt_task_instances.duration"):
adopted_tis: list[TaskInstance] = []

if task_arns := [ti.external_executor_id for ti in tis if ti.external_executor_id]:
task_descriptions = self.__describe_tasks(task_arns).get("tasks", [])

for task in task_descriptions:
ti = [ti for ti in tis if ti.external_executor_id == task.task_arn][0]
self.active_workers.add_task(
task,
ti.key,
ti.queue,
ti.command_as_list(),
ti.executor_config,
ti.prev_attempted_tries,
)
adopted_tis.append(ti)

if adopted_tis:
tasks = [f"{task} in state {task.state}" for task in adopted_tis]
task_instance_str = "\n\t".join(tasks)
self.log.info(
"Adopted the following %d tasks from a dead executor:\n\t%s",
len(adopted_tis),
task_instance_str,
)

not_adopted_tis = [ti for ti in tis if ti not in adopted_tis]
return not_adopted_tis
12 changes: 7 additions & 5 deletions airflow/providers/amazon/aws/executors/ecs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,30 +78,30 @@ class RunTaskKwargsConfigKeys(BaseConfigKeys):
ASSIGN_PUBLIC_IP = "assign_public_ip"
CAPACITY_PROVIDER_STRATEGY = "capacity_provider_strategy"
CLUSTER = "cluster"
CONTAINER_NAME = "container_name"
LAUNCH_TYPE = "launch_type"
PLATFORM_VERSION = "platform_version"
SECURITY_GROUPS = "security_groups"
SUBNETS = "subnets"
TASK_DEFINITION = "task_definition"
CONTAINER_NAME = "container_name"


class AllEcsConfigKeys(RunTaskKwargsConfigKeys):
"""All keys loaded into the config which are related to the ECS Executor."""

MAX_RUN_TASK_ATTEMPTS = "max_run_task_attempts"
AWS_CONN_ID = "conn_id"
RUN_TASK_KWARGS = "run_task_kwargs"
REGION_NAME = "region_name"
CHECK_HEALTH_ON_STARTUP = "check_health_on_startup"
MAX_RUN_TASK_ATTEMPTS = "max_run_task_attempts"
REGION_NAME = "region_name"
RUN_TASK_KWARGS = "run_task_kwargs"


class EcsExecutorException(Exception):
"""Thrown when something unexpected has occurred within the ECS ecosystem."""


class EcsExecutorTask:
"""Data Transfer Object for an ECS Fargate Task."""
"""Data Transfer Object for an ECS Task."""

def __init__(
self,
Expand All @@ -111,13 +111,15 @@ def __init__(
containers: list[dict[str, Any]],
started_at: Any | None = None,
stopped_reason: str | None = None,
external_executor_id: str | None = None,
):
self.task_arn = task_arn
self.last_status = last_status
self.desired_status = desired_status
self.containers = containers
self.started_at = started_at
self.stopped_reason = stopped_reason
self.external_executor_id = external_executor_id

def get_task_state(self) -> str:
"""
Expand Down
40 changes: 40 additions & 0 deletions tests/providers/amazon/aws/executors/ecs/test_ecs_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

from airflow.exceptions import AirflowException
from airflow.executors.base_executor import BaseExecutor
from airflow.models import TaskInstance
from airflow.models.taskinstancekey import TaskInstanceKey
from airflow.providers.amazon.aws.executors.ecs import ecs_executor, ecs_executor_config
from airflow.providers.amazon.aws.executors.ecs.boto_schema import BotoTaskSchema
Expand Down Expand Up @@ -829,6 +830,45 @@ def test_update_running_tasks_failed(self, mock_executor, caplog):
"test failure" in caplog.messages[0]
)

def test_try_adopt_task_instances(self, mock_executor):
"""Test that executor can adopt orphaned task instances from a SchedulerJob shutdown event."""
mock_executor.ecs.describe_tasks.return_value = {
"tasks": [
{
"taskArn": "001",
"lastStatus": "RUNNING",
"desiredStatus": "RUNNING",
"containers": [{"name": "some-ecs-container"}],
},
{
"taskArn": "002",
"lastStatus": "RUNNING",
"desiredStatus": "RUNNING",
"containers": [{"name": "another-ecs-container"}],
},
],
"failures": [],
}

orphaned_tasks = [
mock.Mock(spec=TaskInstance),
mock.Mock(spec=TaskInstance),
mock.Mock(spec=TaskInstance),
]
orphaned_tasks[0].external_executor_id = "001" # Matches a running task_arn
orphaned_tasks[1].external_executor_id = "002" # Matches a running task_arn
orphaned_tasks[2].external_executor_id = None # One orphaned task has no external_executor_id
for task in orphaned_tasks:
task.prev_attempted_tries = 1

not_adopted_tasks = mock_executor.try_adopt_task_instances(orphaned_tasks)

mock_executor.ecs.describe_tasks.assert_called_once()
# Two of the three tasks should be adopted.
assert len(orphaned_tasks) - 1 == len(mock_executor.active_workers)
# The remaining one task is unable to be adopted.
assert 1 == len(not_adopted_tasks)


class TestEcsExecutorConfig:
@pytest.fixture
Expand Down

0 comments on commit ee6f7ba

Please sign in to comment.