diff --git a/airflow/executors/base_executor.py b/airflow/executors/base_executor.py index 410a558922ac1..d96c10fcd95f9 100644 --- a/airflow/executors/base_executor.py +++ b/airflow/executors/base_executor.py @@ -141,13 +141,26 @@ def fail(self, key): def success(self, key): self.change_state(key, State.SUCCESS) - def get_event_buffer(self): + def get_event_buffer(self, dag_ids=None): """ - Returns and flush the event buffer + Returns and flush the event buffer. In case dag_ids is specified + it will only return and flush events for the given dag_ids. Otherwise + it returns and flushes all + + :param dag_ids: to dag_ids to return events for, if None returns all + :return: a dict of events """ - d = self.event_buffer - self.event_buffer = {} - return d + cleared_events = dict() + if dag_ids is None: + cleared_events = self.event_buffer + self.event_buffer = dict() + else: + for key in list(self.event_buffer.keys()): + dag_id, _, _ = key + if dag_id in dag_ids: + cleared_events[key] = self.event_buffer.pop(key) + + return cleared_events def execute_async(self, key, command, queue=None): # pragma: no cover """ diff --git a/airflow/jobs.py b/airflow/jobs.py index 2675bd3167d9b..892c5cb813654 100644 --- a/airflow/jobs.py +++ b/airflow/jobs.py @@ -1290,6 +1290,7 @@ def _enqueue_task_instances_with_queued_state(self, simple_dag_bag, task_instanc TI = models.TaskInstance # actually enqueue them for task_instance in task_instances: + simple_dag = simple_dag_bag.get_dag(task_instance.dag_id) command = " ".join(TI.generate_command( task_instance.dag_id, task_instance.task_id, @@ -1301,8 +1302,8 @@ def _enqueue_task_instances_with_queued_state(self, simple_dag_bag, task_instanc ignore_task_deps=False, ignore_ti_state=False, pool=task_instance.pool, - file_path=simple_dag_bag.get_dag(task_instance.dag_id).full_filepath, - pickle_id=simple_dag_bag.get_dag(task_instance.dag_id).pickle_id)) + file_path=simple_dag.full_filepath, + pickle_id=simple_dag.pickle_id)) priority = task_instance.priority_weight queue = task_instance.queue @@ -1412,20 +1413,49 @@ def _process_dags(self, dagbag, dags, tis_out): models.DagStat.update([d.dag_id for d in dags]) - def _process_executor_events(self): + @provide_session + def _process_executor_events(self, simple_dag_bag, session=None): """ Respond to executor events. - - :param executor: the executor that's running the task instances - :type executor: BaseExecutor - :return: None """ - for key, executor_state in list(self.executor.get_event_buffer().items()): + # TODO: this shares quite a lot of code with _manage_executor_state + + TI = models.TaskInstance + for key, state in list(self.executor.get_event_buffer(simple_dag_bag.dag_ids) + .items()): dag_id, task_id, execution_date = key self.log.info( "Executor reports %s.%s execution_date=%s as %s", - dag_id, task_id, execution_date, executor_state + dag_id, task_id, execution_date, state ) + if state == State.FAILED or state == State.SUCCESS: + qry = session.query(TI).filter(TI.dag_id == dag_id, + TI.task_id == task_id, + TI.execution_date == execution_date) + ti = qry.first() + if not ti: + self.log.warning("TaskInstance %s went missing from the database", ti) + continue + + # TODO: should we fail RUNNING as well, as we do in Backfills? + if ti.state == State.QUEUED: + msg = ("Executor reports task instance %s finished (%s) " + "although the task says its %s. Was the task " + "killed externally?".format(ti, state, ti.state)) + self.log.error(msg) + try: + simple_dag = simple_dag_bag.get_dag(dag_id) + dagbag = models.DagBag(simple_dag.full_filepath) + dag = dagbag.get_dag(dag_id) + ti.task = dag.get_task(task_id) + ti.handle_failure(msg) + except Exception: + self.log.error("Cannot load the dag bag to handle failure for %s" + ". Setting task to FAILED without callbacks or " + "retries. Do you have enough resources?", ti) + ti.state = State.FAILED + session.merge(ti) + session.commit() def _log_file_processing_stats(self, known_file_paths, @@ -1626,8 +1656,8 @@ def _execute_helper(self, processor_manager): processor_manager.wait_until_finished() # Send tasks for execution if available + simple_dag_bag = SimpleDagBag(simple_dags) if len(simple_dags) > 0: - simple_dag_bag = SimpleDagBag(simple_dags) # Handle cases where a DAG run state is set (perhaps manually) to # a non-running state. Handle task instances that belong to @@ -1655,7 +1685,7 @@ def _execute_helper(self, processor_manager): self.executor.heartbeat() # Process events from the executor - self._process_executor_events() + self._process_executor_events(simple_dag_bag) # Heartbeat the scheduler periodically time_since_last_heartbeat = (datetime.utcnow() - diff --git a/airflow/utils/dag_processing.py b/airflow/utils/dag_processing.py index b80f7016071fa..ebb5ca077e938 100644 --- a/airflow/utils/dag_processing.py +++ b/airflow/utils/dag_processing.py @@ -36,40 +36,13 @@ class SimpleDag(BaseDag): required for instantiating and scheduling its associated tasks. """ - def __init__(self, - dag_id, - task_ids, - full_filepath, - concurrency, - is_paused, - pickle_id, - task_special_args): - """ - :param dag_id: ID of the DAG - :type dag_id: unicode - :param task_ids: task IDs associated with the DAG - :type task_ids: list[unicode] - :param full_filepath: path to the file containing the DAG e.g. - /a/b/c.py - :type full_filepath: unicode - :param concurrency: No more than these many tasks from the - dag should run concurrently - :type concurrency: int - :param is_paused: Whether or not this DAG is paused. Tasks from paused - DAGs are not scheduled - :type is_paused: bool + def __init__(self, dag, pickle_id=None): + """ + :param dag: the DAG + :type dag: DAG :param pickle_id: ID associated with the pickled version of this DAG. :type pickle_id: unicode """ - self._dag_id = dag_id - self._task_ids = task_ids - self._full_filepath = full_filepath - self._is_paused = is_paused - self._concurrency = concurrency - self._pickle_id = pickle_id - self._task_special_args = task_special_args - - def __init__(self, dag, pickle_id=None): self._dag_id = dag.dag_id self._task_ids = [task.task_id for task in dag.tasks] self._full_filepath = dag.full_filepath diff --git a/tests/executors/test_base_executor.py b/tests/executors/test_base_executor.py new file mode 100644 index 0000000000000..fa6123a3d1bd5 --- /dev/null +++ b/tests/executors/test_base_executor.py @@ -0,0 +1,40 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from airflow.executors.base_executor import BaseExecutor +from airflow.utils.state import State + +from datetime import datetime + + +class BaseExecutorTest(unittest.TestCase): + def test_get_event_buffer(self): + executor = BaseExecutor() + + date = datetime.utcnow() + + key1 = ("my_dag1", "my_task1", date) + key2 = ("my_dag2", "my_task1", date) + key3 = ("my_dag2", "my_task2", date) + state = State.SUCCESS + executor.event_buffer[key1] = state + executor.event_buffer[key2] = state + executor.event_buffer[key3] = state + + self.assertEqual(len(executor.get_event_buffer(("my_dag1",))), 1) + self.assertEqual(len(executor.get_event_buffer()), 2) + self.assertEqual(len(executor.event_buffer), 0) + diff --git a/tests/jobs.py b/tests/jobs.py index 88589d8dee2ff..119e1b4f473fe 100644 --- a/tests/jobs.py +++ b/tests/jobs.py @@ -935,6 +935,51 @@ def run_single_scheduler_loop_with_no_dags(dags_folder): def _make_simple_dag_bag(self, dags): return SimpleDagBag([SimpleDag(dag) for dag in dags]) + def test_process_executor_events(self): + dag_id = "test_process_executor_events" + dag_id2 = "test_process_executor_events_2" + task_id_1 = 'dummy_task' + + dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE) + dag2 = DAG(dag_id=dag_id2, start_date=DEFAULT_DATE) + task1 = DummyOperator(dag=dag, task_id=task_id_1) + task2 = DummyOperator(dag=dag2, task_id=task_id_1) + + dagbag1 = self._make_simple_dag_bag([dag]) + dagbag2 = self._make_simple_dag_bag([dag2]) + + scheduler = SchedulerJob(**self.default_scheduler_args) + session = settings.Session() + + ti1 = TI(task1, DEFAULT_DATE) + ti1.state = State.QUEUED + session.merge(ti1) + session.commit() + + executor = TestExecutor() + executor.event_buffer[ti1.key] = State.FAILED + + scheduler.executor = executor + + # dag bag does not contain dag_id + scheduler._process_executor_events(simple_dag_bag=dagbag2) + ti1.refresh_from_db() + self.assertEqual(ti1.state, State.QUEUED) + + # dag bag does contain dag_id + scheduler._process_executor_events(simple_dag_bag=dagbag1) + ti1.refresh_from_db() + self.assertEqual(ti1.state, State.FAILED) + + ti1.state = State.SUCCESS + session.merge(ti1) + session.commit() + executor.event_buffer[ti1.key] = State.SUCCESS + + scheduler._process_executor_events(simple_dag_bag=dagbag1) + ti1.refresh_from_db() + self.assertEqual(ti1.state, State.SUCCESS) + def test_execute_task_instances_is_paused_wont_execute(self): dag_id = 'SchedulerJobTest.test_execute_task_instances_is_paused_wont_execute' task_id_1 = 'dummy_task'