diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py index 4274a24c2c6a0..7d68c4336d7d8 100644 --- a/airflow/jobs/scheduler_job.py +++ b/airflow/jobs/scheduler_job.py @@ -44,7 +44,7 @@ from airflow.models import DAG, DagRun, SlaMiss, errors from airflow.models.taskinstance import SimpleTaskInstance from airflow.stats import Stats -from airflow.ti_deps.dep_context import SCHEDULEABLE_STATES, SCHEDULED_DEPS, DepContext +from airflow.ti_deps.dep_context import SCHEDULED_DEPS, DepContext from airflow.ti_deps.deps.pool_slots_available_dep import STATES_TO_COUNT_AS_RUNNING from airflow.utils import asciiart, helpers, timezone from airflow.utils.dag_processing import ( @@ -649,28 +649,10 @@ def _process_task_instances(self, dag, task_instances_list, session=None): run.dag = dag # todo: preferably the integrity check happens at dag collection time run.verify_integrity(session=session) - run.update_state(session=session) + ready_tis = run.update_state(session=session) if run.state == State.RUNNING: - make_transient(run) - active_dag_runs.append(run) - - for run in active_dag_runs: - self.log.debug("Examining active DAG run: %s", run) - tis = run.get_task_instances(state=SCHEDULEABLE_STATES) - - # this loop is quite slow as it uses are_dependencies_met for - # every task (in ti.is_runnable). This is also called in - # update_state above which has already checked these tasks - for ti in tis: - task = dag.get_task(ti.task_id) - - # fixme: ti.task is transient but needs to be set - ti.task = task - - if ti.are_dependencies_met( - dep_context=DepContext(flag_upstream_failed=True), - session=session - ): + self.log.debug("Examining active DAG run: %s", run) + for ti in ready_tis: self.log.debug('Queuing task: %s', ti) task_instances_list.append(ti.key) diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py index cc02aa1d467d8..08d41edab0ab6 100644 --- a/airflow/models/dagrun.py +++ b/airflow/models/dagrun.py @@ -28,7 +28,7 @@ from airflow.exceptions import AirflowException from airflow.models.base import ID_LEN, Base from airflow.stats import Stats -from airflow.ti_deps.dep_context import DepContext +from airflow.ti_deps.dep_context import SCHEDULEABLE_STATES, DepContext from airflow.utils import timezone from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.session import provide_session @@ -196,7 +196,6 @@ def get_task_instances(self, state=None, session=None): if self.dag and self.dag.partial: tis = tis.filter(TaskInstance.task_id.in_(self.dag.task_ids)) - return tis.all() @provide_session @@ -263,49 +262,33 @@ def update_state(self, session=None): Determines the overall state of the DagRun based on the state of its TaskInstances. - :return: State + :return: ready_tis: the tis that can be scheduled in the current loop + :rtype ready_tis: list[airflow.models.TaskInstance] """ dag = self.get_dag() - - tis = self.get_task_instances(session=session) - self.log.debug("Updating state for %s considering %s task(s)", self, len(tis)) - + ready_tis = [] + tis = [ti for ti in self.get_task_instances(session=session, + state=State.task_states + (State.SHUTDOWN,))] + self.log.debug("number of tis tasks for %s: %s task(s)", self, len(tis)) for ti in list(tis): - # skip in db? - if ti.state == State.REMOVED: - tis.remove(ti) - else: - ti.task = dag.get_task(ti.task_id) + ti.task = dag.get_task(ti.task_id) - # pre-calculate - # db is faster start_dttm = timezone.utcnow() - unfinished_tasks = self.get_task_instances( - state=State.unfinished(), - session=session - ) + unfinished_tasks = [t for t in tis if t.state in State.unfinished()] + finished_tasks = [t for t in tis if t.state in State.finished() + [State.UPSTREAM_FAILED]] none_depends_on_past = all(not t.task.depends_on_past for t in unfinished_tasks) none_task_concurrency = all(t.task.task_concurrency is None for t in unfinished_tasks) # small speed up if unfinished_tasks and none_depends_on_past and none_task_concurrency: - # todo: this can actually get pretty slow: one task costs between 0.01-015s - no_dependencies_met = True - for ut in unfinished_tasks: - # We need to flag upstream and check for changes because upstream - # failures/re-schedules can result in deadlock false positives - old_state = ut.state - deps_met = ut.are_dependencies_met( - dep_context=DepContext( - flag_upstream_failed=True, - ignore_in_retry_period=True, - ignore_in_reschedule_period=True), - session=session) - if deps_met or old_state != ut.current_state(session=session): - no_dependencies_met = False - break + scheduleable_tasks = [ut for ut in unfinished_tasks if ut.state in SCHEDULEABLE_STATES] + self.log.debug("number of scheduleable tasks for %s: %s task(s)", self, len(scheduleable_tasks)) + ready_tis, changed_tis = self._get_ready_tis(scheduleable_tasks, finished_tasks, session) + self.log.debug("ready tis length for %s: %s task(s)", self, len(ready_tis)) + are_runnable_tasks = ready_tis or self._are_premature_tis( + unfinished_tasks, finished_tasks, session) or changed_tis duration = (timezone.utcnow() - start_dttm) Stats.timing("dagrun.dependency-check.{}".format(self.dag_id), duration) @@ -330,7 +313,7 @@ def update_state(self, session=None): # if *all tasks* are deadlocked, the run failed elif (unfinished_tasks and none_depends_on_past and - none_task_concurrency and no_dependencies_met): + none_task_concurrency and not are_runnable_tasks): self.log.info('Deadlock; marking run %s failed', self) self.set_state(State.FAILED) dag.handle_callback(self, success=False, reason='all_tasks_deadlocked', @@ -346,7 +329,35 @@ def update_state(self, session=None): session.merge(self) session.commit() - return self.state + return ready_tis + + def _get_ready_tis(self, scheduleable_tasks, finished_tasks, session): + ready_tis = [] + changed_tis = False + for st in scheduleable_tasks: + st_old_state = st.state + if st.are_dependencies_met( + dep_context=DepContext( + flag_upstream_failed=True, + finished_tasks=finished_tasks), + session=session): + ready_tis.append(st) + elif st_old_state != st.current_state(session=session): + changed_tis = True + return ready_tis, changed_tis + + def _are_premature_tis(self, unfinished_tasks, finished_tasks, session): + # there might be runnable tasks that are up for retry and from some reason(retry delay, etc) are + # not ready yet so we set the flags to count them in + for ut in unfinished_tasks: + if ut.are_dependencies_met( + dep_context=DepContext( + flag_upstream_failed=True, + ignore_in_retry_period=True, + ignore_in_reschedule_period=True, + finished_tasks=finished_tasks), + session=session): + return True def _emit_duration_stats_for_finished_state(self): if self.state == State.RUNNING: diff --git a/airflow/ti_deps/dep_context.py b/airflow/ti_deps/dep_context.py index 3abec53132eab..7bb078c4cdffc 100644 --- a/airflow/ti_deps/dep_context.py +++ b/airflow/ti_deps/dep_context.py @@ -67,6 +67,8 @@ class DepContext: :type ignore_task_deps: bool :param ignore_ti_state: Ignore the task instance's previous failure/success :type ignore_ti_state: bool + :param finished_tasks: A list of all the finished tasks of this run + :type finished_tasks: list[airflow.models.TaskInstance] """ def __init__( self, @@ -77,7 +79,8 @@ def __init__( ignore_in_retry_period=False, ignore_in_reschedule_period=False, ignore_task_deps=False, - ignore_ti_state=False): + ignore_ti_state=False, + finished_tasks=None): self.deps = deps or set() self.flag_upstream_failed = flag_upstream_failed self.ignore_all_deps = ignore_all_deps @@ -86,6 +89,7 @@ def __init__( self.ignore_in_reschedule_period = ignore_in_reschedule_period self.ignore_task_deps = ignore_task_deps self.ignore_ti_state = ignore_ti_state + self.finished_tasks = finished_tasks # In order to be able to get queued a task must have one of these states diff --git a/airflow/ti_deps/deps/trigger_rule_dep.py b/airflow/ti_deps/deps/trigger_rule_dep.py index df43578ce849f..c8cf822a22a2c 100644 --- a/airflow/ti_deps/deps/trigger_rule_dep.py +++ b/airflow/ti_deps/deps/trigger_rule_dep.py @@ -17,7 +17,7 @@ # specific language governing permissions and limitations # under the License. -from sqlalchemy import case, func +from collections import Counter import airflow from airflow.ti_deps.deps.base_ti_dep import BaseTIDep @@ -34,11 +34,32 @@ class TriggerRuleDep(BaseTIDep): IGNOREABLE = True IS_TASK_DEP = True + @staticmethod + @provide_session + def _get_states_count_upstream_ti(ti, finished_tasks, session): + """ + This function returns the states of the upstream tis for a specific ti in order to determine + whether this ti can run in this iteration + + :param ti: the ti that we want to calculate deps for + :type ti: airflow.models.TaskInstance + :param finished_tasks: all the finished tasks of the dag_run + :type finished_tasks: list[airflow.models.TaskInstance] + """ + if finished_tasks is None: + # this is for the strange feature of running tasks without dag_run + finished_tasks = ti.task.dag.get_task_instances( + start_date=ti.execution_date, + end_date=ti.execution_date, + state=State.finished() + [State.UPSTREAM_FAILED], + session=session) + counter = Counter(task.state for task in finished_tasks if task.task_id in ti.task.upstream_task_ids) + return counter.get(State.SUCCESS, 0), counter.get(State.SKIPPED, 0), counter.get(State.FAILED, 0), \ + counter.get(State.UPSTREAM_FAILED, 0), sum(counter.values()) + @provide_session def _get_dep_statuses(self, ti, session, dep_context): - TI = airflow.models.TaskInstance TR = airflow.utils.trigger_rule.TriggerRule - # Checking that all upstream dependencies have succeeded if not ti.task.upstream_list: yield self._passing_status( @@ -48,34 +69,11 @@ def _get_dep_statuses(self, ti, session, dep_context): if ti.task.trigger_rule == TR.DUMMY: yield self._passing_status(reason="The task had a dummy trigger rule set.") return + # see if the task name is in the task upstream for our task + successes, skipped, failed, upstream_failed, done = self._get_states_count_upstream_ti( + ti=ti, + finished_tasks=dep_context.finished_tasks) - # TODO(unknown): this query becomes quite expensive with dags that have many - # tasks. It should be refactored to let the task report to the dag run and get the - # aggregates from there. - qry = ( - session - .query( - func.coalesce(func.sum( - case([(TI.state == State.SUCCESS, 1)], else_=0)), 0), - func.coalesce(func.sum( - case([(TI.state == State.SKIPPED, 1)], else_=0)), 0), - func.coalesce(func.sum( - case([(TI.state == State.FAILED, 1)], else_=0)), 0), - func.coalesce(func.sum( - case([(TI.state == State.UPSTREAM_FAILED, 1)], else_=0)), 0), - func.count(TI.task_id), - ) - .filter( - TI.dag_id == ti.dag_id, - TI.task_id.in_(ti.task.upstream_task_ids), - TI.execution_date == ti.execution_date, - TI.state.in_([ - State.SUCCESS, State.FAILED, - State.UPSTREAM_FAILED, State.SKIPPED]), - ) - ) - - successes, skipped, failed, upstream_failed, done = qry.first() yield from self._evaluate_trigger_rule( ti=ti, successes=successes, diff --git a/tests/jobs/test_backfill_job.py b/tests/jobs/test_backfill_job.py index 5b2523af417c3..9f3e0096655b8 100644 --- a/tests/jobs/test_backfill_job.py +++ b/tests/jobs/test_backfill_job.py @@ -1266,6 +1266,7 @@ def test_backfill_execute_subdag_with_removed_task(self): session = settings.Session() session.merge(removed_task_ti) + session.commit() with timeout(seconds=30): job.run() diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py index e94c8e496014a..a82e8d407d9d2 100644 --- a/tests/jobs/test_scheduler_job.py +++ b/tests/jobs/test_scheduler_job.py @@ -2034,8 +2034,8 @@ def test_dagrun_root_fail_unfinished(self): ti = dr.get_task_instance('test_dagrun_unfinished', session=session) ti.state = State.NONE session.commit() - dr_state = dr.update_state() - self.assertEqual(dr_state, State.RUNNING) + dr.update_state() + self.assertEqual(dr.state, State.RUNNING) def test_dagrun_root_after_dagrun_unfinished(self): """ diff --git a/tests/models/test_dagrun.py b/tests/models/test_dagrun.py index 43052eb78165a..8277773a34ca4 100644 --- a/tests/models/test_dagrun.py +++ b/tests/models/test_dagrun.py @@ -159,8 +159,8 @@ def test_dagrun_success_when_all_skipped(self): dag_run = self.create_dag_run(dag=dag, state=State.RUNNING, task_states=initial_task_states) - updated_dag_state = dag_run.update_state() - self.assertEqual(State.SUCCESS, updated_dag_state) + dag_run.update_state() + self.assertEqual(State.SUCCESS, dag_run.state) def test_dagrun_success_conditions(self): session = settings.Session() @@ -198,15 +198,15 @@ def test_dagrun_success_conditions(self): ti_op4 = dr.get_task_instance(task_id=op4.task_id) # root is successful, but unfinished tasks - state = dr.update_state() - self.assertEqual(State.RUNNING, state) + dr.update_state() + self.assertEqual(State.RUNNING, dr.state) # one has failed, but root is successful ti_op2.set_state(state=State.FAILED, session=session) ti_op3.set_state(state=State.SUCCESS, session=session) ti_op4.set_state(state=State.SUCCESS, session=session) - state = dr.update_state() - self.assertEqual(State.SUCCESS, state) + dr.update_state() + self.assertEqual(State.SUCCESS, dr.state) def test_dagrun_deadlock(self): session = settings.Session() @@ -321,8 +321,8 @@ def on_success_callable(context): dag_run = self.create_dag_run(dag=dag, state=State.RUNNING, task_states=initial_task_states) - updated_dag_state = dag_run.update_state() - self.assertEqual(State.SUCCESS, updated_dag_state) + dag_run.update_state() + self.assertEqual(State.SUCCESS, dag_run.state) def test_dagrun_failure_callback(self): def on_failure_callable(context): @@ -352,8 +352,8 @@ def on_failure_callable(context): dag_run = self.create_dag_run(dag=dag, state=State.RUNNING, task_states=initial_task_states) - updated_dag_state = dag_run.update_state() - self.assertEqual(State.FAILED, updated_dag_state) + dag_run.update_state() + self.assertEqual(State.FAILED, dag_run.state) def test_dagrun_set_state_end_date(self): session = settings.Session() diff --git a/tests/ti_deps/deps/test_trigger_rule_dep.py b/tests/ti_deps/deps/test_trigger_rule_dep.py index 78d9e0ca0c21c..3860287c21771 100644 --- a/tests/ti_deps/deps/test_trigger_rule_dep.py +++ b/tests/ti_deps/deps/test_trigger_rule_dep.py @@ -19,12 +19,16 @@ import unittest from datetime import datetime -from airflow.models import TaskInstance +from airflow import settings +from airflow.models import DAG, TaskInstance from airflow.models.baseoperator import BaseOperator +from airflow.operators.dummy_operator import DummyOperator from airflow.ti_deps.deps.trigger_rule_dep import TriggerRuleDep +from airflow.utils import timezone from airflow.utils.session import create_session from airflow.utils.state import State from airflow.utils.trigger_rule import TriggerRule +from tests.models import DEFAULT_DATE class TestTriggerRuleDep(unittest.TestCase): @@ -374,3 +378,59 @@ def test_unknown_tr(self): self.assertEqual(len(dep_statuses), 1) self.assertFalse(dep_statuses[0].passed) + + def test_get_states_count_upstream_ti(self): + """ + this test tests the helper function '_get_states_count_upstream_ti' as a unit and inside update_state + """ + get_states_count_upstream_ti = TriggerRuleDep._get_states_count_upstream_ti + session = settings.Session() + now = timezone.utcnow() + dag = DAG( + 'test_dagrun_with_pre_tis', + start_date=DEFAULT_DATE, + default_args={'owner': 'owner1'}) + + with dag: + op1 = DummyOperator(task_id='A') + op2 = DummyOperator(task_id='B') + op3 = DummyOperator(task_id='C') + op4 = DummyOperator(task_id='D') + op5 = DummyOperator(task_id='E', trigger_rule=TriggerRule.ONE_FAILED) + + op1.set_downstream([op2, op3]) # op1 >> op2, op3 + op4.set_upstream([op3, op2]) # op3, op2 >> op4 + op5.set_upstream([op2, op3, op4]) # (op2, op3, op4) >> op5 + + dag.clear() + dr = dag.create_dagrun(run_id='test_dagrun_with_pre_tis', + state=State.RUNNING, + execution_date=now, + start_date=now) + + ti_op1 = TaskInstance(task=dag.get_task(op1.task_id), execution_date=dr.execution_date) + ti_op2 = TaskInstance(task=dag.get_task(op2.task_id), execution_date=dr.execution_date) + ti_op3 = TaskInstance(task=dag.get_task(op3.task_id), execution_date=dr.execution_date) + ti_op4 = TaskInstance(task=dag.get_task(op4.task_id), execution_date=dr.execution_date) + ti_op5 = TaskInstance(task=dag.get_task(op5.task_id), execution_date=dr.execution_date) + + ti_op1.set_state(state=State.SUCCESS, session=session) + ti_op2.set_state(state=State.FAILED, session=session) + ti_op3.set_state(state=State.SUCCESS, session=session) + ti_op4.set_state(state=State.SUCCESS, session=session) + ti_op5.set_state(state=State.SUCCESS, session=session) + + # check handling with cases that tasks are triggered from backfill with no finished tasks + self.assertEqual(get_states_count_upstream_ti(finished_tasks=None, ti=ti_op2, session=session), + (1, 0, 0, 0, 1)) + finished_tasks = dr.get_task_instances(state=State.finished() + [State.UPSTREAM_FAILED], + session=session) + self.assertEqual(get_states_count_upstream_ti(finished_tasks=finished_tasks, ti=ti_op4, + session=session), + (1, 0, 1, 0, 2)) + self.assertEqual(get_states_count_upstream_ti(finished_tasks=finished_tasks, ti=ti_op5, + session=session), + (2, 0, 1, 0, 3)) + + dr.update_state() + self.assertEqual(State.SUCCESS, dr.state)