diff --git a/airflow/api/common/experimental/trigger_dag.py b/airflow/api/common/experimental/trigger_dag.py index cdfb307bdd2c9..06a8eb57d232e 100644 --- a/airflow/api/common/experimental/trigger_dag.py +++ b/airflow/api/common/experimental/trigger_dag.py @@ -89,6 +89,7 @@ def _trigger_dag( state=State.RUNNING, conf=run_conf, external_trigger=True, + dag_hash=dag_bag.dags_hash.get(dag_id, None), ) triggers.append(trigger) diff --git a/airflow/cli/cli_parser.py b/airflow/cli/cli_parser.py index 1627900825273..ad526d5c224e2 100644 --- a/airflow/cli/cli_parser.py +++ b/airflow/cli/cli_parser.py @@ -1335,7 +1335,7 @@ class GroupCommand(NamedTuple): help="Start a scheduler instance", func=lazy_load_command('airflow.cli.commands.scheduler_command.scheduler'), args=( - ARG_DAG_ID_OPT, ARG_SUBDIR, ARG_NUM_RUNS, ARG_DO_PICKLE, ARG_PID, ARG_DAEMON, ARG_STDOUT, + ARG_SUBDIR, ARG_NUM_RUNS, ARG_DO_PICKLE, ARG_PID, ARG_DAEMON, ARG_STDOUT, ARG_STDERR, ARG_LOG_FILE ), ), diff --git a/airflow/cli/commands/scheduler_command.py b/airflow/cli/commands/scheduler_command.py index a7109239380fc..f0f019a58ac12 100644 --- a/airflow/cli/commands/scheduler_command.py +++ b/airflow/cli/commands/scheduler_command.py @@ -32,7 +32,6 @@ def scheduler(args): """Starts Airflow Scheduler""" print(settings.HEADER) job = SchedulerJob( - dag_id=args.dag_id, subdir=process_subdir(args.subdir), num_runs=args.num_runs, do_pickle=args.do_pickle) diff --git a/airflow/config_templates/config.yml b/airflow/config_templates/config.yml index f58ea7bfea981..4325397288ceb 100644 --- a/airflow/config_templates/config.yml +++ b/airflow/config_templates/config.yml @@ -559,6 +559,15 @@ type: string example: "connexion,sqlalchemy" default: "" + - name: max_db_retries + description: | + Number of times the code should be retried in case of DB Operational Errors. + Not all transactions will be retried as it can cause undesired state. + Currently it is only used in ``DagFileProcessor.process_file`` to retry ``dagbag.sync_to_db``. + version_added: ~ + type: int + example: ~ + default: "3" - name: secrets description: ~ options: @@ -1572,6 +1581,41 @@ type: string example: ~ default: "512" + - name: use_row_level_locking + description: | + Should the scheduler issue `SELECT ... FOR UPDATE` in relevant queries. + If this is set to False then you should not run more than a single + scheduler at once + version_added: 2.0.0 + type: boolean + example: ~ + default: "True" + - name: max_dagruns_to_create_per_loop + description: | + This changes the number of dags that are locked by each scheduler when + creating dag runs. One possible reason for setting this lower is if you + have huge dags and are running multiple schedules, you won't want one + scheduler to do all the work. + + Default: 10 + example: ~ + version_added: 2.0.0 + type: string + default: ~ + - name: max_dagruns_per_loop_to_schedule + description: | + How many DagRuns should a scheduler examine (and lock) when scheduling + and queuing tasks. Increasing this limit will allow more throughput for + smaller DAGs but will likely slow down throughput for larger (>500 + tasks for example) DAGs. Setting this too high when using multiple + schedulers could also lead to one scheduler taking all the dag runs + leaving no work for the others. + + Default: 20 + example: ~ + version_added: 2.0.0 + type: string + default: ~ - name: statsd_on description: | Statsd (https://github.com/etsy/statsd) integration settings diff --git a/airflow/config_templates/default_airflow.cfg b/airflow/config_templates/default_airflow.cfg index 2665749ae1428..61b9edf94a9e1 100644 --- a/airflow/config_templates/default_airflow.cfg +++ b/airflow/config_templates/default_airflow.cfg @@ -305,6 +305,11 @@ task_log_reader = task # Example: extra_loggers = connexion,sqlalchemy extra_loggers = +# Number of times the code should be retried in case of DB Operational Errors. +# Not all transactions will be retried as it can cause undesired state. +# Currently it is only used in ``DagFileProcessor.process_file`` to retry ``dagbag.sync_to_db``. +max_db_retries = 3 + [secrets] # Full class name of secrets backend to enable (will precede env vars and metastore in search path) # Example: backend = airflow.providers.amazon.aws.secrets.systems_manager.SystemsManagerParameterStoreBackend @@ -789,6 +794,29 @@ catchup_by_default = True # Set this to 0 for no limit (not advised) max_tis_per_query = 512 +# Should the scheduler issue `SELECT ... FOR UPDATE` in relevant queries. +# If this is set to False then you should not run more than a single +# scheduler at once +use_row_level_locking = True + +# This changes the number of dags that are locked by each scheduler when +# creating dag runs. One possible reason for setting this lower is if you +# have huge dags and are running multiple schedules, you won't want one +# scheduler to do all the work. +# +# Default: 10 +# max_dagruns_to_create_per_loop = + +# How many DagRuns should a scheduler examine (and lock) when scheduling +# and queuing tasks. Increasing this limit will allow more throughput for +# smaller DAGs but will likely slow down throughput for larger (>500 +# tasks for example) DAGs. Setting this too high when using multiple +# schedulers could also lead to one scheduler taking all the dag runs +# leaving no work for the others. +# +# Default: 20 +# max_dagruns_per_loop_to_schedule = + # Statsd (https://github.com/etsy/statsd) integration settings statsd_on = False statsd_host = localhost diff --git a/airflow/executors/base_executor.py b/airflow/executors/base_executor.py index 73a002cabe93f..4140511b86e49 100644 --- a/airflow/executors/base_executor.py +++ b/airflow/executors/base_executor.py @@ -17,11 +17,12 @@ """ Base executor - this is the base class for all the implemented executors. """ +import sys from collections import OrderedDict -from typing import Any, Dict, List, Optional, Set, Tuple, Union +from typing import Any, Dict, List, Optional, Set, Tuple from airflow.configuration import conf -from airflow.models.taskinstance import SimpleTaskInstance, TaskInstance, TaskInstanceKey +from airflow.models.taskinstance import TaskInstance, TaskInstanceKey from airflow.stats import Stats from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.state import State @@ -39,8 +40,8 @@ # Task that is queued. It contains all the information that is # needed to run the task. # -# Tuple of: command, priority, queue name, SimpleTaskInstance -QueuedTaskInstanceType = Tuple[CommandType, int, Optional[str], Union[SimpleTaskInstance, TaskInstance]] +# Tuple of: command, priority, queue name, TaskInstance +QueuedTaskInstanceType = Tuple[CommandType, int, Optional[str], TaskInstance] # Event_buffer dict value type # Tuple of: state, info @@ -72,16 +73,16 @@ def start(self): # pragma: no cover """ def queue_command(self, - simple_task_instance: SimpleTaskInstance, + task_instance: TaskInstance, command: CommandType, priority: int = 1, queue: Optional[str] = None): """Queues command to task""" - if simple_task_instance.key not in self.queued_tasks and simple_task_instance.key not in self.running: + if task_instance.key not in self.queued_tasks and task_instance.key not in self.running: self.log.info("Adding to queue: %s", command) - self.queued_tasks[simple_task_instance.key] = (command, priority, queue, simple_task_instance) + self.queued_tasks[task_instance.key] = (command, priority, queue, task_instance) else: - self.log.error("could not queue task %s", simple_task_instance.key) + self.log.error("could not queue task %s", task_instance.key) def queue_task_instance( self, @@ -112,7 +113,7 @@ def queue_task_instance( pickle_id=pickle_id, cfg_path=cfg_path) self.queue_command( - SimpleTaskInstance(task_instance), + task_instance, command_list_to_run, priority=task_instance.task.priority_weight_total, queue=task_instance.task.queue) @@ -178,13 +179,13 @@ def trigger_tasks(self, open_slots: int) -> None: sorted_queue = self.order_queued_tasks_by_priority() for _ in range(min((open_slots, len(self.queued_tasks)))): - key, (command, _, _, simple_ti) = sorted_queue.pop(0) + key, (command, _, _, ti) = sorted_queue.pop(0) self.queued_tasks.pop(key) self.running.add(key) self.execute_async(key=key, command=command, queue=None, - executor_config=simple_ti.executor_config) + executor_config=ti.executor_config) def change_state(self, key: TaskInstanceKey, state: str, info=None) -> None: """ @@ -282,6 +283,16 @@ def try_adopt_task_instances(self, tis: List[TaskInstance]) -> List[TaskInstance # Subclasses can do better! return tis + @property + def slots_available(self): + """ + Number of new tasks this executor instance can accept + """ + if self.parallelism: + return self.parallelism - len(self.running) - len(self.queued_tasks) + else: + return sys.maxsize + @staticmethod def validate_command(command: List[str]) -> None: """Check if the command to execute is airflow command""" diff --git a/airflow/executors/celery_kubernetes_executor.py b/airflow/executors/celery_kubernetes_executor.py index 51c1e17368419..ef82c2060585f 100644 --- a/airflow/executors/celery_kubernetes_executor.py +++ b/airflow/executors/celery_kubernetes_executor.py @@ -62,17 +62,19 @@ def start(self) -> None: self.celery_executor.start() self.kubernetes_executor.start() - def queue_command(self, - simple_task_instance: SimpleTaskInstance, - command: CommandType, - priority: int = 1, - queue: Optional[str] = None): + def queue_command( + self, + task_instance: TaskInstance, + command: CommandType, + priority: int = 1, + queue: Optional[str] = None + ): """Queues command via celery or kubernetes executor""" - executor = self._router(simple_task_instance) - self.log.debug("Using executor: %s for %s", - executor.__class__.__name__, simple_task_instance.key - ) - executor.queue_command(simple_task_instance, command, priority, queue) + executor = self._router(task_instance) + self.log.debug( + "Using executor: %s for %s", executor.__class__.__name__, task_instance.key + ) + executor.queue_command(task_instance, command, priority, queue) def queue_task_instance( self, diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py index dd31f36b4aea3..e729b9e2d97a3 100644 --- a/airflow/jobs/scheduler_job.py +++ b/airflow/jobs/scheduler_job.py @@ -18,6 +18,7 @@ # under the License. # import datetime +import itertools import logging import multiprocessing import os @@ -28,13 +29,14 @@ from collections import defaultdict from contextlib import ExitStack, redirect_stderr, redirect_stdout, suppress from datetime import timedelta -from itertools import groupby from multiprocessing.connection import Connection as MultiprocessingConnection -from typing import Any, DefaultDict, Dict, List, Optional, Set, Tuple +from typing import Any, DefaultDict, Dict, Iterable, List, Optional, Set, Tuple +import tenacity from setproctitle import setproctitle from sqlalchemy import and_, func, not_, or_ -from sqlalchemy.orm import load_only +from sqlalchemy.exc import OperationalError +from sqlalchemy.orm import load_only, selectinload from sqlalchemy.orm.session import Session, make_transient from airflow import models, settings @@ -45,22 +47,20 @@ from airflow.models import DAG, DagModel, SlaMiss, errors from airflow.models.dagbag import DagBag from airflow.models.dagrun import DagRun +from airflow.models.serialized_dag import SerializedDagModel from airflow.models.taskinstance import SimpleTaskInstance, TaskInstanceKey -from airflow.operators.dummy_operator import DummyOperator -from airflow.serialization.serialized_objects import SerializedDAG from airflow.stats import Stats -from airflow.ti_deps.dep_context import DepContext -from airflow.ti_deps.dependencies_deps import SCHEDULED_DEPS from airflow.ti_deps.dependencies_states import EXECUTION_STATES -from airflow.utils import helpers, timezone -from airflow.utils.dag_processing import ( - AbstractDagFileProcessorProcess, DagFileProcessorAgent, FailureCallbackRequest, SimpleDagBag, +from airflow.utils import timezone +from airflow.utils.callback_requests import ( + CallbackRequest, DagCallbackRequest, SlaCallbackRequest, TaskCallbackRequest, ) +from airflow.utils.dag_processing import AbstractDagFileProcessorProcess, DagFileProcessorAgent from airflow.utils.email import get_email_address_list, send_email from airflow.utils.log.logging_mixin import LoggingMixin, StreamLogWriter, set_context from airflow.utils.mixins import MultiprocessingStartMethodMixin -from airflow.utils.session import provide_session -from airflow.utils.sqlalchemy import skip_locked +from airflow.utils.session import create_session, provide_session +from airflow.utils.sqlalchemy import is_lock_not_available_error, prohibit_commit, skip_locked, with_row_locks from airflow.utils.state import State from airflow.utils.types import DagRunType @@ -78,8 +78,8 @@ class DagFileProcessorProcess(AbstractDagFileProcessorProcess, LoggingMixin, Mul :type pickle_dags: bool :param dag_ids: If specified, only look at these DAG ID's :type dag_ids: List[str] - :param failure_callback_requests: failure callback to execute - :type failure_callback_requests: List[airflow.utils.dag_processing.FailureCallbackRequest] + :param callback_requests: failure callback to execute + :type callback_requests: List[airflow.utils.callback_requests.CallbackRequest] """ # Counter that increments every time an instance of this class is created @@ -90,18 +90,18 @@ def __init__( file_path: str, pickle_dags: bool, dag_ids: Optional[List[str]], - failure_callback_requests: List[FailureCallbackRequest] + callback_requests: List[CallbackRequest], ): super().__init__() self._file_path = file_path self._pickle_dags = pickle_dags self._dag_ids = dag_ids - self._failure_callback_requests = failure_callback_requests + self._callback_requests = callback_requests # The process that was launched to process the given . self._process: Optional[multiprocessing.process.BaseProcess] = None # The result of Scheduler.process_file(file_path). - self._result: Optional[Tuple[List[dict], int]] = None + self._result: Optional[Tuple[int, int]] = None # Whether the process is done running. self._done = False # When the process started. @@ -125,7 +125,7 @@ def _run_file_processor( pickle_dags: bool, dag_ids: Optional[List[str]], thread_name: str, - failure_callback_requests: List[FailureCallbackRequest] + callback_requests: List[CallbackRequest], ) -> None: """ Process the given file. @@ -144,8 +144,8 @@ def _run_file_processor( :type dag_ids: list[str] :param thread_name: the name to use for the process that is launched :type thread_name: str - :param failure_callback_requests: failure callback to execute - :type failure_callback_requests: list[airflow.utils.dag_processing.FailureCallbackRequest] + :param callback_requests: failure callback to execute + :type callback_requests: List[airflow.utils.callback_requests.CallbackRequest] :return: the process that was launched :rtype: multiprocessing.Process """ @@ -178,10 +178,10 @@ def _run_file_processor( log.info("Started process (PID=%s) to work on %s", os.getpid(), file_path) dag_file_processor = DagFileProcessor(dag_ids=dag_ids, log=log) - result: Tuple[List[dict], int] = dag_file_processor.process_file( + result: Tuple[int, int] = dag_file_processor.process_file( file_path=file_path, pickle_dags=pickle_dags, - failure_callback_requests=failure_callback_requests, + callback_requests=callback_requests, ) result_channel.send(result) end_time = time.time() @@ -216,7 +216,7 @@ def start(self) -> None: self._pickle_dags, self._dag_ids, "DagFileProcessor{}".format(self._instance_id), - self._failure_callback_requests + self._callback_requests ), name="DagFileProcessor{}-Process".format(self._instance_id) ) @@ -337,10 +337,10 @@ def done(self) -> bool: return False @property - def result(self) -> Optional[Tuple[List[dict], int]]: + def result(self) -> Optional[Tuple[int, int]]: """ :return: result of running SchedulerJob.process_file() - :rtype: Optional[Tuple[List[dict], int]] + :rtype: int or None """ if not self.done: raise AirflowException("Tried to get the result before it's done!") @@ -565,302 +565,75 @@ def update_import_errors(session: Session, dagbag: DagBag) -> None: stacktrace=stacktrace)) session.commit() - # pylint: disable=too-many-return-statements,too-many-branches @provide_session - def create_dag_run( - self, - dag: DAG, - dag_runs: Optional[List[DagRun]] = None, - session: Session = None, - ) -> Optional[DagRun]: - """ - This method checks whether a new DagRun needs to be created - for a DAG based on scheduling interval. - Returns DagRun if one is scheduled. Otherwise returns None. - """ - # pylint: disable=too-many-nested-blocks - if not dag.schedule_interval: - return None - - active_runs: List[DagRun] - if dag_runs is None: - active_runs = DagRun.find( - dag_id=dag.dag_id, - state=State.RUNNING, - external_trigger=False, - session=session - ) - else: - active_runs = [ - dag_run - for dag_run in dag_runs - if not dag_run.external_trigger - ] - # return if already reached maximum active runs and no timeout setting - if len(active_runs) >= dag.max_active_runs and not dag.dagrun_timeout: - return None - timed_out_runs = 0 - for dr in active_runs: - if ( - dr.start_date and dag.dagrun_timeout and - dr.start_date < timezone.utcnow() - dag.dagrun_timeout - ): - dr.state = State.FAILED - dr.end_date = timezone.utcnow() - dag.handle_callback(dr, success=False, reason='dagrun_timeout', - session=session) - timed_out_runs += 1 - session.commit() - if len(active_runs) - timed_out_runs >= dag.max_active_runs: - return None - - # this query should be replaced by find dagrun - last_scheduled_run: Optional[datetime.datetime] = ( - session.query(func.max(DagRun.execution_date)) - .filter_by(dag_id=dag.dag_id) - .filter(or_( - DagRun.external_trigger == False, # noqa: E712 pylint: disable=singleton-comparison - DagRun.run_type == DagRunType.SCHEDULED.value - )).scalar() - ) - - # don't schedule @once again - if dag.schedule_interval == '@once' and last_scheduled_run: - return None - - # don't do scheduler catchup for dag's that don't have dag.catchup = True - if not (dag.catchup or dag.schedule_interval == '@once'): - # The logic is that we move start_date up until - # one period before, so that timezone.utcnow() is AFTER - # the period end, and the job can be created... - now = timezone.utcnow() - next_start = dag.following_schedule(now) - last_start = dag.previous_schedule(now) - if next_start <= now or isinstance(dag.schedule_interval, timedelta): - new_start = last_start - else: - new_start = dag.previous_schedule(last_start) - - if dag.start_date: - if new_start >= dag.start_date: - dag.start_date = new_start - else: - dag.start_date = new_start - - next_run_date = None - if not last_scheduled_run: - # First run - task_start_dates = [t.start_date for t in dag.tasks] - if task_start_dates: - next_run_date = dag.normalize_schedule(min(task_start_dates)) - self.log.debug( - "Next run date based on tasks %s", - next_run_date - ) - else: - next_run_date = dag.following_schedule(last_scheduled_run) - - # make sure backfills are also considered - last_run = dag.get_last_dagrun(session=session) - if last_run and next_run_date: - while next_run_date <= last_run.execution_date: - next_run_date = dag.following_schedule(next_run_date) - - # don't ever schedule prior to the dag's start_date - if dag.start_date: - next_run_date = (dag.start_date if not next_run_date - else max(next_run_date, dag.start_date)) - if next_run_date == dag.start_date: - next_run_date = dag.normalize_schedule(dag.start_date) - - self.log.debug( - "Dag start date: %s. Next run date: %s", - dag.start_date, next_run_date - ) - - # don't ever schedule in the future or if next_run_date is None - if not next_run_date or next_run_date > timezone.utcnow(): - return None - - # this structure is necessary to avoid a TypeError from concatenating - # NoneType - period_end = None - if dag.schedule_interval == '@once': - period_end = next_run_date - elif next_run_date: - period_end = dag.following_schedule(next_run_date) - - # Don't schedule a dag beyond its end_date (as specified by the dag param) - if next_run_date and dag.end_date and next_run_date > dag.end_date: - return None - - # Don't schedule a dag beyond its end_date (as specified by the task params) - # Get the min task end date, which may come from the dag.default_args - min_task_end_date = min([t.end_date for t in dag.tasks if t.end_date], default=None) - if next_run_date and min_task_end_date and next_run_date > min_task_end_date: - return None - - if next_run_date and period_end and period_end <= timezone.utcnow(): - next_run = dag.create_dagrun( - run_type=DagRunType.SCHEDULED, - execution_date=next_run_date, - start_date=timezone.utcnow(), - state=State.RUNNING, - external_trigger=False - ) - return next_run - - return None - - @provide_session - def _process_task_instances( - self, dag: DAG, dag_runs: List[DagRun], session: Session = None - ) -> List[TaskInstanceKey]: - """ - This method schedules the tasks for a single DAG by looking at the - active DAG runs and adding task instances that should run to the - queue. - """ - # update the state of the previously active dag runs - active_dag_runs = 0 - task_instances_list = [] - for run in dag_runs: - self.log.info("Examining DAG run %s", run) - # don't consider runs that are executed in the future unless - # specified by config and schedule_interval is None - if run.execution_date > timezone.utcnow() and not dag.allow_future_exec_dates: - self.log.error( - "Execution date is in future: %s", - run.execution_date - ) - continue - - if active_dag_runs >= dag.max_active_runs: - self.log.info("Number of active dag runs reached max_active_run.") - break - - # skip backfill dagruns for now as long as they are not really scheduled - if run.is_backfill: - continue - - # todo: run.dag is transient but needs to be set - run.dag = dag # type: ignore - # todo: preferably the integrity check happens at dag collection time - run.verify_integrity(session=session) - ready_tis = run.update_state(session=session) - if run.state == State.RUNNING: - active_dag_runs += 1 - self.log.debug("Examining active DAG run: %s", run) - for ti in ready_tis: - self.log.debug('Queuing task: %s', ti) - task_instances_list.append(ti.key) - return task_instances_list - - @provide_session - def _process_dags(self, dags: List[DAG], session: Session = None) -> List[TaskInstanceKey]: - """ - Iterates over the dags and processes them. Processing includes: - - 1. Create appropriate DagRun(s) in the DB. - 2. Create appropriate TaskInstance(s) in the DB. - 3. Send emails for tasks that have missed SLAs (if CHECK_SLAS config enabled). - - :param dags: the DAGs from the DagBag to process - :type dags: List[airflow.models.DAG] - :rtype: list[TaskInstance] - :return: A list of generated TaskInstance objects - """ - check_slas: bool = conf.getboolean('core', 'CHECK_SLAS', fallback=True) - use_job_schedule: bool = conf.getboolean('scheduler', 'USE_JOB_SCHEDULE') - - # pylint: disable=too-many-nested-blocks - tis_out: List[TaskInstanceKey] = [] - dag_ids: List[str] = [dag.dag_id for dag in dags] - dag_runs = DagRun.find(dag_id=dag_ids, state=State.RUNNING, session=session) - # As per the docs of groupby (https://docs.python.org/3/library/itertools.html#itertools.groupby) - # we need to use `list()` otherwise the result will be wrong/incomplete - dag_runs_by_dag_id: Dict[str, List[DagRun]] = { - k: list(v) for k, v in groupby(dag_runs, lambda d: d.dag_id) - } - - for dag in dags: - dag_id: str = dag.dag_id - self.log.info("Processing %s", dag_id) - dag_runs_for_dag = dag_runs_by_dag_id.get(dag_id) or [] - - # Only creates DagRun for DAGs that are not subdag since - # DagRun of subdags are created when SubDagOperator executes. - if not dag.is_subdag and use_job_schedule: - dag_run = self.create_dag_run(dag, dag_runs=dag_runs_for_dag) - if dag_run: - dag_runs_for_dag.append(dag_run) - expected_start_date = dag.following_schedule(dag_run.execution_date) - if expected_start_date: - schedule_delay = dag_run.start_date - expected_start_date - Stats.timing( - 'dagrun.schedule_delay.{dag_id}'.format(dag_id=dag.dag_id), - schedule_delay) - self.log.info("Created %s", dag_run) - - if dag_runs_for_dag: - tis_out.extend(self._process_task_instances(dag, dag_runs_for_dag)) - if check_slas: - self.manage_slas(dag) - - return tis_out - - def _find_dags_to_process(self, dags: List[DAG]) -> List[DAG]: - """ - Find the DAGs that are not paused to process. - - :param dags: specified DAGs - :return: DAGs to process - """ - if self.dag_ids: - dags = [dag for dag in dags - if dag.dag_id in self.dag_ids] - return dags - - @provide_session - def execute_on_failure_callbacks( + def execute_callbacks( self, dagbag: DagBag, - failure_callback_requests: List[FailureCallbackRequest], + callback_requests: List[CallbackRequest], session: Session = None ) -> None: """ Execute on failure callbacks. These objects can come from SchedulerJob or from DagFileProcessorManager. - :param failure_callback_requests: failure callbacks to execute - :type failure_callback_requests: List[airflow.utils.dag_processing.FailureCallbackRequest] + :param dagbag: Dag Bag of dags + :param callback_requests: failure callbacks to execute + :type callback_requests: List[airflow.utils.callback_requests.CallbackRequest] :param session: DB session. """ - for request in failure_callback_requests: - simple_ti = request.simple_task_instance - if simple_ti.dag_id in dagbag.dags: - dag = dagbag.dags[simple_ti.dag_id] - if simple_ti.task_id in dag.task_ids: - task = dag.get_task(simple_ti.task_id) - ti = TI(task, simple_ti.execution_date) - # Get properties needed for failure handling from SimpleTaskInstance. - ti.start_date = simple_ti.start_date - ti.end_date = simple_ti.end_date - ti.try_number = simple_ti.try_number - ti.state = simple_ti.state - ti.test_mode = self.UNIT_TEST_MODE + for request in callback_requests: + try: + if isinstance(request, TaskCallbackRequest): + self._execute_task_callbacks(dagbag, request) + elif isinstance(request, SlaCallbackRequest): + self.manage_slas(dagbag.dags.get(request.dag_id)) + elif isinstance(request, DagCallbackRequest): + self._execute_dag_callbacks(dagbag, request, session) + except Exception: # pylint: disable=broad-except + self.log.exception( + "Error executing %s callback for file: %s", + request.__class__.__name__, + request.full_filepath + ) + + session.commit() + + @provide_session + def _execute_dag_callbacks(self, dagbag: DagBag, request: DagCallbackRequest, session: Session): + dag = dagbag.dags[request.dag_id] + dag_run = dag.get_dagrun(execution_date=request.execution_date, session=session) + dag.handle_callback( + dagrun=dag_run, + success=not request.is_failure_callback, + reason=request.msg, + session=session + ) + + def _execute_task_callbacks(self, dagbag: DagBag, request: TaskCallbackRequest): + simple_ti = request.simple_task_instance + if simple_ti.dag_id in dagbag.dags: + dag = dagbag.dags[simple_ti.dag_id] + if simple_ti.task_id in dag.task_ids: + task = dag.get_task(simple_ti.task_id) + ti = TI(task, simple_ti.execution_date) + # Get properties needed for failure handling from SimpleTaskInstance. + ti.start_date = simple_ti.start_date + ti.end_date = simple_ti.end_date + ti.try_number = simple_ti.try_number + ti.state = simple_ti.state + ti.test_mode = self.UNIT_TEST_MODE + if request.is_failure_callback: ti.handle_failure(request.msg, ti.test_mode, ti.get_template_context()) self.log.info('Executed failure callback for %s in state %s', ti, ti.state) - session.commit() @provide_session def process_file( self, file_path: str, - failure_callback_requests: List[FailureCallbackRequest], + callback_requests: List[CallbackRequest], pickle_dags: bool = False, session: Session = None - ) -> Tuple[List[dict], int]: + ) -> Tuple[int, int]: """ Process a Python file containing Airflow DAGs. @@ -879,16 +652,15 @@ def process_file( :param file_path: the path to the Python file that should be executed :type file_path: str - :param failure_callback_requests: failure callback to execute - :type failure_callback_requests: List[airflow.utils.dag_processing.FailureCallbackRequest] + :param callback_requests: failure callback to execute + :type callback_requests: List[airflow.utils.dag_processing.CallbackRequest] :param pickle_dags: whether serialize the DAGs found in the file and save them to the db :type pickle_dags: bool :param session: Sqlalchemy ORM Session :type session: Session - :return: a tuple with list of SimpleDags made from the Dags found in the file and - count of import errors. - :rtype: Tuple[List[dict], int] + :return: number of dags found, count of import errors + :rtype: Tuple[int, int] """ self.log.info("Processing file %s for tasks to queue", file_path) @@ -897,36 +669,46 @@ def process_file( except Exception: # pylint: disable=broad-except self.log.exception("Failed at reloading the DAG file %s", file_path) Stats.incr('dag_file_refresh_error', 1, 1) - return [], 0 + return 0, 0 if len(dagbag.dags) > 0: self.log.info("DAG(s) %s retrieved from %s", dagbag.dags.keys(), file_path) else: self.log.warning("No viable dags retrieved from %s", file_path) self.update_import_errors(session, dagbag) - return [], len(dagbag.import_errors) - - try: - self.execute_on_failure_callbacks(dagbag, failure_callback_requests) - except Exception: # pylint: disable=broad-except - self.log.exception("Error executing failure callback!") - - # Save individual DAGs in the ORM and update DagModel.last_scheduled_time - dagbag.sync_to_db() - - paused_dag_ids = DagModel.get_paused_dag_ids(dag_ids=dagbag.dag_ids) - - unpaused_dags: List[DAG] = [ - dag for dag_id, dag in dagbag.dags.items() if dag_id not in paused_dag_ids - ] - - serialized_dags = self._prepare_serialized_dags(unpaused_dags, pickle_dags, session) + return 0, len(dagbag.import_errors) + + self.execute_callbacks(dagbag, callback_requests) + + # Save individual DAGs in the ORM + dagbag.read_dags_from_db = True + + # Retry 'dagbag.sync_to_db()' in case of any Operational Errors + # In case of failures, provide_session handles rollback + for attempt in tenacity.Retrying( + retry=tenacity.retry_if_exception_type(exception_types=OperationalError), + wait=tenacity.wait_random_exponential(multiplier=0.5, max=5), + stop=tenacity.stop_after_attempt(settings.MAX_DB_RETRIES), + before_sleep=tenacity.before_sleep_log(self.log, logging.DEBUG), + reraise=True + ): + with attempt: + self.log.debug( + "Running dagbag.sync_to_db with retries. Try %d of %d", + attempt.retry_state.attempt_number, + settings.MAX_DB_RETRIES + ) + dagbag.sync_to_db() - dags = self._find_dags_to_process(unpaused_dags) + if pickle_dags: + paused_dag_ids = DagModel.get_paused_dag_ids(dag_ids=dagbag.dag_ids) - ti_keys_to_schedule = self._process_dags(dags, session) + unpaused_dags: List[DAG] = [ + dag for dag_id, dag in dagbag.dags.items() if dag_id not in paused_dag_ids + ] - self._schedule_task_instances(dagbag, ti_keys_to_schedule, session) + for dag in unpaused_dags: + dag.pickle(session) # Record import errors into the ORM try: @@ -934,85 +716,7 @@ def process_file( except Exception: # pylint: disable=broad-except self.log.exception("Error logging import errors!") - return serialized_dags, len(dagbag.import_errors) - - @provide_session - def _schedule_task_instances( - self, - dagbag: DagBag, - ti_keys_to_schedule: List[TaskInstanceKey], - session: Session = None - ) -> None: - """ - Checks whether the tasks specified by `ti_keys_to_schedule` parameter can be scheduled and - updates the information in the database, - - :param dagbag: DagBag - :type dagbag: DagBag - :param ti_keys_to_schedule: List of task instance keys which can be scheduled. - :type ti_keys_to_schedule: list - """ - # Refresh all task instances that will be scheduled - filter_for_tis = TI.filter_for_tis(ti_keys_to_schedule) - - refreshed_tis: List[TI] = [] - - if filter_for_tis is not None: - refreshed_tis = session.query(TI).filter(filter_for_tis).with_for_update().all() - - for ti in refreshed_tis: - # Add task to task instance - dag: DAG = dagbag.dags[ti.dag_id] - ti.task = dag.get_task(ti.task_id) - - # We check only deps needed to set TI to SCHEDULED state here. - # Deps needed to set TI to QUEUED state will be batch checked later - # by the scheduler for better performance. - dep_context = DepContext(deps=SCHEDULED_DEPS, ignore_task_deps=True) - - # Only schedule tasks that have their dependencies met, e.g. to avoid - # a task that recently got its state changed to RUNNING from somewhere - # other than the scheduler from getting its state overwritten. - if ti.are_dependencies_met( - dep_context=dep_context, - session=session, - verbose=True - ): - # Task starts out in the scheduled state. All tasks in the - # scheduled state will be sent to the executor - ti.state = State.SCHEDULED - # If the task is dummy, then mark it as done automatically - if isinstance(ti.task, DummyOperator) \ - and not ti.task.on_execute_callback \ - and not ti.task.on_success_callback: - ti.state = State.SUCCESS - ti.start_date = ti.end_date = timezone.utcnow() - ti.duration = 0 - - # Also save this task instance to the DB. - self.log.info("Creating / updating %s in ORM", ti) - session.merge(ti) - # commit batch - session.commit() - - @provide_session - def _prepare_serialized_dags( - self, dags: List[DAG], pickle_dags: bool, session: Session = None - ) -> List[dict]: - """ - Convert DAGS to SimpleDags. If necessary, it also Pickle the DAGs - - :param dags: List of DAGs - :return: List of SimpleDag - :rtype: List[dict] - """ - serialized_dags: List[dict] = [] - # Pickle the DAGs (if necessary) and put them into a SimpleDagBag - for dag in dags: - if pickle_dags: - dag.pickle(session) - serialized_dags.append(SerializedDAG.to_dict(dag)) - return serialized_dags + return len(dagbag.dags), len(dagbag.import_errors) class SchedulerJob(BaseJob): # pylint: disable=too-many-instance-attributes @@ -1030,9 +734,13 @@ class SchedulerJob(BaseJob): # pylint: disable=too-many-instance-attributes :param subdir: directory containing Python files with Airflow DAG definitions, or a specific path to a file :type subdir: str - :param num_runs: The number of times to try to schedule each DAG file. - -1 for unlimited times. + :param num_runs: The number of times to run the scheduling loop. If you + have a large number of DAG files this could complete before each file + has been parsed. -1 for unlimited times. :type num_runs: int + :param num_times_parse_dags: The number of times to try to parse each DAG file. + -1 for unlimited times. + :type num_times_parse_dags: int :param processor_poll_interval: The number of seconds to wait between polls of running processors :type processor_poll_interval: int @@ -1048,23 +756,20 @@ class SchedulerJob(BaseJob): # pylint: disable=too-many-instance-attributes def __init__( self, - dag_id: Optional[str] = None, - dag_ids: Optional[List[str]] = None, subdir: str = settings.DAGS_FOLDER, num_runs: int = conf.getint('scheduler', 'num_runs'), + num_times_parse_dags: int = -1, processor_poll_interval: float = conf.getfloat('scheduler', 'processor_poll_interval'), do_pickle: bool = False, log: Any = None, *args, **kwargs): - # for BaseJob compatibility - self.dag_id = dag_id - self.dag_ids = [dag_id] if dag_id else [] - if dag_ids: - self.dag_ids.extend(dag_ids) - self.subdir = subdir self.num_runs = num_runs + # In specific tests, we want to stop the parse loop after the _files_ have been parsed a certain + # number of times. This is only to support testing, and is n't something a user is likely to want to + # conifugre -- they'll want num_runs + self.num_times_parse_dags = num_times_parse_dags self._processor_poll_interval = processor_poll_interval self.do_pickle = do_pickle @@ -1081,6 +786,8 @@ def __init__( self.max_tis_per_query: int = conf.getint('scheduler', 'max_tis_per_query') self.processor_agent: Optional[DagFileProcessorAgent] = None + self.dagbag = DagBag(read_dags_from_db=True) + def register_exit_signals(self) -> None: """ Register signals that stop child processes @@ -1121,13 +828,12 @@ def is_alive(self, grace_multiplier: Optional[float] = None) -> bool: @provide_session def _change_state_for_tis_without_dagrun( self, - simple_dag_bag: SimpleDagBag, old_states: List[str], new_state: str, session: Session = None ) -> None: """ - For all DAG IDs in the SimpleDagBag, look for task instances in the + For all DAG IDs in the DagBag, look for task instances in the old_states and set them to new_state if the corresponding DagRun does not exist or exists but is not in the running state. This normally should not happen, but it can if the state of DagRuns are @@ -1137,17 +843,12 @@ def _change_state_for_tis_without_dagrun( :type old_states: list[airflow.utils.state.State] :param new_state: set TaskInstances to this state :type new_state: airflow.utils.state.State - :param simple_dag_bag: TaskInstances associated with DAGs in the - simple_dag_bag and with states in the old_states will be examined - :type simple_dag_bag: airflow.utils.dag_processing.SimpleDagBag """ tis_changed = 0 query = session \ .query(models.TaskInstance) \ - .outerjoin(models.DagRun, and_( - models.TaskInstance.dag_id == models.DagRun.dag_id, - models.TaskInstance.execution_date == models.DagRun.execution_date)) \ - .filter(models.TaskInstance.dag_id.in_(simple_dag_bag.dag_ids)) \ + .outerjoin(models.TaskInstance.dag_run) \ + .filter(models.TaskInstance.dag_id.in_(list(self.dagbag.dag_ids))) \ .filter(models.TaskInstance.state.in_(old_states)) \ .filter(or_( # pylint: disable=comparison-with-callable @@ -1156,7 +857,7 @@ def _change_state_for_tis_without_dagrun( # We need to do this for mysql as well because it can cause deadlocks # as discussed in https://issues.apache.org/jira/browse/AIRFLOW-2516 if self.using_sqlite or self.using_mysql: - tis_to_change: List[TI] = query.with_for_update().all() + tis_to_change: List[TI] = with_row_locks(query).all() for ti in tis_to_change: ti.set_state(new_state, session=session) tis_changed += 1 @@ -1183,7 +884,7 @@ def _change_state_for_tis_without_dagrun( models.TaskInstance.execution_date == subq.c.execution_date) \ .update(ti_prop_update, synchronize_session=False) - session.commit() + session.flush() if tis_changed > 0: self.log.warning( @@ -1221,39 +922,54 @@ def __get_concurrency_maps( # pylint: disable=too-many-locals,too-many-statements @provide_session - def _find_executable_task_instances( - self, - simple_dag_bag: SimpleDagBag, - session: Session = None - ) -> List[TI]: + def _executable_task_instances_to_queued(self, max_tis: int, session: Session = None) -> List[TI]: """ Finds TIs that are ready for execution with respect to pool limits, dag concurrency, executor state, and priority. - :param simple_dag_bag: TaskInstances associated with DAGs in the - simple_dag_bag will be fetched from the DB and executed - :type simple_dag_bag: airflow.utils.dag_processing.SimpleDagBag + :param max_tis: Maximum number of TIs to queue in this loop. + :type max_tis: int :return: list[airflow.models.TaskInstance] """ executable_tis: List[TI] = [] + # Get the pool settings. We get a lock on the pool rows, treating this as a "critical section" + # Throws an exception if lock cannot be obtained, rather than blocking + pools = models.Pool.slots_stats(lock_rows=True, session=session) + + # If the pools are full, there is no point doing anything! + # If _somehow_ the pool is overfull, don't let the limit go negative - it breaks SQL + pool_slots_free = max(0, sum(pool['open'] for pool in pools.values())) + + if pool_slots_free == 0: + self.log.debug("All pools are full!") + return executable_tis + + max_tis = min(max_tis, pool_slots_free) + # Get all task instances associated with scheduled # DagRuns which are not backfilled, in the given states, # and the dag is not paused - task_instances_to_examine: List[TI] = ( + query = ( session .query(TI) - .filter(TI.dag_id.in_(simple_dag_bag.dag_ids)) - .outerjoin( - DR, and_(DR.dag_id == TI.dag_id, DR.execution_date == TI.execution_date) - ) - .filter(or_(DR.run_id.is_(None), DR.run_type != DagRunType.BACKFILL_JOB.value)) - .outerjoin(DM, DM.dag_id == TI.dag_id) - .filter(or_(DM.dag_id.is_(None), not_(DM.is_paused))) + .outerjoin(TI.dag_run) + .filter(or_(DR.run_id.is_(None), + DR.run_type != DagRunType.BACKFILL_JOB.value)) + .join(TI.dag_model) + .filter(not_(DM.is_paused)) .filter(TI.state == State.SCHEDULED) - .all() + .options(selectinload('dag_model')) + .limit(max_tis) ) - Stats.gauge('scheduler.tasks.pending', len(task_instances_to_examine)) + + task_instances_to_examine: List[TI] = with_row_locks( + query, + of=TI, + **skip_locked(session=session), + ).all() + # TODO[HA]: This was wrong before anyway, as it only looked at a sub-set of dags, not everything. + # Stats.gauge('scheduler.tasks.pending', len(task_instances_to_examine)) if len(task_instances_to_examine) == 0: self.log.debug("No tasks to consider for execution.") @@ -1267,9 +983,6 @@ def _find_executable_task_instances( task_instance_str ) - # Get the pool settings - pools: Dict[str, models.Pool] = {p.pool: p for p in session.query(models.Pool).all()} - pool_to_task_instances: DefaultDict[str, List[models.Pool]] = defaultdict(list) for task_instance in task_instances_to_examine: pool_to_task_instances[task_instance.pool].append(task_instance) @@ -1296,7 +1009,7 @@ def _find_executable_task_instances( ) continue - open_slots = pools[pool].open_slots(session=session) + open_slots = pools[pool]["open"] num_ready = len(task_instances) self.log.info( @@ -1324,10 +1037,9 @@ def _find_executable_task_instances( # Check to make sure that the task concurrency of the DAG hasn't been # reached. dag_id = task_instance.dag_id - serialized_dag = simple_dag_bag.get_dag(dag_id) current_dag_concurrency = dag_concurrency_map[dag_id] - dag_concurrency_limit = simple_dag_bag.get_dag(dag_id).concurrency + dag_concurrency_limit = task_instance.dag_model.concurrency self.log.info( "DAG %s has %s/%s running and queued tasks", dag_id, current_dag_concurrency, dag_concurrency_limit @@ -1341,27 +1053,23 @@ def _find_executable_task_instances( continue task_concurrency_limit: Optional[int] = None - if serialized_dag.has_task(task_instance.task_id): - task_concurrency_limit = serialized_dag.get_task( - task_instance.task_id).task_concurrency - - if task_concurrency_limit is not None: - current_task_concurrency = task_concurrency_map[ - (task_instance.dag_id, task_instance.task_id) - ] - - if current_task_concurrency >= task_concurrency_limit: - self.log.info("Not executing %s since the task concurrency for" - " this task has been reached.", task_instance) - continue - - if self.executor.has_task(task_instance): - self.log.debug( - "Not handling task %s as the executor reports it is running", - task_instance.key - ) - num_tasks_in_executor += 1 - continue + if task_instance.dag_model.has_task_concurrency_limits: + # Many dags don't have a task_concurrency, so where we can avoid loading the full + # serialized DAG the better. + serialized_dag = self.dagbag.get_dag(dag_id, session=session) + if serialized_dag.has_task(task_instance.task_id): + task_concurrency_limit = serialized_dag.get_task( + task_instance.task_id).task_concurrency + + if task_concurrency_limit is not None: + current_task_concurrency = task_concurrency_map[ + (task_instance.dag_id, task_instance.task_id) + ] + + if current_task_concurrency >= task_concurrency_limit: + self.log.info("Not executing %s since the task concurrency for" + " this task has been reached.", task_instance) + continue if task_instance.pool_slots > open_slots: self.log.info("Not executing %s since it requires %s slots " @@ -1387,116 +1095,63 @@ def _find_executable_task_instances( [repr(x) for x in executable_tis]) self.log.info( "Setting the following tasks to queued state:\n\t%s", task_instance_str) - # so these dont expire on commit - for ti in executable_tis: - copy_dag_id = ti.dag_id - copy_execution_date = ti.execution_date - copy_task_id = ti.task_id - make_transient(ti) - ti.dag_id = copy_dag_id - ti.execution_date = copy_execution_date - ti.task_id = copy_task_id - return executable_tis - - @provide_session - def _change_state_for_executable_task_instances( - self, task_instances: List[TI], session: Session = None - ) -> List[SimpleTaskInstance]: - """ - Changes the state of task instances in the list with one of the given states - to QUEUED atomically, and returns the TIs changed in SimpleTaskInstance format. - - :param task_instances: TaskInstances to change the state of - :type task_instances: list[airflow.models.TaskInstance] - :rtype: list[airflow.models.taskinstance.SimpleTaskInstance] - """ - if len(task_instances) == 0: - session.commit() - return [] - - tis_to_set_to_queued: List[TI] = ( - session - .query(TI) - .filter(TI.filter_for_tis(task_instances)) - .filter(TI.state == State.SCHEDULED) - .with_for_update() - .all() - ) - - if len(tis_to_set_to_queued) == 0: - self.log.info("No tasks were able to have their state changed to queued.") - session.commit() - return [] # set TIs to queued state - filter_for_tis = TI.filter_for_tis(tis_to_set_to_queued) + filter_for_tis = TI.filter_for_tis(executable_tis) session.query(TI).filter(filter_for_tis).update( + # TODO[ha]: should we use func.now()? How does that work with DB timezone on mysql when it's not + # UTC? {TI.state: State.QUEUED, TI.queued_dttm: timezone.utcnow(), TI.queued_by_job_id: self.id}, synchronize_session=False ) - session.commit() - # Generate a list of SimpleTaskInstance for the use of queuing - # them in the executor. - simple_task_instances = [SimpleTaskInstance(ti) for ti in tis_to_set_to_queued] - - task_instance_str = "\n\t".join([repr(x) for x in tis_to_set_to_queued]) - self.log.info("Setting the following %s tasks to queued state:\n\t%s", - len(tis_to_set_to_queued), task_instance_str) - return simple_task_instances + for ti in executable_tis: + make_transient(ti) + return executable_tis def _enqueue_task_instances_with_queued_state( self, - simple_dag_bag: SimpleDagBag, - simple_task_instances: List[SimpleTaskInstance] + task_instances: List[TI] ) -> None: """ Takes task_instances, which should have been set to queued, and enqueues them with the executor. - :param simple_task_instances: TaskInstances to enqueue - :type simple_task_instances: list[SimpleTaskInstance] - :param simple_dag_bag: Should contains all of the task_instances' dags - :type simple_dag_bag: airflow.utils.dag_processing.SimpleDagBag + :param task_instances: TaskInstances to enqueue + :type task_instances: list[TaskInstance] """ # actually enqueue them - for simple_task_instance in simple_task_instances: - serialized_dag = simple_dag_bag.get_dag(simple_task_instance.dag_id) + for ti in task_instances: command = TI.generate_command( - simple_task_instance.dag_id, - simple_task_instance.task_id, - simple_task_instance.execution_date, + ti.dag_id, + ti.task_id, + ti.execution_date, local=True, mark_success=False, ignore_all_deps=False, ignore_depends_on_past=False, ignore_task_deps=False, ignore_ti_state=False, - pool=simple_task_instance.pool, - file_path=serialized_dag.full_filepath, - pickle_id=serialized_dag.pickle_id, + pool=ti.pool, + file_path=ti.dag_model.fileloc, + pickle_id=ti.dag_model.pickle_id, ) - priority = simple_task_instance.priority_weight - queue = simple_task_instance.queue + priority = ti.priority_weight + queue = ti.queue self.log.info( "Sending %s to executor with priority %s and queue %s", - simple_task_instance.key, priority, queue + ti.key, priority, queue ) self.executor.queue_command( - simple_task_instance, + ti, command, priority=priority, queue=queue, ) - @provide_session - def _execute_task_instances( - self, - simple_dag_bag: SimpleDagBag, - session: Session = None - ) -> int: + def _critical_section_execute_task_instances(self, session: Session) -> int: """ Attempts to execute TaskInstances that should be executed by the scheduler. @@ -1506,23 +1161,21 @@ def _execute_task_instances( 2. Change the state for the TIs above atomically. 3. Enqueue the TIs in the executor. - :param simple_dag_bag: TaskInstances associated with DAGs in the - simple_dag_bag will be fetched from the DB and executed - :type simple_dag_bag: airflow.utils.dag_processing.SimpleDagBag + HA note: This function is a "critical section" meaning that only a single executor process can execute + this function at the same time. This is achieved by doing ``SELECT ... from pool FOR UPDATE``. For DBs + that support NOWAIT, a "blocked" scheduler will skip this and continue on with other tasks (creating + new DAG runs, progressing TIs from None to SCHEDULED etc.); DBs that don't support this (such as + MariaDB or MySQL 5.x) the other schedulers will wait for the lock before continuing. + + :param session: + :type session: sqlalchemy.orm.Session :return: Number of task instance with state changed. """ - executable_tis = self._find_executable_task_instances(simple_dag_bag, session=session) - - def query(result: int, items: List[TI]) -> int: - simple_tis_with_state_changed = \ - self._change_state_for_executable_task_instances(items, session=session) - self._enqueue_task_instances_with_queued_state( - simple_dag_bag, - simple_tis_with_state_changed) - session.commit() - return result + len(simple_tis_with_state_changed) + max_tis = min(self.max_tis_per_query, self.executor.slots_available) + queued_tis = self._executable_task_instances_to_queued(max_tis, session=session) - return helpers.reduce_in_chunks(query, executable_tis, 0, self.max_tis_per_query) + self._enqueue_task_instances_with_queued_state(queued_tis) + return len(queued_tis) @provide_session def _change_state_for_tasks_failed_to_execute(self, session: Session = None): @@ -1547,7 +1200,7 @@ def _change_state_for_tasks_failed_to_execute(self, session: Session = None): for dag_id, task_id, execution_date, try_number in self.executor.queued_tasks.keys()]) ti_query = session.query(TI).filter(or_(*filter_for_ti_state_change)) - tis_to_set_to_scheduled: List[TI] = ti_query.with_for_update().all() + tis_to_set_to_scheduled: List[TI] = with_row_locks(ti_query).all() if not tis_to_set_to_scheduled: return @@ -1564,14 +1217,14 @@ def _change_state_for_tasks_failed_to_execute(self, session: Session = None): self.log.info("Set the following tasks to scheduled state:\n\t%s", task_instance_str) @provide_session - def _process_executor_events(self, simple_dag_bag: SimpleDagBag, session: Session = None) -> None: + def _process_executor_events(self, session: Session = None) -> int: """ Respond to executor events. """ if not self.processor_agent: raise ValueError("Processor agent is not started.") ti_primary_key_to_try_number_map: Dict[Tuple[str, str, datetime.datetime], int] = {} - event_buffer = self.executor.get_event_buffer(simple_dag_bag.dag_ids) + event_buffer = self.executor.get_event_buffer() tis_with_right_state: List[TaskInstanceKey] = [] # Report execution @@ -1591,11 +1244,11 @@ def _process_executor_events(self, simple_dag_bag: SimpleDagBag, session: Sessio # Return if no finished tasks if not tis_with_right_state: - return + return len(event_buffer) # Check state of finished tasks filter_for_tis = TI.filter_for_tis(tis_with_right_state) - tis: List[TI] = session.query(TI).filter(filter_for_tis).all() + tis: List[TI] = session.query(TI).filter(filter_for_tis).options(selectinload('dag_model')).all() for ti in tis: try_number = ti_primary_key_to_try_number_map[ti.key.primary] buffer_key = ti.key.with_try_number(try_number) @@ -1612,20 +1265,23 @@ def _process_executor_events(self, simple_dag_bag: SimpleDagBag, session: Sessio msg = "Executor reports task instance %s finished (%s) although the " \ "task says its %s. (Info: %s) Was the task killed externally?" self.log.error(msg, ti, state, ti.state, info) - serialized_dag = simple_dag_bag.get_dag(ti.dag_id) - self.processor_agent.send_callback_to_execute( - full_filepath=serialized_dag.full_filepath, - task_instance=ti, + request = TaskCallbackRequest( + full_filepath=ti.dag_model.fileloc, + simple_task_instance=SimpleTaskInstance(ti), msg=msg % (ti, state, ti.state, info), ) + self.processor_agent.send_callback_to_execute(request) + + return len(event_buffer) + def _execute(self) -> None: self.log.info("Starting the scheduler") # DAGs can be pickled for easier remote execution by some executors pickle_dags = self.do_pickle and self.executor_class not in UNPICKLEABLE_EXECUTORS - self.log.info("Processing each file at most %s times", self.num_runs) + self.log.info("Processing each file at most %s times", self.num_times_parse_dags) # When using sqlite, we do not use async_mode # so the scheduler job and DAG parser don't access the DB at the same time. @@ -1635,10 +1291,10 @@ def _execute(self) -> None: processor_timeout = timedelta(seconds=processor_timeout_seconds) self.processor_agent = DagFileProcessorAgent( dag_directory=self.subdir, - max_runs=self.num_runs, + max_runs=self.num_times_parse_dags, processor_factory=type(self)._create_dag_file_processor, processor_timeout=processor_timeout, - dag_ids=self.dag_ids, + dag_ids=[], pickle_dags=pickle_dags, async_mode=async_mode, ) @@ -1684,7 +1340,7 @@ def _execute(self) -> None: @staticmethod def _create_dag_file_processor( file_path: str, - failure_callback_requests: List[FailureCallbackRequest], + callback_requests: List[CallbackRequest], dag_ids: Optional[List[str]], pickle_dags: bool ) -> DagFileProcessorProcess: @@ -1695,7 +1351,7 @@ def _create_dag_file_processor( file_path=file_path, pickle_dags=pickle_dags, dag_ids=dag_ids, - failure_callback_requests=failure_callback_requests + callback_requests=callback_requests ) def _run_scheduler_loop(self) -> None: @@ -1719,8 +1375,7 @@ def _run_scheduler_loop(self) -> None: raise ValueError("Processor agent is not started.") is_unit_test: bool = conf.getboolean('core', 'unit_test_mode') - # For the execute duration, parse and schedule DAGs - while True: + for loop_count in itertools.count(start=1): loop_start_time = time.time() if self.using_sqlite: @@ -1730,15 +1385,14 @@ def _run_scheduler_loop(self) -> None: self.log.debug("Waiting for processors to finish since we're using sqlite") self.processor_agent.wait_until_finished() - serialized_dags = self.processor_agent.harvest_serialized_dags() + with create_session() as session: + num_queued_tis = self._do_scheduling(session) - self.log.debug("Harvested %d SimpleDAGs", len(serialized_dags)) + self.executor.heartbeat() + session.expunge_all() + num_finished_events = self._process_executor_events(session=session) - # Send tasks for execution if available - simple_dag_bag = SimpleDagBag(serialized_dags) - - if not self._validate_and_run_task_instances(simple_dag_bag=simple_dag_bag): - continue + self.processor_agent.heartbeat() # Heartbeat the scheduler periodically self.heartbeat(only_if_necessary=True) @@ -1749,62 +1403,353 @@ def _run_scheduler_loop(self) -> None: loop_duration = loop_end_time - loop_start_time self.log.debug("Ran scheduling loop in %.2f seconds", loop_duration) - if not is_unit_test: + if not is_unit_test and not num_queued_tis and not num_finished_events: + # If the scheduler is doing things, don't sleep. This means when there is work to do, the + # scheduler will run "as quick as possible", but when it's stopped, it can sleep, dropping CPU + # usage when "idle" time.sleep(self._processor_poll_interval) + if loop_count >= self.num_runs > 0: + self.log.info( + "Exiting scheduler loop as requested number of runs (%d - got to %d) has been reached", + self.num_runs, loop_count, + ) + break if self.processor_agent.done: self.log.info( - "Exiting scheduler loop as all files have been processed %d times", self.num_runs + "Exiting scheduler loop as requested DAG parse count (%d) has been reached after %d " + " scheduler loops", + self.num_times_parse_dags, loop_count, ) break - def _validate_and_run_task_instances(self, simple_dag_bag: SimpleDagBag) -> bool: - if simple_dag_bag.serialized_dags: + def _do_scheduling(self, session) -> int: + """ + This function is where the main scheduling decisions take places. It: + + - Creates any necessary DAG runs by examining the next_dagrun_create_after column of DagModel + + Since creating Dag Runs is a relatively time consuming process, we select only 10 dags by default + (configurable via ``scheduler.max_dagruns_to_create_per_loop`` setting) - putting this higher will + mean one scheduler could spend a chunk of time creating dag runs, and not ever get around to + scheduling tasks. + + - Finds the "next n oldest" running DAG Runs to examine for scheduling (n=20 by default, configurable + via ``scheduler.max_dagruns_per_loop_to_schedule`` config setting) and tries to progress state (TIs + to SCHEDULED, or DagRuns to SUCCESS/FAILURE etc) + + By "next oldest", we mean hasn't been examined/scheduled in the most time. + + The reason we don't select all dagruns at once because the rows are selected with row locks, meaning + that only one scheduler can "process them", even it it is waiting behind other dags. Increasing this + limit will allow more throughput for smaller DAGs but will likely slow down throughput for larger + (>500 tasks.) DAGs + + - Then, via a Critical Section (locking the rows of the Pool model) we queue tasks, and then send them + to the executor. + + See docs of _critical_section_execute_task_instances for more. + + :return: Number of TIs enqueued in this iteration + :rtype: int + """ + # Put a check in place to make sure we don't commit unexpectedly + with prohibit_commit(session) as guard: + + if settings.USE_JOB_SCHEDULE: + query = DagModel.dags_needing_dagruns(session) + self._create_dag_runs(query.all(), session) + + # commit the session - Release the write lock on DagModel table. + guard.commit() + # END: create dagruns + + dag_runs = DagRun.next_dagruns_to_examine(session) + + # Bulk fetch the currently active dag runs for the dags we are + # examining, rather than making one query per DagRun + + # TODO: This query is probably horribly inefficient (though there is an + # index on (dag_id,state)). It is to deal with the case when a user + # clears more than max_active_runs older tasks -- we don't want the + # scheduler to suddenly go and start running tasks from all of the + # runs. (AIRFLOW-137/GH #1442) + # + # The longer term fix would be to have `clear` do this, and put DagRuns + # in to the queued state, then take DRs out of queued before creating + # any new ones + # TODO[HA]: Why is this on TI, not on DagRun?? + currently_active_runs = dict(session.query( + TI.dag_id, + func.count(TI.execution_date.distinct()), + ).filter( + TI.dag_id.in_(list({dag_run.dag_id for dag_run in dag_runs})), + TI.state.notin_(State.finished()) + ).group_by(TI.dag_id).all()) + + for dag_run in dag_runs: + self._schedule_dag_run(dag_run, currently_active_runs.get(dag_run.dag_id, 0), session) + + guard.commit() + + # Without this, the session has an invalid view of the DB + session.expunge_all() + # END: schedule TIs + + # TODO[HA]: Do we need to do it every time? try: - self._process_and_execute_tasks(simple_dag_bag) - except Exception as e: # pylint: disable=broad-except - self.log.error("Error queuing tasks") - self.log.exception(e) - return False - - # Call heartbeats - self.log.debug("Heartbeating the executor") - self.executor.heartbeat() - - self._change_state_for_tasks_failed_to_execute() - - # Process events from the executor - self._process_executor_events(simple_dag_bag) - return True - - def _process_and_execute_tasks(self, simple_dag_bag: SimpleDagBag) -> None: - # Handle cases where a DAG run state is set (perhaps manually) to - # a non-running state. Handle task instances that belong to - # DAG runs in those states - # If a task instance is up for retry but the corresponding DAG run - # isn't running, mark the task instance as FAILED so we don't try - # to re-run it. - self._change_state_for_tis_without_dagrun( - simple_dag_bag=simple_dag_bag, - old_states=[State.UP_FOR_RETRY], - new_state=State.FAILED - ) - # If a task instance is scheduled or queued or up for reschedule, - # but the corresponding DAG run isn't running, set the state to - # NONE so we don't try to re-run it. - self._change_state_for_tis_without_dagrun( - simple_dag_bag=simple_dag_bag, - old_states=[State.QUEUED, - State.SCHEDULED, - State.UP_FOR_RESCHEDULE, - State.SENSING], - new_state=State.NONE + self._change_state_for_tis_without_dagrun( + old_states=[State.UP_FOR_RETRY], + new_state=State.FAILED, + session=session + ) + + self._change_state_for_tis_without_dagrun( + old_states=[State.QUEUED, + State.SCHEDULED, + State.UP_FOR_RESCHEDULE, + State.SENSING], + new_state=State.NONE, + session=session + ) + + guard.commit() + except OperationalError as e: + if is_lock_not_available_error(error=e): + self.log.debug("Lock held by another Scheduler") + session.rollback() + else: + raise + + try: + if self.executor.slots_available <= 0: + # We know we can't do anything here, so don't even try! + self.log.debug("Executor full, skipping critical section") + return 0 + + timer = Stats.timer('scheduler.critical_section_duration') + timer.start() + + # Find anything TIs in state SCHEDULED, try to QUEUE it (send it to the executor) + num_queued_tis = self._critical_section_execute_task_instances(session=session) + + # Make sure we only sent this metric if we obtained the lock, otherwise we'll skew the + # metric, way down + timer.stop(send=True) + except OperationalError as e: + timer.stop(send=False) + + if is_lock_not_available_error(error=e): + self.log.debug("Critical section lock held by another Scheduler") + Stats.incr('scheduler.critical_section_busy') + session.rollback() + return 0 + raise + + return num_queued_tis + + def _create_dag_runs(self, dag_models: Iterable[DagModel], session: Session) -> None: + """ + Unconditionally create a DAG run for the given DAG, and update the dag_model's fields to control + if/when the next DAGRun should be created + """ + for dag_model in dag_models: + dag = self.dagbag.get_dag(dag_model.dag_id, session=session) + dag_hash = self.dagbag.dags_hash.get(dag.dag_id, None) + dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=dag_model.next_dagrun, + start_date=timezone.utcnow(), + state=State.RUNNING, + external_trigger=False, + session=session, + dag_hash=dag_hash + ) + + self._update_dag_next_dagruns(dag_models, session) + + # TODO[HA]: Should we do a session.flush() so we don't have to keep lots of state/object in + # memory for larger dags? or expunge_all() + + def _update_dag_next_dagruns(self, dag_models: Iterable[DagModel], session: Session) -> None: + """ + Bulk update the next_dagrun and next_dagrun_create_after for all the dags. + + We batch the select queries to get info about all the dags at once + """ + # Check max_active_runs, to see if we are _now_ at the limit for any of + # these dag? (we've just created a DagRun for them after all) + active_runs_of_dags = dict(session.query(DagRun.dag_id, func.count('*')).filter( + DagRun.dag_id.in_([o.dag_id for o in dag_models]), + DagRun.state == State.RUNNING, # pylint: disable=comparison-with-callable + DagRun.external_trigger.is_(False), + ).group_by(DagRun.dag_id).all()) + + for dag_model in dag_models: + dag = self.dagbag.get_dag(dag_model.dag_id, session=session) + active_runs_of_dag = active_runs_of_dags.get(dag.dag_id, 0) + if dag.max_active_runs and active_runs_of_dag >= dag.max_active_runs: + self.log.info( + "DAG %s is at (or above) max_active_runs (%d of %d), not creating any more runs", + dag.dag_id, active_runs_of_dag, dag.max_active_runs + ) + dag_model.next_dagrun_create_after = None + else: + dag_model.next_dagrun, dag_model.next_dagrun_create_after = \ + dag.next_dagrun_info(dag_model.next_dagrun) + + def _schedule_dag_run(self, dag_run: DagRun, currently_active_runs: int, session: Session) -> int: + """ + Make scheduling decisions about an individual dag run + + ``currently_active_runs`` is passed in so that a batch query can be + used to ask this for all dag runs in the batch, to avoid an n+1 query. + + :param dag_run: The DagRun to schedule + :param currently_active_runs: Number of currently active runs of this DAG + :return: Number of tasks scheduled + """ + dag = dag_run.dag = self.dagbag.get_dag(dag_run.dag_id, session=session) + + if not dag: + self.log.error( + "Couldn't find dag %s in DagBag/DB!", dag_run.dag_id + ) + return 0 + + if ( + dag_run.start_date and dag.dagrun_timeout and + dag_run.start_date < timezone.utcnow() - dag.dagrun_timeout + ): + dag_run.state = State.FAILED + dag_run.end_date = timezone.utcnow() + self.log.info("Run %s of %s has timed-out", dag_run.run_id, dag_run.dag_id) + session.flush() + + # Work out if we should allow creating a new DagRun now? + self._update_dag_next_dagruns([session.query(DagModel).get(dag_run.dag_id)], session) + + callback_to_execute = DagCallbackRequest( + full_filepath=dag.fileloc, + dag_id=dag.dag_id, + execution_date=dag_run.execution_date, + is_failure_callback=True, + msg='timed_out' + ) + + # Send SLA & DAG Success/Failure Callbacks to be executed + self._send_dag_callbacks_to_processor(dag_run, callback_to_execute) + + return 0 + + if dag_run.execution_date > timezone.utcnow() and not dag.allow_future_exec_dates: + self.log.error( + "Execution date is in future: %s", + dag_run.execution_date + ) + return 0 + + if dag.max_active_runs: + if currently_active_runs >= dag.max_active_runs: + self.log.info( + "DAG %s already has %d active runs, not queuing any more tasks", + dag.dag_id, + currently_active_runs, + ) + return 0 + + self._verify_integrity_if_dag_changed(dag_run=dag_run, session=session) + # TODO[HA]: Rename update_state -> schedule_dag_run, ?? something else? + schedulable_tis, callback_to_run = dag_run.update_state(session=session, execute_callbacks=False) + + self._send_dag_callbacks_to_processor(dag_run, callback_to_run) + + # Get list of TIs that do not need to executed, these are + # tasks using DummyOperator and without on_execute_callback / on_success_callback + dummy_tis = [ + ti for ti in schedulable_tis + if + ( + ti.task.task_type == "DummyOperator" + and not ti.task.on_execute_callback + and not ti.task.on_success_callback + ) + ] + + # This will do one query per dag run. We "could" build up a complex + # query to update all the TIs across all the execution dates and dag + # IDs in a single query, but it turns out that can be _very very slow_ + # see #11147/commit ee90807ac for more details + count = session.query(TI).filter( + TI.dag_id == dag_run.dag_id, + TI.execution_date == dag_run.execution_date, + TI.task_id.in_(ti.task_id for ti in schedulable_tis if ti not in dummy_tis) + ).update({TI.state: State.SCHEDULED}, synchronize_session=False) + + # Tasks using DummyOperator should not be executed, mark them as success + if dummy_tis: + session.query(TI).filter( + TI.dag_id == dag_run.dag_id, + TI.execution_date == dag_run.execution_date, + TI.task_id.in_(ti.task_id for ti in dummy_tis) + ).update({ + TI.state: State.SUCCESS, + TI.start_date: timezone.utcnow(), + TI.end_date: timezone.utcnow(), + TI.duration: 0 + }, synchronize_session=False) + + return count + + @provide_session + def _verify_integrity_if_dag_changed(self, dag_run: DagRun, session=None): + """Only run DagRun.verify integrity if Serialized DAG has changed since it is slow""" + latest_version = SerializedDagModel.get_latest_version_hash(dag_run.dag_id, session=session) + if dag_run.dag_hash == latest_version: + self.log.debug("DAG %s not changed structure, skipping dagrun.verify_integrity", dag_run.dag_id) + return + + dag_run.dag_hash = latest_version + + # Refresh the DAG + dag_run.dag = self.dagbag.get_dag(dag_id=dag_run.dag_id, session=session) + + # Verify integrity also takes care of session.flush + dag_run.verify_integrity(session=session) + + def _send_dag_callbacks_to_processor( + self, + dag_run: DagRun, + callback: Optional[DagCallbackRequest] = None + ): + if not self.processor_agent: + raise ValueError("Processor agent is not started.") + + dag = dag_run.get_dag() + self._send_sla_callbacks_to_processor(dag) + if callback: + self.processor_agent.send_callback_to_execute(callback) + + def _send_sla_callbacks_to_processor(self, dag: DAG): + """Sends SLA Callbacks to DagFileProcessor if tasks have SLAs set and check_slas=True""" + if not settings.CHECK_SLAS: + return + + if not any(isinstance(ti.sla, timedelta) for ti in dag.tasks): + self.log.debug("Skipping SLA check for %s because no tasks in DAG have SLAs", dag) + return + + if not self.processor_agent: + raise ValueError("Processor agent is not started.") + + self.processor_agent.send_sla_callback_request_to_execute( + full_filepath=dag.fileloc, + dag_id=dag.dag_id ) - self._execute_task_instances(simple_dag_bag) @provide_session def _emit_pool_metrics(self, session: Session = None) -> None: - pools = models.Pool.slots_stats(session) + pools = models.Pool.slots_stats(session=session) for pool_name, slot_stats in pools.items(): Stats.gauge(f'pool.open_slots.{pool_name}', slot_stats["open"]) Stats.gauge(f'pool.queued_slots.{pool_name}', slot_stats[State.QUEUED]) @@ -1835,7 +1780,7 @@ def adopt_or_reset_orphaned_tasks(self, session: Session = None): Stats.incr(self.__class__.__name__.lower() + '_end', num_failed) resettable_states = [State.SCHEDULED, State.QUEUED, State.RUNNING] - tis_to_reset_or_adopt = ( + query = ( session.query(TI).filter(TI.state.in_(resettable_states)) # outerjoin is because we didn't use to have queued_by_job # set, so we need to pick up anything pre upgrade. This (and the @@ -1848,10 +1793,10 @@ def adopt_or_reset_orphaned_tasks(self, session: Session = None): # pylint: disable=comparison-with-callable DagRun.state == State.RUNNING) .options(load_only(TI.dag_id, TI.task_id, TI.execution_date)) - # Lock these rows, so that another scheduler can't try and adopt these too - .with_for_update(of=TI, **skip_locked(session=session)) - .all() ) + + # Lock these rows, so that another scheduler can't try and adopt these too + tis_to_reset_or_adopt = with_row_locks(query, of=TI, **skip_locked(session=session)).all() to_reset = self.executor.try_adopt_task_instances(tis_to_reset_or_adopt) reset_tis_message = [] @@ -1871,4 +1816,7 @@ def adopt_or_reset_orphaned_tasks(self, session: Session = None): self.log.info("Reset the following %s orphaned TaskInstances:\n\t%s", len(to_reset), task_instance_str) + # Issue SQL/finish "Unit of Work", but let @provide_session commit (or if passed a session, let caller + # decide when to commit + session.flush() return len(to_reset) diff --git a/airflow/migrations/versions/98271e7606e2_add_scheduling_decision_to_dagrun_and_.py b/airflow/migrations/versions/98271e7606e2_add_scheduling_decision_to_dagrun_and_.py new file mode 100644 index 0000000000000..2d617105c02d6 --- /dev/null +++ b/airflow/migrations/versions/98271e7606e2_add_scheduling_decision_to_dagrun_and_.py @@ -0,0 +1,88 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Add scheduling_decision to DagRun and DAG + +Revision ID: 98271e7606e2 +Revises: bef4f3d11e8b +Create Date: 2020-10-01 12:13:32.968148 + +""" + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import mysql + +# revision identifiers, used by Alembic. +revision = '98271e7606e2' +down_revision = 'bef4f3d11e8b' +branch_labels = None +depends_on = None + + +def upgrade(): + """Apply Add scheduling_decision to DagRun and DAG""" + conn = op.get_bind() # pylint: disable=no-member + is_mysql = bool(conn.dialect.name == "mysql") + timestamp = sa.TIMESTAMP(timezone=True) if not is_mysql else mysql.TIMESTAMP(fsp=6, timezone=True) + + with op.batch_alter_table('dag_run', schema=None) as batch_op: + batch_op.add_column(sa.Column('last_scheduling_decision', timestamp, nullable=True)) + batch_op.create_index('idx_last_scheduling_decision', ['last_scheduling_decision'], unique=False) + batch_op.add_column(sa.Column('dag_hash', sa.String(32), nullable=True)) + + with op.batch_alter_table('dag', schema=None) as batch_op: + batch_op.add_column(sa.Column('next_dagrun', timestamp, nullable=True)) + batch_op.add_column(sa.Column('next_dagrun_create_after', timestamp, nullable=True)) + # Create with nullable and no default, then ALTER to set values, to avoid table level lock + batch_op.add_column(sa.Column('concurrency', sa.Integer(), nullable=True)) + batch_op.add_column(sa.Column('has_task_concurrency_limits', sa.Boolean(), nullable=True)) + + batch_op.create_index('idx_next_dagrun_create_after', ['next_dagrun_create_after'], unique=False) + + try: + from airflow.configuration import conf + concurrency = conf.getint('core', 'dag_concurrency', fallback=16) + except: # noqa + concurrency = 16 + + # Set it to true here as it makes us take the slow/more complete path, and when it's next parsed by the + # DagParser it will get set to correct value. + op.execute( + "UPDATE dag SET concurrency={}, has_task_concurrency_limits=true where concurrency IS NULL".format( + concurrency + ) + ) + with op.batch_alter_table('dag', schema=None) as batch_op: + batch_op.alter_column('concurrency', type_=sa.Integer(), nullable=False) + batch_op.alter_column('has_task_concurrency_limits', type_=sa.Boolean(), nullable=False) + + +def downgrade(): + """Unapply Add scheduling_decision to DagRun and DAG""" + with op.batch_alter_table('dag_run', schema=None) as batch_op: + batch_op.drop_index('idx_last_scheduling_decision') + batch_op.drop_column('last_scheduling_decision') + batch_op.drop_column('dag_hash') + + with op.batch_alter_table('dag', schema=None) as batch_op: + batch_op.drop_index('idx_next_dagrun_create_after') + batch_op.drop_column('next_dagrun_create_after') + batch_op.drop_column('next_dagrun') + batch_op.drop_column('concurrency') + batch_op.drop_column('has_task_concurrency_limits') diff --git a/airflow/models/dag.py b/airflow/models/dag.py index b3465476e81bf..b50f2a184b0ed 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -28,7 +28,8 @@ from collections import OrderedDict from datetime import datetime, timedelta from typing import ( - TYPE_CHECKING, Callable, Collection, Dict, FrozenSet, Iterable, List, Optional, Set, Type, Union, cast, + TYPE_CHECKING, Callable, Collection, Dict, FrozenSet, Iterable, List, Optional, Set, Tuple, Type, Union, + cast, ) import jinja2 @@ -57,7 +58,7 @@ from airflow.utils.helpers import validate_key from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.session import provide_session -from airflow.utils.sqlalchemy import Interval, UtcDateTime +from airflow.utils.sqlalchemy import Interval, UtcDateTime, skip_locked, with_row_locks from airflow.utils.state import State from airflow.utils.types import DagRunType @@ -156,8 +157,7 @@ class DAG(BaseDag, LoggingMixin): :type max_active_runs: int :param dagrun_timeout: specify how long a DagRun should be up before timing out / failing, so that new DagRuns can be created. The timeout - is only enforced for scheduled DagRuns, and only once the - # of active DagRuns == max_active_runs. + is only enforced for scheduled DagRuns. :type dagrun_timeout: datetime.timedelta :param sla_miss_callback: specify a function to call when reporting SLA timeouts. @@ -467,6 +467,109 @@ def previous_schedule(self, dttm): elif self.normalized_schedule_interval is not None: return timezone.convert_to_utc(dttm - self.normalized_schedule_interval) + def next_dagrun_info( + self, + date_last_automated_dagrun: Optional[pendulum.DateTime], + ) -> Tuple[Optional[pendulum.DateTime], Optional[pendulum.DateTime]]: + """ + Get information about the next DagRun of this dag after ``date_last_automated_dagrun`` -- the + execution date, and the earliest it could be scheduled + + :param date_last_automated_dagrun: The max(execution_date) of existing + "automated" DagRuns for this dag (scheduled or backfill, but not + manual) + """ + if (self.schedule_interval == "@once" and date_last_automated_dagrun) or \ + self.schedule_interval is None: + # Manual trigger, or already created the run for @once, can short circuit + return (None, None) + next_execution_date = self.next_dagrun_after_date(date_last_automated_dagrun) + + if next_execution_date is None: + return (None, None) + + if self.schedule_interval == "@once": + # For "@once" it can be created "now" + return (next_execution_date, next_execution_date) + + return (next_execution_date, self.following_schedule(next_execution_date)) + + def next_dagrun_after_date(self, date_last_automated_dagrun: Optional[pendulum.DateTime]): + """ + Get the next execution date after the given ``date_last_automated_dagrun``, according to + schedule_interval, start_date, end_date etc. This doesn't check max active run or any other + "concurrency" type limits, it only performs calculations based on the various date and interval fields + of this dag and it's tasks. + + :param date_last_automated_dagrun: The execution_date of the last scheduler or + backfill triggered run for this dag + :type date_last_automated_dagrun: pendulum.Pendulum + """ + if not self.schedule_interval or self.is_subdag: + return None + + # don't schedule @once again + if self.schedule_interval == '@once' and date_last_automated_dagrun: + return None + + # don't do scheduler catchup for dag's that don't have dag.catchup = True + if not (self.catchup or self.schedule_interval == '@once'): + # The logic is that we move start_date up until + # one period before, so that timezone.utcnow() is AFTER + # the period end, and the job can be created... + now = timezone.utcnow() + next_start = self.following_schedule(now) + last_start = self.previous_schedule(now) + if next_start <= now or isinstance(self.schedule_interval, timedelta): + new_start = last_start + else: + new_start = self.previous_schedule(last_start) + + if self.start_date: + if new_start >= self.start_date: + self.start_date = new_start + else: + self.start_date = new_start + + next_run_date = None + if not date_last_automated_dagrun: + # First run + task_start_dates = [t.start_date for t in self.tasks] + if task_start_dates: + next_run_date = self.normalize_schedule(min(task_start_dates)) + self.log.debug("Next run date based on tasks %s", next_run_date) + else: + next_run_date = self.following_schedule(date_last_automated_dagrun) + + if date_last_automated_dagrun and next_run_date: + while next_run_date <= date_last_automated_dagrun: + next_run_date = self.following_schedule(next_run_date) + + # don't ever schedule prior to the dag's start_date + if self.start_date: + next_run_date = self.start_date if not next_run_date else max(next_run_date, self.start_date) + if next_run_date == self.start_date: + next_run_date = self.normalize_schedule(self.start_date) + + self.log.debug( + "Dag start date: %s. Next run date: %s", + self.start_date, next_run_date + ) + + # Don't schedule a dag beyond its end_date (as specified by the dag param) + if next_run_date and self.end_date and next_run_date > self.end_date: + return None + + # Don't schedule a dag beyond its end_date (as specified by the task params) + # Get the min task end date, which may come from the dag.default_args + task_end_dates = [t.end_date for t in self.tasks if t.end_date] + if task_end_dates and next_run_date: + min_task_end_date = min(task_end_dates) + if next_run_date > min_task_end_date: + return None + + return next_run_date + def get_run_dates(self, start_date, end_date=None): """ Returns a list of dates between the interval received as parameter using this @@ -609,10 +712,7 @@ def owner(self) -> str: @property def allow_future_exec_dates(self) -> bool: - return conf.getboolean( - 'scheduler', - 'allow_trigger_in_future', - fallback=False) and self.schedule_interval is None + return settings.ALLOW_FUTURE_EXEC_DATES and self.schedule_interval is None @provide_session def get_concurrency_reached(self, session=None) -> bool: @@ -1237,12 +1337,14 @@ def __deepcopy__(self, memo): result = cls.__new__(cls) memo[id(self)] = result for k, v in self.__dict__.items(): - if k not in ('user_defined_macros', 'user_defined_filters', 'params'): + if k not in ('user_defined_macros', 'user_defined_filters', 'params', '_log'): setattr(result, k, copy.deepcopy(v, memo)) result.user_defined_macros = self.user_defined_macros result.user_defined_filters = self.user_defined_filters result.params = self.params + if hasattr(self, '_log'): + result._log = self._log return result def sub_dag(self, task_regex, include_downstream=False, @@ -1506,15 +1608,18 @@ def cli(self): args.func(args, self) @provide_session - def create_dagrun(self, - state, - execution_date=None, - run_id=None, - start_date=None, - external_trigger=False, - conf=None, - run_type=None, - session=None): + def create_dagrun( + self, + state, + execution_date=None, + run_id=None, + start_date=None, + external_trigger=False, + conf=None, + run_type=None, + session=None, + dag_hash=None + ): """ Creates a dag run from this dag including the tasks associated with this dag. Returns the dag run. @@ -1535,6 +1640,8 @@ def create_dagrun(self, :type conf: dict :param session: database session :type session: sqlalchemy.orm.session.Session + :param dag_hash: Hash of Serialized DAG + :type dag_hash: str """ if run_id and not run_type: if not isinstance(run_id, str): @@ -1558,10 +1665,10 @@ def create_dagrun(self, conf=conf, state=state, run_type=run_type.value, + dag_hash=dag_hash ) session.add(run) - - session.commit() + session.flush() run.dag = self @@ -1573,34 +1680,40 @@ def create_dagrun(self, @classmethod @provide_session - def bulk_sync_to_db(cls, dags: Collection["DAG"], sync_time=None, session=None): + def bulk_sync_to_db(cls, dags: Collection["DAG"], session=None): + """This method is deprecated in favor of bulk_write_to_db""" + warnings.warn( + "This method is deprecated and will be removed in a future version. Please use bulk_write_to_db", + DeprecationWarning, + stacklevel=2, + ) + return cls.bulk_write_to_db(dags, session) + + @classmethod + @provide_session + def bulk_write_to_db(cls, dags: Collection["DAG"], session=None): """ - Save attributes about list of DAG to the DB. Note that this method - can be called for both DAGs and SubDAGs. A SubDag is actually a - SubDagOperator. + Ensure the DagModel rows for the given dags are up-to-date in the dag table in the DB, including + calculated fields. + + Note that this method can be called for both DAGs and SubDAGs. A SubDag is actually a SubDagOperator. :param dags: the DAG objects to save to the DB :type dags: List[airflow.models.dag.DAG] - :param sync_time: The time that the DAG should be marked as sync'ed - :type sync_time: datetime :return: None """ if not dags: return - if sync_time is None: - sync_time = timezone.utcnow() log.info("Sync %s DAGs", len(dags)) dag_by_ids = {dag.dag_id: dag for dag in dags} dag_ids = set(dag_by_ids.keys()) - orm_dags = ( + query = ( session .query(DagModel) .options(joinedload(DagModel.tags, innerjoin=False)) - .filter(DagModel.dag_id.in_(dag_ids)) - .with_for_update(of=DagModel) - .all() - ) + .filter(DagModel.dag_id.in_(dag_ids))) + orm_dags = with_row_locks(query, of=DagModel).all() existing_dag_ids = {orm_dag.dag_id for orm_dag in orm_dags} missing_dag_ids = dag_ids.difference(existing_dag_ids) @@ -1615,6 +1728,23 @@ def bulk_sync_to_db(cls, dags: Collection["DAG"], sync_time=None, session=None): session.add(orm_dag) orm_dags.append(orm_dag) + # Get the latest dag run for each existing dag as a single query (avoid n+1 query) + most_recent_dag_runs = dict(session.query(DagRun.dag_id, func.max_(DagRun.execution_date)).filter( + DagRun.dag_id.in_(existing_dag_ids), + or_( + DagRun.run_type == DagRunType.BACKFILL_JOB.value, + DagRun.run_type == DagRunType.SCHEDULED.value, + DagRun.external_trigger.is_(True), + ), + ).group_by(DagRun.dag_id).all()) + + # Get number of active dagruns for all dags we are processing as a single query. + num_active_runs = dict(session.query(DagRun.dag_id, func.count('*')).filter( + DagRun.dag_id.in_(existing_dag_ids), + DagRun.state == State.RUNNING, # pylint: disable=comparison-with-callable + DagRun.external_trigger.is_(False) + ).group_by(DagRun.dag_id).all()) + for orm_dag in sorted(orm_dags, key=lambda d: d.dag_id): dag = dag_by_ids[orm_dag.dag_id] if dag.is_subdag: @@ -1627,10 +1757,20 @@ def bulk_sync_to_db(cls, dags: Collection["DAG"], sync_time=None, session=None): orm_dag.fileloc = dag.fileloc orm_dag.owners = dag.owner orm_dag.is_active = True - orm_dag.last_scheduler_run = sync_time orm_dag.default_view = dag.default_view orm_dag.description = dag.description orm_dag.schedule_interval = dag.schedule_interval + orm_dag.concurrency = dag.concurrency + orm_dag.has_task_concurrency_limits = any( + t.task_concurrency is not None for t in dag.tasks + ) + + orm_dag.calculate_dagrun_date_fields( + dag, + most_recent_dag_runs.get(dag.dag_id), + num_active_runs.get(dag.dag_id, 0), + ) + for orm_tag in list(orm_dag.tags): if orm_tag.name not in orm_dag.tags: session.delete(orm_tag) @@ -1646,23 +1786,23 @@ def bulk_sync_to_db(cls, dags: Collection["DAG"], sync_time=None, session=None): if settings.STORE_DAG_CODE: DagCode.bulk_sync_to_db([dag.fileloc for dag in orm_dags]) - session.commit() + # Issue SQL/finish "Unit of Work", but let @provide_session commit (or if passed a session, let caller + # decide when to commit + session.flush() for dag in dags: - cls.bulk_sync_to_db(dag.subdags, sync_time=sync_time, session=session) + cls.bulk_write_to_db(dag.subdags, session=session) @provide_session - def sync_to_db(self, sync_time=None, session=None): + def sync_to_db(self, session=None): """ Save attributes about this DAG to the DB. Note that this method can be called for both DAGs and SubDAGs. A SubDag is actually a SubDagOperator. - :param sync_time: The time that the DAG should be marked as sync'ed - :type sync_time: datetime :return: None """ - self.bulk_sync_to_db([self], sync_time, session) + self.bulk_write_to_db([self], session) def get_default_view(self): """This is only there for backward compatible jinja2 templates""" @@ -1824,10 +1964,34 @@ class DagModel(Base): # Tags for view filter tags = relationship('DagTag', cascade='all,delete-orphan', backref=backref('dag')) + concurrency = Column(Integer, nullable=False) + + has_task_concurrency_limits = Column(Boolean, nullable=False) + + # The execution_date of the next dag run + next_dagrun = Column(UtcDateTime) + # Earliest time at which this ``next_dagrun`` can be created + next_dagrun_create_after = Column(UtcDateTime) + __table_args__ = ( Index('idx_root_dag_id', root_dag_id, unique=False), + Index('idx_next_dagrun_create_after', next_dagrun_create_after, unique=False), + ) + + NUM_DAGS_PER_DAGRUN_QUERY = conf.getint( + 'scheduler', + 'max_dagruns_to_create_per_loop', + fallback=10 ) + def __init__(self, **kwargs): + super().__init__(**kwargs) + if self.concurrency is None: + self.concurrency = conf.getint('core', 'dag_concurrency') + if self.has_task_concurrency_limits is None: + # Be safe -- this will be updated later once the DAG is parsed + self.has_task_concurrency_limits = True + def __repr__(self): return "".format(self=self) @@ -1939,6 +2103,63 @@ def deactivate_deleted_dags(cls, alive_dag_filelocs: List[str], session=None): session.rollback() raise + @classmethod + def dags_needing_dagruns(cls, session: Session): + """ + Return (and lock) a list of Dag objects that are due to create a new DagRun. + + This will return a resultset of rows that is row-level-locked with a "SELECT ... FOR UPDATE" query, + you should ensure that any scheduling decisions are made in a single transaction -- as soon as the + transaction is committed it will be unlocked. + """ + # TODO[HA]: Bake this query, it is run _A lot_ + # We limit so that _one_ scheduler doesn't try to do all the creation + # of dag runs + query = session.query(cls).filter( + cls.is_paused.is_(False), + cls.is_active.is_(True), + cls.next_dagrun_create_after <= func.now(), + ).order_by( + cls.next_dagrun_create_after + ).limit(cls.NUM_DAGS_PER_DAGRUN_QUERY) + + return with_row_locks(query, of=cls, **skip_locked(session=session)) + + def calculate_dagrun_date_fields( + self, + dag: DAG, + most_recent_dag_run: Optional[pendulum.DateTime], + active_runs_of_dag: int + ) -> None: + """ + Calculate ``next_dagrun`` and `next_dagrun_create_after`` + + :param dag: The DAG object + :param most_recent_dag_run: DateTime of most recent run of this dag, or none if not yet scheduled. + :param active_runs_of_dag: Number of currently active runs of this dag + """ + self.next_dagrun, self.next_dagrun_create_after = dag.next_dagrun_info(most_recent_dag_run) + + if dag.max_active_runs and active_runs_of_dag >= dag.max_active_runs: + # Since this happens every time the dag is parsed it would be quite spammy at info + log.debug( + "DAG %s is at (or above) max_active_runs (%d of %d), not creating any more runs", + dag.dag_id, active_runs_of_dag, dag.max_active_runs + ) + self.next_dagrun_create_after = None + + log.info("Setting next_dagrun for %s to %s", dag.dag_id, self.next_dagrun) + + +STATICA_HACK = True +globals()['kcah_acitats'[::-1].upper()] = False +if STATICA_HACK: # pragma: no cover + # Let pylint know about these relationships, without introducing an import cycle + from sqlalchemy.orm import relationship + + from airflow.models.serialized_dag import SerializedDagModel + DagModel.serialized_dag = relationship(SerializedDagModel) + class DagContext: """ diff --git a/airflow/models/dagbag.py b/airflow/models/dagbag.py index ed709f916d948..a2fc4bd4bdf93 100644 --- a/airflow/models/dagbag.py +++ b/airflow/models/dagbag.py @@ -30,6 +30,7 @@ from typing import Dict, List, NamedTuple, Optional from croniter import CroniterBadCronError, CroniterBadDateError, CroniterNotAlphaError, croniter +from sqlalchemy.orm import Session from tabulate import tabulate from airflow import settings @@ -42,6 +43,7 @@ from airflow.utils.dag_cycle_tester import test_cycle from airflow.utils.file import correct_maybe_zipped, list_py_file_paths, might_contain_dag from airflow.utils.log.logging_mixin import LoggingMixin +from airflow.utils.session import provide_session from airflow.utils.timeout import timeout @@ -115,6 +117,8 @@ def __init__( self.read_dags_from_db = read_dags_from_db # Only used by read_dags_from_db=True self.dags_last_fetched: Dict[str, datetime] = {} + # Only used by SchedulerJob to compare the dag_hash to identify change in DAGs + self.dags_hash: Dict[str, str] = {} self.dagbag_import_error_tracebacks = conf.getboolean('core', 'dagbag_import_error_tracebacks') self.dagbag_import_error_traceback_depth = conf.getint('core', 'dagbag_import_error_traceback_depth') @@ -144,7 +148,8 @@ def store_serialized_dags(self) -> bool: def dag_ids(self) -> List[str]: return list(self.dags.keys()) - def get_dag(self, dag_id): + @provide_session + def get_dag(self, dag_id, session: Session = None): """ Gets the DAG out of the dictionary, and refreshes it if expired @@ -159,7 +164,7 @@ def get_dag(self, dag_id): from airflow.models.serialized_dag import SerializedDagModel if dag_id not in self.dags: # Load from DB if not (yet) in the bag - self._add_dag_from_db(dag_id=dag_id) + self._add_dag_from_db(dag_id=dag_id, session=session) return self.dags.get(dag_id) # If DAG is in the DagBag, check the following @@ -171,9 +176,12 @@ def get_dag(self, dag_id): dag_id in self.dags_last_fetched and timezone.utcnow() > self.dags_last_fetched[dag_id] + min_serialized_dag_fetch_secs ): - sd_last_updated_datetime = SerializedDagModel.get_last_updated_datetime(dag_id=dag_id) + sd_last_updated_datetime = SerializedDagModel.get_last_updated_datetime( + dag_id=dag_id, + session=session, + ) if sd_last_updated_datetime > self.dags_last_fetched[dag_id]: - self._add_dag_from_db(dag_id=dag_id) + self._add_dag_from_db(dag_id=dag_id, session=session) return self.dags.get(dag_id) @@ -183,16 +191,16 @@ def get_dag(self, dag_id): if dag_id in self.dags: dag = self.dags[dag_id] if dag.is_subdag: - root_dag_id = dag.parent_dag.dag_id + root_dag_id = dag.parent_dag.dag_id # type: ignore # If DAG Model is absent, we can't check last_expired property. Is the DAG not yet synchronized? - orm_dag = DagModel.get_current(root_dag_id) + orm_dag = DagModel.get_current(root_dag_id, session=session) if not orm_dag: return self.dags.get(dag_id) # If the dag corresponding to root_dag_id is absent or expired is_missing = root_dag_id not in self.dags - is_expired = (orm_dag.last_expired and dag.last_loaded < orm_dag.last_expired) + is_expired = (orm_dag.last_expired and dag and dag.last_loaded < orm_dag.last_expired) if is_missing or is_expired: # Reprocess source file found_dags = self.process_file( @@ -205,10 +213,10 @@ def get_dag(self, dag_id): del self.dags[dag_id] return self.dags.get(dag_id) - def _add_dag_from_db(self, dag_id: str): + def _add_dag_from_db(self, dag_id: str, session: Session): """Add DAG to DagBag from DB""" from airflow.models.serialized_dag import SerializedDagModel - row = SerializedDagModel.get(dag_id) + row = SerializedDagModel.get(dag_id, session) if not row: raise ValueError(f"DAG '{dag_id}' not found in serialized_dag table") @@ -217,6 +225,7 @@ def _add_dag_from_db(self, dag_id: str): self.dags[subdag.dag_id] = subdag self.dags[dag.dag_id] = dag self.dags_last_fetched[dag.dag_id] = timezone.utcnow() + self.dags_hash[dag.dag_id] = row.dag_hash def process_file(self, filepath, only_if_updated=True, safe_mode=True): """ @@ -514,7 +523,8 @@ def dagbag_report(self): """) return report - def sync_to_db(self): + @provide_session + def sync_to_db(self, session: Optional[Session] = None): """ Save attributes about list of DAG to the DB. """ @@ -522,9 +532,9 @@ def sync_to_db(self): from airflow.models.dag import DAG from airflow.models.serialized_dag import SerializedDagModel self.log.debug("Calling the DAG.bulk_sync_to_db method") - DAG.bulk_sync_to_db(self.dags.values()) + DAG.bulk_write_to_db(self.dags.values(), session=session) # Write Serialized DAGs to DB if DAG Serialization is turned on # Even though self.read_dags_from_db is False - if settings.STORE_SERIALIZED_DAGS: + if settings.STORE_SERIALIZED_DAGS or self.read_dags_from_db: self.log.debug("Calling the SerializedDagModel.bulk_sync_to_db method") - SerializedDagModel.bulk_sync_to_db(self.dags.values()) + SerializedDagModel.bulk_sync_to_db(self.dags.values(), session=session) diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py index 429e98dbfa244..524f25ac53733 100644 --- a/airflow/models/dagrun.py +++ b/airflow/models/dagrun.py @@ -26,6 +26,8 @@ from sqlalchemy.orm import backref, relationship, synonym from sqlalchemy.orm.session import Session +from airflow import settings +from airflow.configuration import conf as airflow_conf from airflow.exceptions import AirflowException from airflow.models.base import ID_LEN, Base from airflow.models.taskinstance import TaskInstance as TI @@ -33,10 +35,10 @@ from airflow.stats import Stats from airflow.ti_deps.dep_context import DepContext from airflow.ti_deps.dependencies_states import SCHEDULEABLE_STATES -from airflow.utils import timezone +from airflow.utils import callback_requests, timezone from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.session import provide_session -from airflow.utils.sqlalchemy import UtcDateTime +from airflow.utils.sqlalchemy import UtcDateTime, nulls_first, skip_locked, with_row_locks from airflow.utils.state import State from airflow.utils.types import DagRunType @@ -59,6 +61,9 @@ class DagRun(Base, LoggingMixin): external_trigger = Column(Boolean, default=True) run_type = Column(String(50), nullable=False) conf = Column(PickleType) + # When a scheduler last attempted to schedule TIs for this DagRun + last_scheduling_decision = Column(UtcDateTime) + dag_hash = Column(String(32)) dag = None @@ -66,6 +71,7 @@ class DagRun(Base, LoggingMixin): Index('dag_id_state', dag_id, _state), UniqueConstraint('dag_id', 'execution_date'), UniqueConstraint('dag_id', 'run_id'), + Index('idx_last_scheduling_decision', last_scheduling_decision), ) task_instances = relationship( @@ -75,6 +81,12 @@ class DagRun(Base, LoggingMixin): backref=backref('dag_run', uselist=False), ) + DEFAULT_DAGRUNS_TO_EXAMINE = airflow_conf.getint( + 'scheduler', + 'max_dagruns_per_loop_to_schedule', + fallback=20, + ) + def __init__( self, dag_id: Optional[str] = None, @@ -84,7 +96,8 @@ def __init__( external_trigger: Optional[bool] = None, conf: Optional[Any] = None, state: Optional[str] = None, - run_type: Optional[str] = None + run_type: Optional[str] = None, + dag_hash: Optional[str] = None, ): self.dag_id = dag_id self.run_id = run_id @@ -94,6 +107,7 @@ def __init__( self.conf = conf or {} self.state = state self.run_type = run_type + self.dag_hash = dag_hash super().__init__() def __repr__(self): @@ -139,6 +153,46 @@ def refresh_from_db(self, session: Session = None): self.id = dr.id self.state = dr.state + @classmethod + def next_dagruns_to_examine( + cls, + session: Session, + max_number: Optional[int] = None, + ): + """ + Return the next DagRuns that the scheduler should attempt to schedule. + + This will return zero or more DagRun rows that are row-level-locked with a "SELECT ... FOR UPDATE" + query, you should ensure that any scheduling decisions are made in a single transaction -- as soon as + the transaction is committed it will be unlocked. + + :rtype: list[airflow.models.DagRun] + """ + from airflow.models.dag import DagModel + + if max_number is None: + max_number = cls.DEFAULT_DAGRUNS_TO_EXAMINE + + # TODO: Bake this query, it is run _A lot_ + query = session.query(cls).filter( + cls.state == State.RUNNING, + cls.run_type != DagRunType.BACKFILL_JOB.value + ).join( + DagModel, + DagModel.dag_id == cls.dag_id, + ).filter( + DagModel.is_paused.is_(False), + DagModel.is_active.is_(True), + ).order_by( + nulls_first(cls.last_scheduling_decision, session=session), + cls.execution_date, + ) + + if not settings.ALLOW_FUTURE_EXEC_DATES: + query = query.filter(DagRun.execution_date <= func.now()) + + return with_row_locks(query.limit(max_number), of=cls, **skip_locked(session=session)) + @staticmethod @provide_session def find( @@ -301,16 +355,29 @@ def get_previous_scheduled_dagrun(self, session: Session = None): ).first() @provide_session - def update_state(self, session: Session = None) -> List[TI]: + def update_state( + self, + session: Session = None, + execute_callbacks: bool = True + ) -> Tuple[List[TI], Optional[callback_requests.DagCallbackRequest]]: """ Determines the overall state of the DagRun based on the state of its TaskInstances. :param session: Sqlalchemy ORM Session :type session: Session - :return: ready_tis: the tis that can be scheduled in the current loop - :rtype ready_tis: list[airflow.models.TaskInstance] + :param execute_callbacks: Should dag callbacks (success/failure, SLA etc) be invoked + directly (default: true) or recorded as a pending request in the ``callback`` property + :type execute_callbacks: bool + :return: Tuple containing tis that can be scheduled in the current loop & `callback` that + needs to be executed """ + # Callback to execute in case of Task Failures + callback: Optional[callback_requests.DagCallbackRequest] = None + + start_dttm = timezone.utcnow() + self.last_scheduling_decision = start_dttm + dag = self.get_dag() ready_tis: List[TI] = [] tis = list(self.get_task_instances(session=session, state=State.task_states + (State.SHUTDOWN,))) @@ -318,12 +385,10 @@ def update_state(self, session: Session = None) -> List[TI]: for ti in tis: ti.task = dag.get_task(ti.task_id) - start_dttm = timezone.utcnow() unfinished_tasks = [t for t in tis if t.state in State.unfinished()] finished_tasks = [t for t in tis if t.state in State.finished() + [State.UPSTREAM_FAILED]] none_depends_on_past = all(not t.task.depends_on_past for t in unfinished_tasks) - none_task_concurrency = all(t.task.task_concurrency is None - for t in unfinished_tasks) + none_task_concurrency = all(t.task.task_concurrency is None for t in unfinished_tasks) if unfinished_tasks: scheduleable_tasks = [ut for ut in unfinished_tasks if ut.state in SCHEDULEABLE_STATES] self.log.debug( @@ -348,7 +413,16 @@ def update_state(self, session: Session = None) -> List[TI]: ): self.log.error('Marking run %s failed', self) self.set_state(State.FAILED) - dag.handle_callback(self, success=False, reason='task_failure', session=session) + if execute_callbacks: + dag.handle_callback(self, success=False, reason='task_failure', session=session) + else: + callback = callback_requests.DagCallbackRequest( + full_filepath=dag.fileloc, + dag_id=self.dag_id, + execution_date=self.execution_date, + is_failure_callback=True, + msg='task_failure' + ) # if all leafs succeeded and no unfinished tasks, the run succeeded elif not unfinished_tasks and all( @@ -356,15 +430,32 @@ def update_state(self, session: Session = None) -> List[TI]: ): self.log.info('Marking run %s successful', self) self.set_state(State.SUCCESS) - dag.handle_callback(self, success=True, reason='success', session=session) + if execute_callbacks: + dag.handle_callback(self, success=True, reason='success', session=session) + else: + callback = callback_requests.DagCallbackRequest( + full_filepath=dag.fileloc, + dag_id=self.dag_id, + execution_date=self.execution_date, + is_failure_callback=False, + msg='success' + ) # if *all tasks* are deadlocked, the run failed elif (unfinished_tasks and none_depends_on_past and none_task_concurrency and not are_runnable_tasks): self.log.error('Deadlock; marking run %s failed', self) self.set_state(State.FAILED) - dag.handle_callback(self, success=False, reason='all_tasks_deadlocked', - session=session) + if execute_callbacks: + dag.handle_callback(self, success=False, reason='all_tasks_deadlocked', session=session) + else: + callback = callback_requests.DagCallbackRequest( + full_filepath=dag.fileloc, + dag_id=self.dag_id, + execution_date=self.execution_date, + is_failure_callback=True, + msg='all_tasks_deadlocked' + ) # finally, if the roots aren't done, the dag is still running else: @@ -372,11 +463,9 @@ def update_state(self, session: Session = None) -> List[TI]: self._emit_duration_stats_for_finished_state() - # todo: determine we want to use with_for_update to make sure to lock the run session.merge(self) - session.commit() - return ready_tis + return ready_tis, callback def _get_ready_tis( self, @@ -492,12 +581,13 @@ def verify_integrity(self, session: Session = None): session.add(ti) try: - session.commit() + session.flush() except IntegrityError as err: self.log.info(str(err)) self.log.info('Hit IntegrityError while creating the TIs for ' f'{dag.dag_id} - {self.execution_date}.') self.log.info('Doing session rollback.') + # TODO[HA]: We probably need to savepoint this so we can keep the transaction alive. session.rollback() @staticmethod diff --git a/airflow/models/pool.py b/airflow/models/pool.py index 1fbc3ad0aae23..60fa9264e510b 100644 --- a/airflow/models/pool.py +++ b/airflow/models/pool.py @@ -26,6 +26,7 @@ from airflow.ti_deps.dependencies_states import EXECUTION_STATES from airflow.typing_compat import TypedDict from airflow.utils.session import provide_session +from airflow.utils.sqlalchemy import nowait, with_row_locks from airflow.utils.state import State @@ -81,17 +82,31 @@ def get_default_pool(session: Session = None): @staticmethod @provide_session - def slots_stats(session: Session = None) -> Dict[str, PoolStats]: + def slots_stats( + *, + lock_rows: bool = False, + session: Session = None, + ) -> Dict[str, PoolStats]: """ Get Pool stats (Number of Running, Queued, Open & Total tasks) + If ``lock_rows`` is True, and the database engine in use supports the ``NOWAIT`` syntax, then a + non-blocking lock will be attempted -- if the lock is not available then SQLAlchemy will throw an + OperationalError. + + :param lock_rows: Should we attempt to obtain a row-level lock on all the Pool rows returns :param session: SQLAlchemy ORM Session """ from airflow.models.taskinstance import TaskInstance # Avoid circular import pools: Dict[str, PoolStats] = {} - pool_rows: Iterable[Tuple[str, int]] = session.query(Pool.pool, Pool.slots).all() + query = session.query(Pool.pool, Pool.slots) + + if lock_rows: + query = with_row_locks(query, **nowait(session)) + + pool_rows: Iterable[Tuple[str, int]] = query.all() for (pool_name, total_slots) in pool_rows: pools[pool_name] = PoolStats(total=total_slots, running=0, queued=0, open=0) diff --git a/airflow/models/serialized_dag.py b/airflow/models/serialized_dag.py index 12e4c8c65cef7..e2174bfd97017 100644 --- a/airflow/models/serialized_dag.py +++ b/airflow/models/serialized_dag.py @@ -25,12 +25,13 @@ import sqlalchemy_jsonfield from sqlalchemy import BigInteger, Column, Index, String, and_ -from sqlalchemy.orm import Session +from sqlalchemy.orm import Session, backref, relationship from sqlalchemy.sql import exists from airflow.models.base import ID_LEN, Base from airflow.models.dag import DAG, DagModel from airflow.models.dagcode import DagCode +from airflow.models.dagrun import DagRun from airflow.serialization.serialized_objects import SerializedDAG from airflow.settings import MIN_SERIALIZED_DAG_UPDATE_INTERVAL, json from airflow.utils import timezone @@ -73,6 +74,22 @@ class SerializedDagModel(Base): Index('idx_fileloc_hash', fileloc_hash, unique=False), ) + dag_runs = relationship( + DagRun, + primaryjoin=dag_id == DagRun.dag_id, + foreign_keys=dag_id, + backref=backref('serialized_dag', uselist=False, innerjoin=True), + ) + + dag_model = relationship( + DagModel, + primaryjoin=dag_id == DagModel.dag_id, # type: ignore + foreign_keys=dag_id, + uselist=False, + innerjoin=True, + backref=backref('serialized_dag', uselist=False, innerjoin=True), + ) + def __init__(self, dag: DAG): self.dag_id = dag.dag_id self.fileloc = dag.full_filepath @@ -247,3 +264,18 @@ def get_last_updated_datetime(cls, dag_id: str, session: Session = None) -> date :type session: Session """ return session.query(cls.last_updated).filter(cls.dag_id == dag_id).scalar() + + @classmethod + @provide_session + def get_latest_version_hash(cls, dag_id: str, session: Session = None) -> str: + """ + Get the latest DAG version for a given DAG ID. + + :param dag_id: DAG ID + :type dag_id: str + :param session: ORM Session + :type session: Session + :return: DAG Hash + :rtype: str + """ + return session.query(cls.dag_hash).filter(cls.dag_id == dag_id).scalar() diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 335b4dd4a573f..4c6f30a08cd7e 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -34,7 +34,7 @@ import pendulum from jinja2 import TemplateAssertionError, UndefinedError from sqlalchemy import Column, Float, Index, Integer, PickleType, String, and_, func, or_ -from sqlalchemy.orm import reconstructor +from sqlalchemy.orm import reconstructor, relationship from sqlalchemy.orm.session import Session from sqlalchemy.sql.elements import BooleanClauseList @@ -237,6 +237,14 @@ class TaskInstance(Base, LoggingMixin): # pylint: disable=R0902,R0904 Index('ti_job_id', job_id), ) + dag_model = relationship( + "DagModel", + primaryjoin="TaskInstance.dag_id == DagModel.dag_id", + foreign_keys=dag_id, + uselist=False, + innerjoin=True, + ) + def __init__(self, task, execution_date: datetime, state: Optional[str] = None): super().__init__() self.dag_id = task.dag_id @@ -1733,12 +1741,15 @@ def xcom_push( dag_id=self.dag_id, execution_date=execution_date or self.execution_date) + @provide_session def xcom_pull( # pylint: disable=inconsistent-return-statements - self, - task_ids: Optional[Union[str, Iterable[str]]] = None, - dag_id: Optional[str] = None, - key: str = XCOM_RETURN_KEY, - include_prior_dates: bool = False) -> Any: + self, + task_ids: Optional[Union[str, Iterable[str]]] = None, + dag_id: Optional[str] = None, + key: str = XCOM_RETURN_KEY, + include_prior_dates: bool = False, + session: Session = None + ) -> Any: """ Pull XComs that optionally meet certain criteria. @@ -1767,6 +1778,8 @@ def xcom_pull( # pylint: disable=inconsistent-return-statements execution_date are returned. If True, XComs from previous dates are returned as well. :type include_prior_dates: bool + :param session: Sqlalchemy ORM Session + :type session: Session """ if dag_id is None: dag_id = self.dag_id @@ -1776,7 +1789,8 @@ def xcom_pull( # pylint: disable=inconsistent-return-statements key=key, dag_ids=dag_id, task_ids=task_ids, - include_prior_dates=include_prior_dates + include_prior_dates=include_prior_dates, + session=session ).with_entities(XCom.value) # Since we're only fetching the values field, and not the diff --git a/airflow/settings.py b/airflow/settings.py index 0fe1b93fa9c1d..479694d92b070 100644 --- a/airflow/settings.py +++ b/airflow/settings.py @@ -373,3 +373,17 @@ def initialize(): 'execute_tasks_new_python_interpreter', fallback=False, ) + +ALLOW_FUTURE_EXEC_DATES = conf.getboolean('scheduler', 'allow_trigger_in_future', fallback=False) + +# Whether or not to check each dagrun against defined SLAs +CHECK_SLAS = conf.getboolean('core', 'check_slas', fallback=True) + +# Number of times, the code should be retried in case of DB Operational Errors +# Retries are done using tenacity. Not all transactions should be retried as it can cause +# undesired state. +# Currently used in the following places: +# `DagFileProcessor.process_file` to retry `dagbag.sync_to_db` +MAX_DB_RETRIES = conf.getint('core', 'max_db_retries', fallback=3) + +USE_JOB_SCHEDULE = conf.getboolean('scheduler', 'use_job_schedule', fallback=True) diff --git a/airflow/stats.py b/airflow/stats.py index 5913f765534f5..3d4b0875c6276 100644 --- a/airflow/stats.py +++ b/airflow/stats.py @@ -30,6 +30,24 @@ log = logging.getLogger(__name__) +class TimerProtocol(Protocol): + """Type protocol for StatsLogger.timer""" + + def __enter__(self): + ... + + def __exit__(self, exc_type, exc_value, traceback): + ... + + def start(self): + """Start the timer""" + ... + + def stop(self, send=True): + """Stop, and (by default) submit the timer to statsd""" + ... + + class StatsLogger(Protocol): """This class is only used for TypeChecking (for IDEs, mypy, pylint, etc)""" @@ -49,6 +67,26 @@ def gauge(cls, stat: str, value: float, rate: int = 1, delta: bool = False) -> N def timing(cls, stat: str, dt) -> None: """Stats timing""" + @classmethod + def timer(cls, *args, **kwargs) -> TimerProtocol: + """Timer metric that can be cancelled""" + + +class DummyTimer: + """No-op timer""" + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + return self + + def start(self): + """Start the timer""" + + def stop(self, send=True): # pylint: disable=unused-argument + """Stop, and (by default) submit the timer to statsd""" + class DummyStatsLogger: """If no StatsLogger is configured, DummyStatsLogger is used as a fallback""" @@ -69,6 +107,11 @@ def gauge(cls, stat, value, rate=1, delta=False): def timing(cls, stat, dt): """Stats timing""" + @classmethod + def timer(cls, *args, **kwargs): + """Timer metric that can be cancelled""" + return DummyTimer() + # Only characters in the character set are considered valid # for the stat_name if stat_name_default_handler is used. @@ -171,6 +214,13 @@ def timing(self, stat, dt): return self.statsd.timing(stat, dt) return None + @validate_stat + def timer(self, stat, *args, **kwargs): + """Timer metric that can be cancelled""" + if self.allow_list_validator.test(stat): + return self.statsd.timer(stat, *args, **kwargs) + return DummyTimer() + class SafeDogStatsdLogger: """DogStatsd Logger""" @@ -211,6 +261,14 @@ def timing(self, stat, dt, tags=None): return self.dogstatsd.timing(metric=stat, value=dt, tags=tags) return None + @validate_stat + def timer(self, stat, *args, tags=None, **kwargs): + """Timer metric that can be cancelled""" + if self.allow_list_validator.test(stat): + tags = tags or [] + return self.dogstatsd.timer(stat, *args, tags=tags, **kwargs) + return DummyTimer() + class _Stats(type): instance: Optional[StatsLogger] = None diff --git a/airflow/ti_deps/deps/not_previously_skipped_dep.py b/airflow/ti_deps/deps/not_previously_skipped_dep.py index 409d73a89844c..4ecef93ad847b 100644 --- a/airflow/ti_deps/deps/not_previously_skipped_dep.py +++ b/airflow/ti_deps/deps/not_previously_skipped_dep.py @@ -51,7 +51,7 @@ def _get_dep_statuses( continue prev_result = ti.xcom_pull( - task_ids=parent.task_id, key=XCOM_SKIPMIXIN_KEY + task_ids=parent.task_id, key=XCOM_SKIPMIXIN_KEY, session=session ) if prev_result is None: diff --git a/airflow/utils/callback_requests.py b/airflow/utils/callback_requests.py new file mode 100644 index 0000000000000..fe8017c721fb8 --- /dev/null +++ b/airflow/utils/callback_requests.py @@ -0,0 +1,101 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from datetime import datetime +from typing import Optional + +from airflow.models.taskinstance import SimpleTaskInstance + + +class CallbackRequest: + """ + Base Class with information about the callback to be executed. + + :param full_filepath: File Path to use to run the callback + :param msg: Additional Message that can be used for logging + """ + + def __init__(self, full_filepath: str, msg: Optional[str] = None): + self.full_filepath = full_filepath + self.msg = msg + + def __eq__(self, other): + return self.__dict__ == other.__dict__ + + def __repr__(self): + return str(self.__dict__) + + +class TaskCallbackRequest(CallbackRequest): + """ + A Class with information about the success/failure TI callback to be executed. Currently, only failure + callbacks (when tasks are externally killed) and Zombies are run via DagFileProcessorProcess. + + :param full_filepath: File Path to use to run the callback + :param simple_task_instance: Simplified Task Instance representation + :param is_failure_callback: Flag to determine whether it is a Failure Callback or Success Callback + :param msg: Additional Message that can be used for logging to determine failure/zombie + """ + + def __init__( + self, + full_filepath: str, + simple_task_instance: SimpleTaskInstance, + is_failure_callback: Optional[bool] = True, + msg: Optional[str] = None + ): + super().__init__(full_filepath=full_filepath, msg=msg) + self.simple_task_instance = simple_task_instance + self.is_failure_callback = is_failure_callback + + +class DagCallbackRequest(CallbackRequest): + """ + A Class with information about the success/failure DAG callback to be executed. + + :param full_filepath: File Path to use to run the callback + :param dag_id: DAG ID + :param execution_date: Execution Date for the DagRun + :param is_failure_callback: Flag to determine whether it is a Failure Callback or Success Callback + :param msg: Additional Message that can be used for logging + """ + + def __init__( + self, + full_filepath: str, + dag_id: str, + execution_date: datetime, + is_failure_callback: Optional[bool] = True, + msg: Optional[str] = None + ): + super().__init__(full_filepath=full_filepath, msg=msg) + self.dag_id = dag_id + self.execution_date = execution_date + self.is_failure_callback = is_failure_callback + + +class SlaCallbackRequest(CallbackRequest): + """ + A class with information about the SLA callback to be executed. + + :param full_filepath: File Path to use to run the callback + :param dag_id: DAG ID + """ + + def __init__(self, full_filepath: str, dag_id: str): + super().__init__(full_filepath) + self.dag_id = dag_id diff --git a/airflow/utils/dag_processing.py b/airflow/utils/dag_processing.py index 9363c6b052c16..6581164a06590 100644 --- a/airflow/utils/dag_processing.py +++ b/airflow/utils/dag_processing.py @@ -30,7 +30,7 @@ from datetime import datetime, timedelta from importlib import import_module from multiprocessing.connection import Connection as MultiprocessingConnection -from typing import Callable, Dict, KeysView, List, NamedTuple, Optional, Tuple +from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, Union, cast from setproctitle import setproctitle # pylint: disable=no-name-in-module from sqlalchemy import or_ @@ -38,14 +38,12 @@ import airflow.models from airflow.configuration import conf -from airflow.dag.base_dag import BaseDagBag -from airflow.exceptions import AirflowException from airflow.models import errors -from airflow.models.taskinstance import SimpleTaskInstance, TaskInstance -from airflow.serialization.serialized_objects import SerializedDAG +from airflow.models.taskinstance import SimpleTaskInstance from airflow.settings import STORE_DAG_CODE, STORE_SERIALIZED_DAGS from airflow.stats import Stats from airflow.utils import timezone +from airflow.utils.callback_requests import CallbackRequest, SlaCallbackRequest, TaskCallbackRequest from airflow.utils.file import list_py_file_paths from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.mixins import MultiprocessingStartMethodMixin @@ -54,45 +52,6 @@ from airflow.utils.state import State -class SimpleDagBag(BaseDagBag): - """ - A collection of SimpleDag objects with some convenience methods. - """ - - def __init__(self, serialized_dags: List[SerializedDAG]): - """ - Constructor. - - :param serialized_dags: SimpleDag objects that should be in this - :type serialized_dags: list[dict] - """ - self.serialized_dags = serialized_dags - self.dag_id_to_simple_dag: Dict[str, SerializedDAG] = {} - - for serialized_dag in serialized_dags: - self.dag_id_to_simple_dag[serialized_dag.dag_id] = serialized_dag - - @property - def dag_ids(self) -> KeysView[str]: - """ - :return: IDs of all the DAGs in this - :rtype: list[unicode] - """ - return self.dag_id_to_simple_dag.keys() - - def get_dag(self, dag_id: str) -> SerializedDAG: - """ - :param dag_id: DAG ID - :type dag_id: unicode - :return: if the given DAG ID exists in the bag, return the BaseDag - corresponding to that ID. Otherwise, throw an Exception - :rtype: SerializedDAG - """ - if dag_id not in self.dag_id_to_simple_dag: - raise AirflowException("Unknown DAG ID {}".format(dag_id)) - return self.dag_id_to_simple_dag[dag_id] - - class AbstractDagFileProcessorProcess(metaclass=ABCMeta): """ Processes a DAG file. See SchedulerJob.process_file() for more details. @@ -149,12 +108,12 @@ def done(self) -> bool: @property @abstractmethod - def result(self) -> Optional[Tuple[List[dict], int]]: + def result(self) -> Optional[Tuple[int, int]]: """ A list of simple dags found, and the number of import errors :return: result of running SchedulerJob.process_file() if availlablle. Otherwise, none - :rtype: Optional[Tuple[List[dict], int]] + :rtype: Optional[Tuple[int, int]] """ raise NotImplementedError() @@ -211,14 +170,6 @@ class DagParsingSignal(enum.Enum): END_MANAGER = 'end_manager' -class FailureCallbackRequest(NamedTuple): - """A message with information about the callback to be executed.""" - - full_filepath: str - simple_task_instance: SimpleTaskInstance - msg: str - - class DagFileProcessorAgent(LoggingMixin, MultiprocessingStartMethodMixin): """ Agent for DAG file processing. It is responsible for all DAG parsing @@ -236,7 +187,7 @@ class DagFileProcessorAgent(LoggingMixin, MultiprocessingStartMethodMixin): :type max_runs: int :param processor_factory: function that creates processors for DAG definition files. Arguments are (dag_definition_path, log_file_path) - :type processor_factory: ([str, List[FailureCallbackRequest], Optional[List[str]], bool]) -> ( + :type processor_factory: ([str, List[CallbackRequest], Optional[List[str]], bool]) -> ( AbstractDagFileProcessorProcess ) :param processor_timeout: How long to wait before timing out a DAG file processor @@ -254,7 +205,7 @@ def __init__( dag_directory: str, max_runs: int, processor_factory: Callable[ - [str, List[FailureCallbackRequest], Optional[List[str]], bool], + [str, List[CallbackRequest], Optional[List[str]], bool], AbstractDagFileProcessorProcess ], processor_timeout: timedelta, @@ -280,7 +231,6 @@ def __init__( self._all_files_processed = True self._parent_signal_conn: Optional[MultiprocessingConnection] = None - self._collected_dag_buffer: List = [] self._last_parsing_stat_received_at: float = time.monotonic() @@ -334,27 +284,35 @@ def run_single_parsing_loop(self) -> None: # when harvest_serialized_dags calls _heartbeat_manager. pass - def send_callback_to_execute( - self, full_filepath: str, task_instance: TaskInstance, msg: str - ) -> None: + def send_callback_to_execute(self, request: CallbackRequest) -> None: """ Sends information about the callback to be executed by DagFileProcessor. + :param request: Callback request to be executed. + :type request: CallbackRequest + """ + if not self._parent_signal_conn: + raise ValueError("Process not started.") + try: + self._parent_signal_conn.send(request) + except ConnectionError: + # If this died cos of an error then we will noticed and restarted + # when harvest_serialized_dags calls _heartbeat_manager. + pass + + def send_sla_callback_request_to_execute(self, full_filepath: str, dag_id: str) -> None: + """ + Sends information about the SLA callback to be executed by DagFileProcessor. + :param full_filepath: DAG File path :type full_filepath: str - :param task_instance: Task Instance for which the callback is to be executed. - :type task_instance: airflow.models.taskinstance.TaskInstance - :param msg: Message sent in callback. - :type msg: str + :param dag_id: DAG ID + :type dag_id: str """ if not self._parent_signal_conn: raise ValueError("Process not started.") try: - request = FailureCallbackRequest( - full_filepath=full_filepath, - simple_task_instance=SimpleTaskInstance(task_instance), - msg=msg - ) + request = SlaCallbackRequest(full_filepath=full_filepath, dag_id=dag_id) self._parent_signal_conn.send(request) except ConnectionError: # If this died cos of an error then we will noticed and restarted @@ -382,7 +340,7 @@ def _run_processor_manager( dag_directory: str, max_runs: int, processor_factory: Callable[ - [str, List[FailureCallbackRequest]], + [str, List[CallbackRequest]], AbstractDagFileProcessorProcess ], processor_timeout: timedelta, @@ -421,11 +379,9 @@ def _run_processor_manager( processor_manager.start() - def harvest_serialized_dags(self) -> List[SerializedDAG]: + def heartbeat(self) -> None: """ - Harvest DAG parsing results from result queue and sync metadata from stat queue. - - :return: List of parsing result in SerializedDAG format. + Check if the DagFileProcessorManager process is alive, and process any pending messages """ if not self._parent_signal_conn: raise ValueError("Process not started.") @@ -436,20 +392,16 @@ def harvest_serialized_dags(self) -> List[SerializedDAG]: except (EOFError, ConnectionError): break self._process_message(result) - serialized_dags = self._collected_dag_buffer - self._collected_dag_buffer = [] # If it died unexpectedly restart the manager process self._heartbeat_manager() - return serialized_dags - def _process_message(self, message): self.log.debug("Received message of type %s", type(message).__name__) if isinstance(message, DagParsingStat): self._sync_metadata(message) else: - self._collected_dag_buffer.append(SerializedDAG.from_dict(message)) + raise RuntimeError(f"Unexpected message recieved of type {type(message).__name__}") def _heartbeat_manager(self): """ @@ -520,6 +472,10 @@ def end(self): if not self._process: self.log.warning('Ending without manager process.') return + # Give the Manager some time to cleanly shut down, but not too long, as + # it's better to finish sooner than wait for (non-critical) work to + # finish + self._process.join(timeout=1.0) reap_process_group(self._process.pid, logger=self.log) self._parent_signal_conn.close() @@ -557,7 +513,7 @@ def __init__(self, dag_directory: str, max_runs: int, processor_factory: Callable[ - [str, List[FailureCallbackRequest]], + [str, List[CallbackRequest]], AbstractDagFileProcessorProcess ], processor_timeout: timedelta, @@ -620,11 +576,13 @@ def __init__(self, self.dag_dir_list_interval = conf.getint('scheduler', 'dag_dir_list_interval') # Mapping file name and callbacks requests - self._callback_to_execute: Dict[str, List[FailureCallbackRequest]] = defaultdict(list) + self._callback_to_execute: Dict[str, List[CallbackRequest]] = defaultdict(list) self._log = logging.getLogger('airflow.processor_manager') - self.waitables = {self._signal_conn: self._signal_conn} + self.waitables: Dict[Any, Union[MultiprocessingConnection, AbstractDagFileProcessorProcess]] = { + self._signal_conn: self._signal_conn, + } def register_exit_signals(self): """ @@ -697,7 +655,7 @@ def _run_parsing_loop(self): elif agent_signal == DagParsingSignal.AGENT_RUN_ONCE: # continue the loop to parse dags pass - elif isinstance(agent_signal, FailureCallbackRequest): + elif isinstance(agent_signal, CallbackRequest): self._add_callback_to_queue(agent_signal) else: raise ValueError(f"Invalid message {type(agent_signal)}") @@ -724,11 +682,9 @@ def _run_parsing_loop(self): if not processor: continue - serialized_dags = self._collect_results_from_processor(processor) + self._collect_results_from_processor(processor) self.waitables.pop(sentinel) self._processors.pop(processor.file_path) - for serialized_dag in serialized_dags: - self._signal_conn.send(serialized_dag) self._refresh_dag_dir() self._find_zombies() # pylint: disable=no-value-for-parameter @@ -756,9 +712,7 @@ def _run_parsing_loop(self): self.wait_until_finished() # Collect anything else that has finished, but don't kick off any more processors - serialized_dags = self.collect_results() - for serialized_dag in serialized_dags: - self._signal_conn.send(serialized_dag) + self.collect_results() self._print_stat() @@ -783,7 +737,7 @@ def _run_parsing_loop(self): else: poll_time = 0.0 - def _add_callback_to_queue(self, request: FailureCallbackRequest): + def _add_callback_to_queue(self, request: CallbackRequest): self._callback_to_execute[request.full_filepath].append(request) # Callback has a higher priority over DAG Run scheduling if request.full_filepath in self._file_path_queue: @@ -1039,22 +993,23 @@ def wait_until_finished(self): while not processor.done: time.sleep(0.1) - def _collect_results_from_processor(self, processor): + def _collect_results_from_processor(self, processor) -> None: self.log.debug("Processor for %s finished", processor.file_path) Stats.decr('dag_processing.processes') last_finish_time = timezone.utcnow() if processor.result is not None: - dags, count_import_errors = processor.result + num_dags, count_import_errors = processor.result else: self.log.error( "Processor for %s exited with return code %s.", processor.file_path, processor.exit_code ) - dags, count_import_errors = [], -1 + count_import_errors = -1 + num_dags = 0 stat = DagFileStat( - num_dags=len(dags), + num_dags=num_dags, import_errors=count_import_errors, last_finish_time=last_finish_time, last_duration=(last_finish_time - processor.start_time).total_seconds(), @@ -1062,26 +1017,19 @@ def _collect_results_from_processor(self, processor): ) self._file_stats[processor.file_path] = stat - return dags - - def collect_results(self): + def collect_results(self) -> None: """ Collect the result from any finished DAG processors - - :return: a list of dicts that were produced by processors that - have finished since the last time this was called - :rtype: list[dict] """ - # Collect all the DAGs that were found in the processed files - serialized_dags = [] - ready = multiprocessing.connection.wait(self.waitables.keys() - [self._signal_conn], timeout=0) for sentinel in ready: - processor = self.waitables[sentinel] + if sentinel is self._signal_conn: + continue + processor = cast(AbstractDagFileProcessorProcess, self.waitables[sentinel]) self.waitables.pop(processor.waitable_handle) self._processors.pop(processor.file_path) - serialized_dags += self._collect_results_from_processor(processor) + self._collect_results_from_processor(processor) self.log.debug("%s/%s DAG parsing processes running", len(self._processors), self._parallelism) @@ -1089,8 +1037,6 @@ def collect_results(self): self.log.debug("%s file paths queued for processing", len(self._file_path_queue)) - return serialized_dags - def start_new_processes(self): """ Start more processors if we have enough slots and files to process @@ -1196,7 +1142,7 @@ def _find_zombies(self, session): self._last_zombie_query_time = timezone.utcnow() for ti, file_loc in zombies: - request = FailureCallbackRequest( + request = TaskCallbackRequest( full_filepath=file_loc, simple_task_instance=SimpleTaskInstance(ti), msg="Detected as zombie", diff --git a/airflow/utils/sqlalchemy.py b/airflow/utils/sqlalchemy.py index d5f8f13a5bbbe..5f7069da7b8e7 100644 --- a/airflow/utils/sqlalchemy.py +++ b/airflow/utils/sqlalchemy.py @@ -23,6 +23,8 @@ import pendulum from dateutil import relativedelta +from sqlalchemy import event, nullsfirst +from sqlalchemy.exc import OperationalError from sqlalchemy.orm.session import Session from sqlalchemy.types import DateTime, Text, TypeDecorator @@ -141,3 +143,125 @@ def skip_locked(session: Session) -> Dict[str, Any]: return {'skip_locked': True} else: return {} + + +def nowait(session: Session) -> Dict[str, Any]: + """ + Return kwargs for passing to `with_for_update()` suitable for the current DB engine version. + + We do this as we document the fact that on DB engines that don't support this construct, we do not + support/recommend running HA scheduler. If a user ignores this and tries anyway everything will still + work, just slightly slower in some circumstances. + + Specifically don't emit NOWAIT for MySQL < 8, or MariaDB, neither of which support this construct + + See https://jira.mariadb.org/browse/MDEV-13115 + """ + dialect = session.bind.dialect + + if dialect.name != "mysql" or dialect.supports_for_update_of: + return {'nowait': True} + else: + return {} + + +def nulls_first(col, session: Session) -> Dict[str, Any]: + """ + Adds a nullsfirst construct to the column ordering. Currently only Postgres supports it. + In MySQL & Sqlite NULL values are considered lower than any non-NULL value, therefore, NULL values + appear first when the order is ASC (ascending) + """ + if session.bind.dialect.name == "postgresql": + return nullsfirst(col) + else: + return col + + +USE_ROW_LEVEL_LOCKING: bool = conf.getboolean('scheduler', 'use_row_level_locking', fallback=True) + + +def with_row_locks(query, **kwargs): + """ + Apply with_for_update to an SQLAlchemy query, if row level locking is in use. + + :param query: An SQLAlchemy Query object + :param **kwargs: Extra kwargs to pass to with_for_update (of, nowait, skip_locked, etc) + :return: updated query + """ + if USE_ROW_LEVEL_LOCKING: + return query.with_for_update(**kwargs) + else: + return query + + +class CommitProhibitorGuard: + """ + Context manager class that powers prohibit_commit + """ + + expected_commit = False + + def __init__(self, session: Session): + self.session = session + + def _validate_commit(self, _): + if self.expected_commit: + self.expected_commit = False + return + raise RuntimeError("UNEXPECTED COMMIT - THIS WILL BREAK HA LOCKS!") + + def __enter__(self): + event.listen(self.session.bind, 'commit', self._validate_commit) + return self + + def __exit__(self, *exc_info): + event.remove(self.session.bind, 'commit', self._validate_commit) + + def commit(self): + """ + Commit the session. + + This is the required way to commit when the guard is in scope + """ + self.expected_commit = True + self.session.commit() + + +def prohibit_commit(session): + """ + Return a context manager that will disallow any commit that isn't done via the context manager. + + The aim of this is to ensure that transaction lifetime is strictly controlled which is especially + important in the core scheduler loop. Any commit on the session that is _not_ via this context manager + will result in RuntimeError + + Example usage: + + .. code:: python + + with prohibit_commit(session) as guard: + # ... do something with sesison + guard.commit() + + # This would throw an error + # session.commit() + """ + return CommitProhibitorGuard(session) + + +def is_lock_not_available_error(error: OperationalError): + """Check if the Error is about not being able to acquire lock""" + # DB specific error codes: + # Postgres: 55P03 + # MySQL: 3572, 'Statement aborted because lock(s) could not be acquired immediately and NOWAIT + # is set.' + # MySQL: 1205, 'Lock wait timeout exceeded; try restarting transaction + # (when NOWAIT isn't available) + db_err_code = getattr(error.orig, 'pgcode', None) or error.orig.args[0] + + # We could test if error.orig is an instance of + # psycopg2.errors.LockNotAvailable/_mysql_exceptions.OperationalError, but that involves + # importing it. This doesn't + if db_err_code in ('55P03', 1205, 3572): + return True + return False diff --git a/airflow/www/views.py b/airflow/www/views.py index 6a7c4dcf4ac7e..7ceab347cb12a 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -1282,6 +1282,7 @@ def trigger(self, session=None): state=State.RUNNING, conf=run_conf, external_trigger=True, + dag_hash=current_app.dag_bag.dags_hash.get(dag_id, None), ) flash( diff --git a/docs/logging-monitoring/metrics.rst b/docs/logging-monitoring/metrics.rst index bdff897b654ba..2a3c9f4387d59 100644 --- a/docs/logging-monitoring/metrics.rst +++ b/docs/logging-monitoring/metrics.rst @@ -89,6 +89,9 @@ Name Description ``scheduler.tasks.starving`` Number of tasks that cannot be scheduled because of no open slot in pool ``scheduler.orphaned_tasks.cleared`` Number of Orphaned tasks cleared by the Scheduler ``scheduler.orphaned_tasks.adopted`` Number of Orphaned tasks adopted by the Scheduler +``scheduler.critical_section_busy`` Count of times a scheduler process tried to get a lock on the critical + section (needed to send tasks to the executor) and found it locked by + another process. ``sla_email_notification_failure`` Number of failed SLA miss email notification attempts ``ti.start..`` Number of started task in a given dag. Similar to _start but for task ``ti.finish...`` Number of completed task in a given dag. Similar to _end but for task @@ -124,14 +127,16 @@ Name Description Timers ------ -=========================================== ================================================= +=========================================== ================================================================= Name Description -=========================================== ================================================= +=========================================== ================================================================= ``dagrun.dependency-check.`` Milliseconds taken to check DAG dependencies ``dag...duration`` Milliseconds taken to finish a task ``dag_processing.last_duration.`` Milliseconds taken to load the given DAG file ``dagrun.duration.success.`` Milliseconds taken for a DagRun to reach success state ``dagrun.duration.failed.`` Milliseconds taken for a DagRun to reach failed state -``dagrun.schedule_delay.`` Milliseconds of delay between the scheduled DagRun - start date and the actual DagRun start date -=========================================== ================================================= +``dagrun.schedule_delay.`` Milliseconds of delay between the scheduled DagRun start date and + the actual DagRun start date +``scheduler.critical_section_duration`` Milliseconds spent in the critical section of scheduler loop -- + only a single scheduler can enter this loop at a time +=========================================== ================================================================= diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 08e798d0d961d..9315dd3ebb94b 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -1102,6 +1102,7 @@ reqs resetdb resourceVersion resumable +resultset rfc ricard rideable diff --git a/tests/api/client/test_local_client.py b/tests/api/client/test_local_client.py index 87c8bb155667d..d3260a2cba016 100644 --- a/tests/api/client/test_local_client.py +++ b/tests/api/client/test_local_client.py @@ -18,7 +18,7 @@ import json import unittest -from unittest.mock import patch +from unittest.mock import ANY, patch from freezegun import freeze_time @@ -71,7 +71,8 @@ def test_trigger_dag(self, mock): execution_date=EXECDATE_NOFRACTIONS, state=State.RUNNING, conf=None, - external_trigger=True) + external_trigger=True, + dag_hash=ANY) mock.reset_mock() # execution date with microseconds cutoff @@ -80,7 +81,8 @@ def test_trigger_dag(self, mock): execution_date=EXECDATE_NOFRACTIONS, state=State.RUNNING, conf=None, - external_trigger=True) + external_trigger=True, + dag_hash=ANY) mock.reset_mock() # run id @@ -90,7 +92,8 @@ def test_trigger_dag(self, mock): execution_date=EXECDATE_NOFRACTIONS, state=State.RUNNING, conf=None, - external_trigger=True) + external_trigger=True, + dag_hash=ANY) mock.reset_mock() # test conf @@ -100,7 +103,8 @@ def test_trigger_dag(self, mock): execution_date=EXECDATE_NOFRACTIONS, state=State.RUNNING, conf=json.loads(conf), - external_trigger=True) + external_trigger=True, + dag_hash=ANY) mock.reset_mock() def test_delete_dag(self): diff --git a/tests/api/common/experimental/test_trigger_dag.py b/tests/api/common/experimental/test_trigger_dag.py index e7b1d8f06f969..eddcb3e0eab42 100644 --- a/tests/api/common/experimental/test_trigger_dag.py +++ b/tests/api/common/experimental/test_trigger_dag.py @@ -148,6 +148,7 @@ def test_trigger_dag_with_valid_start_date(self, dag_bag_mock): dag = DAG(dag_id, default_args={'start_date': timezone.datetime(2016, 9, 5, 10, 10, 0)}) dag_bag_mock.dags = [dag_id] dag_bag_mock.get_dag.return_value = dag + dag_bag_mock.dags_hash = {} dag_run = DagRun() triggers = _trigger_dag( @@ -174,6 +175,9 @@ def test_trigger_dag_with_conf(self, conf, expected_conf, dag_bag_mock): dag_bag_mock.dags = [dag_id] dag_bag_mock.get_dag.return_value = dag dag_run = DagRun() + + dag_bag_mock.dags_hash = {} + triggers = _trigger_dag( dag_id, dag_bag_mock, diff --git a/tests/dags/test_zip.zip b/tests/dags/test_zip.zip index 20c75a2c86729..42753d803d264 100644 Binary files a/tests/dags/test_zip.zip and b/tests/dags/test_zip.zip differ diff --git a/tests/jobs/test_backfill_job.py b/tests/jobs/test_backfill_job.py index 184eec97a8539..3682e34548937 100644 --- a/tests/jobs/test_backfill_job.py +++ b/tests/jobs/test_backfill_job.py @@ -25,7 +25,7 @@ import pytest import sqlalchemy -from mock import Mock, patch +from mock import patch from parameterized import parameterized from airflow import settings @@ -35,7 +35,6 @@ TaskConcurrencyLimitReached, ) from airflow.jobs.backfill_job import BackfillJob -from airflow.jobs.scheduler_job import DagFileProcessor from airflow.models import DAG, DagBag, Pool, TaskInstance as TI from airflow.models.dagrun import DagRun from airflow.operators.dummy_operator import DummyOperator @@ -145,11 +144,12 @@ def test_trigger_controller_dag(self): target_dag = self.dagbag.get_dag('example_trigger_target_dag') target_dag.sync_to_db() - dag_file_processor = DagFileProcessor(dag_ids=[], log=Mock()) - task_instances_list = dag_file_processor._process_task_instances( - target_dag, - dag_runs=DagRun.find(dag_id='example_trigger_target_dag') - ) + # dag_file_processor = DagFileProcessor(dag_ids=[], log=Mock()) + task_instances_list = [] + # task_instances_list = dag_file_processor._process_task_instances( + # target_dag, + # dag_runs=DagRun.find(dag_id='example_trigger_target_dag') + # ) self.assertFalse(task_instances_list) job = BackfillJob( @@ -160,10 +160,11 @@ def test_trigger_controller_dag(self): ) job.run() - task_instances_list = dag_file_processor._process_task_instances( - target_dag, - dag_runs=DagRun.find(dag_id='example_trigger_target_dag') - ) + task_instances_list = [] + # task_instances_list = dag_file_processor._process_task_instances( + # target_dag, + # dag_runs=DagRun.find(dag_id='example_trigger_target_dag') + # ) self.assertTrue(task_instances_list) diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py index cfbdbab21b244..6610b9050a1f2 100644 --- a/tests/jobs/test_scheduler_job.py +++ b/tests/jobs/test_scheduler_job.py @@ -29,9 +29,10 @@ import psutil import pytest import six -from freezegun import freeze_time from mock import MagicMock, patch from parameterized import parameterized +from sqlalchemy import func +from sqlalchemy.exc import OperationalError import airflow.example_dags import airflow.smart_sensor_dags @@ -43,12 +44,13 @@ from airflow.jobs.scheduler_job import DagFileProcessor, SchedulerJob from airflow.models import DAG, DagBag, DagModel, Pool, SlaMiss, TaskInstance, errors from airflow.models.dagrun import DagRun +from airflow.models.serialized_dag import SerializedDagModel from airflow.models.taskinstance import SimpleTaskInstance, TaskInstanceKey from airflow.operators.bash import BashOperator from airflow.operators.dummy_operator import DummyOperator from airflow.serialization.serialized_objects import SerializedDAG from airflow.utils import timezone -from airflow.utils.dag_processing import FailureCallbackRequest, SimpleDagBag +from airflow.utils.callback_requests import DagCallbackRequest, TaskCallbackRequest from airflow.utils.dates import days_ago from airflow.utils.file import list_py_file_paths from airflow.utils.session import create_session, provide_session @@ -57,8 +59,8 @@ from tests.test_utils.asserts import assert_queries_count from tests.test_utils.config import conf_vars, env_vars from tests.test_utils.db import ( - clear_db_dags, clear_db_errors, clear_db_jobs, clear_db_pools, clear_db_runs, clear_db_sla_miss, - set_default_pool_slots, + clear_db_dags, clear_db_errors, clear_db_jobs, clear_db_pools, clear_db_runs, clear_db_serialized_dags, + clear_db_sla_miss, set_default_pool_slots, ) from tests.test_utils.mock_executor import MockExecutor @@ -105,6 +107,7 @@ def clean_db(): clear_db_sla_miss() clear_db_errors() clear_db_jobs() + clear_db_serialized_dags() def setUp(self): self.clean_db() @@ -125,8 +128,7 @@ def create_test_dag(self, start_date=DEFAULT_DATE, end_date=DEFAULT_DATE + timed dag.clear() dag.is_subdag = False with create_session() as session: - orm_dag = DagModel(dag_id=dag.dag_id) - orm_dag.is_paused = False + orm_dag = DagModel(dag_id=dag.dag_id, is_paused=False) session.merge(orm_dag) session.commit() return dag @@ -196,58 +198,6 @@ def test_dag_file_processor_sla_miss_callback_invalid_sla(self): dag_file_processor.manage_slas(dag=dag, session=session) sla_callback.assert_not_called() - def test_scheduler_executor_overflow(self): - """ - Test that tasks that are set back to scheduled and removed from the executor - queue in the case of an overflow. - """ - executor = MockExecutor(do_update=True, parallelism=3) - - with create_session() as session: - dagbag = DagBag(dag_folder=os.path.join(settings.DAGS_FOLDER, "no_dags.py"), - include_examples=False, - include_smart_sensor=False) - dag = self.create_test_dag() - dag.clear() - dagbag.bag_dag(dag=dag, root_dag=dag) - dag = self.create_test_dag() - dag.clear() - task = DummyOperator( - task_id='dummy', - dag=dag, - owner='airflow') - tis = [] - for i in range(1, 10): - ti = TaskInstance(task, DEFAULT_DATE + timedelta(days=i)) - ti.state = State.SCHEDULED - tis.append(ti) - session.merge(ti) - - # scheduler._process_dags(simple_dag_bag) - @mock.patch('airflow.jobs.scheduler_job.DagBag', return_value=dagbag) - @mock.patch('airflow.jobs.scheduler_job.SchedulerJob._change_state_for_tis_without_dagrun') - def do_schedule(mock_dagbag, mock_change_state): - # Use a empty file since the above mock will return the - # expected DAGs. Also specify only a single file so that it doesn't - # try to schedule the above DAG repeatedly. - with conf_vars({('core', 'mp_start_method'): 'fork'}): - scheduler = SchedulerJob(num_runs=1, - executor=executor, - subdir=os.path.join(settings.DAGS_FOLDER, - "no_dags.py")) - scheduler.heartrate = 0 - scheduler.run() - - do_schedule() # pylint: disable=no-value-for-parameter - for ti in tis: - ti.refresh_from_db() - self.assertEqual(len(executor.queued_tasks), 0) - - successful_tasks = [ti for ti in tis if ti.state == State.SUCCESS] - scheduled_tasks = [ti for ti in tis if ti.state == State.SCHEDULED] - self.assertEqual(3, len(successful_tasks)) - self.assertEqual(6, len(scheduled_tasks)) - def test_dag_file_processor_sla_miss_callback_sent_notification(self): """ Test that the dag file processor does not call the sla_miss_callback when a @@ -420,70 +370,6 @@ def test_dag_file_processor_sla_miss_deleted_task(self): dag_file_processor = DagFileProcessor(dag_ids=[], log=mock_log) dag_file_processor.manage_slas(dag=dag, session=session) - def test_dag_file_processor_dagrun_once(self): - """ - Test if the dag file proccessor does not create multiple dagruns - if a dag is scheduled with @once and a start_date - """ - dag = DAG( - 'test_scheduler_dagrun_once', - start_date=timezone.datetime(2015, 1, 1), - schedule_interval="@once") - - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) - dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) - dag.clear() - dr = dag_file_processor.create_dag_run(dag) - self.assertIsNotNone(dr) - dr = dag_file_processor.create_dag_run(dag) - self.assertIsNone(dr) - - @freeze_time(timezone.datetime(2020, 1, 5)) - def test_dag_file_processor_dagrun_with_timedelta_schedule_and_catchup_false(self): - """ - Test that the dag file processor does not create multiple dagruns - if a dag is scheduled with 'timedelta' and catchup=False - """ - dag = DAG( - 'test_scheduler_dagrun_once_with_timedelta_and_catchup_false', - start_date=timezone.datetime(2015, 1, 1), - schedule_interval=timedelta(days=1), - catchup=False) - - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) - dag.clear() - dr = dag_file_processor.create_dag_run(dag) - self.assertIsNotNone(dr) - self.assertEqual(dr.execution_date, timezone.datetime(2020, 1, 4)) - dr = dag_file_processor.create_dag_run(dag) - self.assertIsNone(dr) - - @freeze_time(timezone.datetime(2020, 5, 4)) - def test_dag_file_processor_dagrun_with_timedelta_schedule_and_catchup_true(self): - """ - Test that the dag file processor creates multiple dagruns - if a dag is scheduled with 'timedelta' and catchup=True - """ - dag = DAG( - 'test_scheduler_dagrun_once_with_timedelta_and_catchup_true', - start_date=timezone.datetime(2020, 5, 1), - schedule_interval=timedelta(days=1), - catchup=True) - - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) - dag.clear() - dr = dag_file_processor.create_dag_run(dag) - self.assertIsNotNone(dr) - self.assertEqual(dr.execution_date, timezone.datetime(2020, 5, 1)) - dr = dag_file_processor.create_dag_run(dag) - self.assertIsNotNone(dr) - self.assertEqual(dr.execution_date, timezone.datetime(2020, 5, 2)) - dr = dag_file_processor.create_dag_run(dag) - self.assertIsNotNone(dr) - self.assertEqual(dr.execution_date, timezone.datetime(2020, 5, 3)) - dr = dag_file_processor.create_dag_run(dag) - self.assertIsNone(dr) - @parameterized.expand([ [State.NONE, None, None], [State.UP_FOR_RETRY, timezone.utcnow() - datetime.timedelta(minutes=30), @@ -499,10 +385,12 @@ def test_dag_file_processor_process_task_instances(self, state, start_date, end_ dag = DAG( dag_id='test_scheduler_process_execute_task', start_date=DEFAULT_DATE) - dag_task1 = DummyOperator( + BashOperator( task_id='dummy', dag=dag, - owner='airflow') + owner='airflow', + bash_command='echo hi' + ) with create_session() as session: orm_dag = DagModel(dag_id=dag.dag_id) @@ -510,24 +398,28 @@ def test_dag_file_processor_process_task_instances(self, state, start_date, end_ dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) + scheduler = SchedulerJob() + scheduler.processor_agent = mock.MagicMock() + scheduler.dagbag.bag_dag(dag, root_dag=dag) dag.clear() - dr = dag_file_processor.create_dag_run(dag) - self.assertIsNotNone(dr) + dr = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=DEFAULT_DATE, + state=State.RUNNING, + ) + assert dr is not None with create_session() as session: - tis = dr.get_task_instances(session=session) - for ti in tis: - ti.state = state - ti.start_date = start_date - ti.end_date = end_date + ti = dr.get_task_instances(session=session)[0] + ti.state = state + ti.start_date = start_date + ti.end_date = end_date - mock_list = dag_file_processor._process_task_instances(dag, dag_runs=[dr]) + count = scheduler._schedule_dag_run(dr, 0, session) + assert count == 1 - self.assertEqual( - [(dag.dag_id, dag_task1.task_id, DEFAULT_DATE, TRY_NUMBER)], - mock_list - ) + session.refresh(ti) + assert ti.state == State.SCHEDULED @parameterized.expand([ [State.NONE, None, None], @@ -546,11 +438,13 @@ def test_dag_file_processor_process_task_instances_with_task_concurrency( dag = DAG( dag_id='test_scheduler_process_execute_task_with_task_concurrency', start_date=DEFAULT_DATE) - dag_task1 = DummyOperator( + BashOperator( task_id='dummy', task_concurrency=2, dag=dag, - owner='airflow') + owner='airflow', + bash_command='echo Hi' + ) with create_session() as session: orm_dag = DagModel(dag_id=dag.dag_id) @@ -558,23 +452,28 @@ def test_dag_file_processor_process_task_instances_with_task_concurrency( dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) + scheduler = SchedulerJob() + scheduler.processor_agent = mock.MagicMock() + scheduler.dagbag.bag_dag(dag, root_dag=dag) dag.clear() - dr = dag_file_processor.create_dag_run(dag) - self.assertIsNotNone(dr) + dr = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=DEFAULT_DATE, + state=State.RUNNING, + ) + assert dr is not None with create_session() as session: - tis = dr.get_task_instances(session=session) - for ti in tis: - ti.state = state - ti.start_date = start_date - ti.end_date = end_date + ti = dr.get_task_instances(session=session)[0] + ti.state = state + ti.start_date = start_date + ti.end_date = end_date - ti_to_schedule = dag_file_processor._process_task_instances(dag, dag_runs=[dr]) + count = scheduler._schedule_dag_run(dr, 0, session) + assert count == 1 - assert ti_to_schedule == [ - (dag.dag_id, dag_task1.task_id, DEFAULT_DATE, TRY_NUMBER), - ] + session.refresh(ti) + assert ti.state == State.SCHEDULED @parameterized.expand([ [State.NONE, None, None], @@ -595,14 +494,18 @@ def test_dag_file_processor_process_task_instances_depends_on_past(self, state, 'depends_on_past': True, }, ) - dag_task1 = DummyOperator( + BashOperator( task_id='dummy1', dag=dag, - owner='airflow') - dag_task2 = DummyOperator( + owner='airflow', + bash_command='echo hi' + ) + BashOperator( task_id='dummy2', dag=dag, - owner='airflow') + owner='airflow', + bash_command='echo hi' + ) with create_session() as session: orm_dag = DagModel(dag_id=dag.dag_id) @@ -610,10 +513,16 @@ def test_dag_file_processor_process_task_instances_depends_on_past(self, state, dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) + scheduler = SchedulerJob() + scheduler.processor_agent = mock.MagicMock() + scheduler.dagbag.bag_dag(dag, root_dag=dag) dag.clear() - dr = dag_file_processor.create_dag_run(dag) - self.assertIsNotNone(dr) + dr = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=DEFAULT_DATE, + state=State.RUNNING, + ) + assert dr is not None with create_session() as session: tis = dr.get_task_instances(session=session) @@ -622,274 +531,56 @@ def test_dag_file_processor_process_task_instances_depends_on_past(self, state, ti.start_date = start_date ti.end_date = end_date - ti_to_schedule = dag_file_processor._process_task_instances(dag, dag_runs=[dr]) - - assert sorted(ti_to_schedule) == [ - (dag.dag_id, dag_task1.task_id, DEFAULT_DATE, TRY_NUMBER), - (dag.dag_id, dag_task2.task_id, DEFAULT_DATE, TRY_NUMBER), - ] - - def test_dag_file_processor_do_not_schedule_removed_task(self): - dag = DAG( - dag_id='test_scheduler_do_not_schedule_removed_task', - start_date=DEFAULT_DATE) - DummyOperator( - task_id='dummy', - dag=dag, - owner='airflow') - - session = settings.Session() - orm_dag = DagModel(dag_id=dag.dag_id) - session.merge(orm_dag) - session.commit() - session.close() - - dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) - - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) - dag.clear() - - dr = dag_file_processor.create_dag_run(dag) - self.assertIsNotNone(dr) - - dr = DagRun.find(run_id=dr.run_id)[0] - # Re-create the DAG, but remove the task - dag = DAG( - dag_id='test_scheduler_do_not_schedule_removed_task', - start_date=DEFAULT_DATE) - dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) - - mock_list = dag_file_processor._process_task_instances(dag, dag_runs=[dr]) + count = scheduler._schedule_dag_run(dr, 0, session) + assert count == 2 - self.assertEqual([], mock_list) + session.refresh(tis[0]) + session.refresh(tis[1]) + assert tis[0].state == State.SCHEDULED + assert tis[1].state == State.SCHEDULED - def test_dag_file_processor_do_not_schedule_too_early(self): - dag = DAG( - dag_id='test_scheduler_do_not_schedule_too_early', - start_date=timezone.datetime(2200, 1, 1)) - DummyOperator( - task_id='dummy', - dag=dag, - owner='airflow') - - session = settings.Session() - orm_dag = DagModel(dag_id=dag.dag_id) - session.merge(orm_dag) - session.commit() - session.close() - - dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) - - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) - dag.clear() - - dr = dag_file_processor.create_dag_run(dag) - self.assertIsNone(dr) - - mock_list = dag_file_processor._process_task_instances(dag, dag_runs=[]) - self.assertEqual([], mock_list) - - def test_dag_file_processor_do_not_schedule_without_tasks(self): - dag = DAG( - dag_id='test_scheduler_do_not_schedule_without_tasks', - start_date=DEFAULT_DATE) - - with create_session() as session: - orm_dag = DagModel(dag_id=dag.dag_id) - session.merge(orm_dag) - - dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) - dag.clear(session=session) - dag.start_date = None - dr = dag_file_processor.create_dag_run(dag, session=session) - self.assertIsNone(dr) - - def test_dag_file_processor_do_not_run_finished(self): - dag = DAG( - dag_id='test_scheduler_do_not_run_finished', - start_date=DEFAULT_DATE) - DummyOperator( - task_id='dummy', - dag=dag, - owner='airflow') - - session = settings.Session() - orm_dag = DagModel(dag_id=dag.dag_id) - session.merge(orm_dag) - session.commit() - - dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) - - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) - dag.clear() - - dr = dag_file_processor.create_dag_run(dag) - self.assertIsNotNone(dr) - - tis = dr.get_task_instances(session=session) - for ti in tis: - ti.state = State.SUCCESS - - session.commit() - session.close() - - mock_list = dag_file_processor._process_task_instances(dag, dag_runs=[dr]) - - self.assertEqual([], mock_list) - - def test_dag_file_processor_add_new_task(self): + def test_scheduler_job_add_new_task(self): """ Test if a task instance will be added if the dag is updated """ - dag = DAG( - dag_id='test_scheduler_add_new_task', - start_date=DEFAULT_DATE) + dag = DAG(dag_id='test_scheduler_add_new_task', start_date=DEFAULT_DATE) + BashOperator(task_id='dummy', dag=dag, owner='airflow', bash_command='echo test') - DummyOperator( - task_id='dummy', - dag=dag, - owner='airflow') + scheduler = SchedulerJob() + scheduler.dagbag.bag_dag(dag, root_dag=dag) + scheduler.dagbag.sync_to_db() session = settings.Session() - orm_dag = DagModel(dag_id=dag.dag_id) - session.merge(orm_dag) - session.commit() - session.close() - - dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) + orm_dag = session.query(DagModel).get(dag.dag_id) + assert orm_dag is not None - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) - dag.clear() + scheduler = SchedulerJob() + scheduler.processor_agent = mock.MagicMock() + dag = scheduler.dagbag.get_dag('test_scheduler_add_new_task', session=session) + scheduler._create_dag_runs([orm_dag], session) - dr = dag_file_processor.create_dag_run(dag) - self.assertIsNotNone(dr) + drs = DagRun.find(dag_id=dag.dag_id, session=session) + assert len(drs) == 1 + dr = drs[0] tis = dr.get_task_instances() self.assertEqual(len(tis), 1) - DummyOperator( - task_id='dummy2', - dag=dag, - owner='airflow') - dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) + BashOperator(task_id='dummy2', dag=dag, owner='airflow', bash_command='echo test') + SerializedDagModel.write_dag(dag=dag) - dag_file_processor._process_task_instances(dag, dag_runs=[dr]) + scheduled_tis = scheduler._schedule_dag_run(dr, 0, session) + session.flush() + assert scheduled_tis == 2 + + drs = DagRun.find(dag_id=dag.dag_id, session=session) + assert len(drs) == 1 + dr = drs[0] tis = dr.get_task_instances() self.assertEqual(len(tis), 2) - def test_dag_file_processor_verify_max_active_runs(self): - """ - Test if a a dagrun will not be scheduled if max_dag_runs has been reached - """ - dag = DAG( - dag_id='test_scheduler_verify_max_active_runs', - start_date=DEFAULT_DATE) - dag.max_active_runs = 1 - - DummyOperator( - task_id='dummy', - dag=dag, - owner='airflow') - - session = settings.Session() - orm_dag = DagModel(dag_id=dag.dag_id) - session.merge(orm_dag) - session.commit() - session.close() - - dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) - dag.clear() - - dr = dag_file_processor.create_dag_run(dag) - self.assertIsNotNone(dr) - - dr = dag_file_processor.create_dag_run(dag) - self.assertIsNone(dr) - - def test_dag_file_processor_fail_dagrun_timeout(self): - """ - Test if a a dagrun wil be set failed if timeout - """ - dag = DAG( - dag_id='test_scheduler_fail_dagrun_timeout', - start_date=DEFAULT_DATE) - dag.dagrun_timeout = datetime.timedelta(seconds=60) - - DummyOperator( - task_id='dummy', - dag=dag, - owner='airflow') - - session = settings.Session() - orm_dag = DagModel(dag_id=dag.dag_id) - session.merge(orm_dag) - session.commit() - - dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) - - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) - dag.clear() - - dr = dag_file_processor.create_dag_run(dag) - self.assertIsNotNone(dr) - dr.start_date = timezone.utcnow() - datetime.timedelta(days=1) - session.merge(dr) - session.commit() - - dr2 = dag_file_processor.create_dag_run(dag) - self.assertIsNotNone(dr2) - - dr.refresh_from_db(session=session) - self.assertEqual(dr.state, State.FAILED) - - def test_dag_file_processor_verify_max_active_runs_and_dagrun_timeout(self): - """ - Test if a a dagrun will not be scheduled if max_dag_runs - has been reached and dagrun_timeout is not reached - - Test if a a dagrun will be scheduled if max_dag_runs has - been reached but dagrun_timeout is also reached - """ - dag = DAG( - dag_id='test_scheduler_verify_max_active_runs_and_dagrun_timeout', - start_date=DEFAULT_DATE) - dag.max_active_runs = 1 - dag.dagrun_timeout = datetime.timedelta(seconds=60) - - DummyOperator( - task_id='dummy', - dag=dag, - owner='airflow') - - session = settings.Session() - orm_dag = DagModel(dag_id=dag.dag_id) - session.merge(orm_dag) - session.commit() - session.close() - - dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) - - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) - dag.clear() - - dr = dag_file_processor.create_dag_run(dag) - self.assertIsNotNone(dr) - - # Should not be scheduled as DagRun has not timedout and max_active_runs is reached - new_dr = dag_file_processor.create_dag_run(dag) - self.assertIsNone(new_dr) - - # Should be scheduled as dagrun_timeout has passed - dr.start_date = timezone.utcnow() - datetime.timedelta(days=1) - session.merge(dr) - session.commit() - new_dr = dag_file_processor.create_dag_run(dag) - self.assertIsNotNone(new_dr) - - def test_dag_file_processor_max_active_runs_respected_after_clear(self): + def test_runs_respected_after_clear(self): """ Test if _process_task_instances only schedules ti's up to max_active_runs (related to issue AIRFLOW-137) @@ -899,10 +590,12 @@ def test_dag_file_processor_max_active_runs_respected_after_clear(self): start_date=DEFAULT_DATE) dag.max_active_runs = 3 - dag_task1 = DummyOperator( + BashOperator( task_id='dummy', dag=dag, - owner='airflow') + owner='airflow', + bash_command='echo Hi' + ) session = settings.Session() orm_dag = DagModel(dag_id=dag.dag_id) @@ -911,15 +604,33 @@ def test_dag_file_processor_max_active_runs_respected_after_clear(self): session.close() dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) + scheduler = SchedulerJob() + scheduler.processor_agent = mock.MagicMock() + scheduler.dagbag.bag_dag(dag, root_dag=dag) dag.clear() + date = DEFAULT_DATE + dr1 = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=date, + state=State.RUNNING, + ) + date = dag.following_schedule(date) + dr2 = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=date, + state=State.RUNNING, + ) + date = dag.following_schedule(date) + dr3 = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=date, + state=State.RUNNING, + ) + # First create up to 3 dagruns in RUNNING state. - dr1 = dag_file_processor.create_dag_run(dag) assert dr1 is not None - dr2 = dag_file_processor.create_dag_run(dag) assert dr2 is not None - dr3 = dag_file_processor.create_dag_run(dag) assert dr3 is not None assert len(DagRun.find(dag_id=dag.dag_id, state=State.RUNNING, session=session)) == 3 @@ -928,171 +639,13 @@ def test_dag_file_processor_max_active_runs_respected_after_clear(self): # and schedule them in, so we can check how many # tasks are put on the task_instances_list (should be one, not 3) - task_instances_list = dag_file_processor._process_task_instances(dag, dag_runs=[dr1, dr2, dr3]) - - self.assertEqual([(dag.dag_id, dag_task1.task_id, DEFAULT_DATE, TRY_NUMBER)], task_instances_list) - - def test_find_dags_to_run_includes_subdags(self): - dag = self.dagbag.get_dag('test_subdag_operator') - self.assertGreater(len(dag.subdags), 0) - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) - dags = dag_file_processor._find_dags_to_process(self.dagbag.dags.values()) - - self.assertIn(dag, dags) - for subdag in dag.subdags: - self.assertIn(subdag, dags) - - def test_dag_catchup_option(self): - """ - Test to check that a DAG with catchup = False only schedules beginning now, not back to the start date - """ - - def setup_dag(dag_id, schedule_interval, start_date, catchup): - default_args = { - 'owner': 'airflow', - 'depends_on_past': False, - 'start_date': start_date - } - dag = DAG(dag_id, - schedule_interval=schedule_interval, - max_active_runs=1, - catchup=catchup, - default_args=default_args) - - op1 = DummyOperator(task_id='t1', dag=dag) - op2 = DummyOperator(task_id='t2', dag=dag) - op2.set_upstream(op1) - op3 = DummyOperator(task_id='t3', dag=dag) - op3.set_upstream(op2) - - session = settings.Session() - orm_dag = DagModel(dag_id=dag.dag_id) - session.merge(orm_dag) - session.commit() - session.close() - - return SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) - - now = timezone.utcnow() - six_hours_ago_to_the_hour = (now - datetime.timedelta(hours=6)).replace( - minute=0, second=0, microsecond=0) - half_an_hour_ago = now - datetime.timedelta(minutes=30) - two_hours_ago = now - datetime.timedelta(hours=2) - - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) - - dag1 = setup_dag(dag_id='dag_with_catchup', - schedule_interval='* * * * *', - start_date=six_hours_ago_to_the_hour, - catchup=True) - default_catchup = conf.getboolean('scheduler', 'catchup_by_default') - self.assertEqual(default_catchup, True) - self.assertEqual(dag1.catchup, True) - - dag2 = setup_dag(dag_id='dag_without_catchup_ten_minute', - schedule_interval='*/10 * * * *', - start_date=six_hours_ago_to_the_hour, - catchup=False) - dr = dag_file_processor.create_dag_run(dag2) - # We had better get a dag run - self.assertIsNotNone(dr) - # The DR should be scheduled in the last half an hour, not 6 hours ago - self.assertGreater(dr.execution_date, half_an_hour_ago) - # The DR should be scheduled BEFORE now - self.assertLess(dr.execution_date, timezone.utcnow()) - - dag3 = setup_dag(dag_id='dag_without_catchup_hourly', - schedule_interval='@hourly', - start_date=six_hours_ago_to_the_hour, - catchup=False) - dr = dag_file_processor.create_dag_run(dag3) - # We had better get a dag run - self.assertIsNotNone(dr) - # The DR should be scheduled in the last 2 hours, not 6 hours ago - self.assertGreater(dr.execution_date, two_hours_ago) - # The DR should be scheduled BEFORE now - self.assertLess(dr.execution_date, timezone.utcnow()) - - dag4 = setup_dag(dag_id='dag_without_catchup_once', - schedule_interval='@once', - start_date=six_hours_ago_to_the_hour, - catchup=False) - dr = dag_file_processor.create_dag_run(dag4) - self.assertIsNotNone(dr) - - def test_dag_file_processor_auto_align(self): - """ - Test if the schedule_interval will be auto aligned with the start_date - such that if the start_date coincides with the schedule the first - execution_date will be start_date, otherwise it will be start_date + - interval. - """ - dag = DAG( - dag_id='test_scheduler_auto_align_1', - start_date=timezone.datetime(2016, 1, 1, 10, 10, 0), - schedule_interval="4 5 * * *" - ) - DummyOperator( - task_id='dummy', - dag=dag, - owner='airflow') - - session = settings.Session() - orm_dag = DagModel(dag_id=dag.dag_id) - session.merge(orm_dag) - session.commit() - - dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) - - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) - - dr = dag_file_processor.create_dag_run(dag) - self.assertIsNotNone(dr) - self.assertEqual(dr.execution_date, timezone.datetime(2016, 1, 2, 5, 4)) - - dag = DAG( - dag_id='test_scheduler_auto_align_2', - start_date=timezone.datetime(2016, 1, 1, 10, 10, 0), - schedule_interval="10 10 * * *" - ) - DummyOperator( - task_id='dummy', - dag=dag, - owner='airflow') - - session = settings.Session() - orm_dag = DagModel(dag_id=dag.dag_id) - session.merge(orm_dag) - session.commit() - dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) - - dag.clear() - - dr = dag_file_processor.create_dag_run(dag) - self.assertIsNotNone(dr) - self.assertEqual(dr.execution_date, timezone.datetime(2016, 1, 1, 10, 10)) - - def test_process_dags_not_create_dagrun_for_subdags(self): - dag = self.dagbag.get_dag('test_subdag_operator') - - scheduler = DagFileProcessor(dag_ids=[dag.dag_id], log=mock.MagicMock()) - scheduler._process_task_instances = mock.MagicMock() - scheduler.manage_slas = mock.MagicMock() - - scheduler._process_dags([dag] + dag.subdags) - with create_session() as session: - sub_dagruns = ( - session.query(DagRun).filter(DagRun.dag_id == dag.subdags[0].dag_id).count() - ) - - self.assertEqual(0, sub_dagruns) - - parent_dagruns = ( - session.query(DagRun).filter(DagRun.dag_id == dag.dag_id).count() - ) - - self.assertGreater(parent_dagruns, 0) + num_scheduled = scheduler._schedule_dag_run(dr1, 0, session) + assert num_scheduled == 1 + num_scheduled = scheduler._schedule_dag_run(dr2, 1, session) + assert num_scheduled == 0 + num_scheduled = scheduler._schedule_dag_run(dr3, 1, session) + assert num_scheduled == 0 @patch.object(TaskInstance, 'handle_failure') def test_execute_on_failure_callbacks(self, mock_ti_handle_failure): @@ -1109,13 +662,13 @@ def test_execute_on_failure_callbacks(self, mock_ti_handle_failure): session.commit() requests = [ - FailureCallbackRequest( + TaskCallbackRequest( full_filepath="A", simple_task_instance=SimpleTaskInstance(ti), msg="Message" ) ] - dag_file_processor.execute_on_failure_callbacks(dagbag, requests) + dag_file_processor.execute_callbacks(dagbag, requests) mock_ti_handle_failure.assert_called_once_with( "Message", conf.getboolean('core', 'unit_test_mode'), @@ -1139,7 +692,7 @@ def test_process_file_should_failure_callback(self): session.commit() requests = [ - FailureCallbackRequest( + TaskCallbackRequest( full_filepath=dag.full_filepath, simple_task_instance=SimpleTaskInstance(ti), msg="Message" @@ -1154,56 +707,52 @@ def test_process_file_should_failure_callback(self): self.assertEqual("Callback fired", content) os.remove(callback_file.name) - def test_should_parse_only_unpaused_dags(self): - dag_file = os.path.join( - os.path.dirname(os.path.realpath(__file__)), '../dags/test_multiple_dags.py' - ) + @mock.patch("airflow.jobs.scheduler_job.DagBag") + def test_process_file_should_retry_sync_to_db(self, mock_dagbag): + """Test that dagbag.sync_to_db is retried on OperationalError""" dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) - dagbag = DagBag(dag_folder=dag_file, include_examples=False) - dagbag.sync_to_db() - with create_session() as session: - session.query(TaskInstance).delete() - ( - session.query(DagModel) - .filter(DagModel.dag_id == "test_multiple_dags__dag_1") - .update({DagModel.is_paused: True}, synchronize_session=False) - ) - serialized_dags, import_errors_count = dag_file_processor.process_file( - file_path=dag_file, failure_callback_requests=[] - ) + mock_dagbag.return_value.dags = {'example_dag': mock.ANY} + op_error = OperationalError(statement=mock.ANY, params=mock.ANY, orig=mock.ANY) - dags = [SerializedDAG.from_dict(serialized_dag) for serialized_dag in serialized_dags] + # Mock error for the first 2 tries and a successful third try + side_effect = [op_error, op_error, mock.ANY] - with create_session() as session: - tis = session.query(TaskInstance).all() + mock_sync_to_db = mock.Mock(side_effect=side_effect) + mock_dagbag.return_value.sync_to_db = mock_sync_to_db - self.assertEqual(0, import_errors_count) - self.assertEqual(['test_multiple_dags__dag_2'], [dag.dag_id for dag in dags]) - self.assertEqual({'test_multiple_dags__dag_2'}, {ti.dag_id for ti in tis}) + dag_file_processor.process_file("/dev/null", callback_requests=mock.MagicMock()) + mock_sync_to_db.assert_has_calls([mock.call(), mock.call(), mock.call()]) def test_should_mark_dummy_task_as_success(self): dag_file = os.path.join( os.path.dirname(os.path.realpath(__file__)), '../dags/test_only_dummy_tasks.py' ) - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) - with create_session() as session: - session.query(TaskInstance).delete() - session.query(DagModel).delete() - dagbag = DagBag(dag_folder=dag_file, include_examples=False) - dagbag.sync_to_db() + # Write DAGs to dag and serialized_dag table + with mock.patch("airflow.models.dagbag.settings.STORE_SERIALIZED_DAGS", return_value=True): + dagbag = DagBag(dag_folder=dag_file, include_examples=False) + dagbag.sync_to_db() - serialized_dags, import_errors_count = dag_file_processor.process_file( - file_path=dag_file, failure_callback_requests=[] - ) + scheduler_job = SchedulerJob() + scheduler_job.processor_agent = mock.MagicMock() + dag = scheduler_job.dagbag.get_dag("test_only_dummy_tasks") - dags = [SerializedDAG.from_dict(serialized_dag) for serialized_dag in serialized_dags] + # Create DagRun + session = settings.Session() + orm_dag = session.query(DagModel).get(dag.dag_id) + scheduler_job._create_dag_runs([orm_dag], session) + drs = DagRun.find(dag_id=dag.dag_id, session=session) + assert len(drs) == 1 + dr = drs[0] + + # Schedule TaskInstances + scheduler_job._schedule_dag_run(dr, 0, session) with create_session() as session: tis = session.query(TaskInstance).all() - self.assertEqual(0, import_errors_count) + dags = scheduler_job.dagbag.dags.values() self.assertEqual(['test_only_dummy_tasks'], [dag.dag_id for dag in dags]) self.assertEqual(5, len(tis)) self.assertEqual({ @@ -1224,9 +773,7 @@ def test_should_mark_dummy_task_as_success(self): self.assertIsNone(end_date) self.assertIsNone(duration) - dag_file_processor.process_file( - file_path=dag_file, failure_callback_requests=[] - ) + scheduler_job._schedule_dag_run(dr, 0, session) with create_session() as session: tis = session.query(TaskInstance).all() @@ -1250,128 +797,6 @@ def test_should_mark_dummy_task_as_success(self): self.assertIsNone(duration) -@pytest.mark.heisentests -class TestDagFileProcessorQueriesCount(unittest.TestCase): - """ - These tests are designed to detect changes in the number of queries for different DAG files. - - Each test has saved queries count in the table/spreadsheets. If you make a change that affected the number - of queries, please update the tables. - - These tests allow easy detection when a change is made that affects the performance of the - DagFileProcessor. - """ - - def setUp(self) -> None: - clear_db_runs() - clear_db_pools() - clear_db_dags() - clear_db_sla_miss() - clear_db_errors() - - @parameterized.expand( - [ - # pylint: disable=bad-whitespace - # expected, dag_count, task_count, start_ago, schedule_interval, shape - # One DAG with one task per DAG file - ([ 1, 1, 1, 1], 1, 1, "1d", "None", "no_structure"), # noqa - ([ 1, 1, 1, 1], 1, 1, "1d", "None", "linear"), # noqa - ([ 9, 5, 5, 5], 1, 1, "1d", "@once", "no_structure"), # noqa - ([ 9, 5, 5, 5], 1, 1, "1d", "@once", "linear"), # noqa - ([ 9, 12, 15, 18], 1, 1, "1d", "30m", "no_structure"), # noqa - ([ 9, 12, 15, 18], 1, 1, "1d", "30m", "linear"), # noqa - ([ 9, 12, 15, 18], 1, 1, "1d", "30m", "binary_tree"), # noqa - ([ 9, 12, 15, 18], 1, 1, "1d", "30m", "star"), # noqa - ([ 9, 12, 15, 18], 1, 1, "1d", "30m", "grid"), # noqa - # One DAG with five tasks per DAG file - ([ 1, 1, 1, 1], 1, 5, "1d", "None", "no_structure"), # noqa - ([ 1, 1, 1, 1], 1, 5, "1d", "None", "linear"), # noqa - ([ 9, 5, 5, 5], 1, 5, "1d", "@once", "no_structure"), # noqa - ([10, 6, 6, 6], 1, 5, "1d", "@once", "linear"), # noqa - ([ 9, 12, 15, 18], 1, 5, "1d", "30m", "no_structure"), # noqa - ([10, 14, 18, 22], 1, 5, "1d", "30m", "linear"), # noqa - ([10, 14, 18, 22], 1, 5, "1d", "30m", "binary_tree"), # noqa - ([10, 14, 18, 22], 1, 5, "1d", "30m", "star"), # noqa - ([10, 14, 18, 22], 1, 5, "1d", "30m", "grid"), # noqa - # 10 DAGs with 10 tasks per DAG file - ([ 1, 1, 1, 1], 10, 10, "1d", "None", "no_structure"), # noqa - ([ 1, 1, 1, 1], 10, 10, "1d", "None", "linear"), # noqa - ([81, 41, 41, 41], 10, 10, "1d", "@once", "no_structure"), # noqa - ([91, 51, 51, 51], 10, 10, "1d", "@once", "linear"), # noqa - ([81, 111, 111, 111], 10, 10, "1d", "30m", "no_structure"), # noqa - ([91, 131, 131, 131], 10, 10, "1d", "30m", "linear"), # noqa - ([91, 131, 131, 131], 10, 10, "1d", "30m", "binary_tree"), # noqa - ([91, 131, 131, 131], 10, 10, "1d", "30m", "star"), # noqa - ([91, 131, 131, 131], 10, 10, "1d", "30m", "grid"), # noqa - # pylint: enable=bad-whitespace - ] - ) - def test_process_dags_queries_count( - self, expected_query_counts, dag_count, task_count, start_ago, schedule_interval, shape - ): - with mock.patch.dict("os.environ", { - "PERF_DAGS_COUNT": str(dag_count), - "PERF_TASKS_COUNT": str(task_count), - "PERF_START_AGO": start_ago, - "PERF_SCHEDULE_INTERVAL": schedule_interval, - "PERF_SHAPE": shape, - }), conf_vars({ - ('scheduler', 'use_job_schedule'): 'True', - }): - dagbag = DagBag(dag_folder=ELASTIC_DAG_FILE, - include_examples=False, - include_smart_sensor=False) - processor = DagFileProcessor([], mock.MagicMock()) - for expected_query_count in expected_query_counts: - with assert_queries_count(expected_query_count): - processor._process_dags(dagbag.dags.values()) - - @parameterized.expand( - [ - # pylint: disable=bad-whitespace - # expected, dag_count, task_count, start_ago, schedule_interval, shape - # One DAG with two tasks per DAG file - ([ 5, 5, 5, 5], 1, 1, "1d", "None", "no_structure"), # noqa - ([ 5, 5, 5, 5], 1, 1, "1d", "None", "linear"), # noqa - ([15, 9, 9, 9], 1, 1, "1d", "@once", "no_structure"), # noqa - ([15, 9, 9, 9], 1, 1, "1d", "@once", "linear"), # noqa - ([15, 18, 21, 24], 1, 1, "1d", "30m", "no_structure"), # noqa - ([15, 18, 21, 24], 1, 1, "1d", "30m", "linear"), # noqa - # One DAG with five tasks per DAG file - ([ 5, 5, 5, 5], 1, 5, "1d", "None", "no_structure"), # noqa - ([ 5, 5, 5, 5], 1, 5, "1d", "None", "linear"), # noqa - ([15, 9, 9, 9], 1, 5, "1d", "@once", "no_structure"), # noqa - ([16, 10, 10, 10], 1, 5, "1d", "@once", "linear"), # noqa - ([15, 18, 21, 24], 1, 5, "1d", "30m", "no_structure"), # noqa - ([16, 20, 24, 28], 1, 5, "1d", "30m", "linear"), # noqa - # 10 DAGs with 10 tasks per DAG file - ([ 5, 5, 5, 5], 10, 10, "1d", "None", "no_structure"), # noqa - ([ 5, 5, 5, 5], 10, 10, "1d", "None", "linear"), # noqa - ([87, 45, 45, 45], 10, 10, "1d", "@once", "no_structure"), # noqa - ([97, 55, 55, 55], 10, 10, "1d", "@once", "linear"), # noqa - ([87, 117, 117, 117], 10, 10, "1d", "30m", "no_structure"), # noqa - ([97, 137, 137, 137], 10, 10, "1d", "30m", "linear"), # noqa - # pylint: enable=bad-whitespace - ] - ) - def test_process_file_queries_count( - self, expected_query_counts, dag_count, task_count, start_ago, schedule_interval, shape - ): - with mock.patch.dict("os.environ", { - "PERF_DAGS_COUNT": str(dag_count), - "PERF_TASKS_COUNT": str(task_count), - "PERF_START_AGO": start_ago, - "PERF_SCHEDULE_INTERVAL": schedule_interval, - "PERF_SHAPE": shape, - }), conf_vars({ - ('scheduler', 'use_job_schedule'): 'True' - }): - processor = DagFileProcessor([], mock.MagicMock()) - for expected_query_count in expected_query_counts: - with assert_queries_count(expected_query_count): - processor.process_file(ELASTIC_DAG_FILE, []) - - @pytest.mark.usefixtures("disable_load_example") class TestSchedulerJob(unittest.TestCase): @@ -1425,15 +850,11 @@ def run_single_scheduler_loop_with_no_dags(self, dags_folder): """ scheduler = SchedulerJob( executor=self.null_exec, - dag_id='this_dag_doesnt_exist', # We don't want to actually run anything - num_runs=1, + num_times_parse_dags=1, subdir=os.path.join(dags_folder)) scheduler.heartrate = 0 scheduler.run() - def _make_simple_dag_bag(self, dags): - return SimpleDagBag([SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) for dag in dags]) - def test_no_orphan_process_will_be_left(self): empty_dir = mkdtemp() current_process = psutil.Process() @@ -1449,8 +870,9 @@ def test_no_orphan_process_will_be_left(self): old_children) self.assertFalse(current_children) + @mock.patch('airflow.jobs.scheduler_job.TaskCallbackRequest') @mock.patch('airflow.jobs.scheduler_job.Stats.incr') - def test_process_executor_events(self, mock_stats_incr): + def test_process_executor_events(self, mock_stats_incr, mock_task_callback): dag_id = "test_process_executor_events" dag_id2 = "test_process_executor_events_2" task_id_1 = 'dummy_task' @@ -1461,41 +883,36 @@ def test_process_executor_events(self, mock_stats_incr): DummyOperator(dag=dag2, task_id=task_id_1) dag.fileloc = "/test_path1/" dag2.fileloc = "/test_path1/" - dagbag1 = self._make_simple_dag_bag([dag]) - dagbag2 = self._make_simple_dag_bag([dag2]) - scheduler = SchedulerJob() + executor = MockExecutor(do_update=False) + task_callback = mock.MagicMock() + mock_task_callback.return_value = task_callback + scheduler = SchedulerJob(executor=executor) + scheduler.processor_agent = mock.MagicMock() + session = settings.Session() + dag.sync_to_db(session=session) + dag2.sync_to_db(session=session) ti1 = TaskInstance(task1, DEFAULT_DATE) ti1.state = State.QUEUED session.merge(ti1) session.commit() - executor = MockExecutor(do_update=False) executor.event_buffer[ti1.key] = State.FAILED, None - scheduler.executor = executor - - scheduler.processor_agent = mock.MagicMock() - # dag bag does not contain dag_id - scheduler._process_executor_events(simple_dag_bag=dagbag2) - ti1.refresh_from_db() - self.assertEqual(ti1.state, State.QUEUED) - scheduler.processor_agent.send_callback_to_execute.assert_not_called() - - # dag bag does contain dag_id - scheduler._process_executor_events(simple_dag_bag=dagbag1) + scheduler._process_executor_events(session=session) ti1.refresh_from_db() self.assertEqual(ti1.state, State.QUEUED) - scheduler.processor_agent.send_callback_to_execute.assert_called_once_with( + mock_task_callback.assert_called_once_with( full_filepath='/test_path1/', - task_instance=mock.ANY, + simple_task_instance=mock.ANY, msg='Executor reports task instance ' ' ' 'finished (failed) although the task says its queued. (Info: None) ' 'Was the task killed externally?' ) + scheduler.processor_agent.send_callback_to_execute.assert_called_once_with(task_callback) scheduler.processor_agent.reset_mock() # ti in success state @@ -1504,7 +921,7 @@ def test_process_executor_events(self, mock_stats_incr): session.commit() executor.event_buffer[ti1.key] = State.SUCCESS, None - scheduler._process_executor_events(simple_dag_bag=dagbag1) + scheduler._process_executor_events(session=session) ti1.refresh_from_db() self.assertEqual(ti1.state, State.SUCCESS) scheduler.processor_agent.send_callback_to_execute.assert_not_called() @@ -1517,16 +934,13 @@ def test_process_executor_events_uses_inmemory_try_number(self): task_id = "task_id" try_number = 42 - scheduler = SchedulerJob() executor = MagicMock() + scheduler = SchedulerJob(executor=executor) + scheduler.processor_agent = MagicMock() event_buffer = { TaskInstanceKey(dag_id, task_id, execution_date, try_number): (State.SUCCESS, None) } executor.get_event_buffer.return_value = event_buffer - scheduler.executor = executor - - processor_agent = MagicMock() - scheduler.processor_agent = processor_agent dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE) task = DummyOperator(dag=dag, task_id=task_id) @@ -1536,7 +950,7 @@ def test_process_executor_events_uses_inmemory_try_number(self): ti.state = State.SUCCESS session.merge(ti) - scheduler._process_executor_events(simple_dag_bag=MagicMock()) + scheduler._process_executor_events() # Assert that the even_buffer is empty so the task was popped using right # task instance key self.assertEqual(event_buffer, {}) @@ -1548,27 +962,33 @@ def test_execute_task_instances_is_paused_wont_execute(self): dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE) task1 = DummyOperator(dag=dag, task_id=task_id_1) dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) - dagbag = self._make_simple_dag_bag([dag]) - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) scheduler = SchedulerJob() session = settings.Session() - dr1 = dag_file_processor.create_dag_run(dag) + dagmodel = DagModel( + dag_id=dag_id, + is_paused=False, + concurrency=dag.concurrency, + has_task_concurrency_limits=False, + ) + dr1 = dag.create_dagrun( + run_type=DagRunType.BACKFILL_JOB, + execution_date=DEFAULT_DATE, + state=State.RUNNING, + ) ti1 = TaskInstance(task1, DEFAULT_DATE) ti1.state = State.SCHEDULED - dr1.state = State.RUNNING - dagmodel = DagModel() - dagmodel.dag_id = dag_id - dagmodel.is_paused = True session.merge(ti1) session.merge(dr1) session.add(dagmodel) - session.commit() + session.flush() - scheduler._execute_task_instances(dagbag) + scheduler._critical_section_execute_task_instances(session) + session.flush() ti1.refresh_from_db() self.assertEqual(State.SCHEDULED, ti1.state) + session.rollback() def test_execute_task_instances_no_dagrun_task_will_execute(self): """ @@ -1581,22 +1001,32 @@ def test_execute_task_instances_no_dagrun_task_will_execute(self): task1 = DummyOperator(dag=dag, task_id=task_id_1) dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) - dagbag = self._make_simple_dag_bag([dag]) - - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) scheduler = SchedulerJob() session = settings.Session() - dag_file_processor.create_dag_run(dag) + dag_model = DagModel( + dag_id=dag_id, + is_paused=False, + concurrency=dag.concurrency, + has_task_concurrency_limits=False, + ) + session.add(dag_model) + dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=DEFAULT_DATE, + state=State.RUNNING, + ) ti1 = TaskInstance(task1, DEFAULT_DATE) ti1.state = State.SCHEDULED ti1.execution_date = ti1.execution_date + datetime.timedelta(days=1) session.merge(ti1) - session.commit() + session.flush() - scheduler._execute_task_instances(dagbag) + scheduler._critical_section_execute_task_instances(session) + session.flush() ti1.refresh_from_db() self.assertEqual(State.QUEUED, ti1.state) + session.rollback() def test_execute_task_instances_backfill_tasks_wont_execute(self): """ @@ -1608,26 +1038,36 @@ def test_execute_task_instances_backfill_tasks_wont_execute(self): dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE) task1 = DummyOperator(dag=dag, task_id=task_id_1) dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) - dagbag = self._make_simple_dag_bag([dag]) - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) scheduler = SchedulerJob() session = settings.Session() - dr1 = dag_file_processor.create_dag_run(dag) - dr1.run_type = DagRunType.BACKFILL_JOB.value + dag_model = DagModel( + dag_id=dag_id, + is_paused=False, + concurrency=dag.concurrency, + has_task_concurrency_limits=False, + ) + session.add(dag_model) + dr1 = dag.create_dagrun( + run_type=DagRunType.BACKFILL_JOB, + execution_date=DEFAULT_DATE, + state=State.RUNNING, + ) ti1 = TaskInstance(task1, dr1.execution_date) ti1.refresh_from_db() ti1.state = State.SCHEDULED session.merge(ti1) session.merge(dr1) - session.commit() + session.flush() self.assertTrue(dr1.is_backfill) - scheduler._execute_task_instances(dagbag) + scheduler._critical_section_execute_task_instances(session) + session.flush() ti1.refresh_from_db() self.assertEqual(State.SCHEDULED, ti1.state) + session.rollback() def test_find_executable_task_instances_backfill_nodagrun(self): dag_id = 'SchedulerJobTest.test_find_executable_task_instances_backfill_nodagrun' @@ -1635,15 +1075,27 @@ def test_find_executable_task_instances_backfill_nodagrun(self): dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, concurrency=16) task1 = DummyOperator(dag=dag, task_id=task_id_1) dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) - dagbag = self._make_simple_dag_bag([dag]) - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) scheduler = SchedulerJob() session = settings.Session() - dr1 = dag_file_processor.create_dag_run(dag) - dr2 = dag_file_processor.create_dag_run(dag) - dr2.run_type = DagRunType.BACKFILL_JOB.value + dag_model = DagModel( + dag_id=dag_id, + is_paused=False, + concurrency=dag.concurrency, + has_task_concurrency_limits=False, + ) + session.add(dag_model) + dr1 = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=DEFAULT_DATE, + state=State.RUNNING, + ) + dr2 = dag.create_dagrun( + run_type=DagRunType.BACKFILL_JOB, + execution_date=dag.following_schedule(dr1.execution_date), + state=State.RUNNING, + ) ti_no_dagrun = TaskInstance(task1, DEFAULT_DATE - datetime.timedelta(days=1)) ti_backfill = TaskInstance(task1, dr2.execution_date) @@ -1657,16 +1109,15 @@ def test_find_executable_task_instances_backfill_nodagrun(self): session.merge(ti_no_dagrun) session.merge(ti_backfill) session.merge(ti_with_dagrun) - session.commit() + session.flush() - res = scheduler._find_executable_task_instances( - dagbag, - session=session) + res = scheduler._executable_task_instances_to_queued(max_tis=32, session=session) self.assertEqual(2, len(res)) res_keys = map(lambda x: x.key, res) self.assertIn(ti_no_dagrun.key, res_keys) self.assertIn(ti_with_dagrun.key, res_keys) + session.rollback() def test_find_executable_task_instances_pool(self): dag_id = 'SchedulerJobTest.test_find_executable_task_instances_pool' @@ -1676,14 +1127,27 @@ def test_find_executable_task_instances_pool(self): task1 = DummyOperator(dag=dag, task_id=task_id_1, pool='a') task2 = DummyOperator(dag=dag, task_id=task_id_2, pool='b') dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) - dagbag = self._make_simple_dag_bag([dag]) - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) scheduler = SchedulerJob() session = settings.Session() - dr1 = dag_file_processor.create_dag_run(dag) - dr2 = dag_file_processor.create_dag_run(dag) + dag_model = DagModel( + dag_id=dag_id, + is_paused=False, + concurrency=dag.concurrency, + has_task_concurrency_limits=False, + ) + session.add(dag_model) + dr1 = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=DEFAULT_DATE, + state=State.RUNNING, + ) + dr2 = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=dag.following_schedule(dr1.execution_date), + state=State.RUNNING, + ) tis = ([ TaskInstance(task1, dr1.execution_date), @@ -1698,12 +1162,10 @@ def test_find_executable_task_instances_pool(self): pool2 = Pool(pool='b', slots=100, description='haha') session.add(pool) session.add(pool2) - session.commit() + session.flush() - res = scheduler._find_executable_task_instances( - dagbag, - session=session) - session.commit() + res = scheduler._executable_task_instances_to_queued(max_tis=32, session=session) + session.flush() self.assertEqual(3, len(res)) res_keys = [] for ti in res: @@ -1711,6 +1173,7 @@ def test_find_executable_task_instances_pool(self): self.assertIn(tis[0].key, res_keys) self.assertIn(tis[1].key, res_keys) self.assertIn(tis[3].key, res_keys) + session.rollback() def test_find_executable_task_instances_in_default_pool(self): set_default_pool_slots(1) @@ -1720,40 +1183,50 @@ def test_find_executable_task_instances_in_default_pool(self): op1 = DummyOperator(dag=dag, task_id='dummy1') op2 = DummyOperator(dag=dag, task_id='dummy2') dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) - dagbag = self._make_simple_dag_bag([dag]) executor = MockExecutor(do_update=True) - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) scheduler = SchedulerJob(executor=executor) - dr1 = dag_file_processor.create_dag_run(dag) - dr2 = dag_file_processor.create_dag_run(dag) + session = settings.Session() + dag_model = DagModel( + dag_id=dag_id, + is_paused=False, + concurrency=dag.concurrency, + has_task_concurrency_limits=False, + ) + session.add(dag_model) + dr1 = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=DEFAULT_DATE, + state=State.RUNNING, + ) + dr2 = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=dag.following_schedule(dr1.execution_date), + state=State.RUNNING, + ) ti1 = TaskInstance(task=op1, execution_date=dr1.execution_date) ti2 = TaskInstance(task=op2, execution_date=dr2.execution_date) ti1.state = State.SCHEDULED ti2.state = State.SCHEDULED - session = settings.Session() session.merge(ti1) session.merge(ti2) - session.commit() + session.flush() # Two tasks w/o pool up for execution and our default pool size is 1 - res = scheduler._find_executable_task_instances( - dagbag, - session=session) + res = scheduler._executable_task_instances_to_queued(max_tis=32, session=session) self.assertEqual(1, len(res)) ti2.state = State.RUNNING session.merge(ti2) - session.commit() + session.flush() # One task w/o pool up for execution and one task task running - res = scheduler._find_executable_task_instances( - dagbag, - session=session) + res = scheduler._executable_task_instances_to_queued(max_tis=32, session=session) self.assertEqual(0, len(res)) + session.rollback() session.close() def test_nonexistent_pool(self): @@ -1762,24 +1235,32 @@ def test_nonexistent_pool(self): dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, concurrency=16) task = DummyOperator(dag=dag, task_id=task_id, pool="this_pool_doesnt_exist") dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) - dagbag = self._make_simple_dag_bag([dag]) - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) scheduler = SchedulerJob() session = settings.Session() - dr = dag_file_processor.create_dag_run(dag) + dag_model = DagModel( + dag_id=dag_id, + is_paused=False, + concurrency=dag.concurrency, + has_task_concurrency_limits=False, + ) + session.add(dag_model) + dr = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=DEFAULT_DATE, + state=State.RUNNING, + ) ti = TaskInstance(task, dr.execution_date) ti.state = State.SCHEDULED session.merge(ti) session.commit() - res = scheduler._find_executable_task_instances( - dagbag, - session=session) - session.commit() + res = scheduler._executable_task_instances_to_queued(max_tis=32, session=session) + session.flush() self.assertEqual(0, len(res)) + session.rollback() def test_find_executable_task_instances_none(self): dag_id = 'SchedulerJobTest.test_find_executable_task_instances_none' @@ -1787,18 +1268,28 @@ def test_find_executable_task_instances_none(self): dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, concurrency=16) DummyOperator(dag=dag, task_id=task_id_1) dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) - dagbag = self._make_simple_dag_bag([dag]) - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) scheduler = SchedulerJob() session = settings.Session() - dag_file_processor.create_dag_run(dag) - session.commit() + dag_model = DagModel( + dag_id=dag_id, + is_paused=False, + concurrency=dag.concurrency, + has_task_concurrency_limits=False, + ) + session.add(dag_model) + dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=DEFAULT_DATE, + state=State.RUNNING, + ) + session.flush() - self.assertEqual(0, len(scheduler._find_executable_task_instances( - dagbag, + self.assertEqual(0, len(scheduler._executable_task_instances_to_queued( + max_tis=32, session=session))) + session.rollback() def test_find_executable_task_instances_concurrency(self): dag_id = 'SchedulerJobTest.test_find_executable_task_instances_concurrency' @@ -1806,15 +1297,32 @@ def test_find_executable_task_instances_concurrency(self): dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, concurrency=2) task1 = DummyOperator(dag=dag, task_id=task_id_1) dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) - dagbag = self._make_simple_dag_bag([dag]) - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) scheduler = SchedulerJob() session = settings.Session() - dr1 = dag_file_processor.create_dag_run(dag) - dr2 = dag_file_processor.create_dag_run(dag) - dr3 = dag_file_processor.create_dag_run(dag) + dag_model = DagModel( + dag_id=dag_id, + is_paused=False, + concurrency=dag.concurrency, + has_task_concurrency_limits=False, + ) + session.add(dag_model) + dr1 = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=DEFAULT_DATE, + state=State.RUNNING, + ) + dr2 = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=dag.following_schedule(dr1.execution_date), + state=State.RUNNING, + ) + dr3 = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=dag.following_schedule(dr2.execution_date), + state=State.RUNNING, + ) ti1 = TaskInstance(task1, dr1.execution_date) ti2 = TaskInstance(task1, dr2.execution_date) @@ -1826,11 +1334,9 @@ def test_find_executable_task_instances_concurrency(self): session.merge(ti2) session.merge(ti3) - session.commit() + session.flush() - res = scheduler._find_executable_task_instances( - dagbag, - session=session) + res = scheduler._executable_task_instances_to_queued(max_tis=32, session=session) self.assertEqual(1, len(res)) res_keys = map(lambda x: x.key, res) @@ -1838,13 +1344,12 @@ def test_find_executable_task_instances_concurrency(self): ti2.state = State.RUNNING session.merge(ti2) - session.commit() + session.flush() - res = scheduler._find_executable_task_instances( - dagbag, - session=session) + res = scheduler._executable_task_instances_to_queued(max_tis=32, session=session) self.assertEqual(0, len(res)) + session.rollback() def test_find_executable_task_instances_concurrency_queued(self): dag_id = 'SchedulerJobTest.test_find_executable_task_instances_concurrency_queued' @@ -1853,12 +1358,21 @@ def test_find_executable_task_instances_concurrency_queued(self): task2 = DummyOperator(dag=dag, task_id='dummy2') task3 = DummyOperator(dag=dag, task_id='dummy3') dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) - dagbag = self._make_simple_dag_bag([dag]) - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) scheduler = SchedulerJob() session = settings.Session() - dag_run = dag_file_processor.create_dag_run(dag) + dag_model = DagModel( + dag_id=dag_id, + is_paused=False, + concurrency=dag.concurrency, + has_task_concurrency_limits=False, + ) + session.add(dag_model) + dag_run = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=DEFAULT_DATE, + state=State.RUNNING, + ) ti1 = TaskInstance(task1, dag_run.execution_date) ti2 = TaskInstance(task2, dag_run.execution_date) @@ -1871,15 +1385,15 @@ def test_find_executable_task_instances_concurrency_queued(self): session.merge(ti2) session.merge(ti3) - session.commit() + session.flush() - res = scheduler._find_executable_task_instances( - dagbag, - session=session) + res = scheduler._executable_task_instances_to_queued(max_tis=32, session=session) self.assertEqual(1, len(res)) self.assertEqual(res[0].key, ti3.key) + session.rollback() + # TODO: This is a hack, I think I need to just remove the setting and have it on always def test_find_executable_task_instances_task_concurrency(self): # pylint: disable=too-many-statements dag_id = 'SchedulerJobTest.test_find_executable_task_instances_task_concurrency' task_id_1 = 'dummy' @@ -1887,17 +1401,29 @@ def test_find_executable_task_instances_task_concurrency(self): # pylint: disab dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, concurrency=16) task1 = DummyOperator(dag=dag, task_id=task_id_1, task_concurrency=2) task2 = DummyOperator(dag=dag, task_id=task_id_2) - dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) - dagbag = self._make_simple_dag_bag([dag]) executor = MockExecutor(do_update=True) - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) scheduler = SchedulerJob(executor=executor) session = settings.Session() - dr1 = dag_file_processor.create_dag_run(dag) - dr2 = dag_file_processor.create_dag_run(dag) - dr3 = dag_file_processor.create_dag_run(dag) + scheduler.dagbag.bag_dag(dag, root_dag=dag) + scheduler.dagbag.sync_to_db(session=session) + dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) + dr1 = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=DEFAULT_DATE, + state=State.RUNNING, + ) + dr2 = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=dag.following_schedule(dr1.execution_date), + state=State.RUNNING, + ) + dr3 = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=dag.following_schedule(dr2.execution_date), + state=State.RUNNING, + ) ti1_1 = TaskInstance(task1, dr1.execution_date) ti2 = TaskInstance(task2, dr1.execution_date) @@ -1906,11 +1432,9 @@ def test_find_executable_task_instances_task_concurrency(self): # pylint: disab ti2.state = State.SCHEDULED session.merge(ti1_1) session.merge(ti2) - session.commit() + session.flush() - res = scheduler._find_executable_task_instances( - dagbag, - session=session) + res = scheduler._executable_task_instances_to_queued(max_tis=32, session=session) self.assertEqual(2, len(res)) @@ -1921,11 +1445,9 @@ def test_find_executable_task_instances_task_concurrency(self): # pylint: disab session.merge(ti1_1) session.merge(ti2) session.merge(ti1_2) - session.commit() + session.flush() - res = scheduler._find_executable_task_instances( - dagbag, - session=session) + res = scheduler._executable_task_instances_to_queued(max_tis=32, session=session) self.assertEqual(1, len(res)) @@ -1934,11 +1456,9 @@ def test_find_executable_task_instances_task_concurrency(self): # pylint: disab ti1_3.state = State.SCHEDULED session.merge(ti1_2) session.merge(ti1_3) - session.commit() + session.flush() - res = scheduler._find_executable_task_instances( - dagbag, - session=session) + res = scheduler._executable_task_instances_to_queued(max_tis=32, session=session) self.assertEqual(0, len(res)) @@ -1948,11 +1468,9 @@ def test_find_executable_task_instances_task_concurrency(self): # pylint: disab session.merge(ti1_1) session.merge(ti1_2) session.merge(ti1_3) - session.commit() + session.flush() - res = scheduler._find_executable_task_instances( - dagbag, - session=session) + res = scheduler._executable_task_instances_to_queued(max_tis=32, session=session) self.assertEqual(2, len(res)) @@ -1962,20 +1480,12 @@ def test_find_executable_task_instances_task_concurrency(self): # pylint: disab session.merge(ti1_1) session.merge(ti1_2) session.merge(ti1_3) - session.commit() + session.flush() - res = scheduler._find_executable_task_instances( - dagbag, - session=session) + res = scheduler._executable_task_instances_to_queued(max_tis=32, session=session) self.assertEqual(1, len(res)) - - def test_change_state_for_executable_task_instances_no_tis(self): - scheduler = SchedulerJob() - session = settings.Session() - res = scheduler._change_state_for_executable_task_instances( - [], session) - self.assertEqual(0, len(res)) + session.rollback() def test_change_state_for_executable_task_instances_no_tis_with_state(self): dag_id = 'SchedulerJobTest.test_change_state_for__no_tis_with_state' @@ -1983,15 +1493,28 @@ def test_change_state_for_executable_task_instances_no_tis_with_state(self): dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, concurrency=2) task1 = DummyOperator(dag=dag, task_id=task_id_1) dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) - self._make_simple_dag_bag([dag]) scheduler = SchedulerJob() session = settings.Session() - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) - dr1 = dag_file_processor.create_dag_run(dag) - dr2 = dag_file_processor.create_dag_run(dag) - dr3 = dag_file_processor.create_dag_run(dag) + date = DEFAULT_DATE + dr1 = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=date, + state=State.RUNNING, + ) + date = dag.following_schedule(date) + dr2 = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=date, + state=State.RUNNING, + ) + date = dag.following_schedule(date) + dr3 = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=date, + state=State.RUNNING, + ) ti1 = TaskInstance(task1, dr1.execution_date) ti2 = TaskInstance(task1, dr2.execution_date) @@ -2003,57 +1526,47 @@ def test_change_state_for_executable_task_instances_no_tis_with_state(self): session.merge(ti2) session.merge(ti3) - session.commit() + session.flush() - res = scheduler._change_state_for_executable_task_instances( - [ti1, ti2, ti3], - session) + res = scheduler._executable_task_instances_to_queued(max_tis=100, session=session) self.assertEqual(0, len(res)) + session.rollback() + def test_enqueue_task_instances_with_queued_state(self): dag_id = 'SchedulerJobTest.test_enqueue_task_instances_with_queued_state' task_id_1 = 'dummy' dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE) task1 = DummyOperator(dag=dag, task_id=task_id_1) dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) - dagbag = self._make_simple_dag_bag([dag]) - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) scheduler = SchedulerJob() session = settings.Session() - dr1 = dag_file_processor.create_dag_run(dag) - + dag_model = DagModel( + dag_id=dag_id, + is_paused=False, + concurrency=dag.concurrency, + has_task_concurrency_limits=False, + ) + session.add(dag_model) + dr1 = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=DEFAULT_DATE, + state=State.RUNNING, + ) ti1 = TaskInstance(task1, dr1.execution_date) + ti1.dag_model = dag_model session.merge(ti1) - session.commit() + session.flush() with patch.object(BaseExecutor, 'queue_command') as mock_queue_command: - scheduler._enqueue_task_instances_with_queued_state(dagbag, [ti1]) + scheduler._enqueue_task_instances_with_queued_state([ti1]) assert mock_queue_command.called + session.rollback() - def test_execute_task_instances_nothing(self): - dag_id = 'SchedulerJobTest.test_execute_task_instances_nothing' - task_id_1 = 'dummy' - dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, concurrency=2) - task1 = DummyOperator(dag=dag, task_id=task_id_1) - dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) - dagbag = SimpleDagBag([]) - - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) - scheduler = SchedulerJob() - session = settings.Session() - - dr1 = dag_file_processor.create_dag_run(dag) - ti1 = TaskInstance(task1, dr1.execution_date) - ti1.state = State.SCHEDULED - session.merge(ti1) - session.commit() - - self.assertEqual(0, scheduler._execute_task_instances(dagbag)) - - def test_execute_task_instances(self): + def test_critical_section_execute_task_instances(self): dag_id = 'SchedulerJobTest.test_execute_task_instances' task_id_1 = 'dummy_task' task_id_2 = 'dummy_task_nonexistent_queue' @@ -2065,14 +1578,24 @@ def test_execute_task_instances(self): task1 = DummyOperator(dag=dag, task_id=task_id_1) task2 = DummyOperator(dag=dag, task_id=task_id_2) dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) - dagbag = self._make_simple_dag_bag([dag]) - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) scheduler = SchedulerJob() session = settings.Session() # create first dag run with 1 running and 1 queued - dr1 = dag_file_processor.create_dag_run(dag) + dag_model = DagModel( + dag_id=dag_id, + is_paused=False, + concurrency=dag.concurrency, + has_task_concurrency_limits=False, + ) + session.add(dag_model) + dr1 = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=DEFAULT_DATE, + state=State.RUNNING, + ) + ti1 = TaskInstance(task1, dr1.execution_date) ti2 = TaskInstance(task2, dr1.execution_date) ti1.refresh_from_db() @@ -2081,7 +1604,7 @@ def test_execute_task_instances(self): ti2.state = State.RUNNING session.merge(ti1) session.merge(ti2) - session.commit() + session.flush() self.assertEqual(State.RUNNING, dr1.state) self.assertEqual( @@ -2092,7 +1615,11 @@ def test_execute_task_instances(self): ) # create second dag run - dr2 = dag_file_processor.create_dag_run(dag) + dr2 = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=dag.following_schedule(dr1.execution_date), + state=State.RUNNING, + ) ti3 = TaskInstance(task1, dr2.execution_date) ti4 = TaskInstance(task2, dr2.execution_date) ti3.refresh_from_db() @@ -2102,11 +1629,11 @@ def test_execute_task_instances(self): ti4.state = State.SCHEDULED session.merge(ti3) session.merge(ti4) - session.commit() + session.flush() self.assertEqual(State.RUNNING, dr2.state) - res = scheduler._execute_task_instances(dagbag) + res = scheduler._critical_section_execute_task_instances(session) # check that concurrency is respected ti1.refresh_from_db() @@ -2136,16 +1663,26 @@ def test_execute_task_instances_limit(self): task1 = DummyOperator(dag=dag, task_id=task_id_1) task2 = DummyOperator(dag=dag, task_id=task_id_2) dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) - dagbag = self._make_simple_dag_bag([dag]) - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) scheduler = SchedulerJob() - scheduler.max_tis_per_query = 3 session = settings.Session() + dag_model = DagModel( + dag_id=dag_id, + is_paused=False, + concurrency=dag.concurrency, + has_task_concurrency_limits=False, + ) + session.add(dag_model) + date = dag.start_date tis = [] for _ in range(0, 4): - dr = dag_file_processor.create_dag_run(dag) + dr = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=date, + state=State.RUNNING, + ) + date = dag.following_schedule(date) ti1 = TaskInstance(task1, dr.execution_date) ti2 = TaskInstance(task2, dr.execution_date) tis.append(ti1) @@ -2156,10 +1693,22 @@ def test_execute_task_instances_limit(self): ti2.state = State.SCHEDULED session.merge(ti1) session.merge(ti2) - session.commit() - res = scheduler._execute_task_instances(dagbag) - - self.assertEqual(8, res) + session.flush() + scheduler.max_tis_per_query = 2 + res = scheduler._critical_section_execute_task_instances(session) + self.assertEqual(2, res) + + scheduler.max_tis_per_query = 8 + with mock.patch.object(type(scheduler.executor), + 'slots_available', + new_callable=mock.PropertyMock) as mock_slots: + mock_slots.return_value = 2 + # Check that we don't "overfill" the executor + self.assertEqual(2, res) + res = scheduler._critical_section_execute_task_instances(session) + + res = scheduler._critical_section_execute_task_instances(session) + self.assertEqual(4, res) for ti in tis: ti.refresh_from_db() self.assertEqual(State.QUEUED, ti.state) @@ -2171,20 +1720,15 @@ def test_change_state_for_tis_without_dagrun(self): DummyOperator(task_id='dummy', dag=dag1, owner='airflow') DummyOperator(task_id='dummy_b', dag=dag1, owner='airflow') - dag1 = SerializedDAG.from_dict(SerializedDAG.to_dict(dag1)) dag2 = DAG(dag_id='test_change_state_for_tis_without_dagrun_dont_change', start_date=DEFAULT_DATE) DummyOperator(task_id='dummy', dag=dag2, owner='airflow') - dag2 = SerializedDAG.from_dict(SerializedDAG.to_dict(dag2)) - dag3 = DAG(dag_id='test_change_state_for_tis_without_dagrun_no_dagrun', start_date=DEFAULT_DATE) DummyOperator(task_id='dummy', dag=dag3, owner='airflow') - dag3 = SerializedDAG.from_dict(SerializedDAG.to_dict(dag3)) - session = settings.Session() dr1 = dag1.create_dagrun(run_type=DagRunType.SCHEDULED, state=State.RUNNING, @@ -2213,10 +1757,17 @@ def test_change_state_for_tis_without_dagrun(self): session.merge(ti3) session.commit() - dagbag = self._make_simple_dag_bag([dag1, dag2, dag3]) + with mock.patch.object(settings, "STORE_SERIALIZED_DAGS", True): + dagbag = DagBag("/dev/null", include_examples=False) + dagbag.bag_dag(dag1, root_dag=dag1) + dagbag.bag_dag(dag2, root_dag=dag2) + dagbag.bag_dag(dag3, root_dag=dag3) + dagbag.sync_to_db(session) + scheduler = SchedulerJob(num_runs=0) + scheduler.dagbag.collect_dags_from_db() + scheduler._change_state_for_tis_without_dagrun( - simple_dag_bag=dagbag, old_states=[State.SCHEDULED, State.QUEUED], new_state=State.NONE, session=session) @@ -2247,7 +1798,6 @@ def test_change_state_for_tis_without_dagrun(self): session.commit() scheduler._change_state_for_tis_without_dagrun( - simple_dag_bag=dagbag, old_states=[State.SCHEDULED, State.QUEUED], new_state=State.NONE, session=session) @@ -2362,48 +1912,277 @@ def test_scheduler_loop_should_change_state_for_tis_without_dagrun(self, initial_task_state, expected_task_state): session = settings.Session() - dag = DAG( - 'test_execute_helper_should_change_state_for_tis_without_dagrun', - start_date=DEFAULT_DATE, - default_args={'owner': 'owner1'}) + dag_id = 'test_execute_helper_should_change_state_for_tis_without_dagrun' + dag = DAG(dag_id, start_date=DEFAULT_DATE, default_args={'owner': 'owner1'}) with dag: op1 = DummyOperator(task_id='op1') + + # Write Dag to DB + with mock.patch.object(settings, "STORE_SERIALIZED_DAGS", True): + dagbag = DagBag(dag_folder="/dev/null", include_examples=False) + dagbag.bag_dag(dag, root_dag=dag) + dagbag.sync_to_db() + + dag = DagBag(read_dags_from_db=True, include_examples=False).get_dag(dag_id) + # Create DAG run with FAILED state + dag.clear() + dr = dag.create_dagrun(run_type=DagRunType.SCHEDULED, + state=State.FAILED, + execution_date=DEFAULT_DATE + timedelta(days=1), + start_date=DEFAULT_DATE + timedelta(days=1), + session=session) + ti = dr.get_task_instance(task_id=op1.task_id, session=session) + ti.state = initial_task_state + session.commit() + + # Create scheduler and mock calls to processor. Run duration is set + # to a high value to ensure loop is entered. Poll interval is 0 to + # avoid sleep. Done flag is set to true to exist the loop immediately. + scheduler = SchedulerJob(num_runs=0, processor_poll_interval=0) + executor = MockExecutor(do_update=False) + executor.queued_tasks + scheduler.executor = executor + processor = mock.MagicMock() + processor.done = True + scheduler.processor_agent = processor + + scheduler._run_scheduler_loop() + + ti = dr.get_task_instance(task_id=op1.task_id, session=session) + self.assertEqual(ti.state, expected_task_state) + self.assertIsNotNone(ti.start_date) + if expected_task_state in State.finished(): + self.assertIsNotNone(ti.end_date) + self.assertEqual(ti.start_date, ti.end_date) + self.assertIsNotNone(ti.duration) + + def test_dagrun_timeout_verify_max_active_runs(self): + """ + Test if a a dagrun will not be scheduled if max_dag_runs + has been reached and dagrun_timeout is not reached + + Test if a a dagrun would be scheduled if max_dag_runs has + been reached but dagrun_timeout is also reached + """ + dag = DAG( + dag_id='test_scheduler_verify_max_active_runs_and_dagrun_timeout', + start_date=DEFAULT_DATE) + dag.max_active_runs = 1 + dag.dagrun_timeout = datetime.timedelta(seconds=60) + + DummyOperator( + task_id='dummy', + dag=dag, + owner='airflow') + + scheduler = SchedulerJob() + scheduler.dagbag.bag_dag(dag, root_dag=dag) + scheduler.dagbag.sync_to_db() + + session = settings.Session() + orm_dag = session.query(DagModel).get(dag.dag_id) + assert orm_dag is not None + + dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) + + scheduler = SchedulerJob() + scheduler._create_dag_runs([orm_dag], session) + + drs = DagRun.find(dag_id=dag.dag_id, session=session) + assert len(drs) == 1 + dr = drs[0] + + # Should not be able to create a new dag run, as we are at max active runs + assert orm_dag.next_dagrun_create_after is None + # But we should record the date of _what run_ it would be + assert isinstance(orm_dag.next_dagrun, datetime.datetime) + + # Should be scheduled as dagrun_timeout has passed + dr.start_date = timezone.utcnow() - datetime.timedelta(days=1) + session.flush() + + # Mock that processor_agent is started + scheduler.processor_agent = mock.Mock() + scheduler.processor_agent.send_callback_to_execute = mock.Mock() + + scheduler._schedule_dag_run(dr, 0, session) + session.flush() + + session.refresh(dr) + assert dr.state == State.FAILED + session.refresh(orm_dag) + assert isinstance(orm_dag.next_dagrun, datetime.datetime) + assert isinstance(orm_dag.next_dagrun_create_after, datetime.datetime) + + expected_callback = DagCallbackRequest( + full_filepath=dr.dag.fileloc, + dag_id=dr.dag_id, + is_failure_callback=True, + execution_date=dr.execution_date, + msg="timed_out" + ) + + # Verify dag failure callback request is sent to file processor + scheduler.processor_agent.send_callback_to_execute.assert_called_once_with(expected_callback) + + session.rollback() + session.close() + + def test_dagrun_timeout_fails_run(self): + """ + Test if a a dagrun will be set failed if timeout, even without max_active_runs + """ + dag = DAG( + dag_id='test_scheduler_fail_dagrun_timeout', + start_date=DEFAULT_DATE) + dag.dagrun_timeout = datetime.timedelta(seconds=60) + + DummyOperator( + task_id='dummy', + dag=dag, + owner='airflow') + + scheduler = SchedulerJob() + scheduler.dagbag.bag_dag(dag, root_dag=dag) + scheduler.dagbag.sync_to_db() + + session = settings.Session() + orm_dag = session.query(DagModel).get(dag.dag_id) + assert orm_dag is not None + + dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) + + scheduler = SchedulerJob() + scheduler._create_dag_runs([orm_dag], session) + + drs = DagRun.find(dag_id=dag.dag_id, session=session) + assert len(drs) == 1 + dr = drs[0] + + # Should be scheduled as dagrun_timeout has passed + dr.start_date = timezone.utcnow() - datetime.timedelta(days=1) + session.flush() + + # Mock that processor_agent is started + scheduler.processor_agent = mock.Mock() + scheduler.processor_agent.send_callback_to_execute = mock.Mock() + + scheduler._schedule_dag_run(dr, 0, session) + session.flush() + + session.refresh(dr) + assert dr.state == State.FAILED + + expected_callback = DagCallbackRequest( + full_filepath=dr.dag.fileloc, + dag_id=dr.dag_id, + is_failure_callback=True, + execution_date=dr.execution_date, + msg="timed_out" + ) + + # Verify dag failure callback request is sent to file processor + scheduler.processor_agent.send_callback_to_execute.assert_called_once_with(expected_callback) + + session.rollback() + session.close() + + @parameterized.expand([ + (State.SUCCESS, "success"), + (State.FAILED, "task_failure") + ]) + def test_dagrun_callbacks_are_called(self, state, expected_callback_msg): + """ + Test if DagRun is successful, and if Success callbacks is defined, it is sent to DagFileProcessor. + Also test that SLA Callback Function is called. + """ + dag = DAG( + dag_id='test_dagrun_callbacks_are_called', + start_date=DEFAULT_DATE, + on_success_callback=lambda x: print("success"), + on_failure_callback=lambda x: print("failed") + ) + + DummyOperator(task_id='dummy', dag=dag, owner='airflow') + + scheduler = SchedulerJob() + scheduler.processor_agent = mock.Mock() + scheduler.processor_agent.send_callback_to_execute = mock.Mock() + scheduler._send_sla_callbacks_to_processor = mock.Mock() + + # Sync DAG into DB + scheduler.dagbag.bag_dag(dag, root_dag=dag) + scheduler.dagbag.sync_to_db() + + session = settings.Session() + orm_dag = session.query(DagModel).get(dag.dag_id) + assert orm_dag is not None + + # Create DagRun + scheduler._create_dag_runs([orm_dag], session) + + drs = DagRun.find(dag_id=dag.dag_id, session=session) + assert len(drs) == 1 + dr = drs[0] + + ti = dr.get_task_instance('dummy') + ti.set_state(state, session) + + scheduler._schedule_dag_run(dr, 0, session) + + expected_callback = DagCallbackRequest( + full_filepath=dr.dag.fileloc, + dag_id=dr.dag_id, + is_failure_callback=bool(state == State.FAILED), + execution_date=dr.execution_date, + msg=expected_callback_msg + ) + + # Verify dag failure callback request is sent to file processor + scheduler.processor_agent.send_callback_to_execute.assert_called_once_with(expected_callback) + # This is already tested separately + # In this test we just want to verify that this function is called + scheduler._send_sla_callbacks_to_processor.assert_called_once_with(dag) + + session.rollback() + session.close() + + def test_do_not_schedule_removed_task(self): + dag = DAG( + dag_id='test_scheduler_do_not_schedule_removed_task', + start_date=DEFAULT_DATE) + DummyOperator( + task_id='dummy', + dag=dag, + owner='airflow') + + session = settings.Session() + dag.sync_to_db(session=session) + session.flush() + dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) - # Create DAG run with FAILED state - dag.clear() - dr = dag.create_dagrun(run_type=DagRunType.SCHEDULED, - state=State.FAILED, - execution_date=DEFAULT_DATE, - start_date=DEFAULT_DATE, - session=session) - ti = dr.get_task_instance(task_id=op1.task_id, session=session) - ti.state = initial_task_state - session.commit() + dr = dag.create_dagrun( + execution_date=DEFAULT_DATE, + run_type=DagRunType.SCHEDULED, + state=State.RUNNING, + session=session, + ) + self.assertIsNotNone(dr) - # Create scheduler and mock calls to processor. Run duration is set - # to a high value to ensure loop is entered. Poll interval is 0 to - # avoid sleep. Done flag is set to true to exist the loop immediately. - scheduler = SchedulerJob(num_runs=0, processor_poll_interval=0) - executor = MockExecutor(do_update=False) - executor.queued_tasks - scheduler.executor = executor - processor = mock.MagicMock() - processor.harvest_serialized_dags.return_value = [ - SerializedDAG.from_dict(SerializedDAG.to_dict(dag))] - processor.done = True - scheduler.processor_agent = processor + # Re-create the DAG, but remove the task + dag = DAG( + dag_id='test_scheduler_do_not_schedule_removed_task', + start_date=DEFAULT_DATE) + dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) - scheduler._run_scheduler_loop() + scheduler = SchedulerJob() + res = scheduler._executable_task_instances_to_queued(max_tis=32, session=session) - ti = dr.get_task_instance(task_id=op1.task_id, session=session) - self.assertEqual(ti.state, expected_task_state) - self.assertIsNotNone(ti.start_date) - if expected_task_state in State.finished(): - self.assertIsNotNone(ti.end_date) - self.assertEqual(ti.start_date, ti.end_date) - self.assertIsNotNone(ti.duration) + self.assertEqual([], res) + session.rollback() + session.close() @provide_session def evaluate_dagrun( @@ -2423,13 +2202,20 @@ def evaluate_dagrun( if run_kwargs is None: run_kwargs = {} - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) dag = self.dagbag.get_dag(dag_id) - dr = dag_file_processor.create_dag_run(dag) + dr = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=dag.next_dagrun_after_date(None), + state=State.RUNNING, + ) if advance_execution_date: # run a second time to schedule a dagrun after the start_date - dr = dag_file_processor.create_dag_run(dag) + dr = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=dag.following_schedule(dr.execution_date), + state=State.RUNNING, + ) ex_date = dr.execution_date for tid, state in expected_task_states.items(): @@ -2438,6 +2224,8 @@ def evaluate_dagrun( self.null_exec.mock_task_fail(dag_id, tid, ex_date) try: + dag = DagBag().get_dag(dag.dag_id) + assert not isinstance(dag, SerializedDAG) # This needs a _REAL_ dag, not the serialized version dag.run(start_date=ex_date, end_date=ex_date, executor=self.null_exec, **run_kwargs) except AirflowException: @@ -2499,10 +2287,13 @@ def test_dagrun_root_fail_unfinished(self): """ # TODO: this should live in test_dagrun.py # Run both the failed and successful tasks - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) dag_id = 'test_dagrun_states_root_fail_unfinished' dag = self.dagbag.get_dag(dag_id) - dr = dag_file_processor.create_dag_run(dag) + dr = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=DEFAULT_DATE, + state=State.RUNNING, + ) self.null_exec.mock_task_fail(dag_id, 'test_dagrun_fail', DEFAULT_DATE) with self.assertRaises(AirflowException): @@ -2523,10 +2314,11 @@ def test_dagrun_root_after_dagrun_unfinished(self): Noted: the DagRun state could be still in running state during CI. """ + clear_db_dags() dag_id = 'test_dagrun_states_root_future' dag = self.dagbag.get_dag(dag_id) + dag.sync_to_db() scheduler = SchedulerJob( - dag_id, num_runs=1, executor=self.null_exec, subdir=dag.fileloc) @@ -2582,8 +2374,12 @@ def test_scheduler_start_date(self): dag.clear() self.assertGreater(dag.start_date, datetime.datetime.now(timezone.utc)) - scheduler = SchedulerJob(dag_id, - executor=self.null_exec, + # Deactivate other dags in this file + other_dag = self.dagbag.get_dag('test_task_start_date_scheduling') + other_dag.is_paused_upon_creation = True + other_dag.sync_to_db() + + scheduler = SchedulerJob(executor=self.null_exec, subdir=dag.fileloc, num_runs=1) scheduler.run() @@ -2617,9 +2413,8 @@ def test_scheduler_start_date(self): ) session.commit() - scheduler = SchedulerJob(dag_id, + scheduler = SchedulerJob(dag.fileloc, executor=self.null_exec, - subdir=dag.fileloc, num_runs=1) scheduler.run() @@ -2634,12 +2429,21 @@ def test_scheduler_task_start_date(self): Test that the scheduler respects task start dates that are different from DAG start dates """ + dagbag = DagBag(dag_folder=os.path.join(settings.DAGS_FOLDER, "no_dags.py"), include_examples=False) dag_id = 'test_task_start_date_scheduling' dag = self.dagbag.get_dag(dag_id) - dag.clear() - scheduler = SchedulerJob(dag_id, - executor=self.null_exec, - subdir=os.path.join(TEST_DAG_FOLDER, 'test_scheduler_dags.py'), + dag.is_paused_upon_creation = False + dagbag.bag_dag(dag=dag, root_dag=dag) + + # Deactivate other dags in this file so the scheduler doesn't waste time processing them + other_dag = self.dagbag.get_dag('test_start_date_scheduling') + other_dag.is_paused_upon_creation = True + dagbag.bag_dag(dag=other_dag, root_dag=other_dag) + + dagbag.sync_to_db() + + scheduler = SchedulerJob(executor=self.null_exec, + subdir=dag.fileloc, num_runs=2) scheduler.run() @@ -2661,8 +2465,7 @@ def test_scheduler_multiprocessing(self): dag = self.dagbag.get_dag(dag_id) dag.clear() - scheduler = SchedulerJob(dag_ids=dag_ids, - executor=self.null_exec, + scheduler = SchedulerJob(executor=self.null_exec, subdir=os.path.join(TEST_DAG_FOLDER, 'test_scheduler_dags.py'), num_runs=1) scheduler.run() @@ -2684,8 +2487,7 @@ def test_scheduler_multiprocessing_with_spawn_method(self): dag = self.dagbag.get_dag(dag_id) dag.clear() - scheduler = SchedulerJob(dag_ids=dag_ids, - executor=self.null_exec, + scheduler = SchedulerJob(executor=self.null_exec, subdir=os.path.join( TEST_DAG_FOLDER, 'test_scheduler_dags.py'), num_runs=1) @@ -2706,53 +2508,47 @@ def test_scheduler_verify_pool_full(self): dag_id='test_scheduler_verify_pool_full', start_date=DEFAULT_DATE) - DummyOperator( + BashOperator( task_id='dummy', dag=dag, owner='airflow', - pool='test_scheduler_verify_pool_full') + pool='test_scheduler_verify_pool_full', + bash_command='echo hi', + ) + + dagbag = DagBag(dag_folder=os.path.join(settings.DAGS_FOLDER, "no_dags.py"), + include_examples=False, + read_dags_from_db=True) + dagbag.bag_dag(dag=dag, root_dag=dag) + dagbag.sync_to_db() session = settings.Session() pool = Pool(pool='test_scheduler_verify_pool_full', slots=1) session.add(pool) - orm_dag = DagModel(dag_id=dag.dag_id) - orm_dag.is_paused = False - session.merge(orm_dag) - session.commit() + session.flush() dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) scheduler = SchedulerJob(executor=self.null_exec) + scheduler.processor_agent = mock.MagicMock() # Create 2 dagruns, which will create 2 task instances. - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) - dr = dag_file_processor.create_dag_run(dag) - self.assertIsNotNone(dr) - self.assertEqual(dr.execution_date, DEFAULT_DATE) - dr = dag_file_processor.create_dag_run(dag) - self.assertIsNotNone(dr) - dag_runs = DagRun.find(dag_id="test_scheduler_verify_pool_full") - task_instances_list = dag_file_processor._process_task_instances(dag, dag_runs=dag_runs) - self.assertEqual(len(task_instances_list), 2) - dagbag = self._make_simple_dag_bag([dag]) - - # Recreated part of the scheduler here, to kick off tasks -> executor - for ti_key in task_instances_list: - task = dag.get_task(ti_key[1]) - ti = TaskInstance(task, ti_key[2]) - # Task starts out in the scheduled state. All tasks in the - # scheduled state will be sent to the executor - ti.state = State.SCHEDULED - - # Also save this task instance to the DB. - session.merge(ti) - session.commit() + dr = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=DEFAULT_DATE, + state=State.RUNNING, + ) + scheduler._schedule_dag_run(dr, 0, session) + dr = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=dag.following_schedule(dr.execution_date), + state=State.RUNNING, + ) + scheduler._schedule_dag_run(dr, 0, session) - self.assertEqual(len(scheduler.executor.queued_tasks), 0, "Check test pre-condition") - scheduler._execute_task_instances(dagbag, session=session) + task_instances_list = scheduler._executable_task_instances_to_queued(max_tis=32, session=session) - self.assertEqual(len(scheduler.executor.queued_tasks), 1) + self.assertEqual(len(task_instances_list), 1) def test_scheduler_verify_pool_full_2_slots_per_task(self): """ @@ -2764,52 +2560,46 @@ def test_scheduler_verify_pool_full_2_slots_per_task(self): dag_id='test_scheduler_verify_pool_full_2_slots_per_task', start_date=DEFAULT_DATE) - DummyOperator( + BashOperator( task_id='dummy', dag=dag, owner='airflow', pool='test_scheduler_verify_pool_full_2_slots_per_task', pool_slots=2, + bash_command='echo hi', ) + dagbag = DagBag(dag_folder=os.path.join(settings.DAGS_FOLDER, "no_dags.py"), + include_examples=False, + read_dags_from_db=True) + dagbag.bag_dag(dag=dag, root_dag=dag) + dagbag.sync_to_db() + session = settings.Session() pool = Pool(pool='test_scheduler_verify_pool_full_2_slots_per_task', slots=6) session.add(pool) - orm_dag = DagModel(dag_id=dag.dag_id) - orm_dag.is_paused = False - session.merge(orm_dag) session.commit() dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) scheduler = SchedulerJob(executor=self.null_exec) + scheduler.processor_agent = mock.MagicMock() # Create 5 dagruns, which will create 5 task instances. + date = DEFAULT_DATE for _ in range(5): - dag_file_processor.create_dag_run(dag) - dag_runs = DagRun.find(dag_id="test_scheduler_verify_pool_full_2_slots_per_task") - task_instances_list = dag_file_processor._process_task_instances(dag, dag_runs=dag_runs) - self.assertEqual(len(task_instances_list), 5) - dagbag = self._make_simple_dag_bag([dag]) - - # Recreated part of the scheduler here, to kick off tasks -> executor - for ti_key in task_instances_list: - task = dag.get_task(ti_key[1]) - ti = TaskInstance(task, ti_key[2]) - # Task starts out in the scheduled state. All tasks in the - # scheduled state will be sent to the executor - ti.state = State.SCHEDULED - - # Also save this task instance to the DB. - session.merge(ti) - session.commit() + dr = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=date, + state=State.RUNNING, + ) + scheduler._schedule_dag_run(dr, 0, session) + date = dag.following_schedule(date) - self.assertEqual(len(scheduler.executor.queued_tasks), 0, "Check test pre-condition") - scheduler._execute_task_instances(dagbag, session=session) + task_instances_list = scheduler._executable_task_instances_to_queued(max_tis=32, session=session) # As tasks require 2 slots, only 3 can fit into 6 available - self.assertEqual(len(scheduler.executor.queued_tasks), 3) + self.assertEqual(len(task_instances_list), 3) def test_scheduler_verify_priority_and_slots(self): """ @@ -2823,69 +2613,63 @@ def test_scheduler_verify_priority_and_slots(self): start_date=DEFAULT_DATE) # Medium priority, not enough slots - DummyOperator( + BashOperator( task_id='test_scheduler_verify_priority_and_slots_t0', dag=dag, owner='airflow', pool='test_scheduler_verify_priority_and_slots', pool_slots=2, priority_weight=2, + bash_command='echo hi', ) # High priority, occupies first slot - DummyOperator( + BashOperator( task_id='test_scheduler_verify_priority_and_slots_t1', dag=dag, owner='airflow', pool='test_scheduler_verify_priority_and_slots', pool_slots=1, priority_weight=3, + bash_command='echo hi', ) # Low priority, occupies second slot - DummyOperator( + BashOperator( task_id='test_scheduler_verify_priority_and_slots_t2', dag=dag, owner='airflow', pool='test_scheduler_verify_priority_and_slots', pool_slots=1, priority_weight=1, + bash_command='echo hi', ) + dagbag = DagBag(dag_folder=os.path.join(settings.DAGS_FOLDER, "no_dags.py"), + include_examples=False, + read_dags_from_db=True) + dagbag.bag_dag(dag=dag, root_dag=dag) + dagbag.sync_to_db() + session = settings.Session() pool = Pool(pool='test_scheduler_verify_priority_and_slots', slots=2) session.add(pool) - orm_dag = DagModel(dag_id=dag.dag_id) - orm_dag.is_paused = False - session.merge(orm_dag) session.commit() dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) scheduler = SchedulerJob(executor=self.null_exec) + scheduler.processor_agent = mock.MagicMock() - dag_file_processor.create_dag_run(dag) - dag_runs = DagRun.find(dag_id="test_scheduler_verify_priority_and_slots") - task_instances_list = dag_file_processor._process_task_instances(dag, dag_runs=dag_runs) - self.assertEqual(len(task_instances_list), 3) - dagbag = self._make_simple_dag_bag([dag]) - - # Recreated part of the scheduler here, to kick off tasks -> executor - for ti_key in task_instances_list: - task = dag.get_task(ti_key[1]) - ti = TaskInstance(task, ti_key[2]) - # Task starts out in the scheduled state. All tasks in the - # scheduled state will be sent to the executor - ti.state = State.SCHEDULED - - # Also save this task instance to the DB. - session.merge(ti) - session.commit() + dr = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=DEFAULT_DATE, + state=State.RUNNING, + ) + scheduler._schedule_dag_run(dr, 0, session) - self.assertEqual(len(scheduler.executor.queued_tasks), 0, "Check test pre-condition") - scheduler._execute_task_instances(dagbag, session=session) + task_instances_list = scheduler._executable_task_instances_to_queued(max_tis=32, session=session) # Only second and third - self.assertEqual(len(scheduler.executor.queued_tasks), 2) + self.assertEqual(len(task_instances_list), 2) ti0 = session.query(TaskInstance)\ .filter(TaskInstance.task_id == 'test_scheduler_verify_priority_and_slots_t0').first() @@ -2899,61 +2683,118 @@ def test_scheduler_verify_priority_and_slots(self): .filter(TaskInstance.task_id == 'test_scheduler_verify_priority_and_slots_t2').first() self.assertEqual(ti2.state, State.QUEUED) - def test_scheduler_reschedule(self): - """ - Checks if tasks that are not taken up by the executor - get rescheduled - """ - executor = MockExecutor(do_update=False) - dagbag = DagBag(dag_folder=os.path.join(settings.DAGS_FOLDER, "no_dags.py")) - dagbag.dags.clear() + def test_verify_integrity_if_dag_not_changed(self): + # CleanUp + with create_session() as session: + session.query(SerializedDagModel).filter( + SerializedDagModel.dag_id == 'test_verify_integrity_if_dag_not_changed' + ).delete(synchronize_session=False) - dag = DAG( - dag_id='test_scheduler_reschedule', - start_date=DEFAULT_DATE) - dummy_task = BashOperator( - task_id='dummy', - dag=dag, - owner='airflow', - bash_command='echo 1', - ) + dag = DAG(dag_id='test_verify_integrity_if_dag_not_changed', start_date=DEFAULT_DATE) + BashOperator(task_id='dummy', dag=dag, owner='airflow', bash_command='echo hi') - dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) - dag.clear() - dag.is_subdag = False + scheduler = SchedulerJob() + scheduler.dagbag.bag_dag(dag, root_dag=dag) + scheduler.dagbag.sync_to_db() - with create_session() as session: - orm_dag = DagModel(dag_id=dag.dag_id) - orm_dag.is_paused = False - session.merge(orm_dag) + session = settings.Session() + orm_dag = session.query(DagModel).get(dag.dag_id) + assert orm_dag is not None - dagbag.bag_dag(dag=dag, root_dag=dag) + scheduler = SchedulerJob() + scheduler.processor_agent = mock.MagicMock() + dag = scheduler.dagbag.get_dag('test_verify_integrity_if_dag_not_changed', session=session) + scheduler._create_dag_runs([orm_dag], session) - @mock.patch('airflow.jobs.scheduler_job.DagBag', return_value=dagbag) - def do_schedule(mock_dagbag): - # Use a empty file since the above mock will return the - # expected DAGs. Also specify only a single file so that it doesn't - # try to schedule the above DAG repeatedly. - with conf_vars({('core', 'mp_start_method'): 'fork'}): - scheduler = SchedulerJob(num_runs=1, - executor=executor, - subdir=os.path.join(settings.DAGS_FOLDER, - "no_dags.py")) - scheduler.heartrate = 0 - scheduler.run() + drs = DagRun.find(dag_id=dag.dag_id, session=session) + assert len(drs) == 1 + dr = drs[0] - do_schedule() # pylint: disable=no-value-for-parameter + # Verify that DagRun.verify_integrity is not called + with mock.patch('airflow.jobs.scheduler_job.DagRun.verify_integrity') as mock_verify_integrity: + scheduled_tis = scheduler._schedule_dag_run(dr, 0, session) + mock_verify_integrity.assert_not_called() + session.flush() + + assert scheduled_tis == 1 + + tis_count = session.query(func.count(TaskInstance.task_id)).filter( + TaskInstance.dag_id == dr.dag_id, + TaskInstance.execution_date == dr.execution_date, + TaskInstance.task_id == dr.dag.tasks[0].task_id, + TaskInstance.state == State.SCHEDULED + ).scalar() + assert tis_count == 1 + + latest_dag_version = SerializedDagModel.get_latest_version_hash(dr.dag_id, session=session) + assert dr.dag_hash == latest_dag_version + + session.rollback() + session.close() + + def test_verify_integrity_if_dag_changed(self): + # CleanUp with create_session() as session: - ti = session.query(TaskInstance).filter(TaskInstance.dag_id == dag.dag_id, - TaskInstance.task_id == dummy_task.task_id).first() - self.assertEqual(0, len(executor.queued_tasks)) - self.assertEqual(State.SCHEDULED, ti.state) + session.query(SerializedDagModel).filter( + SerializedDagModel.dag_id == 'test_verify_integrity_if_dag_changed' + ).delete(synchronize_session=False) - executor.do_update = True - do_schedule() # pylint: disable=no-value-for-parameter - self.assertEqual(0, len(executor.queued_tasks)) - ti.refresh_from_db() - self.assertEqual(State.SUCCESS, ti.state) + dag = DAG(dag_id='test_verify_integrity_if_dag_changed', start_date=DEFAULT_DATE) + BashOperator(task_id='dummy', dag=dag, owner='airflow', bash_command='echo hi') + + scheduler = SchedulerJob() + scheduler.dagbag.bag_dag(dag, root_dag=dag) + scheduler.dagbag.sync_to_db() + + session = settings.Session() + orm_dag = session.query(DagModel).get(dag.dag_id) + assert orm_dag is not None + + scheduler = SchedulerJob() + scheduler.processor_agent = mock.MagicMock() + dag = scheduler.dagbag.get_dag('test_verify_integrity_if_dag_changed', session=session) + scheduler._create_dag_runs([orm_dag], session) + + drs = DagRun.find(dag_id=dag.dag_id, session=session) + assert len(drs) == 1 + dr = drs[0] + + dag_version_1 = SerializedDagModel.get_latest_version_hash(dr.dag_id, session=session) + assert dr.dag_hash == dag_version_1 + assert scheduler.dagbag.dags == {'test_verify_integrity_if_dag_changed': dag} + assert len(scheduler.dagbag.dags.get("test_verify_integrity_if_dag_changed").tasks) == 1 + + # Now let's say the DAG got updated (new task got added) + BashOperator(task_id='bash_task_1', dag=dag, bash_command='echo hi') + SerializedDagModel.write_dag(dag=dag) + + dag_version_2 = SerializedDagModel.get_latest_version_hash(dr.dag_id, session=session) + assert dag_version_2 != dag_version_1 + + scheduled_tis = scheduler._schedule_dag_run(dr, 0, session) + session.flush() + + assert scheduled_tis == 2 + + drs = DagRun.find(dag_id=dag.dag_id, session=session) + assert len(drs) == 1 + dr = drs[0] + assert dr.dag_hash == dag_version_2 + assert scheduler.dagbag.dags == {'test_verify_integrity_if_dag_changed': dag} + assert len(scheduler.dagbag.dags.get("test_verify_integrity_if_dag_changed").tasks) == 2 + + tis_count = session.query(func.count(TaskInstance.task_id)).filter( + TaskInstance.dag_id == dr.dag_id, + TaskInstance.execution_date == dr.execution_date, + TaskInstance.state == State.SCHEDULED + ).scalar() + assert tis_count == 2 + + latest_dag_version = SerializedDagModel.get_latest_version_hash(dr.dag_id, session=session) + assert dr.dag_hash == latest_dag_version + + session.rollback() + session.close() def test_retry_still_in_executor(self): """ @@ -2961,7 +2802,7 @@ def test_retry_still_in_executor(self): but is still present in the executor. """ executor = MockExecutor(do_update=False) - dagbag = DagBag(dag_folder=os.path.join(settings.DAGS_FOLDER, "no_dags.py")) + dagbag = DagBag(dag_folder=os.path.join(settings.DAGS_FOLDER, "no_dags.py"), include_examples=False) dagbag.dags.clear() dag = DAG( @@ -2984,6 +2825,7 @@ def test_retry_still_in_executor(self): session.merge(orm_dag) dagbag.bag_dag(dag=dag, root_dag=dag) + dagbag.sync_to_db() @mock.patch('airflow.jobs.scheduler_job.DagBag', return_value=dagbag) def do_schedule(mock_dagbag): @@ -3003,10 +2845,6 @@ def do_schedule(mock_dagbag): TaskInstance.task_id == 'test_retry_handling_op').first() ti.task = dag_task1 - # Nothing should be left in the queued_tasks as we don't do update in MockExecutor yet, - # and the queued_tasks will be cleared by scheduler job. - self.assertEqual(0, len(executor.queued_tasks)) - def run_with_error(ti, ignore_ti_state=False): try: ti.run(ignore_ti_state=ignore_ti_state) @@ -3028,13 +2866,6 @@ def run_with_error(ti, ignore_ti_state=False): ti.state = State.SCHEDULED session.merge(ti) - # do schedule - do_schedule() # pylint: disable=no-value-for-parameter - # MockExecutor is not aware of the TaskInstance since we don't do update yet - # and no trace of this TaskInstance will be left in the executor. - self.assertFalse(executor.has_task(ti)) - self.assertEqual(ti.state, State.SCHEDULED) - # To verify that task does get re-queued. executor.do_update = True do_schedule() # pylint: disable=no-value-for-parameter @@ -3064,34 +2895,6 @@ def test_retry_handling_job(self): self.assertEqual(ti.try_number, 2) self.assertEqual(ti.state, State.UP_FOR_RETRY) - def test_dag_with_system_exit(self): - """ - Test to check that a DAG with a system.exit() doesn't break the scheduler. - """ - - dag_id = 'exit_test_dag' - dag_ids = [dag_id] - dag_directory = os.path.join(settings.DAGS_FOLDER, "..", "dags_with_system_exit") - dag_file = os.path.join(dag_directory, 'b_test_scheduler_dags.py') - - dagbag = DagBag(dag_folder=dag_file) - for dag_id in dag_ids: - dag = dagbag.get_dag(dag_id) - dag.clear() - - scheduler = SchedulerJob(dag_ids=dag_ids, - executor=self.null_exec, - subdir=dag_directory, - num_runs=1) - scheduler.run() - with create_session() as session: - tis = session.query(TaskInstance).filter(TaskInstance.dag_id == dag_id).all() - # Since this dag has no end date, and there's a chance that we'll - # start a and finish two dag parsing processes twice in one loop! - self.assertGreaterEqual( - len(tis), 1, - repr(tis)) - def test_dag_get_active_runs(self): """ Test to check that a DAG returns its active runs @@ -3128,10 +2931,13 @@ def test_dag_get_active_runs(self): session.commit() session.close() - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) dag1.clear() - dr = dag_file_processor.create_dag_run(dag1) + dr = dag1.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=start_date, + state=State.RUNNING, + ) # We had better get a dag run self.assertIsNotNone(dr) @@ -3171,6 +2977,12 @@ def test_add_unparseable_file_before_sched_start_creates_import_error(self): @conf_vars({("core", "dagbag_import_error_tracebacks"): "False"}) def test_add_unparseable_file_after_sched_start_creates_import_error(self): + """ + Check that new DAG files are picked up, and import errors recorded. + + This is more of an "integration" test as it checks SchedulerJob, DagFileProcessorManager and + DagFileProcessor + """ dags_folder = mkdtemp() try: unparseable_filename = os.path.join(dags_folder, TEMP_DAG_FILENAME) @@ -3178,6 +2990,7 @@ def test_add_unparseable_file_after_sched_start_creates_import_error(self): with open(unparseable_filename, 'w') as unparseable_file: unparseable_file.writelines(UNPARSEABLE_DAG_FILE_CONTENTS) + print("Second run") self.run_single_scheduler_loop_with_no_dags(dags_folder) finally: shutil.rmtree(dags_folder) @@ -3459,15 +3272,17 @@ def test_adopt_or_reset_orphaned_tasks_external_triggered_dag(self): task_id = dag_id + '_task' DummyOperator(task_id=task_id, dag=dag) - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) scheduler = SchedulerJob() session = settings.Session() - dr1 = dag_file_processor.create_dag_run(dag, session=session) + dr1 = dag.create_dagrun(run_type=DagRunType.SCHEDULED, + state=State.RUNNING, + execution_date=DEFAULT_DATE, + start_date=DEFAULT_DATE, + external_trigger=True, + session=session) ti = dr1.get_task_instances(session=session)[0] - dr1.state = State.RUNNING ti.state = State.SCHEDULED - dr1.external_trigger = True session.merge(ti) session.merge(dr1) session.commit() @@ -3481,17 +3296,18 @@ def test_adopt_or_reset_orphaned_tasks_backfill_dag(self): task_id = dag_id + '_task' DummyOperator(task_id=task_id, dag=dag) - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) scheduler = SchedulerJob() session = settings.Session() session.add(scheduler) session.flush() - dr1 = dag_file_processor.create_dag_run(dag, session=session) + dr1 = dag.create_dagrun(run_type=DagRunType.BACKFILL_JOB, + state=State.RUNNING, + execution_date=DEFAULT_DATE, + start_date=DEFAULT_DATE, + session=session) ti = dr1.get_task_instances(session=session)[0] ti.state = State.SCHEDULED - dr1.state = State.RUNNING - dr1.run_type = DagRunType.BACKFILL_JOB.value session.merge(ti) session.merge(dr1) session.flush() @@ -3528,14 +3344,16 @@ def test_reset_orphaned_tasks_no_orphans(self): task_id = dag_id + '_task' DummyOperator(task_id=task_id, dag=dag) - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) scheduler = SchedulerJob() session = settings.Session() session.add(scheduler) session.flush() - dr1 = dag_file_processor.create_dag_run(dag, session=session) - dr1.state = State.RUNNING + dr1 = dag.create_dagrun(run_type=DagRunType.SCHEDULED, + state=State.RUNNING, + execution_date=DEFAULT_DATE, + start_date=DEFAULT_DATE, + session=session) tis = dr1.get_task_instances(session=session) tis[0].state = State.RUNNING tis[0].queued_by_job_id = scheduler.id @@ -3554,14 +3372,16 @@ def test_reset_orphaned_tasks_non_running_dagruns(self): task_id = dag_id + '_task' DummyOperator(task_id=task_id, dag=dag) - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) scheduler = SchedulerJob() session = settings.Session() session.add(scheduler) session.flush() - dr1 = dag_file_processor.create_dag_run(dag, session=session) - dr1.state = State.SUCCESS + dr1 = dag.create_dagrun(run_type=DagRunType.SCHEDULED, + state=State.SUCCESS, + execution_date=DEFAULT_DATE, + start_date=DEFAULT_DATE, + session=session) tis = dr1.get_task_instances(session=session) self.assertEqual(1, len(tis)) tis[0].state = State.SCHEDULED @@ -3579,7 +3399,6 @@ def test_adopt_or_reset_orphaned_tasks_stale_scheduler_jobs(self): DummyOperator(task_id='task1', dag=dag) DummyOperator(task_id='task2', dag=dag) - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) scheduler_job = SchedulerJob() session = settings.Session() scheduler_job.state = State.RUNNING @@ -3592,7 +3411,14 @@ def test_adopt_or_reset_orphaned_tasks_stale_scheduler_jobs(self): session.add(old_job) session.flush() - dr1 = dag_file_processor.create_dag_run(dag, session=session) + dr1 = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=DEFAULT_DATE, + start_date=timezone.utcnow(), + state=State.RUNNING, + session=session + ) + ti1, ti2 = dr1.get_task_instances(session=session) dr1.state = State.RUNNING ti1.state = State.SCHEDULED @@ -3606,7 +3432,7 @@ def test_adopt_or_reset_orphaned_tasks_stale_scheduler_jobs(self): session.flush() num_reset_tis = scheduler_job.adopt_or_reset_orphaned_tasks(session=session) - session.flush() + self.assertEqual(1, num_reset_tis) session.refresh(ti1) @@ -3615,7 +3441,55 @@ def test_adopt_or_reset_orphaned_tasks_stale_scheduler_jobs(self): self.assertEqual(State.SCHEDULED, ti2.state) session.rollback() + def test_send_sla_callbacks_to_processor_sla_disabled(self): + """Test SLA Callbacks are not sent when check_slas is False""" + dag_id = 'test_send_sla_callbacks_to_processor_sla_disabled' + dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, schedule_interval='@daily') + DummyOperator(task_id='task1', dag=dag) + + with patch.object(settings, "CHECK_SLAS", False): + scheduler_job = SchedulerJob() + mock_agent = mock.MagicMock() + + scheduler_job.processor_agent = mock_agent + + scheduler_job._send_sla_callbacks_to_processor(dag) + scheduler_job.processor_agent.send_sla_callback_request_to_execute.assert_not_called() + + def test_send_sla_callbacks_to_processor_sla_no_task_slas(self): + """Test SLA Callbacks are not sent when no task SLAs are defined""" + dag_id = 'test_send_sla_callbacks_to_processor_sla_no_task_slas' + dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, schedule_interval='@daily') + DummyOperator(task_id='task1', dag=dag) + + with patch.object(settings, "CHECK_SLAS", True): + scheduler_job = SchedulerJob() + mock_agent = mock.MagicMock() + + scheduler_job.processor_agent = mock_agent + + scheduler_job._send_sla_callbacks_to_processor(dag) + scheduler_job.processor_agent.send_sla_callback_request_to_execute.assert_not_called() + + def test_send_sla_callbacks_to_processor_sla_with_task_slas(self): + """Test SLA Callbacks are sent to the DAG Processor when SLAs are defined on tasks""" + dag_id = 'test_send_sla_callbacks_to_processor_sla_with_task_slas' + dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, schedule_interval='@daily') + DummyOperator(task_id='task1', dag=dag, sla=timedelta(seconds=60)) + + with patch.object(settings, "CHECK_SLAS", True): + scheduler_job = SchedulerJob() + mock_agent = mock.MagicMock() + + scheduler_job.processor_agent = mock_agent + + scheduler_job._send_sla_callbacks_to_processor(dag) + scheduler_job.processor_agent.send_sla_callback_request_to_execute.assert_called_once_with( + full_filepath=dag.fileloc, dag_id=dag_id + ) + +@pytest.mark.xfail(reason="Work out where this goes") def test_task_with_upstream_skip_process_task_instances(): """ Test if _process_task_instances puts a task instance into SKIPPED state if any of its @@ -3632,7 +3506,7 @@ def test_task_with_upstream_skip_process_task_instances(): dummy3 = DummyOperator(task_id="dummy3") [dummy1, dummy2] >> dummy3 - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) + # dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) dag.clear() dr = dag.create_dagrun(run_type=DagRunType.MANUAL, state=State.RUNNING, @@ -3646,8 +3520,8 @@ def test_task_with_upstream_skip_process_task_instances(): tis[dummy2.task_id].state = State.SUCCESS assert tis[dummy3.task_id].state == State.NONE - dag_runs = DagRun.find(dag_id='test_task_with_upstream_skip_dag') - dag_file_processor._process_task_instances(dag, dag_runs=dag_runs) + # dag_runs = DagRun.find(dag_id='test_task_with_upstream_skip_dag') + # dag_file_processor._process_task_instances(dag, dag_runs=dag_runs) with create_session() as session: tis = {ti.task_id: ti for ti in dr.get_task_instances(session=session)} @@ -3670,17 +3544,19 @@ def setUp(self) -> None: clear_db_dags() clear_db_sla_miss() clear_db_errors() + clear_db_serialized_dags() + clear_db_dags() @parameterized.expand( [ # pylint: disable=bad-whitespace # expected, dag_count, task_count # One DAG with one task per DAG file - (13, 1, 1), # noqa + (23, 1, 1), # noqa # One DAG with five tasks per DAG file - (17, 1, 5), # noqa + (23, 1, 5), # noqa # 10 DAGs with 10 tasks per DAG file - (46, 10, 10), # noqa + (95, 10, 10), # noqa ] ) def test_execute_queries_count_with_harvested_dags(self, expected_query_count, dag_count, task_count): @@ -3693,63 +3569,99 @@ def test_execute_queries_count_with_harvested_dags(self, expected_query_count, d }), conf_vars({ ('scheduler', 'use_job_schedule'): 'True', ('core', 'load_examples'): 'False', - }): - + ('core', 'store_serialized_dags'): 'True', + }), mock.patch.object(settings, 'STORE_SERIALIZED_DAGS', True): + dagruns = [] dagbag = DagBag(dag_folder=ELASTIC_DAG_FILE, include_examples=False) - for i, dag in enumerate(dagbag.dags.values()): - dr = dag.create_dagrun(state=State.RUNNING, run_id=f"{DagRunType.MANUAL.value}__{i}") + dagbag.sync_to_db() + + dag_ids = dagbag.dag_ids + dagbag = DagBag(read_dags_from_db=True) + for i, dag_id in enumerate(dag_ids): + dag = dagbag.get_dag(dag_id) + dr = dag.create_dagrun( + state=State.RUNNING, + run_id=f"{DagRunType.MANUAL.value}__{i}", + dag_hash=dagbag.dags_hash[dag.dag_id], + ) + dagruns.append(dr) for ti in dr.get_task_instances(): ti.set_state(state=State.SCHEDULED) mock_agent = mock.MagicMock() - mock_agent.harvest_serialized_dags.return_value = [ - SerializedDAG.from_dict(SerializedDAG.to_dict(d)) for d in dagbag.dags.values()] - job = SchedulerJob(subdir=PERF_DAGS_FOLDER) - job.executor = MockExecutor() + job = SchedulerJob(subdir=PERF_DAGS_FOLDER, num_runs=1) + job.executor = MockExecutor(do_update=False) job.heartbeat = mock.MagicMock() job.processor_agent = mock_agent with assert_queries_count(expected_query_count): - job._run_scheduler_loop() + with mock.patch.object(DagRun, 'next_dagruns_to_examine') as mock_dagruns: + mock_dagruns.return_value = dagruns + + job._run_scheduler_loop() @parameterized.expand( [ # pylint: disable=bad-whitespace - # expected, dag_count, task_count + # expected, dag_count, task_count, start_ago, schedule_interval, shape # One DAG with one task per DAG file - (2, 1, 1), # noqa + ([10, 10, 10, 10], 1, 1, "1d", "None", "no_structure"), # noqa + ([10, 10, 10, 10], 1, 1, "1d", "None", "linear"), # noqa + ([22, 14, 14, 14], 1, 1, "1d", "@once", "no_structure"), # noqa + ([22, 14, 14, 14], 1, 1, "1d", "@once", "linear"), # noqa + ([22, 24, 27, 30], 1, 1, "1d", "30m", "no_structure"), # noqa + ([22, 24, 27, 30], 1, 1, "1d", "30m", "linear"), # noqa + ([22, 24, 27, 30], 1, 1, "1d", "30m", "binary_tree"), # noqa + ([22, 24, 27, 30], 1, 1, "1d", "30m", "star"), # noqa + ([22, 24, 27, 30], 1, 1, "1d", "30m", "grid"), # noqa # One DAG with five tasks per DAG file - (2, 1, 5), # noqa + ([10, 10, 10, 10], 1, 5, "1d", "None", "no_structure"), # noqa + ([10, 10, 10, 10], 1, 5, "1d", "None", "linear"), # noqa + ([22, 14, 14, 14], 1, 5, "1d", "@once", "no_structure"), # noqa + ([23, 15, 15, 15], 1, 5, "1d", "@once", "linear"), # noqa + ([22, 24, 27, 30], 1, 5, "1d", "30m", "no_structure"), # noqa + ([23, 26, 30, 34], 1, 5, "1d", "30m", "linear"), # noqa + ([23, 26, 30, 34], 1, 5, "1d", "30m", "binary_tree"), # noqa + ([23, 26, 30, 34], 1, 5, "1d", "30m", "star"), # noqa + ([23, 26, 30, 34], 1, 5, "1d", "30m", "grid"), # noqa # 10 DAGs with 10 tasks per DAG file - (2, 10, 10), # noqa + ([10, 10, 10, 10], 10, 10, "1d", "None", "no_structure"), # noqa + ([10, 10, 10, 10], 10, 10, "1d", "None", "linear"), # noqa + ([85, 38, 38, 38], 10, 10, "1d", "@once", "no_structure"), # noqa + ([95, 51, 51, 51], 10, 10, "1d", "@once", "linear"), # noqa + ([85, 99, 99, 99], 10, 10, "1d", "30m", "no_structure"), # noqa + ([95, 125, 125, 125], 10, 10, "1d", "30m", "linear"), # noqa + ([95, 119, 119, 119], 10, 10, "1d", "30m", "binary_tree"), # noqa + ([95, 119, 119, 119], 10, 10, "1d", "30m", "star"), # noqa + ([95, 119, 119, 119], 10, 10, "1d", "30m", "grid"), # noqa + # pylint: enable=bad-whitespace ] ) - def test_execute_queries_count_no_harvested_dags(self, expected_query_count, dag_count, task_count): + def test_process_dags_queries_count( + self, expected_query_counts, dag_count, task_count, start_ago, schedule_interval, shape + ): with mock.patch.dict("os.environ", { "PERF_DAGS_COUNT": str(dag_count), "PERF_TASKS_COUNT": str(task_count), - "PERF_START_AGO": "1d", - "PERF_SCHEDULE_INTERVAL": "30m", - "PERF_SHAPE": "no_structure", + "PERF_START_AGO": start_ago, + "PERF_SCHEDULE_INTERVAL": schedule_interval, + "PERF_SHAPE": shape, }), conf_vars({ ('scheduler', 'use_job_schedule'): 'True', - ('core', 'load_examples'): 'False', - }): + ('core', 'store_serialized_dags'): 'True', + }), mock.patch.object(settings, 'STORE_SERIALIZED_DAGS', True): dagbag = DagBag(dag_folder=ELASTIC_DAG_FILE, include_examples=False) - for i, dag in enumerate(dagbag.dags.values()): - dr = dag.create_dagrun(state=State.RUNNING, run_id=f"{DagRunType.MANUAL.value}__{i}") - for ti in dr.get_task_instances(): - ti.set_state(state=State.SCHEDULED) + dagbag.sync_to_db() mock_agent = mock.MagicMock() - mock_agent.harvest_serialized_dags.return_value = [] - job = SchedulerJob(subdir=PERF_DAGS_FOLDER) - job.executor = MockExecutor() + job = SchedulerJob(subdir=PERF_DAGS_FOLDER, num_runs=1) + job.executor = MockExecutor(do_update=False) job.heartbeat = mock.MagicMock() job.processor_agent = mock_agent - - with assert_queries_count(expected_query_count): - job._run_scheduler_loop() + for expected_query_count in expected_query_counts: + with create_session() as session: + with assert_queries_count(expected_query_count): + job._do_scheduling(session) diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py index eac4f1ea5c688..8a11cf7b87263 100644 --- a/tests/models/test_dag.py +++ b/tests/models/test_dag.py @@ -24,6 +24,7 @@ import re import unittest from contextlib import redirect_stdout +from datetime import timedelta from tempfile import NamedTemporaryFile from typing import Optional from unittest import mock @@ -31,12 +32,12 @@ import pendulum from dateutil.relativedelta import relativedelta +from freezegun import freeze_time from parameterized import parameterized from airflow import models, settings from airflow.configuration import conf from airflow.exceptions import AirflowException, DuplicateTaskIdFound -from airflow.jobs.scheduler_job import DagFileProcessor from airflow.models import DAG, DagModel, DagRun, DagTag, TaskFail, TaskInstance as TI from airflow.models.baseoperator import BaseOperator from airflow.operators.bash import BashOperator @@ -53,6 +54,8 @@ from tests.test_utils.asserts import assert_queries_count from tests.test_utils.db import clear_db_dags, clear_db_runs +TEST_DATE = datetime_tz(2015, 1, 2, 0, 0) + class TestDag(unittest.TestCase): @@ -652,6 +655,25 @@ def test_following_previous_schedule_daily_dag_cet_to_cest(self): self.assertEqual(prev_local.isoformat(), "2018-03-24T03:00:00+01:00") self.assertEqual(prev.isoformat(), "2018-03-24T02:00:00+00:00") + def test_following_schedule_relativedelta(self): + """ + Tests following_schedule a dag with a relativedelta schedule_interval + """ + dag_id = "test_schedule_dag_relativedelta" + delta = relativedelta(hours=+1) + dag = DAG(dag_id=dag_id, + schedule_interval=delta) + dag.add_task(BaseOperator( + task_id="faketastic", + owner='Also fake', + start_date=TEST_DATE)) + + _next = dag.following_schedule(TEST_DATE) + self.assertEqual(_next.isoformat(), "2015-01-02T01:00:00+00:00") + + _next = dag.following_schedule(_next) + self.assertEqual(_next.isoformat(), "2015-01-02T02:00:00+00:00") + def test_dagtag_repr(self): clear_db_dags() dag = DAG('dag-test-dagtag', start_date=DEFAULT_DATE, tags=['tag-1', 'tag-2']) @@ -661,14 +683,14 @@ def test_dagtag_repr(self): {repr(t) for t in session.query(DagTag).filter( DagTag.dag_id == 'dag-test-dagtag').all()}) - def test_bulk_sync_to_db(self): + def test_bulk_write_to_db(self): clear_db_dags() dags = [ DAG(f'dag-bulk-sync-{i}', start_date=DEFAULT_DATE, tags=["test-dag"]) for i in range(0, 4) ] - with assert_queries_count(3): - DAG.bulk_sync_to_db(dags) + with assert_queries_count(5): + DAG.bulk_write_to_db(dags) with create_session() as session: self.assertEqual( {'dag-bulk-sync-0', 'dag-bulk-sync-1', 'dag-bulk-sync-2', 'dag-bulk-sync-3'}, @@ -684,15 +706,15 @@ def test_bulk_sync_to_db(self): set(session.query(DagTag.dag_id, DagTag.name).all()) ) # Re-sync should do fewer queries - with assert_queries_count(2): - DAG.bulk_sync_to_db(dags) - with assert_queries_count(2): - DAG.bulk_sync_to_db(dags) + with assert_queries_count(3): + DAG.bulk_write_to_db(dags) + with assert_queries_count(3): + DAG.bulk_write_to_db(dags) # Adding tags for dag in dags: dag.tags.append("test-dag2") - with assert_queries_count(3): - DAG.bulk_sync_to_db(dags) + with assert_queries_count(4): + DAG.bulk_write_to_db(dags) with create_session() as session: self.assertEqual( {'dag-bulk-sync-0', 'dag-bulk-sync-1', 'dag-bulk-sync-2', 'dag-bulk-sync-3'}, @@ -714,8 +736,8 @@ def test_bulk_sync_to_db(self): # Removing tags for dag in dags: dag.tags.remove("test-dag") - with assert_queries_count(3): - DAG.bulk_sync_to_db(dags) + with assert_queries_count(4): + DAG.bulk_write_to_db(dags) with create_session() as session: self.assertEqual( {'dag-bulk-sync-0', 'dag-bulk-sync-1', 'dag-bulk-sync-2', 'dag-bulk-sync-3'}, @@ -731,8 +753,46 @@ def test_bulk_sync_to_db(self): set(session.query(DagTag.dag_id, DagTag.name).all()) ) - @patch('airflow.models.dag.timezone.utcnow') - def test_sync_to_db(self, mock_now): + def test_bulk_write_to_db_max_active_runs(self): + """ + Test that DagModel.next_dagrun_create_after is set to NULL when the dag cannot be created due to max + active runs being hit. + """ + dag = DAG( + dag_id='test_scheduler_verify_max_active_runs', + start_date=DEFAULT_DATE) + dag.max_active_runs = 1 + + DummyOperator( + task_id='dummy', + dag=dag, + owner='airflow') + + session = settings.Session() + dag.clear() + DAG.bulk_write_to_db([dag], session) + + model = session.query(DagModel).get((dag.dag_id,)) + + period_end = dag.following_schedule(DEFAULT_DATE) + assert model.next_dagrun == DEFAULT_DATE + assert model.next_dagrun_create_after == period_end + + dr = dag.create_dagrun( + state=State.RUNNING, + execution_date=model.next_dagrun, + run_type=DagRunType.SCHEDULED, + session=session, + ) + assert dr is not None + DAG.bulk_write_to_db([dag]) + + model = session.query(DagModel).get((dag.dag_id,)) + assert model.next_dagrun == period_end + # We signle "at max active runs" by saying this run is never eligible to be created + assert model.next_dagrun_create_after is None + + def test_sync_to_db(self): dag = DAG( 'dag', start_date=DEFAULT_DATE, @@ -748,31 +808,25 @@ def test_sync_to_db(self, mock_now): owner='owner2', subdag=subdag ) - now = datetime.datetime.utcnow().replace(tzinfo=pendulum.timezone('UTC')) - mock_now.return_value = now session = settings.Session() dag.sync_to_db(session=session) orm_dag = session.query(DagModel).filter(DagModel.dag_id == 'dag').one() self.assertEqual(set(orm_dag.owners.split(', ')), {'owner1', 'owner2'}) - self.assertEqual(orm_dag.last_scheduler_run, now) self.assertTrue(orm_dag.is_active) self.assertIsNotNone(orm_dag.default_view) self.assertEqual(orm_dag.default_view, conf.get('webserver', 'dag_default_view').lower()) self.assertEqual(orm_dag.safe_dag_id, 'dag') - orm_subdag = session.query(DagModel).filter( - DagModel.dag_id == 'dag.subtask').one() + orm_subdag = session.query(DagModel).filter(DagModel.dag_id == 'dag.subtask').one() self.assertEqual(set(orm_subdag.owners.split(', ')), {'owner1', 'owner2'}) - self.assertEqual(orm_subdag.last_scheduler_run, now) self.assertTrue(orm_subdag.is_active) self.assertEqual(orm_subdag.safe_dag_id, 'dag__dot__subtask') self.assertEqual(orm_subdag.fileloc, orm_dag.fileloc) session.close() - @patch('airflow.models.dag.timezone.utcnow') - def test_sync_to_db_default_view(self, mock_now): + def test_sync_to_db_default_view(self): dag = DAG( 'dag', start_date=DEFAULT_DATE, @@ -788,8 +842,6 @@ def test_sync_to_db_default_view(self, mock_now): start_date=DEFAULT_DATE, ) ) - now = datetime.datetime.utcnow().replace(tzinfo=pendulum.timezone('UTC')) - mock_now.return_value = now session = settings.Session() dag.sync_to_db(session=session) @@ -1038,65 +1090,25 @@ def test_schedule_dag_no_previous_runs(self): dag.add_task(BaseOperator( task_id="faketastic", owner='Also fake', - start_date=datetime_tz(2015, 1, 2, 0, 0))) + start_date=TEST_DATE)) - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) - dag_run = dag_file_processor.create_dag_run(dag) - self.assertIsNotNone(dag_run) - self.assertEqual(dag.dag_id, dag_run.dag_id) - self.assertIsNotNone(dag_run.run_id) - self.assertNotEqual('', dag_run.run_id) - self.assertEqual( - datetime_tz(2015, 1, 2, 0, 0), - dag_run.execution_date, - msg='dag_run.execution_date did not match expectation: {0}' - .format(dag_run.execution_date) + dag_run = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + execution_date=TEST_DATE, + state=State.RUNNING, ) - self.assertEqual(State.RUNNING, dag_run.state) - self.assertFalse(dag_run.external_trigger) - dag.clear() - self._clean_up(dag_id) - - def test_schedule_dag_relativedelta(self): - """ - Tests scheduling a dag with a relativedelta schedule_interval - """ - dag_id = "test_schedule_dag_relativedelta" - delta = relativedelta(hours=+1) - dag = DAG(dag_id=dag_id, - schedule_interval=delta) - dag.add_task(BaseOperator( - task_id="faketastic", - owner='Also fake', - start_date=datetime_tz(2015, 1, 2, 0, 0))) - - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) - dag_run = dag_file_processor.create_dag_run(dag) self.assertIsNotNone(dag_run) self.assertEqual(dag.dag_id, dag_run.dag_id) self.assertIsNotNone(dag_run.run_id) self.assertNotEqual('', dag_run.run_id) self.assertEqual( - datetime_tz(2015, 1, 2, 0, 0), + TEST_DATE, dag_run.execution_date, msg='dag_run.execution_date did not match expectation: {0}' .format(dag_run.execution_date) ) self.assertEqual(State.RUNNING, dag_run.state) self.assertFalse(dag_run.external_trigger) - dag_run2 = dag_file_processor.create_dag_run(dag) - self.assertIsNotNone(dag_run2) - self.assertEqual(dag.dag_id, dag_run2.dag_id) - self.assertIsNotNone(dag_run2.run_id) - self.assertNotEqual('', dag_run2.run_id) - self.assertEqual( - datetime_tz(2015, 1, 2, 0, 0) + delta, - dag_run2.execution_date, - msg='dag_run2.execution_date did not match expectation: {0}' - .format(dag_run2.execution_date) - ) - self.assertEqual(State.RUNNING, dag_run2.state) - self.assertFalse(dag_run2.external_trigger) dag.clear() self._clean_up(dag_id) @@ -1113,13 +1125,13 @@ def test_dag_handle_callback_crash(self, mock_stats): # callback with invalid signature should not cause crashes on_success_callback=lambda: 1, on_failure_callback=mock_callback_with_exception) + when = TEST_DATE dag.add_task(BaseOperator( task_id="faketastic", owner='Also fake', - start_date=datetime_tz(2015, 1, 2, 0, 0))) + start_date=when)) - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) - dag_run = dag_file_processor.create_dag_run(dag) + dag_run = dag.create_dagrun(State.RUNNING, when, run_type=DagRunType.MANUAL) # should not rause any exception dag.handle_callback(dag_run, success=False) dag.handle_callback(dag_run, success=True) @@ -1129,7 +1141,7 @@ def test_dag_handle_callback_crash(self, mock_stats): dag.clear() self._clean_up(dag_id) - def test_schedule_dag_fake_scheduled_previous(self): + def test_next_dagrun_after_fake_scheduled_previous(self): """ Test scheduling a dag where there is a prior DagRun which has the same run_id as the next run should have @@ -1144,24 +1156,19 @@ def test_schedule_dag_fake_scheduled_previous(self): owner='Also fake', start_date=DEFAULT_DATE)) - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) dag.create_dagrun(run_type=DagRunType.SCHEDULED, execution_date=DEFAULT_DATE, state=State.SUCCESS, external_trigger=True) - dag_run = dag_file_processor.create_dag_run(dag) - self.assertIsNotNone(dag_run) - self.assertEqual(dag.dag_id, dag_run.dag_id) - self.assertIsNotNone(dag_run.run_id) - self.assertNotEqual('', dag_run.run_id) - self.assertEqual( - DEFAULT_DATE + delta, - dag_run.execution_date, - msg='dag_run.execution_date did not match expectation: {0}' - .format(dag_run.execution_date) - ) - self.assertEqual(State.RUNNING, dag_run.state) - self.assertFalse(dag_run.external_trigger) + dag.sync_to_db() + with create_session() as session: + model = session.query(DagModel).get((dag.dag_id,)) + + # Even though there is a run for this date already, it is marked as manual/external, so we should + # create a scheduled one anyway! + assert model.next_dagrun == DEFAULT_DATE + assert model.next_dagrun_create_after == dag.following_schedule(DEFAULT_DATE) + self._clean_up(dag_id) def test_schedule_dag_once(self): @@ -1176,13 +1183,22 @@ def test_schedule_dag_once(self): dag.add_task(BaseOperator( task_id="faketastic", owner='Also fake', - start_date=datetime_tz(2015, 1, 2, 0, 0))) - dag_run = DagFileProcessor(dag_ids=[], log=mock.MagicMock()).create_dag_run(dag) - dag_run2 = DagFileProcessor(dag_ids=[], log=mock.MagicMock()).create_dag_run(dag) + start_date=TEST_DATE)) - self.assertIsNotNone(dag_run) - self.assertIsNone(dag_run2) - dag.clear() + # Sync once to create the DagModel + dag.sync_to_db() + + dag.create_dagrun(run_type=DagRunType.SCHEDULED, + execution_date=TEST_DATE, + state=State.SUCCESS) + + # Then sync again after creating the dag run -- this should update next_dagrun + dag.sync_to_db() + with create_session() as session: + model = session.query(DagModel).get((dag.dag_id,)) + + assert model.next_dagrun is None + assert model.next_dagrun_create_after is None self._clean_up(dag_id) def test_fractional_seconds(self): @@ -1195,7 +1211,7 @@ def test_fractional_seconds(self): dag.add_task(BaseOperator( task_id="faketastic", owner='Also fake', - start_date=datetime_tz(2015, 1, 2, 0, 0))) + start_date=TEST_DATE)) start_date = timezone.utcnow() @@ -1215,77 +1231,6 @@ def test_fractional_seconds(self): "dag run start_date loses precision ") self._clean_up(dag_id) - def test_schedule_dag_start_end_dates(self): - """ - Tests that an attempt to schedule a task after the Dag's end_date - does not succeed. - """ - delta = datetime.timedelta(hours=1) - runs = 3 - start_date = DEFAULT_DATE - end_date = start_date + (runs - 1) * delta - dag_id = "test_schedule_dag_start_end_dates" - dag = DAG(dag_id=dag_id, - start_date=start_date, - end_date=end_date, - schedule_interval=delta) - dag.add_task(BaseOperator(task_id='faketastic', owner='Also fake')) - - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) - # Create and schedule the dag runs - dag_runs = [] - for _ in range(runs): - dag_runs.append(dag_file_processor.create_dag_run(dag)) - - additional_dag_run = dag_file_processor.create_dag_run(dag) - - for dag_run in dag_runs: - self.assertIsNotNone(dag_run) - - self.assertIsNone(additional_dag_run) - self._clean_up(dag_id) - - def test_schedule_dag_no_end_date_up_to_today_only(self): - """ - Tests that a Dag created without an end_date can only be scheduled up - to and including the current datetime. - - For example, if today is 2016-01-01 and we are scheduling from a - start_date of 2015-01-01, only jobs up to, but not including - 2016-01-01 should be scheduled. - """ - session = settings.Session() - delta = datetime.timedelta(days=1) - now = pendulum.now('UTC') - start_date = now.subtract(weeks=1) - - runs = (now - start_date).days - dag_id = "test_schedule_dag_no_end_date_up_to_today_only" - dag = DAG(dag_id=dag_id, - start_date=start_date, - schedule_interval=delta) - dag.add_task(BaseOperator(task_id='faketastic', owner='Also fake')) - - dag_file_processor = DagFileProcessor(dag_ids=[], log=mock.MagicMock()) - dag_runs = [] - for _ in range(runs): - dag_run = dag_file_processor.create_dag_run(dag) - dag_runs.append(dag_run) - - # Mark the DagRun as complete - dag_run.state = State.SUCCESS - session.merge(dag_run) - session.commit() - - # Attempt to schedule an additional dag run (for 2016-01-01) - additional_dag_run = dag_file_processor.create_dag_run(dag) - - for dag_run in dag_runs: - self.assertIsNotNone(dag_run) - - self.assertIsNone(additional_dag_run) - self._clean_up(dag_id) - def test_pickling(self): test_dag_id = 'test_pickling' args = {'owner': 'airflow', 'start_date': DEFAULT_DATE} @@ -1489,6 +1434,290 @@ def test_clear_dag(self, ti_state_begin, ti_state_end: Optional[str]): self.assertEqual(task_instance.state, ti_state_end) self._clean_up(dag_id) + def test_next_dagrun_after_date_once(self): + dag = DAG( + 'test_scheduler_dagrun_once', + start_date=timezone.datetime(2015, 1, 1), + schedule_interval="@once") + + next_date = dag.next_dagrun_after_date(None) + + assert next_date == timezone.datetime(2015, 1, 1) + + next_date = dag.next_dagrun_after_date(next_date) + assert next_date is None + + def test_next_dagrun_after_date_start_end_dates(self): + """ + Tests that an attempt to schedule a task after the Dag's end_date + does not succeed. + """ + delta = datetime.timedelta(hours=1) + runs = 3 + start_date = DEFAULT_DATE + end_date = start_date + (runs - 1) * delta + dag_id = "test_schedule_dag_start_end_dates" + dag = DAG(dag_id=dag_id, + start_date=start_date, + end_date=end_date, + schedule_interval=delta) + dag.add_task(BaseOperator(task_id='faketastic', owner='Also fake')) + + # Create and schedule the dag runs + dates = [] + date = None + for _ in range(runs): + date = dag.next_dagrun_after_date(date) + dates.append(date) + + for date in dates: + assert date is not None + + assert dates[-1] == end_date + + assert dag.next_dagrun_after_date(date) is None + + def test_next_dagrun_after_date_catcup(self): + """ + Test to check that a DAG with catchup = False only schedules beginning now, not back to the start date + """ + + def make_dag(dag_id, schedule_interval, start_date, catchup): + default_args = { + 'owner': 'airflow', + 'depends_on_past': False, + } + dag = DAG(dag_id, + schedule_interval=schedule_interval, + start_date=start_date, + catchup=catchup, + default_args=default_args) + + op1 = DummyOperator(task_id='t1', dag=dag) + op2 = DummyOperator(task_id='t2', dag=dag) + op3 = DummyOperator(task_id='t3', dag=dag) + op1 >> op2 >> op3 + + return dag + + now = timezone.utcnow() + six_hours_ago_to_the_hour = (now - datetime.timedelta(hours=6)).replace( + minute=0, second=0, microsecond=0) + half_an_hour_ago = now - datetime.timedelta(minutes=30) + two_hours_ago = now - datetime.timedelta(hours=2) + + dag1 = make_dag(dag_id='dag_without_catchup_ten_minute', + schedule_interval='*/10 * * * *', + start_date=six_hours_ago_to_the_hour, + catchup=False) + next_date = dag1.next_dagrun_after_date(None) + # The DR should be scheduled in the last half an hour, not 6 hours ago + assert next_date > half_an_hour_ago + assert next_date < timezone.utcnow() + + dag2 = make_dag(dag_id='dag_without_catchup_hourly', + schedule_interval='@hourly', + start_date=six_hours_ago_to_the_hour, + catchup=False) + + next_date = dag2.next_dagrun_after_date(None) + # The DR should be scheduled in the last 2 hours, not 6 hours ago + assert next_date > two_hours_ago + # The DR should be scheduled BEFORE now + assert next_date < timezone.utcnow() + + dag3 = make_dag(dag_id='dag_without_catchup_once', + schedule_interval='@once', + start_date=six_hours_ago_to_the_hour, + catchup=False) + + next_date = dag3.next_dagrun_after_date(None) + # The DR should be scheduled in the last 2 hours, not 6 hours ago + assert next_date == six_hours_ago_to_the_hour + + @freeze_time(timezone.datetime(2020, 1, 5)) + def test_next_dagrun_after_date_timedelta_schedule_and_catchup_false(self): + """ + Test that the dag file processor does not create multiple dagruns + if a dag is scheduled with 'timedelta' and catchup=False + """ + dag = DAG( + 'test_scheduler_dagrun_once_with_timedelta_and_catchup_false', + start_date=timezone.datetime(2015, 1, 1), + schedule_interval=timedelta(days=1), + catchup=False) + + next_date = dag.next_dagrun_after_date(None) + assert next_date == timezone.datetime(2020, 1, 4) + + # The date to create is in the future, this is handled by "DagModel.dags_needing_dagruns" + next_date = dag.next_dagrun_after_date(next_date) + assert next_date == timezone.datetime(2020, 1, 5) + + @freeze_time(timezone.datetime(2020, 5, 4)) + def test_next_dagrun_after_date_timedelta_schedule_and_catchup_true(self): + """ + Test that the dag file processor creates multiple dagruns + if a dag is scheduled with 'timedelta' and catchup=True + """ + dag = DAG( + 'test_scheduler_dagrun_once_with_timedelta_and_catchup_true', + start_date=timezone.datetime(2020, 5, 1), + schedule_interval=timedelta(days=1), + catchup=True) + + next_date = dag.next_dagrun_after_date(None) + assert next_date == timezone.datetime(2020, 5, 1) + + next_date = dag.next_dagrun_after_date(next_date) + assert next_date == timezone.datetime(2020, 5, 2) + + next_date = dag.next_dagrun_after_date(next_date) + assert next_date == timezone.datetime(2020, 5, 3) + + # The date to create is in the future, this is handled by "DagModel.dags_needing_dagruns" + next_date = dag.next_dagrun_after_date(next_date) + assert next_date == timezone.datetime(2020, 5, 4) + + def test_next_dagrun_after_auto_align(self): + """ + Test if the schedule_interval will be auto aligned with the start_date + such that if the start_date coincides with the schedule the first + execution_date will be start_date, otherwise it will be start_date + + interval. + """ + dag = DAG( + dag_id='test_scheduler_auto_align_1', + start_date=timezone.datetime(2016, 1, 1, 10, 10, 0), + schedule_interval="4 5 * * *" + ) + DummyOperator( + task_id='dummy', + dag=dag, + owner='airflow') + + next_date = dag.next_dagrun_after_date(None) + assert next_date == timezone.datetime(2016, 1, 2, 5, 4) + + dag = DAG( + dag_id='test_scheduler_auto_align_2', + start_date=timezone.datetime(2016, 1, 1, 10, 10, 0), + schedule_interval="10 10 * * *" + ) + DummyOperator( + task_id='dummy', + dag=dag, + owner='airflow') + + next_date = dag.next_dagrun_after_date(None) + assert next_date == timezone.datetime(2016, 1, 1, 10, 10) + + def test_next_dagrun_after_not_for_subdags(self): + """ + Test the subdags are never marked to have dagruns created, as they are + handled by the SubDagOperator, not the scheduler + """ + + def subdag(parent_dag_name, child_dag_name, args): + """ + Create a subdag. + """ + dag_subdag = DAG(dag_id='%s.%s' % (parent_dag_name, child_dag_name), + schedule_interval="@daily", + default_args=args) + + for i in range(2): + DummyOperator(task_id='%s-task-%s' % (child_dag_name, i + 1), dag=dag_subdag) + + return dag_subdag + + with DAG( + dag_id='test_subdag_operator', + start_date=datetime.datetime(2019, 1, 1), + max_active_runs=1, + schedule_interval=timedelta(minutes=1), + ) as dag: + section_1 = SubDagOperator( + task_id='section-1', + subdag=subdag(dag.dag_id, 'section-1', {'start_date': dag.start_date}), + ) + + subdag = section_1.subdag + # parent_dag and is_subdag was set by DagBag. We don't use DagBag, so this value is not set. + subdag.parent_dag = dag + subdag.is_subdag = True + + next_date = dag.next_dagrun_after_date(None) + assert next_date == timezone.datetime(2019, 1, 1, 0, 0) + + next_subdag_date = subdag.next_dagrun_after_date(None) + assert next_subdag_date is None, "SubDags should never have DagRuns created by the scheduler" + + +class TestDagModel: + + def test_dags_needing_dagruns_not_too_early(self): + dag = DAG( + dag_id='far_future_dag', + start_date=timezone.datetime(2038, 1, 1)) + DummyOperator( + task_id='dummy', + dag=dag, + owner='airflow') + + session = settings.Session() + orm_dag = DagModel( + dag_id=dag.dag_id, + concurrency=1, + has_task_concurrency_limits=False, + next_dagrun=dag.start_date, + next_dagrun_create_after=timezone.datetime(2038, 1, 2), + is_active=True, + ) + session.add(orm_dag) + session.flush() + + dag_models = DagModel.dags_needing_dagruns(session).all() + assert dag_models == [] + + session.rollback() + session.close() + + def test_dags_needing_dagruns_only_unpaused(self): + """ + We should never create dagruns for unpaused DAGs + """ + dag = DAG( + dag_id='test_dags', + start_date=DEFAULT_DATE) + DummyOperator( + task_id='dummy', + dag=dag, + owner='airflow') + + session = settings.Session() + orm_dag = DagModel( + dag_id=dag.dag_id, + has_task_concurrency_limits=False, + next_dagrun=dag.start_date, + next_dagrun_create_after=dag.following_schedule(DEFAULT_DATE), + is_active=True, + ) + session.add(orm_dag) + session.flush() + + needed = DagModel.dags_needing_dagruns(session).all() + assert needed == [orm_dag] + + orm_dag.is_paused = True + session.flush() + + dag_models = DagModel.dags_needing_dagruns(session).all() + assert dag_models == [] + + session.rollback() + session.close() + class TestQueries(unittest.TestCase): @@ -1506,8 +1735,9 @@ def test_count_number_queries(self, tasks_count): dag = DAG('test_dagrun_query_count', start_date=DEFAULT_DATE) for i in range(tasks_count): DummyOperator(task_id=f'dummy_task_{i}', owner='test', dag=dag) - with assert_queries_count(3): + with assert_queries_count(2): dag.create_dagrun( run_id="test_dagrun_query_count", - state=State.RUNNING + state=State.RUNNING, + execution_date=TEST_DATE, ) diff --git a/tests/models/test_dagrun.py b/tests/models/test_dagrun.py index 78a02de4f4c58..4d1dce875ff37 100644 --- a/tests/models/test_dagrun.py +++ b/tests/models/test_dagrun.py @@ -23,11 +23,12 @@ from parameterized import parameterized from airflow import models, settings -from airflow.models import DAG, DagBag, TaskInstance as TI, clear_task_instances +from airflow.models import DAG, DagBag, DagModel, TaskInstance as TI, clear_task_instances from airflow.models.dagrun import DagRun from airflow.operators.dummy_operator import DummyOperator from airflow.operators.python import ShortCircuitOperator from airflow.utils import timezone +from airflow.utils.callback_requests import DagCallbackRequest from airflow.utils.state import State from airflow.utils.trigger_rule import TriggerRule from airflow.utils.types import DagRunType @@ -327,8 +328,10 @@ def on_success_callable(context): dag_run = self.create_dag_run(dag=dag, state=State.RUNNING, task_states=initial_task_states) - dag_run.update_state() + _, callback = dag_run.update_state() self.assertEqual(State.SUCCESS, dag_run.state) + # Callbacks are not added until handle_callback = False is passed to dag_run.update_state() + self.assertIsNone(callback) def test_dagrun_failure_callback(self): def on_failure_callable(context): @@ -358,8 +361,88 @@ def on_failure_callable(context): dag_run = self.create_dag_run(dag=dag, state=State.RUNNING, task_states=initial_task_states) - dag_run.update_state() + _, callback = dag_run.update_state() self.assertEqual(State.FAILED, dag_run.state) + # Callbacks are not added until handle_callback = False is passed to dag_run.update_state() + self.assertIsNone(callback) + + def test_dagrun_update_state_with_handle_callback_success(self): + def on_success_callable(context): + self.assertEqual( + context['dag_run'].dag_id, + 'test_dagrun_update_state_with_handle_callback_success' + ) + + dag = DAG( + dag_id='test_dagrun_update_state_with_handle_callback_success', + start_date=datetime.datetime(2017, 1, 1), + on_success_callback=on_success_callable, + ) + dag_task1 = DummyOperator( + task_id='test_state_succeeded1', + dag=dag) + dag_task2 = DummyOperator( + task_id='test_state_succeeded2', + dag=dag) + dag_task1.set_downstream(dag_task2) + + initial_task_states = { + 'test_state_succeeded1': State.SUCCESS, + 'test_state_succeeded2': State.SUCCESS, + } + + dag_run = self.create_dag_run(dag=dag, state=State.RUNNING, task_states=initial_task_states) + + _, callback = dag_run.update_state(execute_callbacks=False) + self.assertEqual(State.SUCCESS, dag_run.state) + # Callbacks are not added until handle_callback = False is passed to dag_run.update_state() + + assert callback == DagCallbackRequest( + full_filepath=dag_run.dag.fileloc, + dag_id="test_dagrun_update_state_with_handle_callback_success", + execution_date=dag_run.execution_date, + is_failure_callback=False, + msg="success" + ) + + def test_dagrun_update_state_with_handle_callback_failure(self): + def on_failure_callable(context): + self.assertEqual( + context['dag_run'].dag_id, + 'test_dagrun_update_state_with_handle_callback_failure' + ) + + dag = DAG( + dag_id='test_dagrun_update_state_with_handle_callback_failure', + start_date=datetime.datetime(2017, 1, 1), + on_failure_callback=on_failure_callable, + ) + dag_task1 = DummyOperator( + task_id='test_state_succeeded1', + dag=dag) + dag_task2 = DummyOperator( + task_id='test_state_failed2', + dag=dag) + dag_task1.set_downstream(dag_task2) + + initial_task_states = { + 'test_state_succeeded1': State.SUCCESS, + 'test_state_failed2': State.FAILED, + } + + dag_run = self.create_dag_run(dag=dag, state=State.RUNNING, task_states=initial_task_states) + + _, callback = dag_run.update_state(execute_callbacks=False) + self.assertEqual(State.FAILED, dag_run.state) + # Callbacks are not added until handle_callback = False is passed to dag_run.update_state() + + assert callback == DagCallbackRequest( + full_filepath=dag_run.dag.fileloc, + dag_id="test_dagrun_update_state_with_handle_callback_failure", + execution_date=dag_run.execution_date, + is_failure_callback=True, + msg="task_failure" + ) def test_dagrun_set_state_end_date(self): session = settings.Session() @@ -662,3 +745,45 @@ def test_wait_for_downstream(self, prev_ti_state, is_ti_success): ti.set_state(State.QUEUED) ti.run() self.assertEqual(ti.state == State.SUCCESS, is_ti_success) + + def test_next_dagruns_to_examine_only_unpaused(self): + """ + Check that "next_dagruns_to_examine" ignores runs from paused/inactive DAGs + """ + + dag = DAG( + dag_id='test_dags', + start_date=DEFAULT_DATE) + DummyOperator( + task_id='dummy', + dag=dag, + owner='airflow') + + session = settings.Session() + orm_dag = DagModel( + dag_id=dag.dag_id, + has_task_concurrency_limits=False, + next_dagrun=dag.start_date, + next_dagrun_create_after=dag.following_schedule(DEFAULT_DATE), + is_active=True, + ) + session.add(orm_dag) + session.flush() + dr = dag.create_dagrun(run_type=DagRunType.SCHEDULED, + state=State.RUNNING, + execution_date=DEFAULT_DATE, + start_date=DEFAULT_DATE, + session=session) + + runs = DagRun.next_dagruns_to_examine(session).all() + + assert runs == [dr] + + orm_dag.is_paused = True + session.flush() + + runs = DagRun.next_dagruns_to_examine(session).all() + assert runs == [] + + session.rollback() + session.close() diff --git a/tests/task/task_runner/test_standard_task_runner.py b/tests/task/task_runner/test_standard_task_runner.py index 921ae5ff412f3..f6b777b290de2 100644 --- a/tests/task/task_runner/test_standard_task_runner.py +++ b/tests/task/task_runner/test_standard_task_runner.py @@ -149,6 +149,7 @@ def test_on_kill(self): session=session) ti = TI(task=task, execution_date=DEFAULT_DATE) job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True) + session.commit() runner = StandardTaskRunner(job1) runner.start() diff --git a/tests/test_utils/mock_executor.py b/tests/test_utils/mock_executor.py index 0143c95fae853..746caf4dfcb6a 100644 --- a/tests/test_utils/mock_executor.py +++ b/tests/test_utils/mock_executor.py @@ -67,10 +67,9 @@ def sort_by(item): open_slots = self.parallelism - len(self.running) sorted_queue = sorted(self.queued_tasks.items(), key=sort_by) for index in range(min((open_slots, len(sorted_queue)))): - (key, (_, _, _, simple_ti)) = sorted_queue[index] + (key, (_, _, _, ti)) = sorted_queue[index] self.queued_tasks.pop(key) state = self.mock_task_results[key] - ti = simple_ti.construct_task_instance(session=session, lock_for_update=True) ti.set_state(state, session=session) self.change_state(key, state) diff --git a/tests/test_utils/perf/perf_kit/__init__.py b/tests/test_utils/perf/perf_kit/__init__.py index d17ca8e629090..0b4e344b343d1 100644 --- a/tests/test_utils/perf/perf_kit/__init__.py +++ b/tests/test_utils/perf/perf_kit/__init__.py @@ -72,14 +72,14 @@ self.assertEqual(prev_local.isoformat(), "2018-03-24T03:00:00+01:00") self.assertEqual(prev.isoformat(), "2018-03-24T02:00:00+00:00") - def test_bulk_sync_to_db(self): + def test_bulk_write_to_db(self): clear_db_dags() dags = [ DAG(f'dag-bulk-sync-{i}', start_date=DEFAULT_DATE, tags=["test-dag"]) for i in range(0, 4) ] with assert_queries_count(3): - DAG.bulk_sync_to_db(dags) + DAG.bulk_write_to_db(dags) You can add a code snippet before the method definition, and then perform only one test and count the queries in it. @@ -96,20 +96,20 @@ def test_bulk_sync_to_db(self): from tests.utils.perf.perf_kit.sqlalchemy import trace_queries @trace_queries - def test_bulk_sync_to_db(self): + def test_bulk_write_to_db(self): clear_db_dags() dags = [ DAG(f'dag-bulk-sync-{i}', start_date=DEFAULT_DATE, tags=["test-dag"]) for i in range(0, 4) ] with assert_queries_count(3): - DAG.bulk_sync_to_db(dags) + DAG.bulk_write_to_db(dags) To run the test, execute the command .. code-block:: bash - pytest tests.models.dag -k test_bulk_sync_to_db -s + pytest tests.models.dag -k test_bulk_write_to_db -s This is not a beautiful solution, but it allows you to easily check a random piece of code. diff --git a/tests/test_utils/perf/perf_kit/python.py b/tests/test_utils/perf/perf_kit/python.py index 3169e9c43ea93..7d92a497fe9d7 100644 --- a/tests/test_utils/perf/perf_kit/python.py +++ b/tests/test_utils/perf/perf_kit/python.py @@ -96,7 +96,7 @@ def case(): log = logging.getLogger(__name__) processor = DagFileProcessor(dag_ids=[], log=log) dag_file = os.path.join(os.path.dirname(airflow.__file__), "example_dags", "example_complex.py") - processor.process_file(file_path=dag_file, failure_callback_requests=[]) + processor.process_file(file_path=dag_file, callback_requests=[]) # Load modules case() diff --git a/tests/test_utils/perf/perf_kit/sqlalchemy.py b/tests/test_utils/perf/perf_kit/sqlalchemy.py index 06a305c709646..e5c7c359cec90 100644 --- a/tests/test_utils/perf/perf_kit/sqlalchemy.py +++ b/tests/test_utils/perf/perf_kit/sqlalchemy.py @@ -222,7 +222,7 @@ def case(): log = logging.getLogger(__name__) processor = DagFileProcessor(dag_ids=[], log=log) dag_file = os.path.join(os.path.dirname(__file__), os.path.pardir, "dags", "elastic_dag.py") - processor.process_file(file_path=dag_file, failure_callback_requests=[]) + processor.process_file(file_path=dag_file, callback_requests=[]) with trace_queries(), count_queries(): case() diff --git a/tests/ti_deps/deps/test_runnable_exec_date_dep.py b/tests/ti_deps/deps/test_runnable_exec_date_dep.py index ba20658e1eb12..4c08f577d163e 100644 --- a/tests/ti_deps/deps/test_runnable_exec_date_dep.py +++ b/tests/ti_deps/deps/test_runnable_exec_date_dep.py @@ -17,33 +17,34 @@ # under the License. import unittest -from unittest.mock import Mock +from unittest.mock import Mock, patch import pytest from freezegun import freeze_time +from airflow import settings from airflow.models import DAG, TaskInstance from airflow.operators.dummy_operator import DummyOperator from airflow.ti_deps.deps.runnable_exec_date_dep import RunnableExecDateDep from airflow.utils.timezone import datetime -from tests.test_utils.config import conf_vars @freeze_time('2016-11-01') @pytest.mark.parametrize("allow_trigger_in_future,schedule_interval,execution_date,is_met", [ - ('True', None, datetime(2016, 11, 3), True), - ('True', "@daily", datetime(2016, 11, 3), False), - ('False', None, datetime(2016, 11, 3), False), - ('False', "@daily", datetime(2016, 11, 3), False), - ('False', "@daily", datetime(2016, 11, 1), True), - ('False', None, datetime(2016, 11, 1), True)] + (True, None, datetime(2016, 11, 3), True), + (True, "@daily", datetime(2016, 11, 3), False), + (False, None, datetime(2016, 11, 3), False), + (False, "@daily", datetime(2016, 11, 3), False), + (False, "@daily", datetime(2016, 11, 1), True), + (False, None, datetime(2016, 11, 1), True)] ) def test_exec_date_dep(allow_trigger_in_future, schedule_interval, execution_date, is_met): """ If the dag's execution date is in the future but (allow_trigger_in_future=False or not schedule_interval) this dep should fail """ - with conf_vars({('scheduler', 'allow_trigger_in_future'): allow_trigger_in_future}): + + with patch.object(settings, 'ALLOW_FUTURE_EXEC_DATES', allow_trigger_in_future): dag = DAG( 'test_localtaskjob_heartbeat', start_date=datetime(2015, 1, 1), diff --git a/tests/utils/test_dag_processing.py b/tests/utils/test_dag_processing.py index 9a51c761a8655..077c2fad7479a 100644 --- a/tests/utils/test_dag_processing.py +++ b/tests/utils/test_dag_processing.py @@ -25,22 +25,25 @@ from unittest import mock from unittest.mock import MagicMock, PropertyMock +import pytest + from airflow.configuration import conf from airflow.jobs.local_task_job import LocalTaskJob as LJ from airflow.jobs.scheduler_job import DagFileProcessorProcess -from airflow.models import DagBag, TaskInstance as TI +from airflow.models import DagBag, DagModel, TaskInstance as TI +from airflow.models.serialized_dag import SerializedDagModel from airflow.models.taskinstance import SimpleTaskInstance from airflow.utils import timezone +from airflow.utils.callback_requests import TaskCallbackRequest from airflow.utils.dag_processing import ( DagFileProcessorAgent, DagFileProcessorManager, DagFileStat, DagParsingSignal, DagParsingStat, - FailureCallbackRequest, ) from airflow.utils.file import correct_maybe_zipped, open_maybe_zipped from airflow.utils.session import create_session from airflow.utils.state import State from tests.test_logging_config import SETTINGS_FILE_VALID, settings_context from tests.test_utils.config import conf_vars -from tests.test_utils.db import clear_db_runs +from tests.test_utils.db import clear_db_dags, clear_db_runs, clear_db_serialized_dags TEST_DAG_FOLDER = os.path.join( os.path.dirname(os.path.realpath(__file__)), os.pardir, 'dags') @@ -51,14 +54,14 @@ class FakeDagFileProcessorRunner(DagFileProcessorProcess): # This fake processor will return the zombies it received in constructor # as its processing result w/o actually parsing anything. - def __init__(self, file_path, pickle_dags, dag_ids, zombies): - super().__init__(file_path, pickle_dags, dag_ids, zombies) + def __init__(self, file_path, pickle_dags, dag_ids, callbacks): + super().__init__(file_path, pickle_dags, dag_ids, callbacks) # We need a "real" selectable handle for waitable_handle to work readable, writable = multiprocessing.Pipe(duplex=False) writable.send('abc') writable.close() self._waitable_handle = readable - self._result = zombies, 0 + self._result = 0, 0 def start(self): pass @@ -80,12 +83,12 @@ def result(self): return self._result @staticmethod - def _fake_dag_processor_factory(file_path, zombies, dag_ids, pickle_dags): + def _fake_dag_processor_factory(file_path, callbacks, dag_ids, pickle_dags): return FakeDagFileProcessorRunner( file_path, pickle_dags, dag_ids, - zombies + callbacks, ) @property @@ -214,6 +217,7 @@ def test_find_zombies(self): self.assertEqual(1, len(requests)) self.assertEqual(requests[0].full_filepath, dag.full_filepath) self.assertEqual(requests[0].msg, "Detected as zombie") + self.assertEqual(requests[0].is_failure_callback, True) self.assertIsInstance(requests[0].simple_task_instance, SimpleTaskInstance) self.assertEqual(ti.dag_id, requests[0].simple_task_instance.dag_id) self.assertEqual(ti.task_id, requests[0].simple_task_instance.task_id) @@ -249,8 +253,8 @@ def test_handle_failure_callback_with_zombies_are_correctly_passed_to_dag_file_p ti.job_id = local_job.id session.commit() - fake_failure_callback_requests = [ - FailureCallbackRequest( + expected_failure_callback_requests = [ + TaskCallbackRequest( full_filepath=dag.full_filepath, simple_task_instance=SimpleTaskInstance(ti), msg="Message" @@ -262,23 +266,41 @@ def test_handle_failure_callback_with_zombies_are_correctly_passed_to_dag_file_p child_pipe, parent_pipe = multiprocessing.Pipe() async_mode = 'sqlite' not in conf.get('core', 'sql_alchemy_conn') + fake_processors = [] + + def fake_processor_factory(*args, **kwargs): + nonlocal fake_processors + processor = FakeDagFileProcessorRunner._fake_dag_processor_factory(*args, **kwargs) + fake_processors.append(processor) + return processor + manager = DagFileProcessorManager( dag_directory=test_dag_path, max_runs=1, - processor_factory=FakeDagFileProcessorRunner._fake_dag_processor_factory, + processor_factory=fake_processor_factory, processor_timeout=timedelta.max, signal_conn=child_pipe, dag_ids=[], pickle_dags=False, async_mode=async_mode) - parsing_result = self.run_processor_manager_one_loop(manager, parent_pipe) + self.run_processor_manager_one_loop(manager, parent_pipe) - self.assertEqual(len(fake_failure_callback_requests), len(parsing_result)) - self.assertEqual( - set(zombie.simple_task_instance.key for zombie in fake_failure_callback_requests), - set(result.simple_task_instance.key for result in parsing_result) + if async_mode: + # Once for initial parse, and then again for the add_callback_to_queue + assert len(fake_processors) == 2 + assert fake_processors[0]._file_path == test_dag_path + assert fake_processors[0]._callback_requests == [] + else: + assert len(fake_processors) == 1 + + assert fake_processors[-1]._file_path == test_dag_path + callback_requests = fake_processors[-1]._callback_requests + assert ( + set(zombie.simple_task_instance.key for zombie in expected_failure_callback_requests) == + set(result.simple_task_instance.key for result in callback_requests) ) + child_pipe.close() parent_pipe.close() @@ -322,6 +344,50 @@ def test_kill_timed_out_processors_no_kill(self, mock_dag_file_processor, mock_p manager._kill_timed_out_processors() mock_dag_file_processor.kill.assert_not_called() + @conf_vars({('core', 'load_examples'): 'False'}) + @pytest.mark.execution_timeout(10) + def test_dag_with_system_exit(self): + """ + Test to check that a DAG with a system.exit() doesn't break the scheduler. + """ + + # We need to _actually_ parse the files here to test the behaviour. + # Right now the parsing code lives in SchedulerJob, even though it's + # called via utils.dag_processing. + from airflow.jobs.scheduler_job import SchedulerJob + + dag_id = 'exit_test_dag' + dag_directory = os.path.normpath(os.path.join(TEST_DAG_FOLDER, os.pardir, "dags_with_system_exit")) + + # Delete the one valid DAG/SerializedDAG, and check that it gets re-created + clear_db_dags() + clear_db_serialized_dags() + + child_pipe, parent_pipe = multiprocessing.Pipe() + + manager = DagFileProcessorManager( + dag_directory=dag_directory, + dag_ids=[], + max_runs=1, + processor_factory=SchedulerJob._create_dag_file_processor, + processor_timeout=timedelta(seconds=5), + signal_conn=child_pipe, + pickle_dags=False, + async_mode=True) + + manager._run_parsing_loop() + + while parent_pipe.poll(timeout=None): + result = parent_pipe.recv() + if isinstance(result, DagParsingStat) and result.done: + break + + # Three files in folder should be processed + assert len(result.file_paths) == 3 + + with create_session() as session: + assert session.query(DagModel).get(dag_id) is not None + class TestDagFileProcessorAgent(unittest.TestCase): def setUp(self): @@ -384,6 +450,9 @@ class path, thus when reloading logging module the airflow.processor_manager @conf_vars({('core', 'load_examples'): 'False'}) def test_parse_once(self): + clear_db_serialized_dags() + clear_db_dags() + test_dag_path = os.path.join(TEST_DAG_FOLDER, 'test_scheduler_dags.py') async_mode = 'sqlite' not in conf.get('core', 'sql_alchemy_conn') processor_agent = DagFileProcessorAgent(test_dag_path, @@ -394,16 +463,22 @@ def test_parse_once(self): False, async_mode) processor_agent.start() - parsing_result = [] if not async_mode: processor_agent.run_single_parsing_loop() while not processor_agent.done: if not async_mode: processor_agent.wait_until_finished() - parsing_result.extend(processor_agent.harvest_serialized_dags()) + processor_agent.heartbeat() + + assert processor_agent.all_files_processed + assert processor_agent.done + + with create_session() as session: + dag_ids = session.query(DagModel.dag_id).order_by("dag_id").all() + assert dag_ids == [('test_start_date_scheduling',), ('test_task_start_date_scheduling',)] - dag_ids = [result.dag_id for result in parsing_result] - self.assertEqual(dag_ids.count('test_start_date_scheduling'), 1) + dag_ids = session.query(SerializedDagModel.dag_id).order_by("dag_id").all() + assert dag_ids == [('test_start_date_scheduling',), ('test_task_start_date_scheduling',)] def test_launch_process(self): test_dag_path = os.path.join(TEST_DAG_FOLDER, 'test_scheduler_dags.py') diff --git a/tests/utils/test_sqlalchemy.py b/tests/utils/test_sqlalchemy.py index d59bbc96d1343..ac8a52a776ef9 100644 --- a/tests/utils/test_sqlalchemy.py +++ b/tests/utils/test_sqlalchemy.py @@ -26,7 +26,7 @@ from airflow import settings from airflow.models import DAG from airflow.settings import Session -from airflow.utils.sqlalchemy import skip_locked +from airflow.utils.sqlalchemy import nowait, skip_locked from airflow.utils.state import State from airflow.utils.timezone import utcnow @@ -110,6 +110,18 @@ def test_skip_locked(self, dialect, supports_for_update_of, expected_return_valu session.bind.dialect.supports_for_update_of = supports_for_update_of self.assertEqual(skip_locked(session=session), expected_return_value) + @parameterized.expand([ + ("postgresql", True, {'nowait': True}, ), + ("mysql", False, {}, ), + ("mysql", True, {'nowait': True}, ), + ("sqlite", False, {'nowait': True, }, ), + ]) + def test_nowait(self, dialect, supports_for_update_of, expected_return_value): + session = mock.Mock() + session.bind.dialect.name = dialect + session.bind.dialect.supports_for_update_of = supports_for_update_of + self.assertEqual(nowait(session=session), expected_return_value) + def tearDown(self): self.session.close() settings.engine.dispose() diff --git a/tests/www/test_views.py b/tests/www/test_views.py index 98f3bca1633c9..47982e02e2671 100644 --- a/tests/www/test_views.py +++ b/tests/www/test_views.py @@ -52,7 +52,6 @@ from airflow.models.serialized_dag import SerializedDagModel from airflow.operators.bash import BashOperator from airflow.operators.dummy_operator import DummyOperator -from airflow.settings import Session from airflow.ti_deps.dependencies_states import QUEUEABLE_STATES, RUNNABLE_STATES from airflow.utils import dates, timezone from airflow.utils.log.logging_mixin import ExternalLoggingMixin @@ -410,7 +409,7 @@ def setUpClass(cls): super().setUpClass() cls.dagbag = models.DagBag(include_examples=True) cls.app.dag_bag = cls.dagbag - DAG.bulk_sync_to_db(cls.dagbag.dags.values()) + DAG.bulk_write_to_db(cls.dagbag.dags.values()) def setUp(self): super().setUp() @@ -639,10 +638,10 @@ def test_view_uses_existing_dagbag(self, endpoint): self.check_content_in_response('example_bash_operator', resp) @parameterized.expand([ - ("hello\nworld", r'\"conf\":{\"abc\":\"hello\\nworld\"}}'), - ("hello'world", r'\"conf\":{\"abc\":\"hello\\u0027world\"}}'), - ("