Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions airflow/config_templates/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -1373,6 +1373,23 @@
type: int
example: ~
default: "2"
- name: task_track_started
description: |
Celery task will report its status as 'started' when the task is executed by a worker.
This is used in Airflow to keep track of the running tasks and if a Scheduler is restarted
or run in HA mode, it can adopt the orphan tasks launched by previous SchedulerJob.
version_added: 2.0.0
type: boolean
example: ~
default: "True"
- name: task_adoption_timeout
description: |
Time in seconds after which Adopted tasks are cleared by CeleryExecutor. This is helpful to clear
stalled tasks.
version_added: 2.0.0
type: int
example: ~
default: "600"
- name: celery_broker_transport_options
description: |
This section is for specifying options which can be passed to the
Expand Down
9 changes: 9 additions & 0 deletions airflow/config_templates/default_airflow.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -685,6 +685,15 @@ pool = prefork
# ``fetch_celery_task_state`` operations.
operation_timeout = 2

# Celery task will report its status as 'started' when the task is executed by a worker.
# This is used in Airflow to keep track of the running tasks and if a Scheduler is restarted
# or run in HA mode, it can adopt the orphan tasks launched by previous SchedulerJob.
task_track_started = True

# Time in seconds after which Adopted tasks are cleared by CeleryExecutor. This is helpful to clear
# stalled tasks.
task_adoption_timeout = 600

[celery_broker_transport_options]

# This section is for specifying options which can be passed to the
Expand Down
1 change: 1 addition & 0 deletions airflow/config_templates/default_celery.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def _broker_supports_visibility_timeout(url):
'task_acks_late': True,
'task_default_queue': conf.get('celery', 'DEFAULT_QUEUE'),
'task_default_exchange': conf.get('celery', 'DEFAULT_QUEUE'),
'task_track_started': conf.get('celery', 'task_track_started', fallback=True),
'broker_url': broker_url,
'broker_transport_options': broker_transport_options,
'result_backend': conf.get('celery', 'RESULT_BACKEND'),
Expand Down
14 changes: 14 additions & 0 deletions airflow/executors/base_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,20 @@ def terminate(self):
"""
raise NotImplementedError()

def try_adopt_task_instances(self, tis: List[TaskInstance]) -> List[TaskInstance]:
"""
Try to adopt running task instances that have been abandoned by a SchedulerJob dying.

Anything that is not adopted will be cleared by the scheduler (and then become eligible for
re-scheduling)

:return: any TaskInstances that were unable to be adopted
:rtype: list[airflow.models.TaskInstance]
"""
# By default, assume Executors cannot adopt tasks, so just say we failed to adopt anything.
# Subclasses can do better!
return tis

@staticmethod
def validate_command(command: List[str]) -> None:
"""Check if the command to execute is airflow command"""
Expand Down
155 changes: 134 additions & 21 deletions airflow/executors/celery_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,17 @@
For more information on how the CeleryExecutor works, take a look at the guide:
:ref:`executor:CeleryExecutor`
"""
import datetime
import logging
import math
import operator
import os
import subprocess
import time
import traceback
from collections import OrderedDict
from multiprocessing import Pool, cpu_count
from typing import Any, List, Mapping, MutableMapping, Optional, Set, Tuple, Union
from typing import Any, Dict, List, Mapping, MutableMapping, Optional, Set, Tuple, Union

from celery import Celery, Task, states as celery_states
from celery.backends.base import BaseKeyValueStoreBackend
Expand All @@ -39,11 +42,12 @@
from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.executors.base_executor import BaseExecutor, CommandType, EventBufferValueType
from airflow.models.taskinstance import SimpleTaskInstance, TaskInstanceKey
from airflow.models.taskinstance import SimpleTaskInstance, TaskInstance, TaskInstanceKey
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.net import get_hostname
from airflow.utils.state import State
from airflow.utils.timeout import timeout
from airflow.utils.timezone import utcnow

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -142,7 +146,11 @@ def __init__(self):
self._sync_parallelism = max(1, cpu_count() - 1)
self.bulk_state_fetcher = BulkStateFetcher(self._sync_parallelism)
self.tasks = {}
self.last_state = {}
# Mapping of tasks we've adopted, ordered by the earliest date they timeout
self.adopted_task_timeouts: Dict[TaskInstanceKey, datetime.datetime] = OrderedDict()
self.task_adoption_timeout = datetime.timedelta(
seconds=conf.getint('celery', 'task_adoption_timeout', fallback=600)
)

def start(self) -> None:
self.log.debug(
Expand Down Expand Up @@ -199,7 +207,14 @@ def _process_tasks(self, task_tuples_to_send: List[TaskInstanceInCelery]) -> Non
result.backend = cached_celery_backend
self.running.add(key)
self.tasks[key] = result
self.last_state[key] = celery_states.PENDING

# Store the Celery task_id in the event buffer. This will get "overwritten" if the task
# has another event, but that is fine, because the only other events are success/failed at
# which point we dont need the ID anymore anyway
self.event_buffer[key] = (State.QUEUED, result.task_id)

# If the task runs _really quickly_ we may already have a result!
self.update_task_state(key, result.state, getattr(result, 'info', None))

def _send_tasks_to_celery(self, task_tuples_to_send):
if len(task_tuples_to_send) == 1 or self._sync_parallelism == 1:
Expand All @@ -223,6 +238,48 @@ def sync(self) -> None:
return
self.update_all_task_states()

if self.adopted_task_timeouts:
self._check_for_stalled_adopted_tasks()

def _check_for_stalled_adopted_tasks(self):
"""
See if any of the tasks we adopted from another Executor run have not
progressed after the configured timeout.

If they haven't, they likely never made it to Celery, and we should
just resend them. We do that by clearing the state and letting the
normal scheduler loop deal with that
"""
now = utcnow()

timedout_keys = []
for key, stalled_after in self.adopted_task_timeouts.items():
if stalled_after > now:
# Since items are stored sorted, if we get to a stalled_after
# in the future then we can stop
break

# If the task gets updated to STARTED (which Celery does) or has
# already finished, then it will be removed from this list -- so
# the only time it's still in this list is when it a) never made it
# to celery in the first place (i.e. race condition somehwere in
# the dying executor) or b) a really long celery queue and it just
# hasn't started yet -- better cancel it and let the scheduler
# re-queue rather than have this task risk stalling for ever
timedout_keys.append(key)

if timedout_keys:
self.log.error(
"Adopted tasks were still pending after %s, assuming they never made it to celery and "
"clearing:\n\t%s",
self.task_adoption_timeout,
"\n\t".join([repr(x) for x in timedout_keys])
)
for key in timedout_keys:
self.event_buffer[key] = (State.FAILED, None)
del self.tasks[key]
del self.adopted_task_timeouts[key]

def update_all_task_states(self) -> None:
"""Updates states of the tasks."""

Expand All @@ -235,25 +292,25 @@ def update_all_task_states(self) -> None:
if state:
self.update_task_state(key, state, info)

def change_state(self, key: TaskInstanceKey, state: str, info=None) -> None:
super().change_state(key, state, info)
self.tasks.pop(key, None)
self.adopted_task_timeouts.pop(key, None)

def update_task_state(self, key: TaskInstanceKey, state: str, info: Any) -> None:
"""Updates state of a single task."""
try:
if self.last_state[key] != state:
if state == celery_states.SUCCESS:
self.success(key, info)
del self.tasks[key]
del self.last_state[key]
elif state == celery_states.FAILURE:
self.fail(key, info)
del self.tasks[key] # noqa
del self.last_state[key]
elif state == celery_states.REVOKED:
self.fail(key, info)
del self.tasks[key] # noqa
del self.last_state[key]
else:
self.log.info("Unexpected state: %s", state)
self.last_state[key] = state
if state == celery_states.SUCCESS:
self.success(key, info)
elif state in (celery_states.FAILURE, celery_states.REVOKED):
self.fail(key, info)
elif state == celery_states.STARTED:
# It's now actually running, so know it made it to celery okay!
self.adopted_task_timeouts.pop(key, None)
elif state == celery_states.PENDING:
pass
else:
self.log.info("Unexpected state for %s: %s", key, state)
except Exception: # noqa pylint: disable=broad-except
self.log.exception("Error syncing the Celery executor, ignoring it.")

Expand All @@ -274,6 +331,62 @@ def execute_async(self,
def terminate(self):
pass

def try_adopt_task_instances(self, tis: List[TaskInstance]) -> List[TaskInstance]:
# See which of the TIs are still alive (or have finished even!)
#
# Since Celery doesn't store "SENT" state for queued commands (if we create an AsyncResult with a made
# up id it just returns PENDING state for it), we have to store Celery's task_id against the TI row to
# look at in future.
#
# This process is not perfect -- we could have sent the task to celery, and crashed before we were
# able to record the AsyncResult.task_id in the TaskInstance table, in which case we won't adopt the
# task (it'll either run and update the TI state, or the scheduler will clear and re-queue it. Either
# way it won't get executed more than once)
#
# (If we swapped it around, and generated a task_id for Celery, stored that in TI and enqueued that
# there is also still a race condition where we could generate and store the task_id, but die before
# we managed to enqueue the command. Since neither way is perfect we always have to deal with this
# process not being perfect.)

celery_tasks = {}
not_adopted_tis = []

for ti in tis:
if ti.external_executor_id is not None:
celery_tasks[ti.external_executor_id] = (AsyncResult(ti.external_executor_id), ti)
else:
not_adopted_tis.append(ti)

if not celery_tasks:
# Nothing to adopt
return tis

states_by_celery_task_id = self.bulk_state_fetcher.get_many(
map(operator.itemgetter(0), celery_tasks.values())
)

adopted = []
cached_celery_backend = next(iter(celery_tasks.values()))[0].backend

for celery_task_id, (state, info) in states_by_celery_task_id.items():
result, ti = celery_tasks[celery_task_id]
result.backend = cached_celery_backend

# Set the correct elements of the state dicts, then update this
# like we just queried it.
self.adopted_task_timeouts[ti.key] = ti.queued_dttm + self.task_adoption_timeout
self.tasks[ti.key] = result
self.running.add(ti.key)
self.update_task_state(ti.key, state, info)
adopted.append(f"{ti} in state {state}")

if adopted:
task_instance_str = '\n\t'.join(adopted)
self.log.info("Adopted the following %d tasks from a dead executor\n\t%s",
len(adopted), task_instance_str)

return not_adopted_tis


def fetch_celery_task_state(async_result: AsyncResult) -> \
Tuple[str, Union[str, ExceptionWithTraceback], Any]:
Expand Down Expand Up @@ -309,7 +422,7 @@ class BulkStateFetcher(LoggingMixin):
Gets status for many Celery tasks using the best method available

If BaseKeyValueStoreBackend is used as result backend, the mget method is used.
If DatabaseBackend is used as result backend, the SELECT ...WHER task_id IN (...) query is used
If DatabaseBackend is used as result backend, the SELECT ...WHERE task_id IN (...) query is used
Otherwise, multiprocessing.Pool will be used. Each task status will be downloaded individually.
"""
def __init__(self, sync_parralelism=None):
Expand Down
Loading