Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for adopting orphaned task instances #15

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 47 additions & 1 deletion airflow_aws_executors/batch_executor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""AWS Batch Executor. Each Airflow task gets deligated out to an AWS Batch Job"""
"""AWS Batch Executor. Each Airflow task gets delegated out to an AWS Batch Job"""

import time
from copy import deepcopy
Expand All @@ -7,6 +7,7 @@
import boto3
from airflow.configuration import conf
from airflow.executors.base_executor import BaseExecutor
from airflow.models import TaskInstance
from airflow.utils.module_loading import import_string
from airflow.utils.state import State
from marshmallow import EXCLUDE, Schema, ValidationError, fields, post_load
Expand Down Expand Up @@ -73,13 +74,15 @@ def __init__(self, *args, **kwargs):
self.active_workers: Optional[BatchJobCollection] = None
self.batch = None
self.submit_job_kwargs = None
self.adopt_task_instances = None

def start(self):
"""Initialize Boto3 Batch Client, and other internal variables"""
region = conf.get('batch', 'region')
self.active_workers = BatchJobCollection()
self.batch = boto3.client('batch', region_name=region)
self.submit_job_kwargs = self._load_submit_kwargs()
self.adopt_task_instances = conf.getboolean('batch', 'adopt_task_instances', fallback=False)

def sync(self):
"""Checks and update state on all running tasks"""
Expand Down Expand Up @@ -128,6 +131,9 @@ def execute_async(self, key: TaskInstanceKeyType, command: CommandType, queue=No
job_id = self._submit_job(key, command, queue, executor_config or {})
self.active_workers.add_job(job_id, key)

# Add batch job_id to executor event buffer, which gets saved in TaskInstance.external_executor_id
self.event_buffer[key] = (State.QUEUED, job_id)

def _submit_job(
self,
key: TaskInstanceKeyType,
Expand Down Expand Up @@ -184,14 +190,54 @@ def end(self, heartbeat_interval=10):
def terminate(self):
"""
Kill all Batch Jobs by calling Boto3's TerminateJob API.
Do not kill Batch Jobs if [batch].adopt_task_instances option is set to True
"""
if self.adopt_task_instances:
pass

for job_id in self.active_workers.get_all_jobs():
self.batch.terminate_job(
jobId=job_id,
reason='Airflow Executor received a SIGTERM'
)
self.end()

def try_adopt_task_instances(self, tis: List[TaskInstance]) -> List[TaskInstance]:
"""
If [batch].adopt_task_instances option is set to True, try to adopt running task instances that have been
abandoned by a SchedulerJob dying.
These tasks instances should have a corresponding AWS Batch Job which can be adopted by the unique job_id.

Anything that is not adopted will be cleared by the scheduler (and then become eligible for re-scheduling)

:return: any TaskInstances that were unable to be adopted
:rtype: list[airflow.models.TaskInstance]
"""
if not self.adopt_task_instances:
# Do not try to adopt task instances, return all orphaned tasks for clearing
return tis

adopted_tis: List[TaskInstance] = []
not_adopted_tis: List[TaskInstance] = []

for ti in tis:
if ti.external_executor_id is not None:
self.active_workers.add_job(ti.external_executor_id, ti.key)
adopted_tis.append(ti)
else:
not_adopted_tis.append(ti)

if adopted_tis:
tasks = [f'{task} in state {task.state}' for task in adopted_tis]
task_instance_str = '\n\t'.join(tasks)
self.log.info(
'Adopted the following %d tasks from a dead executor:\n\t%s',
len(adopted_tis),
task_instance_str,
)

return not_adopted_tis

@staticmethod
def _load_submit_kwargs() -> dict:
submit_kwargs = import_string(
Expand Down
47 changes: 47 additions & 0 deletions airflow_aws_executors/ecs_fargate_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import boto3
from airflow.configuration import conf
from airflow.executors.base_executor import BaseExecutor
from airflow.models import TaskInstance
from airflow.utils.module_loading import import_string
from airflow.utils.state import State
from marshmallow import EXCLUDE, Schema, ValidationError, fields, post_load
Expand Down Expand Up @@ -93,6 +94,7 @@ def __init__(self, *args, **kwargs):
self.pending_tasks: Optional[deque] = None
self.ecs = None
self.run_task_kwargs = None
self.adopt_task_instances = None

def start(self):
"""Initialize Boto3 ECS Client, and other internal variables"""
Expand All @@ -103,6 +105,7 @@ def start(self):
self.pending_tasks = deque()
self.ecs = boto3.client('ecs', region_name=region) # noqa
self.run_task_kwargs = self._load_run_kwargs()
self.adopt_task_instances = conf.getboolean('ecs_fargate', 'adopt_task_instances', fallback=False)

def sync(self):
self.sync_running_tasks()
Expand Down Expand Up @@ -207,6 +210,8 @@ def attempt_task_runs(self):
else:
task = run_task_response['tasks'][0]
self.active_workers.add_task(task, task_key, queue, cmd, exec_config)
# Add fargate task arn to executor event buffer, which gets saved in TaskInstance.external_executor_id
self.event_buffer[task_key] = (State.QUEUED, task.task_arn)
if failure_reasons:
self.log.debug('Pending tasks failed to launch for the following reasons: %s. Will retry later.',
dict(failure_reasons))
Expand Down Expand Up @@ -267,7 +272,11 @@ def end(self, heartbeat_interval=10):
def terminate(self):
"""
Kill all ECS processes by calling Boto3's StopTask API.
Do not kill ECS processes if [ecs_fargate].adopt_task_instances option is set to True
"""
if self.adopt_task_instances:
pass

for arn in self.active_workers.get_all_arns():
self.ecs.stop_task(
cluster=self.cluster,
Expand All @@ -276,6 +285,44 @@ def terminate(self):
)
self.end()

def try_adopt_task_instances(self, tis: List[TaskInstance]) -> List[TaskInstance]:
"""
If [ecs_fargate].adopt_task_instances option is set to True, try to adopt running task instances that have been
abandoned by a SchedulerJob dying.
These tasks instances should have a corresponding ECS process which can be adopted by the unique task arn.

Anything that is not adopted will be cleared by the scheduler (and then become eligible for re-scheduling)

:return: any TaskInstances that were unable to be adopted
:rtype: list[airflow.models.TaskInstance]
"""
if not self.adopt_task_instances:
# Do not try to adopt task instances, return all orphaned tasks for clearing
return tis

adopted_tis: List[TaskInstance] = []

task_arns = [ti.external_executor_id for ti in tis if ti.external_executor_id]
if task_arns:
task_descriptions = self.__describe_tasks(task_arns).get('tasks', [])

for task in task_descriptions:
ti = [ti for ti in tis if ti.external_executor_id == task.task_arn][0]
self.active_workers.add_task(task, ti.key, ti.queue, ti.command_as_list(), ti.executor_config)
adopted_tis.append(ti)

if adopted_tis:
tasks = [f'{task} in state {task.state}' for task in adopted_tis]
task_instance_str = '\n\t'.join(tasks)
self.log.info(
'Adopted the following %d tasks from a dead executor:\n\t%s',
len(adopted_tis),
task_instance_str,
)

not_adopted_tis = [ti for ti in tis if ti not in adopted_tis]
return not_adopted_tis

def _load_run_kwargs(self) -> dict:
run_kwargs = import_string(
conf.get(
Expand Down
12 changes: 12 additions & 0 deletions readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,12 @@ task = PythonOperator(
To change the parameters used to run a task in Batch, the user can overwrite the path to
specify another python dictionary. More documentation can be found in the `Extensibility` section below.
* **default**: airflow_aws_executors.conf.BATCH_SUBMIT_JOB_KWARGS
* `adopt_task_instances`
* **description**: Boolean flag. If set to True, the executor will try to adopt orphaned task instances from a
SchedulerJob shutdown event (for example when a scheduler container is re-deployed or terminated).
If set to False (default), the executor will terminate all active AWS Batch Jobs when the scheduler shuts down.
More documentation can be found in the [airflow docs](https://airflow.apache.org/docs/apache-airflow/stable/scheduler.html#scheduler-tuneables).
* **default**: False
#### ECS & FARGATE
`[ecs_fargate]`
* `region`
Expand Down Expand Up @@ -181,6 +187,12 @@ task = PythonOperator(
To change the parameters used to run a task in FARGATE or ECS, the user can overwrite the path to
specify another python dictionary. More documentation can be found in the `Extensibility` section below.
* **default**: airflow_aws_executors.conf.ECS_FARGATE_RUN_TASK_KWARGS
* `adopt_task_instances`
* **description**: Boolean flag. If set to True, the executor will try to adopt orphaned task instances from a
SchedulerJob shutdown event (for example when a scheduler container is re-deployed or terminated).
If set to False (default), the executor will terminate all active ECS Tasks when the scheduler shuts down.
More documentation can be found in the [airflow docs](https://airflow.apache.org/docs/apache-airflow/stable/scheduler.html#scheduler-tuneables).
* **default**: False


*NOTE: Modify airflow.cfg or export environmental variables. For example:*
Expand Down
55 changes: 50 additions & 5 deletions tests/test_batch_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from airflow_aws_executors.batch_executor import (
AwsBatchExecutor, BatchJobDetailSchema, BatchJob, BatchJobCollection
)
from airflow.models import TaskInstance
from airflow.utils.state import State

from .botocore_helper import get_botocore_model, assert_botocore_call
Expand Down Expand Up @@ -104,14 +105,17 @@ def test_execute(self):
# task is stored in active worker
self.assertEqual(1, len(self.executor.active_workers))

# job_id is stored in executor event buffer
self.assertEqual((State.QUEUED, 'ABC'), self.executor.event_buffer[airflow_key])

@mock.patch('airflow.executors.base_executor.BaseExecutor.fail')
@mock.patch('airflow.executors.base_executor.BaseExecutor.success')
def test_sync(self, success_mock, fail_mock):
"""Test synch from end-to-end. Mocks a successful job & makes sure it's removed"""
after_sync_reponse = self.__mock_sync()
after_sync_response = self.__mock_sync()

# sanity check that container's status code is mocked to success
loaded_batch_job = BatchJobDetailSchema().load(after_sync_reponse)
loaded_batch_job = BatchJobDetailSchema().load(after_sync_response)
self.assertEqual(State.SUCCESS, loaded_batch_job.get_job_state())

self.executor.sync()
Expand All @@ -130,11 +134,11 @@ def test_sync(self, success_mock, fail_mock):
@mock.patch('airflow.executors.base_executor.BaseExecutor.success')
def test_failed_sync(self, success_mock, fail_mock):
"""Test failure states"""
after_sync_reponse = self.__mock_sync()
after_sync_response = self.__mock_sync()

# set container's status code to failure & sanity-check
after_sync_reponse['status'] = 'FAILED'
self.assertEqual(State.FAILED, BatchJobDetailSchema().load(after_sync_reponse).get_job_state())
after_sync_response['status'] = 'FAILED'
self.assertEqual(State.FAILED, BatchJobDetailSchema().load(after_sync_response).get_job_state())
self.executor.sync()

# ensure that run_task is called correctly as defined by Botocore docs
Expand All @@ -158,6 +162,47 @@ def test_terminate(self):
self.assertTrue(self.executor.batch.terminate_job.called)
self.assert_botocore_call('TerminateJob', *self.executor.batch.terminate_job.call_args)

def test_terminate_with_task_adoption(self):
"""Test that executor does not shut down active Batch jobs when 'adopt_task_instances' is set to True"""
self.executor.adopt_task_instances = True
self.executor.terminate()

# jobs are not terminated
self.assertFalse(self.executor.batch.terminate_job.called)

def test_try_adopt_task_instances(self):
"""Test that executor can adopt orphaned task instances from a SchedulerJob shutdown event"""
self.executor.adopt_task_instances = True

orphaned_tasks = [
mock.Mock(TaskInstance),
mock.Mock(TaskInstance),
mock.Mock(TaskInstance),
]
orphaned_tasks[0].external_executor_id = None # One orphaned task has no external_executor_id
not_adopted_tasks = self.executor.try_adopt_task_instances(orphaned_tasks)

# adopted tasks are stored in active workers
self.assertEqual(len(orphaned_tasks) - 1, len(self.executor.active_workers))

# one task is unable to be adopted
self.assertEqual(1, len(not_adopted_tasks))

def test_try_adopt_task_instances_disabled(self):
"""Test that executor won't adopt orphaned task instances if 'adopt_task_instances' is set to False (default)"""
orphaned_tasks = [
mock.Mock(TaskInstance),
mock.Mock(TaskInstance),
mock.Mock(TaskInstance),
]
not_adopted_tasks = self.executor.try_adopt_task_instances(orphaned_tasks)

# no orphaned tasks are stored in active workers
self.assertEqual(0, len(self.executor.active_workers))

# all tasks are unable to be adopted
self.assertEqual(len(orphaned_tasks), len(not_adopted_tasks))

def test_end(self):
"""The end() function should call sync 3 times, and the task should fail on the 3rd call"""
sync_call_count = 0
Expand Down
Loading