diff --git a/airflow/config_templates/config.yml b/airflow/config_templates/config.yml index 275903fd8984d..c1400ad7aa096 100644 --- a/airflow/config_templates/config.yml +++ b/airflow/config_templates/config.yml @@ -1373,6 +1373,23 @@ type: int example: ~ default: "2" + - name: task_track_started + description: | + Celery task will report its status as 'started' when the task is executed by a worker. + This is used in Airflow to keep track of the running tasks and if a Scheduler is restarted + or run in HA mode, it can adopt the orphan tasks launched by previous SchedulerJob. + version_added: 2.0.0 + type: boolean + example: ~ + default: "True" + - name: task_adoption_timeout + description: | + Time in seconds after which Adopted tasks are cleared by CeleryExecutor. This is helpful to clear + stalled tasks. + version_added: 2.0.0 + type: int + example: ~ + default: "600" - name: celery_broker_transport_options description: | This section is for specifying options which can be passed to the diff --git a/airflow/config_templates/default_airflow.cfg b/airflow/config_templates/default_airflow.cfg index 68284d3eb0add..bc1c62b15aba0 100644 --- a/airflow/config_templates/default_airflow.cfg +++ b/airflow/config_templates/default_airflow.cfg @@ -685,6 +685,15 @@ pool = prefork # ``fetch_celery_task_state`` operations. operation_timeout = 2 +# Celery task will report its status as 'started' when the task is executed by a worker. +# This is used in Airflow to keep track of the running tasks and if a Scheduler is restarted +# or run in HA mode, it can adopt the orphan tasks launched by previous SchedulerJob. +task_track_started = True + +# Time in seconds after which Adopted tasks are cleared by CeleryExecutor. This is helpful to clear +# stalled tasks. +task_adoption_timeout = 600 + [celery_broker_transport_options] # This section is for specifying options which can be passed to the diff --git a/airflow/config_templates/default_celery.py b/airflow/config_templates/default_celery.py index bf0b03bbb713f..b12bf14e13e10 100644 --- a/airflow/config_templates/default_celery.py +++ b/airflow/config_templates/default_celery.py @@ -43,6 +43,7 @@ def _broker_supports_visibility_timeout(url): 'task_acks_late': True, 'task_default_queue': conf.get('celery', 'DEFAULT_QUEUE'), 'task_default_exchange': conf.get('celery', 'DEFAULT_QUEUE'), + 'task_track_started': conf.get('celery', 'task_track_started', fallback=True), 'broker_url': broker_url, 'broker_transport_options': broker_transport_options, 'result_backend': conf.get('celery', 'RESULT_BACKEND'), diff --git a/airflow/executors/base_executor.py b/airflow/executors/base_executor.py index 1427f3e601c94..df08a2bd1e252 100644 --- a/airflow/executors/base_executor.py +++ b/airflow/executors/base_executor.py @@ -265,6 +265,20 @@ def terminate(self): """ raise NotImplementedError() + def try_adopt_task_instances(self, tis: List[TaskInstance]) -> List[TaskInstance]: + """ + Try to adopt running task instances that have been abandoned by a SchedulerJob dying. + + Anything that is not adopted will be cleared by the scheduler (and then become eligible for + re-scheduling) + + :return: any TaskInstances that were unable to be adopted + :rtype: list[airflow.models.TaskInstance] + """ + # By default, assume Executors cannot adopt tasks, so just say we failed to adopt anything. + # Subclasses can do better! + return tis + @staticmethod def validate_command(command: List[str]) -> None: """Check if the command to execute is airflow command""" diff --git a/airflow/executors/celery_executor.py b/airflow/executors/celery_executor.py index 13726674bd2b0..69290daa7c4df 100644 --- a/airflow/executors/celery_executor.py +++ b/airflow/executors/celery_executor.py @@ -21,14 +21,17 @@ For more information on how the CeleryExecutor works, take a look at the guide: :ref:`executor:CeleryExecutor` """ +import datetime import logging import math +import operator import os import subprocess import time import traceback +from collections import OrderedDict from multiprocessing import Pool, cpu_count -from typing import Any, List, Mapping, MutableMapping, Optional, Set, Tuple, Union +from typing import Any, Dict, List, Mapping, MutableMapping, Optional, Set, Tuple, Union from celery import Celery, Task, states as celery_states from celery.backends.base import BaseKeyValueStoreBackend @@ -39,11 +42,12 @@ from airflow.configuration import conf from airflow.exceptions import AirflowException from airflow.executors.base_executor import BaseExecutor, CommandType, EventBufferValueType -from airflow.models.taskinstance import SimpleTaskInstance, TaskInstanceKey +from airflow.models.taskinstance import SimpleTaskInstance, TaskInstance, TaskInstanceKey from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.net import get_hostname from airflow.utils.state import State from airflow.utils.timeout import timeout +from airflow.utils.timezone import utcnow log = logging.getLogger(__name__) @@ -142,7 +146,11 @@ def __init__(self): self._sync_parallelism = max(1, cpu_count() - 1) self.bulk_state_fetcher = BulkStateFetcher(self._sync_parallelism) self.tasks = {} - self.last_state = {} + # Mapping of tasks we've adopted, ordered by the earliest date they timeout + self.adopted_task_timeouts: Dict[TaskInstanceKey, datetime.datetime] = OrderedDict() + self.task_adoption_timeout = datetime.timedelta( + seconds=conf.getint('celery', 'task_adoption_timeout', fallback=600) + ) def start(self) -> None: self.log.debug( @@ -199,7 +207,14 @@ def _process_tasks(self, task_tuples_to_send: List[TaskInstanceInCelery]) -> Non result.backend = cached_celery_backend self.running.add(key) self.tasks[key] = result - self.last_state[key] = celery_states.PENDING + + # Store the Celery task_id in the event buffer. This will get "overwritten" if the task + # has another event, but that is fine, because the only other events are success/failed at + # which point we dont need the ID anymore anyway + self.event_buffer[key] = (State.QUEUED, result.task_id) + + # If the task runs _really quickly_ we may already have a result! + self.update_task_state(key, result.state, getattr(result, 'info', None)) def _send_tasks_to_celery(self, task_tuples_to_send): if len(task_tuples_to_send) == 1 or self._sync_parallelism == 1: @@ -223,6 +238,48 @@ def sync(self) -> None: return self.update_all_task_states() + if self.adopted_task_timeouts: + self._check_for_stalled_adopted_tasks() + + def _check_for_stalled_adopted_tasks(self): + """ + See if any of the tasks we adopted from another Executor run have not + progressed after the configured timeout. + + If they haven't, they likely never made it to Celery, and we should + just resend them. We do that by clearing the state and letting the + normal scheduler loop deal with that + """ + now = utcnow() + + timedout_keys = [] + for key, stalled_after in self.adopted_task_timeouts.items(): + if stalled_after > now: + # Since items are stored sorted, if we get to a stalled_after + # in the future then we can stop + break + + # If the task gets updated to STARTED (which Celery does) or has + # already finished, then it will be removed from this list -- so + # the only time it's still in this list is when it a) never made it + # to celery in the first place (i.e. race condition somehwere in + # the dying executor) or b) a really long celery queue and it just + # hasn't started yet -- better cancel it and let the scheduler + # re-queue rather than have this task risk stalling for ever + timedout_keys.append(key) + + if timedout_keys: + self.log.error( + "Adopted tasks were still pending after %s, assuming they never made it to celery and " + "clearing:\n\t%s", + self.task_adoption_timeout, + "\n\t".join([repr(x) for x in timedout_keys]) + ) + for key in timedout_keys: + self.event_buffer[key] = (State.FAILED, None) + del self.tasks[key] + del self.adopted_task_timeouts[key] + def update_all_task_states(self) -> None: """Updates states of the tasks.""" @@ -235,25 +292,25 @@ def update_all_task_states(self) -> None: if state: self.update_task_state(key, state, info) + def change_state(self, key: TaskInstanceKey, state: str, info=None) -> None: + super().change_state(key, state, info) + self.tasks.pop(key, None) + self.adopted_task_timeouts.pop(key, None) + def update_task_state(self, key: TaskInstanceKey, state: str, info: Any) -> None: """Updates state of a single task.""" try: - if self.last_state[key] != state: - if state == celery_states.SUCCESS: - self.success(key, info) - del self.tasks[key] - del self.last_state[key] - elif state == celery_states.FAILURE: - self.fail(key, info) - del self.tasks[key] # noqa - del self.last_state[key] - elif state == celery_states.REVOKED: - self.fail(key, info) - del self.tasks[key] # noqa - del self.last_state[key] - else: - self.log.info("Unexpected state: %s", state) - self.last_state[key] = state + if state == celery_states.SUCCESS: + self.success(key, info) + elif state in (celery_states.FAILURE, celery_states.REVOKED): + self.fail(key, info) + elif state == celery_states.STARTED: + # It's now actually running, so know it made it to celery okay! + self.adopted_task_timeouts.pop(key, None) + elif state == celery_states.PENDING: + pass + else: + self.log.info("Unexpected state for %s: %s", key, state) except Exception: # noqa pylint: disable=broad-except self.log.exception("Error syncing the Celery executor, ignoring it.") @@ -274,6 +331,62 @@ def execute_async(self, def terminate(self): pass + def try_adopt_task_instances(self, tis: List[TaskInstance]) -> List[TaskInstance]: + # See which of the TIs are still alive (or have finished even!) + # + # Since Celery doesn't store "SENT" state for queued commands (if we create an AsyncResult with a made + # up id it just returns PENDING state for it), we have to store Celery's task_id against the TI row to + # look at in future. + # + # This process is not perfect -- we could have sent the task to celery, and crashed before we were + # able to record the AsyncResult.task_id in the TaskInstance table, in which case we won't adopt the + # task (it'll either run and update the TI state, or the scheduler will clear and re-queue it. Either + # way it won't get executed more than once) + # + # (If we swapped it around, and generated a task_id for Celery, stored that in TI and enqueued that + # there is also still a race condition where we could generate and store the task_id, but die before + # we managed to enqueue the command. Since neither way is perfect we always have to deal with this + # process not being perfect.) + + celery_tasks = {} + not_adopted_tis = [] + + for ti in tis: + if ti.external_executor_id is not None: + celery_tasks[ti.external_executor_id] = (AsyncResult(ti.external_executor_id), ti) + else: + not_adopted_tis.append(ti) + + if not celery_tasks: + # Nothing to adopt + return tis + + states_by_celery_task_id = self.bulk_state_fetcher.get_many( + map(operator.itemgetter(0), celery_tasks.values()) + ) + + adopted = [] + cached_celery_backend = next(iter(celery_tasks.values()))[0].backend + + for celery_task_id, (state, info) in states_by_celery_task_id.items(): + result, ti = celery_tasks[celery_task_id] + result.backend = cached_celery_backend + + # Set the correct elements of the state dicts, then update this + # like we just queried it. + self.adopted_task_timeouts[ti.key] = ti.queued_dttm + self.task_adoption_timeout + self.tasks[ti.key] = result + self.running.add(ti.key) + self.update_task_state(ti.key, state, info) + adopted.append(f"{ti} in state {state}") + + if adopted: + task_instance_str = '\n\t'.join(adopted) + self.log.info("Adopted the following %d tasks from a dead executor\n\t%s", + len(adopted), task_instance_str) + + return not_adopted_tis + def fetch_celery_task_state(async_result: AsyncResult) -> \ Tuple[str, Union[str, ExceptionWithTraceback], Any]: @@ -309,7 +422,7 @@ class BulkStateFetcher(LoggingMixin): Gets status for many Celery tasks using the best method available If BaseKeyValueStoreBackend is used as result backend, the mget method is used. - If DatabaseBackend is used as result backend, the SELECT ...WHER task_id IN (...) query is used + If DatabaseBackend is used as result backend, the SELECT ...WHERE task_id IN (...) query is used Otherwise, multiprocessing.Pool will be used. Each task status will be downloaded individually. """ def __init__(self, sync_parralelism=None): diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py index bfb579af1e836..83c731df8d74f 100644 --- a/airflow/jobs/scheduler_job.py +++ b/airflow/jobs/scheduler_job.py @@ -34,6 +34,7 @@ from setproctitle import setproctitle from sqlalchemy import and_, func, not_, or_ +from sqlalchemy.orm import load_only from sqlalchemy.orm.session import Session, make_transient from airflow import models, settings @@ -59,6 +60,7 @@ from airflow.utils.log.logging_mixin import LoggingMixin, StreamLogWriter, set_context from airflow.utils.mixins import MultiprocessingStartMethodMixin from airflow.utils.session import provide_session +from airflow.utils.sqlalchemy import skip_locked from airflow.utils.state import State from airflow.utils.types import DagRunType @@ -1387,7 +1389,8 @@ def _change_state_for_executable_task_instances( # set TIs to queued state filter_for_tis = TI.filter_for_tis(tis_to_set_to_queued) session.query(TI).filter(filter_for_tis).update( - {TI.state: State.QUEUED, TI.queued_dttm: timezone.utcnow()}, synchronize_session=False + {TI.state: State.QUEUED, TI.queued_dttm: timezone.utcnow(), TI.queued_by_job_id: self.id}, + synchronize_session=False ) session.commit() @@ -1541,7 +1544,7 @@ def _process_executor_events(self, simple_dag_bag: SimpleDagBag, session: Sessio "exited with status %s for try_number %s", ti_key.dag_id, ti_key.task_id, ti_key.execution_date, state, ti_key.try_number ) - if state in (State.FAILED, State.SUCCESS): + if state in (State.FAILED, State.SUCCESS, State.QUEUED): tis_with_right_state.append(ti_key) # Return if no finished tasks @@ -1557,6 +1560,10 @@ def _process_executor_events(self, simple_dag_bag: SimpleDagBag, session: Sessio state, info = event_buffer.pop(buffer_key) # TODO: should we fail RUNNING as well, as we do in Backfills? + if state == State.QUEUED: + ti.external_executor_id = info + continue + if ti.try_number == buffer_key.try_number and ti.state == State.QUEUED: Stats.incr('scheduler.tasks.killed_externally') msg = "Executor reports task instance %s finished (%s) although the " \ @@ -1597,7 +1604,7 @@ def _execute(self) -> None: self.executor.start() self.log.info("Resetting orphaned tasks for active dag runs") - self.reset_state_for_orphaned_tasks() + self.adopt_or_reset_orphaned_tasks() self.register_exit_signals() @@ -1764,7 +1771,7 @@ def heartbeat_callback(self, session: Session = None) -> None: Stats.incr('scheduler_heartbeat', 1, 1) @provide_session - def reset_state_for_orphaned_tasks(self, session: Session = None): + def adopt_or_reset_orphaned_tasks(self, session: Session = None): """ Reset any TaskInstance still in QUEUED or SCHEDULED states that were enqueued by a SchedulerJob that is no longer running. @@ -1781,11 +1788,12 @@ def reset_state_for_orphaned_tasks(self, session: Session = None): if num_failed: self.log.info("Marked %d SchedulerJob instances as failed", num_failed) + Stats.incr(self.__class__.__name__.lower() + '_end', num_failed) - resettable_states = [State.SCHEDULED, State.QUEUED] - tis_to_reset = ( + resettable_states = [State.SCHEDULED, State.QUEUED, State.RUNNING] + tis_to_reset_or_adopt = ( session.query(TI).filter(TI.state.in_(resettable_states)) - # outerjoin is becase we didn't use to have queued_by_job + # outerjoin is because we didn't use to have queued_by_job # set, so we need to pick up anything pre upgrade. This (and the # "or queued_by_job_id IS NONE") can go as soon as scheduler HA is # released. @@ -1795,30 +1803,28 @@ def reset_state_for_orphaned_tasks(self, session: Session = None): .filter(DagRun.run_type != DagRunType.BACKFILL_JOB.value, # pylint: disable=comparison-with-callable DagRun.state == State.RUNNING) - .with_entities(TI.dag_id, TI.task_id, TI.execution_date) + .options(load_only(TI.dag_id, TI.task_id, TI.execution_date)) + # Lock these rows, so that another scheduler can't try and adopt these too + .with_for_update(of=TI, **skip_locked(session=session)) + .all() ) + to_reset = self.executor.try_adopt_task_instances(tis_to_reset_or_adopt) - if self.using_sqlite: - tis_to_reset = tis_to_reset.with_for_update(of=TI).all() - if tis_to_reset: - filter_for_tis = TI.filter_for_tis([ - TaskInstanceKey(dag_id, task_id, execution_date, 0) - for (dag_id, task_id, execution_date) in tis_to_reset - ]) - num_reset = session.query(TI).filter( - filter_for_tis, TI.state.in_(resettable_states) - ).update({TI.state: State.NONE}, synchronize_session=False) - else: - num_reset = 0 - else: - tis_to_reset = tis_to_reset.subquery('tis_to_reset') - num_reset = session.query(TI).filter( - TI.dag_id == tis_to_reset.c.dag_id, - TI.task_id == tis_to_reset.c.task_id, - TI.execution_date == tis_to_reset.c.execution_date, - ).update({TI.state: State.NONE}, synchronize_session=False) + reset_tis_message = [] + for ti in to_reset: + reset_tis_message.append(repr(ti)) + ti.state = State.NONE + ti.queued_by_job_id = None + + for ti in set(tis_to_reset_or_adopt) - set(to_reset): + ti.queued_by_job_id = self.id + + Stats.incr('scheduler.orphaned_tasks.cleared', len(to_reset)) + Stats.incr('scheduler.orphaned_tasks.adopted', len(tis_to_reset_or_adopt) - len(to_reset)) - if num_reset: - self.log.info("Reset %d orphaned TaskInstances that were in queued state", num_reset) + if to_reset: + task_instance_str = '\n\t'.join(reset_tis_message) + self.log.info("Reset the following %s orphaned TaskInstances:\n\t%s", + len(to_reset), task_instance_str) - return num_reset + return len(to_reset) diff --git a/airflow/migrations/versions/e1a11ece99cc_add_external_executor_id_to_ti.py b/airflow/migrations/versions/e1a11ece99cc_add_external_executor_id_to_ti.py new file mode 100644 index 0000000000000..8662ee418949e --- /dev/null +++ b/airflow/migrations/versions/e1a11ece99cc_add_external_executor_id_to_ti.py @@ -0,0 +1,46 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Add external executor ID to TI + +Revision ID: e1a11ece99cc +Revises: b247b1e3d1ed +Create Date: 2020-09-12 08:23:45.698865 + +""" + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = 'e1a11ece99cc' +down_revision = 'b247b1e3d1ed' +branch_labels = None +depends_on = None + + +def upgrade(): + """Apply Add external executor ID to TI""" + with op.batch_alter_table('task_instance', schema=None) as batch_op: + batch_op.add_column(sa.Column('external_executor_id', sa.String(length=250), nullable=True)) + + +def downgrade(): + """Unapply Add external executor ID to TI""" + with op.batch_alter_table('task_instance', schema=None) as batch_op: + batch_op.drop_column('external_executor_id') diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 5280cb4605e86..21bad5033458d 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -222,6 +222,8 @@ class TaskInstance(Base, LoggingMixin): # pylint: disable=R0902,R0904 queued_by_job_id = Column(Integer) pid = Column(Integer) executor_config = Column(PickleType(pickler=dill)) + + external_executor_id = Column(String(ID_LEN, **COLLATION_ARGS)) # If adding new fields here then remember to add them to # refresh_from_db() or they wont display in the UI correctly diff --git a/airflow/utils/sqlalchemy.py b/airflow/utils/sqlalchemy.py index 0106d7639cd9d..2562dc3791142 100644 --- a/airflow/utils/sqlalchemy.py +++ b/airflow/utils/sqlalchemy.py @@ -19,9 +19,11 @@ import datetime import json import logging +from typing import Any, Dict import pendulum from dateutil import relativedelta +from sqlalchemy.orm.session import Session from sqlalchemy.types import DateTime, Text, TypeDecorator from airflow.configuration import conf @@ -118,3 +120,23 @@ def process_result_value(self, value, dialect): type_map = {key.__name__: key for key in self.attr_keys} return type_map[data['type']](**data['attrs']) return data + + +def skip_locked(session: Session) -> Dict[str, Any]: + """ + Return kargs for passing to `with_for_update()` suitable for the current DB engine version. + + We do this as we document the fact that on DB engines that don't support this construct, we do not + support/recommend running HA scheduler. If a user ignores this and tries anyway everything will still + work, just slightly slower in some circumstances. + + Specifically don't emit SKIP LOCKED for MySQL < 8, or MariaDB, neither of which support this construct + + See https://jira.mariadb.org/browse/MDEV-13115 + """ + dialect = session.bind.dialect + + if dialect.name != "mysql" or dialect.supports_for_update_of: + return {'skip_locked': True} + else: + return {} diff --git a/docs/logging-monitoring/metrics.rst b/docs/logging-monitoring/metrics.rst index afe780161a1d2..bdff897b654ba 100644 --- a/docs/logging-monitoring/metrics.rst +++ b/docs/logging-monitoring/metrics.rst @@ -87,6 +87,8 @@ Name Description ``scheduler.tasks.killed_externally`` Number of tasks killed externally ``scheduler.tasks.running`` Number of tasks running in executor ``scheduler.tasks.starving`` Number of tasks that cannot be scheduled because of no open slot in pool +``scheduler.orphaned_tasks.cleared`` Number of Orphaned tasks cleared by the Scheduler +``scheduler.orphaned_tasks.adopted`` Number of Orphaned tasks adopted by the Scheduler ``sla_email_notification_failure`` Number of failed SLA miss email notification attempts ``ti.start..`` Number of started task in a given dag. Similar to _start but for task ``ti.finish...`` Number of completed task in a given dag. Similar to _end but for task diff --git a/setup.py b/setup.py index 7907ba669ecb1..636c11d72f6ff 100644 --- a/setup.py +++ b/setup.py @@ -731,7 +731,7 @@ def is_package_excluded(package: str, exclusion_list: List[str]): 'python-slugify>=3.0.0,<5.0', 'requests>=2.20.0, <3', 'setproctitle>=1.1.8, <2', - 'sqlalchemy~=1.3', + 'sqlalchemy>=1.3.18, <2', 'sqlalchemy_jsonfield~=0.9', 'tabulate>=0.7.5, <0.9', 'tenacity>=4.12.0, <5.2', diff --git a/tests/executors/test_base_executor.py b/tests/executors/test_base_executor.py index 47d5ca7549d7a..f324fa753b21f 100644 --- a/tests/executors/test_base_executor.py +++ b/tests/executors/test_base_executor.py @@ -17,11 +17,13 @@ # under the License. import unittest -from datetime import datetime +from datetime import datetime, timedelta from unittest import mock from airflow.executors.base_executor import BaseExecutor -from airflow.models.taskinstance import TaskInstanceKey +from airflow.models.baseoperator import BaseOperator +from airflow.models.dag import DAG +from airflow.models.taskinstance import TaskInstance, TaskInstanceKey from airflow.utils.state import State @@ -53,3 +55,18 @@ def test_gauge_executor_metrics(self, mock_stats_gauge, mock_trigger_tasks, mock mock.call('executor.queued_tasks', mock.ANY), mock.call('executor.running_tasks', mock.ANY)] mock_stats_gauge.assert_has_calls(calls) + + def test_try_adopt_task_instances(self): + date = datetime.utcnow() + start_date = datetime.utcnow() - timedelta(days=2) + + with DAG("test_try_adopt_task_instances"): + task_1 = BaseOperator(task_id="task_1", start_date=start_date) + task_2 = BaseOperator(task_id="task_2", start_date=start_date) + task_3 = BaseOperator(task_id="task_3", start_date=start_date) + + key1 = TaskInstance(task=task_1, execution_date=date) + key2 = TaskInstance(task=task_2, execution_date=date) + key3 = TaskInstance(task=task_3, execution_date=date) + tis = [key1, key2, key3] + self.assertEqual(BaseExecutor().try_adopt_task_instances(tis), tis) diff --git a/tests/executors/test_celery_executor.py b/tests/executors/test_celery_executor.py index 76b572e260eab..7408106dc8bfa 100644 --- a/tests/executors/test_celery_executor.py +++ b/tests/executors/test_celery_executor.py @@ -16,11 +16,11 @@ # specific language governing permissions and limitations # under the License. import contextlib -import datetime import json import os import sys import unittest +from datetime import datetime, timedelta from unittest import mock # leave this it is used by the test worker @@ -30,6 +30,7 @@ from celery.backends.base import BaseBackend, BaseKeyValueStoreBackend # noqa from celery.backends.database import DatabaseBackend from celery.contrib.testing.worker import start_worker +from celery.result import AsyncResult from kombu.asynchronous import set_event_loop from parameterized import parameterized @@ -37,11 +38,13 @@ from airflow.exceptions import AirflowException from airflow.executors import celery_executor from airflow.executors.celery_executor import BulkStateFetcher -from airflow.models import TaskInstance +from airflow.models.baseoperator import BaseOperator from airflow.models.dag import DAG -from airflow.models.taskinstance import SimpleTaskInstance +from airflow.models.taskinstance import SimpleTaskInstance, TaskInstance, TaskInstanceKey from airflow.operators.bash import BashOperator +from airflow.utils import timezone from airflow.utils.state import State +from tests.test_utils import db def _prepare_test_bodies(): @@ -95,6 +98,14 @@ def _prepare_app(broker_url=None, execute=None): class TestCeleryExecutor(unittest.TestCase): + def setUp(self) -> None: + db.clear_db_runs() + db.clear_db_jobs() + + def tearDown(self) -> None: + db.clear_db_runs() + db.clear_db_jobs() + @parameterized.expand(_prepare_test_bodies()) @pytest.mark.integration("redis") @pytest.mark.integration("rabbitmq") @@ -109,10 +120,11 @@ def fake_execute_command(command): with _prepare_app(broker_url, execute=fake_execute_command) as app: executor = celery_executor.CeleryExecutor() + self.assertEqual(executor.tasks, {}) executor.start() with start_worker(app=app, logfile=sys.stdout, loglevel='info'): - execute_date = datetime.datetime.now() + execute_date = datetime.now() task_tuples_to_send = [ (('success', 'fake_simple_ti', execute_date, 0), @@ -129,6 +141,22 @@ def fake_execute_command(command): executor._process_tasks(task_tuples_to_send) + self.assertEqual( + list(executor.tasks.keys()), + [ + ('success', 'fake_simple_ti', execute_date, 0), + ('fail', 'fake_simple_ti', execute_date, 0) + ] + ) + self.assertEqual( + executor.event_buffer[('success', 'fake_simple_ti', execute_date, 0)][0], + State.QUEUED + ) + self.assertEqual( + executor.event_buffer[('fail', 'fake_simple_ti', execute_date, 0)][0], + State.QUEUED + ) + executor.end(synchronous=True) self.assertEqual(executor.event_buffer[('success', 'fake_simple_ti', execute_date, 0)][0], @@ -139,8 +167,8 @@ def fake_execute_command(command): self.assertNotIn('success', executor.tasks) self.assertNotIn('fail', executor.tasks) - self.assertNotIn('success', executor.last_state) - self.assertNotIn('fail', executor.last_state) + self.assertEqual(executor.queued_tasks, {}) + self.assertEqual(timedelta(0, 600), executor.task_adoption_timeout) @pytest.mark.integration("redis") @pytest.mark.integration("rabbitmq") @@ -157,11 +185,11 @@ def fake_execute_command(): task_id="test", bash_command="true", dag=DAG(dag_id='id'), - start_date=datetime.datetime.now() + start_date=datetime.now() ) - when = datetime.datetime.now() + when = datetime.now() value_tuple = 'command', 1, None, \ - SimpleTaskInstance(ti=TaskInstance(task=task, execution_date=datetime.datetime.now())) + SimpleTaskInstance(ti=TaskInstance(task=task, execution_date=datetime.now())) key = ('fail', 'fake_simple_ti', when, 0) executor.queued_tasks[key] = value_tuple executor.heartbeat() @@ -210,6 +238,89 @@ def test_command_validation(self, command, expected_exception, mock_check_output command, stderr=mock.ANY, close_fds=mock.ANY, env=mock.ANY, ) + @pytest.mark.backend("mysql", "postgres") + def test_try_adopt_task_instances_none(self): + date = datetime.utcnow() + start_date = datetime.utcnow() - timedelta(days=2) + + with DAG("test_try_adopt_task_instances_none"): + task_1 = BaseOperator(task_id="task_1", start_date=start_date) + + key1 = TaskInstance(task=task_1, execution_date=date) + tis = [key1] + executor = celery_executor.CeleryExecutor() + + self.assertEqual(executor.try_adopt_task_instances(tis), tis) + + @pytest.mark.backend("mysql", "postgres") + def test_try_adopt_task_instances(self): + exec_date = timezone.utcnow() - timedelta(minutes=2) + start_date = timezone.utcnow() - timedelta(days=2) + queued_dttm = timezone.utcnow() - timedelta(minutes=1) + + try_number = 1 + + with DAG("test_try_adopt_task_instances_none") as dag: + task_1 = BaseOperator(task_id="task_1", start_date=start_date) + task_2 = BaseOperator(task_id="task_2", start_date=start_date) + + ti1 = TaskInstance(task=task_1, execution_date=exec_date) + ti1.external_executor_id = '231' + ti1.queued_dttm = queued_dttm + ti2 = TaskInstance(task=task_2, execution_date=exec_date) + ti2.external_executor_id = '232' + ti2.queued_dttm = queued_dttm + + tis = [ti1, ti2] + executor = celery_executor.CeleryExecutor() + self.assertEqual(executor.running, set()) + self.assertEqual(executor.adopted_task_timeouts, {}) + self.assertEqual(executor.tasks, {}) + + not_adopted_tis = executor.try_adopt_task_instances(tis) + + key_1 = TaskInstanceKey(dag.dag_id, task_1.task_id, exec_date, try_number) + key_2 = TaskInstanceKey(dag.dag_id, task_2.task_id, exec_date, try_number) + self.assertEqual(executor.running, {key_1, key_2}) + self.assertEqual( + dict(executor.adopted_task_timeouts), + { + key_1: queued_dttm + executor.task_adoption_timeout, + key_2: queued_dttm + executor.task_adoption_timeout + } + ) + self.assertEqual(executor.tasks, {key_1: AsyncResult("231"), key_2: AsyncResult("232")}) + self.assertEqual(not_adopted_tis, []) + + @pytest.mark.backend("mysql", "postgres") + def test_check_for_stalled_adopted_tasks(self): + exec_date = timezone.utcnow() - timedelta(minutes=40) + start_date = timezone.utcnow() - timedelta(days=2) + queued_dttm = timezone.utcnow() - timedelta(minutes=30) + + try_number = 1 + + with DAG("test_check_for_stalled_adopted_tasks") as dag: + task_1 = BaseOperator(task_id="task_1", start_date=start_date) + task_2 = BaseOperator(task_id="task_2", start_date=start_date) + + key_1 = TaskInstanceKey(dag.dag_id, task_1.task_id, exec_date, try_number) + key_2 = TaskInstanceKey(dag.dag_id, task_2.task_id, exec_date, try_number) + + executor = celery_executor.CeleryExecutor() + executor.adopted_task_timeouts = { + key_1: queued_dttm + executor.task_adoption_timeout, + key_2: queued_dttm + executor.task_adoption_timeout + } + executor.tasks = {key_1: AsyncResult("231"), key_2: AsyncResult("232")} + executor.sync() + self.assertEqual( + executor.event_buffer, + {key_1: (State.FAILED, None), key_2: (State.FAILED, None)} + ) + self.assertEqual(executor.tasks, {}) + self.assertEqual(executor.adopted_task_timeouts, {}) + def test_operation_timeout_config(): assert celery_executor.OPERATION_TIMEOUT == 2 diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py index 2ce75ae2f8e6e..59653644e4dc3 100644 --- a/tests/jobs/test_scheduler_job.py +++ b/tests/jobs/test_scheduler_job.py @@ -2302,7 +2302,7 @@ def test_change_state_for_tasks_failed_to_execute(self): ti.refresh_from_db() self.assertEqual(State.RUNNING, ti.state) - def test_reset_state_for_orphaned_tasks(self): + def test_adopt_or_reset_orphaned_tasks(self): session = settings.Session() dag = DAG( 'test_execute_helper_reset_orphaned_tasks', @@ -2335,7 +2335,7 @@ def test_reset_state_for_orphaned_tasks(self): scheduler = SchedulerJob(num_runs=0) scheduler.processor_agent = processor - scheduler.reset_state_for_orphaned_tasks() + scheduler.adopt_or_reset_orphaned_tasks() ti = dr.get_task_instance(task_id=op1.task_id, session=session) self.assertEqual(ti.state, State.NONE) @@ -3316,13 +3316,13 @@ def test_list_py_file_paths(self): detected_files.add(file_path) self.assertEqual(detected_files, expected_files) - def test_reset_orphaned_tasks_nothing(self): + def test_adopt_or_reset_orphaned_tasks_nothing(self): """Try with nothing. """ scheduler = SchedulerJob() session = settings.Session() - self.assertEqual(0, scheduler.reset_state_for_orphaned_tasks(session=session)) + self.assertEqual(0, scheduler.adopt_or_reset_orphaned_tasks(session=session)) - def test_reset_orphaned_tasks_external_triggered_dag(self): + def test_adopt_or_reset_orphaned_tasks_external_triggered_dag(self): dag_id = 'test_reset_orphaned_tasks_external_triggered_dag' dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, schedule_interval='@daily') task_id = dag_id + '_task' @@ -3341,11 +3341,11 @@ def test_reset_orphaned_tasks_external_triggered_dag(self): session.merge(dr1) session.commit() - num_reset_tis = scheduler.reset_state_for_orphaned_tasks(session=session) + num_reset_tis = scheduler.adopt_or_reset_orphaned_tasks(session=session) self.assertEqual(1, num_reset_tis) - def test_reset_orphaned_tasks_backfill_dag(self): - dag_id = 'test_reset_orphaned_tasks_backfill_dag' + def test_adopt_or_reset_orphaned_tasks_backfill_dag(self): + dag_id = 'test_adopt_or_reset_orphaned_tasks_backfill_dag' dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, schedule_interval='@daily') task_id = dag_id + '_task' DummyOperator(task_id=task_id, dag=dag) @@ -3366,7 +3366,7 @@ def test_reset_orphaned_tasks_backfill_dag(self): session.flush() self.assertTrue(dr1.is_backfill) - self.assertEqual(0, scheduler.reset_state_for_orphaned_tasks(session=session)) + self.assertEqual(0, scheduler.adopt_or_reset_orphaned_tasks(session=session)) session.rollback() def test_reset_orphaned_tasks_nonexistent_dagrun(self): @@ -3388,7 +3388,7 @@ def test_reset_orphaned_tasks_nonexistent_dagrun(self): session.merge(ti) session.flush() - self.assertEqual(0, scheduler.reset_state_for_orphaned_tasks(session=session)) + self.assertEqual(0, scheduler.adopt_or_reset_orphaned_tasks(session=session)) session.rollback() def test_reset_orphaned_tasks_no_orphans(self): @@ -3412,7 +3412,7 @@ def test_reset_orphaned_tasks_no_orphans(self): session.merge(tis[0]) session.flush() - self.assertEqual(0, scheduler.reset_state_for_orphaned_tasks(session=session)) + self.assertEqual(0, scheduler.adopt_or_reset_orphaned_tasks(session=session)) tis[0].refresh_from_db() self.assertEqual(State.RUNNING, tis[0].state) @@ -3439,21 +3439,21 @@ def test_reset_orphaned_tasks_non_running_dagruns(self): session.merge(tis[0]) session.flush() - self.assertEqual(0, scheduler.reset_state_for_orphaned_tasks(session=session)) + self.assertEqual(0, scheduler.adopt_or_reset_orphaned_tasks(session=session)) session.rollback() - def test_reset_orphaned_tasks_stale_scheduler_jobs(self): - dag_id = 'test_reset_orphaned_tasks_external_triggered_dag' + def test_adopt_or_reset_orphaned_tasks_stale_scheduler_jobs(self): + dag_id = 'test_adopt_or_reset_orphaned_tasks_stale_scheduler_jobs' dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, schedule_interval='@daily') DummyOperator(task_id='task1', dag=dag) DummyOperator(task_id='task2', dag=dag) dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) - scheduler = SchedulerJob() + scheduler_job = SchedulerJob() session = settings.Session() - scheduler.state = State.RUNNING - scheduler.latest_heartbeat = timezone.utcnow() - session.add(scheduler) + scheduler_job.state = State.RUNNING + scheduler_job.latest_heartbeat = timezone.utcnow() + session.add(scheduler_job) old_job = SchedulerJob() old_job.state = State.RUNNING @@ -3470,11 +3470,12 @@ def test_reset_orphaned_tasks_stale_scheduler_jobs(self): session.merge(ti1) ti2.state = State.SCHEDULED - ti2.queued_by_job_id = scheduler.id + ti2.queued_by_job_id = scheduler_job.id session.merge(ti2) session.flush() - num_reset_tis = scheduler.reset_state_for_orphaned_tasks(session=session) + num_reset_tis = scheduler_job.adopt_or_reset_orphaned_tasks(session=session) + session.flush() self.assertEqual(1, num_reset_tis) session.refresh(ti1) diff --git a/tests/utils/test_sqlalchemy.py b/tests/utils/test_sqlalchemy.py index 98449c0071759..d59bbc96d1343 100644 --- a/tests/utils/test_sqlalchemy.py +++ b/tests/utils/test_sqlalchemy.py @@ -18,12 +18,15 @@ # import datetime import unittest +from unittest import mock +from parameterized import parameterized from sqlalchemy.exc import StatementError from airflow import settings from airflow.models import DAG from airflow.settings import Session +from airflow.utils.sqlalchemy import skip_locked from airflow.utils.state import State from airflow.utils.timezone import utcnow @@ -95,6 +98,18 @@ def test_process_bind_param_naive(self): ) dag.clear() + @parameterized.expand([ + ("postgresql", True, {'skip_locked': True}, ), + ("mysql", False, {}, ), + ("mysql", True, {'skip_locked': True}, ), + ("sqlite", False, {'skip_locked': True}, ), + ]) + def test_skip_locked(self, dialect, supports_for_update_of, expected_return_value): + session = mock.Mock() + session.bind.dialect.name = dialect + session.bind.dialect.supports_for_update_of = supports_for_update_of + self.assertEqual(skip_locked(session=session), expected_return_value) + def tearDown(self): self.session.close() settings.engine.dispose()