diff --git a/airflow/cli/commands/dag_command.py b/airflow/cli/commands/dag_command.py index 452e00b98a886..5b4da82143b81 100644 --- a/airflow/cli/commands/dag_command.py +++ b/airflow/cli/commands/dag_command.py @@ -102,7 +102,7 @@ def dag_backfill(args, dag=None): if args.task_regex: dag = dag.partial_subset( - task_regex=args.task_regex, + task_ids_or_regex=args.task_regex, include_upstream=not args.ignore_dependencies) run_conf = None diff --git a/airflow/cli/commands/task_command.py b/airflow/cli/commands/task_command.py index 6d48907aa6d70..d4bc9619603ea 100644 --- a/airflow/cli/commands/task_command.py +++ b/airflow/cli/commands/task_command.py @@ -411,7 +411,7 @@ def task_clear(args): if args.task_regex: for idx, dag in enumerate(dags): dags[idx] = dag.partial_subset( - task_regex=args.task_regex, + task_ids_or_regex=args.task_regex, include_downstream=args.downstream, include_upstream=args.upstream) diff --git a/airflow/config_templates/config.yml b/airflow/config_templates/config.yml index 5384dc748c2be..732183bda8f33 100644 --- a/airflow/config_templates/config.yml +++ b/airflow/config_templates/config.yml @@ -1617,6 +1617,17 @@ type: string default: ~ see_also: ":ref:`scheduler:ha:tunables`" + - name: schedule_after_task_execution + description: | + Should the Task supervisor process perform a "mini scheduler" to attempt to schedule more tasks of the + same DAG. Leaving this on will mean tasks in the same DAG execute quicker, but might starve out other + dags in some circumstances + + Default: True + example: ~ + version_added: 2.0.0 + type: boolean + default: ~ - name: statsd_on description: | Statsd (https://github.com/etsy/statsd) integration settings diff --git a/airflow/config_templates/default_airflow.cfg b/airflow/config_templates/default_airflow.cfg index 291b157aca665..e93eaf2191fde 100644 --- a/airflow/config_templates/default_airflow.cfg +++ b/airflow/config_templates/default_airflow.cfg @@ -811,6 +811,13 @@ use_row_level_locking = True # Default: 20 # max_dagruns_per_loop_to_schedule = +# Should the Task supervisor process perform a "mini scheduler" to attempt to schedule more tasks of the +# same DAG. Leaving this on will mean tasks in the same DAG execute quicker, but might starve out other +# dags in some circumstances +# +# Default: True +# schedule_after_task_execution = + # Statsd (https://github.com/etsy/statsd) integration settings statsd_on = False statsd_host = localhost diff --git a/airflow/config_templates/default_test.cfg b/airflow/config_templates/default_test.cfg index 5c1f1d1cbc085..c741bd58e92c1 100644 --- a/airflow/config_templates/default_test.cfg +++ b/airflow/config_templates/default_test.cfg @@ -102,6 +102,7 @@ sync_parallelism = 0 [scheduler] job_heartbeat_sec = 1 +schedule_after_task_execution = False scheduler_heartbeat_sec = 5 scheduler_health_check_threshold = 30 max_threads = 2 diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py index 99e64e90ceb07..a014a9524cb69 100644 --- a/airflow/jobs/scheduler_job.py +++ b/airflow/jobs/scheduler_job.py @@ -1652,42 +1652,11 @@ def _schedule_dag_run( self._send_dag_callbacks_to_processor(dag_run, callback_to_run) - # Get list of TIs that do not need to executed, these are - # tasks using DummyOperator and without on_execute_callback / on_success_callback - dummy_tis = [ - ti for ti in schedulable_tis - if - ( - ti.task.task_type == "DummyOperator" - and not ti.task.on_execute_callback - and not ti.task.on_success_callback - ) - ] - # This will do one query per dag run. We "could" build up a complex # query to update all the TIs across all the execution dates and dag # IDs in a single query, but it turns out that can be _very very slow_ # see #11147/commit ee90807ac for more details - count = session.query(TI).filter( - TI.dag_id == dag_run.dag_id, - TI.execution_date == dag_run.execution_date, - TI.task_id.in_(ti.task_id for ti in schedulable_tis if ti not in dummy_tis) - ).update({TI.state: State.SCHEDULED}, synchronize_session=False) - - # Tasks using DummyOperator should not be executed, mark them as success - if dummy_tis: - session.query(TI).filter( - TI.dag_id == dag_run.dag_id, - TI.execution_date == dag_run.execution_date, - TI.task_id.in_(ti.task_id for ti in dummy_tis) - ).update({ - TI.state: State.SUCCESS, - TI.start_date: timezone.utcnow(), - TI.end_date: timezone.utcnow(), - TI.duration: 0 - }, synchronize_session=False) - - return count + return dag_run.schedule_tis(schedulable_tis, session) @provide_session def _verify_integrity_if_dag_changed(self, dag_run: DagRun, session=None): diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 08487d376da2c..d85beff099024 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -1013,7 +1013,8 @@ def get_task_instances(self, start_date: Optional[datetime] = None, def get_flat_relative_ids(self, upstream: bool = False, - found_descendants: Optional[Set[str]] = None) -> Set[str]: + found_descendants: Optional[Set[str]] = None, + ) -> Set[str]: """Get a flat set of relatives' ids, either upstream or downstream.""" if not self._dag: return set() @@ -1026,8 +1027,7 @@ def get_flat_relative_ids(self, if relative_id not in found_descendants: found_descendants.add(relative_id) relative_task = self._dag.task_dict[relative_id] - relative_task.get_flat_relative_ids(upstream, - found_descendants) + relative_task.get_flat_relative_ids(upstream, found_descendants) return found_descendants diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 45bdf779d2017..5c190ed0d22ff 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -68,6 +68,14 @@ if TYPE_CHECKING: from airflow.utils.task_group import TaskGroup + +# Before Py 3.7, there is no re.Pattern class +try: + from re import Pattern as PatternType # type: ignore +except ImportError: + PatternType = type(re.compile('', 0)) + + log = logging.getLogger(__name__) ScheduleInterval = Union[str, timedelta, relativedelta] @@ -1172,7 +1180,7 @@ def clear( if include_parentdag and self.is_subdag and self.parent_dag is not None: p_dag = self.parent_dag.sub_dag( - task_regex=r"^{}$".format(self.dag_id.split('.')[1]), + task_ids_or_regex=r"^{}$".format(self.dag_id.split('.')[1]), include_upstream=False, include_downstream=True) @@ -1245,7 +1253,7 @@ def clear( if not external_dag: raise AirflowException("Could not find dag {}".format(tii.dag_id)) downstream = external_dag.sub_dag( - task_regex=r"^{}$".format(tii.task_id), + task_ids_or_regex=r"^{}$".format(tii.task_id), include_upstream=False, include_downstream=True ) @@ -1394,36 +1402,54 @@ def sub_dag(self, *args, **kwargs): return self.partial_subset(*args, **kwargs) def partial_subset( - self, task_regex, include_downstream=False, include_upstream=True + self, + task_ids_or_regex: Union[str, PatternType, Iterable[str]], + include_downstream=False, + include_upstream=True, + include_direct_upstream=False, ): """ Returns a subset of the current dag as a deep copy of the current dag based on a regex that should match one or many tasks, and includes upstream and downstream neighbours based on the flag passed. + + :param task_ids_or_regex: Either a list of task_ids, or a regex to + match against task ids (as a string, or compiled regex pattern). + :type task_ids_or_regex: [str] or str or re.Pattern + :param include_downstream: Include all downstream tasks of matched + tasks, in addition to matched tasks. + :param include_upstream: Include all upstream tasks of matched tasks, + in addition to matched tasks. """ # deep-copying self.task_dict and self._task_group takes a long time, and we don't want all # the tasks anyway, so we copy the tasks manually later task_dict = self.task_dict task_group = self._task_group self.task_dict = {} - self._task_group = None + self._task_group = None # type: ignore dag = copy.deepcopy(self) self.task_dict = task_dict self._task_group = task_group - regex_match = [ - t for t in self.tasks if re.findall(task_regex, t.task_id)] + if isinstance(task_ids_or_regex, (str, PatternType)): + matched_tasks = [ + t for t in self.tasks if re.findall(task_ids_or_regex, t.task_id)] + else: + matched_tasks = [t for t in self.tasks if t.task_id in task_ids_or_regex] + also_include = [] - for t in regex_match: + for t in matched_tasks: if include_downstream: also_include += t.get_flat_relatives(upstream=False) if include_upstream: also_include += t.get_flat_relatives(upstream=True) + elif include_direct_upstream: + also_include += t.upstream_list # Compiling the unique list of tasks that made the cut # Make sure to not recursively deepcopy the dag while copying the task - dag.task_dict = {t.task_id: copy.deepcopy(t, {id(t.dag): dag}) - for t in regex_match + also_include} + dag.task_dict = {t.task_id: copy.deepcopy(t, {id(t.dag): dag}) # type: ignore + for t in matched_tasks + also_include} def filter_task_group(group, parent_group): """Exclude tasks not included in the subdag from the given TaskGroup.""" diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py index 9b0166849dc72..e9adf47315e6f 100644 --- a/airflow/models/dagrun.py +++ b/airflow/models/dagrun.py @@ -16,7 +16,7 @@ # specific language governing permissions and limitations # under the License. from datetime import datetime -from typing import Any, List, Optional, Tuple, Union +from typing import Any, Iterable, List, NamedTuple, Optional, Tuple, Union from sqlalchemy import ( Boolean, Column, DateTime, Index, Integer, PickleType, String, UniqueConstraint, and_, func, or_, @@ -43,6 +43,16 @@ from airflow.utils.types import DagRunType +class TISchedulingDecision(NamedTuple): + """Type of return for DagRun.task_instance_scheduling_decisions""" + + tis: List[TI] + schedulable_tis: List[TI] + changed_tis: bool + unfinished_tasks: List[TI] + finished_tasks: List[TI] + + class DagRun(Base, LoggingMixin): """ DagRun describes an instance of a Dag. It can be created @@ -380,27 +390,21 @@ def update_state( self.last_scheduling_decision = start_dttm dag = self.get_dag() - ready_tis: List[TI] = [] - tis = list(self.get_task_instances(session=session, state=State.task_states + (State.SHUTDOWN,))) - self.log.debug("number of tis tasks for %s: %s task(s)", self, len(tis)) - for ti in tis: - ti.task = dag.get_task(ti.task_id) + info = self.task_instance_scheduling_decisions(session) + + tis = info.tis + schedulable_tis = info.schedulable_tis + changed_tis = info.changed_tis + finished_tasks = info.finished_tasks + unfinished_tasks = info.unfinished_tasks - unfinished_tasks = [t for t in tis if t.state in State.unfinished] - finished_tasks = [t for t in tis if t.state in State.finished] none_depends_on_past = all(not t.task.depends_on_past for t in unfinished_tasks) none_task_concurrency = all(t.task.task_concurrency is None for t in unfinished_tasks) - if unfinished_tasks: - scheduleable_tasks = [ut for ut in unfinished_tasks if ut.state in SCHEDULEABLE_STATES] - self.log.debug( - "number of scheduleable tasks for %s: %s task(s)", - self, len(scheduleable_tasks)) - ready_tis, changed_tis = self._get_ready_tis(scheduleable_tasks, finished_tasks, session) - self.log.debug("ready tis length for %s: %s task(s)", self, len(ready_tis)) - if none_depends_on_past and none_task_concurrency: - # small speed up - are_runnable_tasks = ready_tis or self._are_premature_tis( - unfinished_tasks, finished_tasks, session) or changed_tis + + if unfinished_tasks and none_depends_on_past and none_task_concurrency: + # small speed up + are_runnable_tasks = schedulable_tis or self._are_premature_tis( + unfinished_tasks, finished_tasks, session) or changed_tis duration = (timezone.utcnow() - start_dttm) Stats.timing("dagrun.dependency-check.{}".format(self.dag_id), duration) @@ -466,7 +470,35 @@ def update_state( session.merge(self) - return ready_tis, callback + return schedulable_tis, callback + + @provide_session + def task_instance_scheduling_decisions(self, session: Session = None) -> TISchedulingDecision: + + schedulable_tis: List[TI] = [] + changed_tis = False + + tis = list(self.get_task_instances(session=session, state=State.task_states + (State.SHUTDOWN,))) + self.log.debug("number of tis tasks for %s: %s task(s)", self, len(tis)) + for ti in tis: + ti.task = self.get_dag().get_task(ti.task_id) + + unfinished_tasks = [t for t in tis if t.state in State.unfinished] + finished_tasks = [t for t in tis if t.state in State.finished] + if unfinished_tasks: + scheduleable_tasks = [ut for ut in unfinished_tasks if ut.state in SCHEDULEABLE_STATES] + self.log.debug( + "number of scheduleable tasks for %s: %s task(s)", + self, len(scheduleable_tasks)) + schedulable_tis, changed_tis = self._get_ready_tis(scheduleable_tasks, finished_tasks, session) + + return TISchedulingDecision( + tis=tis, + schedulable_tis=schedulable_tis, + changed_tis=changed_tis, + unfinished_tasks=unfinished_tasks, + finished_tasks=finished_tasks, + ) def _get_ready_tis( self, @@ -638,3 +670,52 @@ def get_latest_runs(cls, session=None): .all() ) return dagruns + + @provide_session + def schedule_tis(self, schedulable_tis: Iterable[TI], session: Session = None) -> int: + """ + Set the given task instances in to the scheduled state. + + Each element of ``schedulable_tis`` should have it's ``task`` attribute already set. + + Any DummyOperator without callbacks is instead set straight to the success state. + + All the TIs should belong to this DagRun, but this code is in the hot-path, this is not checked -- it + is the caller's responsibility to call this function only with TIs from a single dag run. + """ + # Get list of TIs that do not need to executed, these are + # tasks using DummyOperator and without on_execute_callback / on_success_callback + dummy_tis = { + ti for ti in schedulable_tis + if + ( + ti.task.task_type == "DummyOperator" + and not ti.task.on_execute_callback + and not ti.task.on_success_callback + ) + } + + schedulable_ti_ids = [ti.task_id for ti in schedulable_tis if ti not in dummy_tis] + count = 0 + + if schedulable_ti_ids: + count += session.query(TI).filter( + TI.dag_id == self.dag_id, + TI.execution_date == self.execution_date, + TI.task_id.in_(schedulable_ti_ids) + ).update({TI.state: State.SCHEDULED}, synchronize_session=False) + + # Tasks using DummyOperator should not be executed, mark them as success + if dummy_tis: + count += session.query(TI).filter( + TI.dag_id == self.dag_id, + TI.execution_date == self.execution_date, + TI.task_id.in_(ti.task_id for ti in dummy_tis) + ).update({ + TI.state: State.SUCCESS, + TI.start_date: timezone.utcnow(), + TI.end_date: timezone.utcnow(), + TI.duration: 0 + }, synchronize_session=False) + + return count diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 83d8f0eb19d59..a1af753858f94 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -34,6 +34,7 @@ import pendulum from jinja2 import TemplateAssertionError, UndefinedError from sqlalchemy import Column, Float, Index, Integer, PickleType, String, and_, func, or_ +from sqlalchemy.exc import OperationalError from sqlalchemy.orm import reconstructor, relationship from sqlalchemy.orm.session import Session from sqlalchemy.sql.elements import BooleanClauseList @@ -61,7 +62,7 @@ from airflow.utils.net import get_hostname from airflow.utils.operator_helpers import context_to_airflow_vars from airflow.utils.session import provide_session -from airflow.utils.sqlalchemy import UtcDateTime +from airflow.utils.sqlalchemy import UtcDateTime, with_row_locks from airflow.utils.state import State from airflow.utils.timeout import timeout @@ -1135,8 +1136,63 @@ def _run_raw_task( if not test_mode: session.add(Log(self.state, self)) session.merge(self) + session.commit() + self._run_mini_scheduler_on_child_tasks(session) + + @provide_session + @Sentry.enrich_errors + def _run_mini_scheduler_on_child_tasks(self, session=None) -> None: + if conf.getboolean('scheduler', 'schedule_after_task_execution', fallback=True): + from airflow.models.dagrun import DagRun # Avoid circular import + + try: + # Re-select the row with a lock + dag_run = with_row_locks(session.query(DagRun).filter_by( + dag_id=self.dag_id, + execution_date=self.execution_date, + )).one() + + # Get a partial dag with just the specific tasks we want to + # examine. In order for dep checks to work correctly, we + # include ourself (so TriggerRuleDep can check the state of the + # task we just executed) + partial_dag = self.task.dag.partial_subset( + self.task.downstream_task_ids, + include_downstream=False, + include_upstream=False, + include_direct_upstream=True, + ) + + dag_run.dag = partial_dag + info = dag_run.task_instance_scheduling_decisions(session) + + skippable_task_ids = { + task_id + for task_id in partial_dag.task_ids + if task_id not in self.task.downstream_task_ids + } + + schedulable_tis = [ + ti for ti in info.schedulable_tis if ti.task_id not in skippable_task_ids + ] + for schedulable_ti in schedulable_tis: + if not hasattr(schedulable_ti, "task"): + schedulable_ti.task = self.task.dag.get_task(schedulable_ti.task_id) + + num = dag_run.schedule_tis(schedulable_tis) + self.log.info("%d downstream tasks scheduled from follow-on schedule check", num) + + session.commit() + except OperationalError as e: + # Any kind of DB error here is _non fatal_ as this block is just an optimisation. + self.log.info( + "Skipping mini scheduling run due to exception: {}".format(e.statement), + exc_info=True, + ) + session.rollback() + def _prepare_and_execute_task_with_callbacks( self, context, diff --git a/airflow/www/views.py b/airflow/www/views.py index ac585f851b682..47ebac0ee25b5 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -1433,7 +1433,7 @@ def clear(self): only_failed = request.form.get('only_failed') == "true" dag = dag.sub_dag( - task_regex=r"^{0}$".format(task_id), + task_ids_or_regex=r"^{0}$".format(task_id), include_downstream=downstream, include_upstream=upstream) @@ -1706,7 +1706,7 @@ def tree(self): root = request.args.get('root') if root: dag = dag.sub_dag( - task_regex=root, + task_ids_or_regex=root, include_downstream=False, include_upstream=True) @@ -1876,7 +1876,7 @@ def graph(self, session=None): root = request.args.get('root') if root: dag = dag.sub_dag( - task_regex=root, + task_ids_or_regex=root, include_upstream=True, include_downstream=False) @@ -1980,7 +1980,7 @@ def duration(self, session=None): root = request.args.get('root') if root: dag = dag.sub_dag( - task_regex=root, + task_ids_or_regex=root, include_upstream=True, include_downstream=False) @@ -2091,7 +2091,7 @@ def tries(self, session=None): root = request.args.get('root') if root: dag = dag.sub_dag( - task_regex=root, + task_ids_or_regex=root, include_upstream=True, include_downstream=False) @@ -2161,7 +2161,7 @@ def landing_times(self, session=None): root = request.args.get('root') if root: dag = dag.sub_dag( - task_regex=root, + task_ids_or_regex=root, include_upstream=True, include_downstream=False) @@ -2284,7 +2284,7 @@ def gantt(self, session=None): root = request.args.get('root') if root: dag = dag.sub_dag( - task_regex=root, + task_ids_or_regex=root, include_upstream=True, include_downstream=False) diff --git a/tests/core/test_core.py b/tests/core/test_core.py index 629cf3167f423..71a22368d7461 100644 --- a/tests/core/test_core.py +++ b/tests/core/test_core.py @@ -43,6 +43,7 @@ from airflow.utils.timezone import datetime from airflow.utils.types import DagRunType from tests.test_utils.config import conf_vars +from tests.test_utils.db import clear_db_dags, clear_db_runs DEV_NULL = '/dev/null' DEFAULT_DATE = datetime(2015, 1, 1) @@ -88,6 +89,8 @@ def tearDown(self): synchronize_session=False) session.commit() session.close() + clear_db_dags() + clear_db_runs() def test_check_operators(self): @@ -97,11 +100,13 @@ def test_check_operators(self): captain_hook.run("CREATE TABLE operator_test_table (a, b)") captain_hook.run("insert into operator_test_table values (1,2)") + self.dag.create_dagrun(run_type=DagRunType.MANUAL, state=State.RUNNING, execution_date=DEFAULT_DATE) op = CheckOperator( task_id='check', sql="select count(*) from operator_test_table", conn_id=conn_id, dag=self.dag) + op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) op = ValueCheckOperator( @@ -158,6 +163,8 @@ def test_bash_operator(self): task_id='test_bash_operator', bash_command="echo success", dag=self.dag) + self.dag.create_dagrun(run_type=DagRunType.MANUAL, state=State.RUNNING, execution_date=DEFAULT_DATE) + op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) def test_bash_operator_multi_byte_output(self): @@ -166,6 +173,7 @@ def test_bash_operator_multi_byte_output(self): bash_command="echo \u2600", dag=self.dag, output_encoding='utf-8') + self.dag.create_dagrun(run_type=DagRunType.MANUAL, state=State.RUNNING, execution_date=DEFAULT_DATE) op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) def test_bash_operator_kill(self): @@ -222,6 +230,7 @@ def test_sqlite(self): task_id='time_sqlite', sql="CREATE TABLE IF NOT EXISTS unitest (dummy VARCHAR(20))", dag=self.dag) + self.dag.create_dagrun(run_type=DagRunType.MANUAL, state=State.RUNNING, execution_date=DEFAULT_DATE) op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) def test_timeout(self): @@ -245,6 +254,7 @@ def test_py_op(templates_dict, ds, **kwargs): python_callable=test_py_op, templates_dict={'ds': "{{ ds }}"}, dag=self.dag) + self.dag.create_dagrun(run_type=DagRunType.MANUAL, state=State.RUNNING, execution_date=DEFAULT_DATE) op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) def test_complex_template(self): @@ -260,6 +270,7 @@ def verify_templated_field(context): }, dag=self.dag) op.execute = verify_templated_field + self.dag.create_dagrun(run_type=DagRunType.MANUAL, state=State.RUNNING, execution_date=DEFAULT_DATE) op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) def test_template_non_bool(self): @@ -285,6 +296,11 @@ def test_task_get_template(self): ti = TI( task=self.runme_0, execution_date=DEFAULT_DATE) ti.dag = self.dag_bash + self.dag_bash.create_dagrun( + run_type=DagRunType.MANUAL, + state=State.RUNNING, + execution_date=DEFAULT_DATE + ) ti.run(ignore_ti_state=True) context = ti.get_template_context() @@ -322,6 +338,11 @@ def test_raw_job(self): ti = TI( task=self.runme_0, execution_date=DEFAULT_DATE) ti.dag = self.dag_bash + self.dag_bash.create_dagrun( + run_type=DagRunType.MANUAL, + state=State.RUNNING, + execution_date=DEFAULT_DATE + ) ti.run(ignore_ti_state=True) def test_round_time(self): diff --git a/tests/jobs/test_backfill_job.py b/tests/jobs/test_backfill_job.py index 55d5f1ab7dea4..383200f976b32 100644 --- a/tests/jobs/test_backfill_job.py +++ b/tests/jobs/test_backfill_job.py @@ -1056,7 +1056,7 @@ def test_sub_set_subdag(self): start_date=DEFAULT_DATE) executor = MockExecutor() - sub_dag = dag.sub_dag(task_regex="leave*", + sub_dag = dag.sub_dag(task_ids_or_regex="leave*", include_downstream=False, include_upstream=False) job = BackfillJob(dag=sub_dag, @@ -1226,7 +1226,7 @@ def test_subdag_clear_parentdag_downstream_clear(self): self.assertEqual(ti_downstream.state, State.SUCCESS) sdag = subdag.sub_dag( - task_regex='daily_job_subdag_task', + task_ids_or_regex='daily_job_subdag_task', include_downstream=True, include_upstream=False) diff --git a/tests/models/test_cleartasks.py b/tests/models/test_cleartasks.py index bd6244dcd0dec..9efe51ae150db 100644 --- a/tests/models/test_cleartasks.py +++ b/tests/models/test_cleartasks.py @@ -24,6 +24,7 @@ from airflow.operators.dummy_operator import DummyOperator from airflow.utils.session import create_session from airflow.utils.state import State +from airflow.utils.types import DagRunType from tests.models import DEFAULT_DATE from tests.test_utils import db @@ -44,6 +45,12 @@ def test_clear_task_instances(self): ti0 = TI(task=task0, execution_date=DEFAULT_DATE) ti1 = TI(task=task1, execution_date=DEFAULT_DATE) + dag.create_dagrun( + execution_date=ti0.execution_date, + state=State.RUNNING, + run_type=DagRunType.SCHEDULED, + ) + ti0.run() ti1.run() with create_session() as session: @@ -66,6 +73,13 @@ def test_clear_task_instances_without_task(self): task1 = DummyOperator(task_id='task1', owner='test', dag=dag, retries=2) ti0 = TI(task=task0, execution_date=DEFAULT_DATE) ti1 = TI(task=task1, execution_date=DEFAULT_DATE) + + dag.create_dagrun( + execution_date=ti0.execution_date, + state=State.RUNNING, + run_type=DagRunType.SCHEDULED, + ) + ti0.run() ti1.run() @@ -95,6 +109,13 @@ def test_clear_task_instances_without_dag(self): task1 = DummyOperator(task_id='task_1', owner='test', dag=dag, retries=2) ti0 = TI(task=task0, execution_date=DEFAULT_DATE) ti1 = TI(task=task1, execution_date=DEFAULT_DATE) + + dag.create_dagrun( + execution_date=ti0.execution_date, + state=State.RUNNING, + run_type=DagRunType.SCHEDULED, + ) + ti0.run() ti1.run() @@ -117,6 +138,13 @@ def test_dag_clear(self): end_date=DEFAULT_DATE + datetime.timedelta(days=10)) task0 = DummyOperator(task_id='test_dag_clear_task_0', owner='test', dag=dag) ti0 = TI(task=task0, execution_date=DEFAULT_DATE) + + dag.create_dagrun( + execution_date=ti0.execution_date, + state=State.RUNNING, + run_type=DagRunType.SCHEDULED, + ) + # Next try to run will be try 1 self.assertEqual(ti0.try_number, 1) ti0.run() @@ -158,6 +186,12 @@ def test_dags_clear(self): ti = TI(task=DummyOperator(task_id='test_task_clear_' + str(i), owner='test', dag=dag), execution_date=DEFAULT_DATE) + + dag.create_dagrun( + execution_date=ti.execution_date, + state=State.RUNNING, + run_type=DagRunType.SCHEDULED, + ) dags.append(dag) tis.append(ti) @@ -221,6 +255,13 @@ def test_operator_clear(self): ti1 = TI(task=op1, execution_date=DEFAULT_DATE) ti2 = TI(task=op2, execution_date=DEFAULT_DATE) + + dag.create_dagrun( + execution_date=ti1.execution_date, + state=State.RUNNING, + run_type=DagRunType.SCHEDULED, + ) + ti2.run() # Dependency not met self.assertEqual(ti2.try_number, 1) @@ -228,7 +269,7 @@ def test_operator_clear(self): op2.clear(upstream=True) ti1.run() - ti2.run() + ti2.run(ignore_ti_state=True) self.assertEqual(ti1.try_number, 2) # max_tries is 0 because there is no task instance in db for ti1 # so clear won't change the max_tries. diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index 7f8dcfd5a6b05..5a7cd9ab4f1be 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -32,8 +32,9 @@ from airflow import models, settings from airflow.exceptions import AirflowException, AirflowFailException, AirflowSkipException +from airflow.jobs.scheduler_job import SchedulerJob from airflow.models import ( - DAG, DagRun, Pool, RenderedTaskInstanceFields, TaskInstance as TI, TaskReschedule, Variable, + DAG, DagModel, DagRun, Pool, RenderedTaskInstanceFields, TaskInstance as TI, TaskReschedule, Variable, ) from airflow.operators.bash import BashOperator from airflow.operators.dummy_operator import DummyOperator @@ -84,6 +85,7 @@ class TestTaskInstance(unittest.TestCase): @staticmethod def clean_db(): + db.clear_db_dags() db.clear_db_pools() db.clear_db_runs() db.clear_db_task_fail() @@ -343,6 +345,13 @@ def test_mark_non_runnable_task_as_success(self): # TI.run() will sync from DB before validating deps. with create_session() as session: session.add(ti) + + dag.create_dagrun( + execution_date=ti.execution_date, + state=State.RUNNING, + run_type=DagRunType.SCHEDULED, + session=session, + ) session.commit() ti.run(mark_success=True) self.assertEqual(ti.state, State.SUCCESS) @@ -355,8 +364,13 @@ def test_run_pooling_task(self): task = DummyOperator(task_id='test_run_pooling_task_op', dag=dag, pool='test_pool', owner='airflow', start_date=timezone.datetime(2016, 2, 1, 0, 0, 0)) - ti = TI( - task=task, execution_date=timezone.utcnow()) + ti = TI(task=task, execution_date=timezone.utcnow()) + + dag.create_dagrun( + execution_date=ti.execution_date, + state=State.RUNNING, + run_type=DagRunType.SCHEDULED, + ) ti.run() db.clear_db_pools() @@ -384,8 +398,14 @@ def test_ti_updates_with_task(self, session=None): task = DummyOperator(task_id='test_run_pooling_task_op', owner='airflow', executor_config={'foo': 'bar'}, start_date=timezone.datetime(2016, 2, 1, 0, 0, 0)) - ti = TI( - task=task, execution_date=timezone.utcnow()) + ti = TI(task=task, execution_date=timezone.utcnow()) + + dag.create_dagrun( + execution_date=ti.execution_date, + state=State.RUNNING, + run_type=DagRunType.SCHEDULED, + session=session, + ) ti.run(session=session) tis = dag.get_task_instances() @@ -395,11 +415,18 @@ def test_ti_updates_with_task(self, session=None): executor_config={'bar': 'baz'}, start_date=timezone.datetime(2016, 2, 1, 0, 0, 0)) - ti = TI( - task=task2, execution_date=timezone.utcnow()) + ti = TI(task=task2, execution_date=timezone.utcnow()) + + dag.create_dagrun( + execution_date=ti.execution_date, + state=State.RUNNING, + run_type=DagRunType.SCHEDULED, + session=session, + ) ti.run(session=session) tis = dag.get_task_instances() self.assertEqual({'bar': 'baz'}, tis[1].executor_config) + session.rollback() def test_run_pooling_task_with_mark_success(self): """ @@ -414,8 +441,13 @@ def test_run_pooling_task_with_mark_success(self): pool='test_pool', owner='airflow', start_date=timezone.datetime(2016, 2, 1, 0, 0, 0)) - ti = TI( - task=task, execution_date=timezone.utcnow()) + ti = TI(task=task, execution_date=timezone.utcnow()) + + dag.create_dagrun( + execution_date=ti.execution_date, + state=State.RUNNING, + run_type=DagRunType.SCHEDULED, + ) ti.run(mark_success=True) self.assertEqual(ti.state, State.SUCCESS) @@ -435,8 +467,12 @@ def raise_skip_exception(): python_callable=raise_skip_exception, owner='airflow', start_date=timezone.datetime(2016, 2, 1, 0, 0, 0)) - ti = TI( - task=task, execution_date=timezone.utcnow()) + ti = TI(task=task, execution_date=timezone.utcnow()) + dag.create_dagrun( + execution_date=ti.execution_date, + state=State.RUNNING, + run_type=DagRunType.SCHEDULED, + ) ti.run() self.assertEqual(State.SKIPPED, ti.state) @@ -460,8 +496,12 @@ def run_with_error(ti): except AirflowException: pass - ti = TI( - task=task, execution_date=timezone.utcnow()) + ti = TI(task=task, execution_date=timezone.utcnow()) + dag.create_dagrun( + execution_date=ti.execution_date, + state=State.RUNNING, + run_type=DagRunType.SCHEDULED, + ) self.assertEqual(ti.try_number, 1) # first run -- up for retry @@ -632,6 +672,12 @@ def func(): self.assertEqual(ti._try_number, 0) self.assertEqual(ti.try_number, 1) + dag.create_dagrun( + execution_date=ti.execution_date, + state=State.RUNNING, + run_type=DagRunType.SCHEDULED, + ) + def run_ti_and_assert(run_date, expected_start_date, expected_end_date, expected_duration, expected_state, expected_try_number, @@ -776,6 +822,12 @@ def test_depends_on_past(self): run_date = task.start_date + datetime.timedelta(days=5) + dag.create_dagrun( + execution_date=run_date, + state=State.RUNNING, + run_type=DagRunType.SCHEDULED, + ) + ti = TI(task, run_date) # depends_on_past prevents the run @@ -949,8 +1001,14 @@ def test_xcom_pull_after_success(self): owner='airflow', start_date=timezone.datetime(2016, 6, 2, 0, 0, 0)) exec_date = timezone.utcnow() - ti = TI( - task=task, execution_date=exec_date) + ti = TI(task=task, execution_date=exec_date) + + dag.create_dagrun( + execution_date=ti.execution_date, + state=State.RUNNING, + run_type=DagRunType.SCHEDULED, + ) + ti.run(mark_success=True) ti.xcom_push(key=key, value=value) self.assertEqual(ti.xcom_pull(task_ids='test_xcom', key=key), value) @@ -983,8 +1041,14 @@ def test_xcom_pull_different_execution_date(self): owner='airflow', start_date=timezone.datetime(2016, 6, 2, 0, 0, 0)) exec_date = timezone.utcnow() - ti = TI( - task=task, execution_date=exec_date) + ti = TI(task=task, execution_date=exec_date) + + dag.create_dagrun( + execution_date=ti.execution_date, + state=State.RUNNING, + run_type=DagRunType.SCHEDULED, + ) + ti.run(mark_success=True) ti.xcom_push(key=key, value=value) self.assertEqual(ti.xcom_pull(task_ids='test_xcom', key=key), value) @@ -1021,6 +1085,11 @@ def test_xcom_push_flag(self): start_date=datetime.datetime(2017, 1, 1) ) ti = TI(task=task, execution_date=datetime.datetime(2017, 1, 1)) + dag.create_dagrun( + execution_date=ti.execution_date, + state=State.RUNNING, + run_type=DagRunType.SCHEDULED, + ) ti.run() self.assertEqual( ti.xcom_pull( @@ -1189,7 +1258,7 @@ def test_email_alert(self, mock_send_email): start_date=DEFAULT_DATE, email='to') - ti = TI(task=task, execution_date=datetime.datetime.now()) + ti = TI(task=task, execution_date=timezone.utcnow()) try: ti.run() @@ -1216,8 +1285,7 @@ def test_email_alert_with_config(self, mock_send_email): start_date=DEFAULT_DATE, email='to') - ti = TI( - task=task, execution_date=datetime.datetime.now()) + ti = TI(task=task, execution_date=timezone.utcnow()) opener = mock_open(read_data='template: {{ti.task_id}}') with patch('airflow.models.taskinstance.open', opener, create=True): @@ -1258,7 +1326,15 @@ def test_success_callback_no_race_condition(self): ti.state = State.RUNNING session = settings.Session() session.merge(ti) + + dag.create_dagrun( + execution_date=ti.execution_date, + state=State.RUNNING, + run_type=DagRunType.SCHEDULED, + session=session, + ) session.commit() + callback_wrapper.wrap_task_instance(ti) ti._run_raw_task() self.assertTrue(callback_wrapper.callback_ran) @@ -1481,8 +1557,16 @@ def on_execute_callable(context): ti = TI(task=task, execution_date=datetime.datetime.now()) ti.state = State.RUNNING session = settings.Session() + + dag.create_dagrun( + execution_date=ti.execution_date, + state=State.RUNNING, + run_type=DagRunType.SCHEDULED, + session=session, + ) session.merge(ti) session.commit() + ti._run_raw_task() assert called ti.refresh_from_db() @@ -1670,6 +1754,106 @@ def test_get_rendered_template_fields(self): with create_session() as session: session.query(RenderedTaskInstanceFields).delete() + def validate_ti_states(self, dag_run, ti_state_mapping, error_message): + for task_id, expected_state in ti_state_mapping.items(): + task_instance = dag_run.get_task_instance(task_id=task_id) + self.assertEqual(task_instance.state, expected_state, error_message) + + @parameterized.expand([ + ( + {('scheduler', 'schedule_after_task_execution'): 'True'}, + {'A': 'B', 'B': 'C'}, + {'A': State.QUEUED, 'B': State.NONE, 'C': State.NONE}, + {'A': State.SUCCESS, 'B': State.SCHEDULED, 'C': State.NONE}, + {'A': State.SUCCESS, 'B': State.SUCCESS, 'C': State.SCHEDULED}, + "A -> B -> C, with fast-follow ON when A runs, B should be QUEUED. Same for B and C." + ), + ( + {('scheduler', 'schedule_after_task_execution'): 'False'}, + {'A': 'B', 'B': 'C'}, + {'A': State.QUEUED, 'B': State.NONE, 'C': State.NONE}, + {'A': State.SUCCESS, 'B': State.NONE, 'C': State.NONE}, + None, + "A -> B -> C, with fast-follow OFF, when A runs, B shouldn't be QUEUED." + ), + ( + {('scheduler', 'schedule_after_task_execution'): 'True'}, + {'A': 'B', 'C': 'B', 'D': 'C'}, + {'A': State.QUEUED, 'B': State.NONE, 'C': State.NONE, 'D': State.NONE}, + {'A': State.SUCCESS, 'B': State.NONE, 'C': State.NONE, 'D': State.NONE}, + None, + "D -> C -> B & A -> B, when A runs but C isn't QUEUED yet, B shouldn't be QUEUED." + ), + ( + {('scheduler', 'schedule_after_task_execution'): 'True'}, + {'A': 'C', 'B': 'C'}, + {'A': State.QUEUED, 'B': State.FAILED, 'C': State.NONE}, + {'A': State.SUCCESS, 'B': State.FAILED, 'C': State.UPSTREAM_FAILED}, + None, + "A -> C & B -> C, when A is QUEUED but B has FAILED, C is marked UPSTREAM_FAILED." + ), + ]) + def test_fast_follow( + self, conf, dependencies, init_state, first_run_state, second_run_state, error_message + ): + with conf_vars(conf): + session = settings.Session() + + dag = DAG( + 'test_dagrun_fast_follow', + start_date=DEFAULT_DATE + ) + + dag_model = DagModel( + dag_id=dag.dag_id, + next_dagrun=dag.start_date, + is_active=True, + ) + session.add(dag_model) + session.flush() + + python_callable = lambda: True + with dag: + task_a = PythonOperator(task_id='A', python_callable=python_callable) + task_b = PythonOperator(task_id='B', python_callable=python_callable) + task_c = PythonOperator(task_id='C', python_callable=python_callable) + if 'D' in init_state: + task_d = PythonOperator(task_id='D', python_callable=python_callable) + for upstream, downstream in dependencies.items(): + dag.set_dependency(upstream, downstream) + + scheduler = SchedulerJob() + scheduler.dagbag.bag_dag(dag, root_dag=dag) + + dag_run = dag.create_dagrun(run_id='test_dagrun_fast_follow', state=State.RUNNING) + + task_instance_a = dag_run.get_task_instance(task_id=task_a.task_id) + task_instance_a.task = task_a + task_instance_a.set_state(init_state['A']) + + task_instance_b = dag_run.get_task_instance(task_id=task_b.task_id) + task_instance_b.task = task_b + task_instance_b.set_state(init_state['B']) + + task_instance_c = dag_run.get_task_instance(task_id=task_c.task_id) + task_instance_c.task = task_c + task_instance_c.set_state(init_state['C']) + + if 'D' in init_state: + task_instance_d = dag_run.get_task_instance(task_id=task_d.task_id) + task_instance_d.task = task_d + task_instance_d.state = init_state['D'] + + session.commit() + task_instance_a.run() + + self.validate_ti_states(dag_run, first_run_state, error_message) + + if second_run_state: + scheduler._critical_section_execute_task_instances(session=session) + task_instance_b.run() + self.validate_ti_states(dag_run, second_run_state, error_message) + @pytest.mark.parametrize("pool_override", [None, "test_pool2"]) def test_refresh_from_task(pool_override): @@ -1725,7 +1909,14 @@ def test_execute_queries_count(self, expected_query_count, mark_success): task = DummyOperator(task_id='op', dag=dag) ti = TI(task=task, execution_date=datetime.datetime.now()) ti.state = State.RUNNING + session.merge(ti) + dag.create_dagrun( + execution_date=ti.execution_date, + state=State.RUNNING, + run_type=DagRunType.SCHEDULED, + session=session, + ) with assert_queries_count(expected_query_count): ti._run_raw_task(mark_success=mark_success) @@ -1736,7 +1927,14 @@ def test_execute_queries_count_store_serialized(self): task = DummyOperator(task_id='op', dag=dag) ti = TI(task=task, execution_date=datetime.datetime.now()) ti.state = State.RUNNING + session.merge(ti) + dag.create_dagrun( + execution_date=ti.execution_date, + state=State.RUNNING, + run_type=DagRunType.SCHEDULED, + session=session, + ) with assert_queries_count(10): ti._run_raw_task() diff --git a/tests/sensors/test_external_task_sensor.py b/tests/sensors/test_external_task_sensor.py index e45bdc52280fd..0a978e81e67f8 100644 --- a/tests/sensors/test_external_task_sensor.py +++ b/tests/sensors/test_external_task_sensor.py @@ -512,7 +512,7 @@ def clear_tasks(dag_bag, dag, task, start_date=DEFAULT_DATE, end_date=DEFAULT_DA """ Clear the task and its downstream tasks recursively for the dag in the given dagbag. """ - subdag = dag.sub_dag(task_regex="^{}$".format(task.task_id), include_downstream=True) + subdag = dag.sub_dag(task_ids_or_regex="^{}$".format(task.task_id), include_downstream=True) subdag.clear(start_date=start_date, end_date=end_date, dag_bag=dag_bag) diff --git a/tests/ti_deps/deps/test_not_previously_skipped_dep.py b/tests/ti_deps/deps/test_not_previously_skipped_dep.py index d733224797d92..658fee3644076 100644 --- a/tests/ti_deps/deps/test_not_previously_skipped_dep.py +++ b/tests/ti_deps/deps/test_not_previously_skipped_dep.py @@ -18,13 +18,14 @@ import pendulum -from airflow.models import DAG, TaskInstance +from airflow.models import DAG, DagRun, TaskInstance from airflow.operators.dummy_operator import DummyOperator from airflow.operators.python_operator import BranchPythonOperator from airflow.ti_deps.dep_context import DepContext from airflow.ti_deps.deps.not_previously_skipped_dep import NotPreviouslySkippedDep from airflow.utils.session import create_session from airflow.utils.state import State +from airflow.utils.types import DagRunType def test_no_parent(): @@ -73,10 +74,10 @@ def test_parent_follow_branch(): dag = DAG( "test_parent_follow_branch_dag", schedule_interval=None, start_date=start_date ) + dag.create_dagrun(run_type=DagRunType.MANUAL, state=State.RUNNING, execution_date=start_date) op1 = BranchPythonOperator(task_id="op1", python_callable=lambda: "op2", dag=dag) op2 = DummyOperator(task_id="op2", dag=dag) op1 >> op2 - TaskInstance(op1, start_date).run() ti2 = TaskInstance(op2, start_date) @@ -91,21 +92,24 @@ def test_parent_skip_branch(): """ A simple DAG with a BranchPythonOperator that does not follow op2. NotPreviouslySkippedDep is not met. """ - start_date = pendulum.datetime(2020, 1, 1) - dag = DAG( - "test_parent_skip_branch_dag", schedule_interval=None, start_date=start_date - ) - op1 = BranchPythonOperator(task_id="op1", python_callable=lambda: "op3", dag=dag) - op2 = DummyOperator(task_id="op2", dag=dag) - op3 = DummyOperator(task_id="op3", dag=dag) - op1 >> [op2, op3] - - TaskInstance(op1, start_date).run() - ti2 = TaskInstance(op2, start_date) - with create_session() as session: + session.query(DagRun).delete() + session.query(TaskInstance).delete() + start_date = pendulum.datetime(2020, 1, 1) + dag = DAG( + "test_parent_skip_branch_dag", schedule_interval=None, start_date=start_date + ) + dag.create_dagrun(run_type=DagRunType.MANUAL, state=State.RUNNING, execution_date=start_date) + op1 = BranchPythonOperator(task_id="op1", python_callable=lambda: "op3", dag=dag) + op2 = DummyOperator(task_id="op2", dag=dag) + op3 = DummyOperator(task_id="op3", dag=dag) + op1 >> [op2, op3] + TaskInstance(op1, start_date).run() + ti2 = TaskInstance(op2, start_date) dep = NotPreviouslySkippedDep() + assert len(list(dep.get_dep_statuses(ti2, session, DepContext()))) == 1 + session.commit() assert not dep.is_met(ti2, session) assert ti2.state == State.SKIPPED diff --git a/tests/utils/test_log_handlers.py b/tests/utils/test_log_handlers.py index f1ef957e1e660..df23ed1af4b36 100644 --- a/tests/utils/test_log_handlers.py +++ b/tests/utils/test_log_handlers.py @@ -30,6 +30,7 @@ from airflow.utils.session import create_session from airflow.utils.state import State from airflow.utils.timezone import datetime +from airflow.utils.types import DagRunType DEFAULT_DATE = datetime(2016, 1, 1) TASK_LOGGER = 'airflow.task' @@ -66,6 +67,7 @@ def test_file_task_handler(self): def task_callable(ti, **kwargs): ti.log.info("test") dag = DAG('dag_for_testing_file_task_handler', start_date=DEFAULT_DATE) + dag.create_dagrun(run_type=DagRunType.MANUAL, state=State.RUNNING, execution_date=DEFAULT_DATE) task = PythonOperator( task_id='task_for_testing_file_log_handler', dag=dag, diff --git a/tests/utils/test_task_group.py b/tests/utils/test_task_group.py index c4f7a125dec4c..876cd43062c6d 100644 --- a/tests/utils/test_task_group.py +++ b/tests/utils/test_task_group.py @@ -365,7 +365,7 @@ def test_sub_dag_task_group(): group234 >> group6 group234 >> task7 - subdag = dag.sub_dag(task_regex="task5", include_upstream=True, include_downstream=False) + subdag = dag.sub_dag(task_ids_or_regex="task5", include_upstream=True, include_downstream=False) assert extract_node_id(task_group_to_dict(subdag.task_group)) == { 'id': None,