Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AIRFLOW-56] Airflow's scheduler can "lose" queued tasks #1378

Merged
merged 2 commits into from
May 9, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 24 additions & 9 deletions airflow/executors/base_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,11 @@ def start(self): # pragma: no cover
"""
pass

def queue_command(self, key, command, priority=1, queue=None):
def queue_command(self, task_instance, command, priority=1, queue=None):
key = task_instance.key
if key not in self.queued_tasks and key not in self.running:
self.logger.info("Adding to queue: {}".format(command))
self.queued_tasks[key] = (command, priority, queue)
self.queued_tasks[key] = (command, priority, queue, task_instance)

def queue_task_instance(
self,
Expand All @@ -54,7 +55,7 @@ def queue_task_instance(
pool=pool,
pickle_id=pickle_id)
self.queue_command(
task_instance.key,
task_instance,
command,
priority=task_instance.task.priority_weight_total,
queue=task_instance.task.queue)
Expand All @@ -67,9 +68,6 @@ def sync(self):
pass

def heartbeat(self):
# Calling child class sync method
self.logger.debug("Calling the {} sync method".format(self.__class__))
self.sync()

# Triggering new jobs
if not self.parallelism:
Expand All @@ -86,10 +84,27 @@ def heartbeat(self):
key=lambda x: x[1][1],
reverse=True)
for i in range(min((open_slots, len(self.queued_tasks)))):
key, (command, priority, queue) = sorted_queue.pop(0)
self.running[key] = command
key, (command, _, queue, ti) = sorted_queue.pop(0)
# TODO(jlowin) without a way to know what Job ran which tasks,
# there is a danger that another Job started running a task
# that was also queued to this executor. This is the last chance
# to check if that hapened. The most probable way is that a
# Scheduler tried to run a task that was originally queued by a
# Backfill. This fix reduces the probability of a collision but
# does NOT eliminate it.
self.queued_tasks.pop(key)
self.execute_async(key, command=command, queue=queue)
ti.refresh_from_db()
if ti.state != State.RUNNING:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Should log if the state is running

self.running[key] = command
self.execute_async(key, command=command, queue=queue)
else:
self.logger.debug(
'Task is already running, not sending to '
'executor: {}'.format(key))

# Calling child class sync method
self.logger.debug("Calling the {} sync method".format(self.__class__))
self.sync()

def change_state(self, key, state):
self.running.pop(key)
Expand Down
1 change: 1 addition & 0 deletions airflow/executors/celery_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,4 @@ def end(self, synchronous=False):
async.state not in celery_states.READY_STATES
for async in self.tasks.values()]):
time.sleep(5)
self.sync()
2 changes: 1 addition & 1 deletion airflow/executors/local_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,4 +73,4 @@ def end(self):
[self.queue.put((None, None)) for w in self.workers]
# Wait for commands to finish
self.queue.join()

self.sync()
46 changes: 9 additions & 37 deletions airflow/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,6 @@ def __init__(

self.refresh_dags_every = refresh_dags_every
self.do_pickle = do_pickle
self.queued_tis = set()
super(SchedulerJob, self).__init__(*args, **kwargs)

self.heartrate = conf.getint('scheduler', 'SCHEDULER_HEARTBEAT_SEC')
Expand Down Expand Up @@ -567,47 +566,22 @@ def process_dag(self, dag, queue):

session.close()

def process_events(self, executor, dagbag):
"""
Respond to executor events.

Used to identify queued tasks and schedule them for further processing.
"""
for key, executor_state in list(executor.get_event_buffer().items()):
dag_id, task_id, execution_date = key
if dag_id not in dagbag.dags:
self.logger.error(
'Executor reported a dag_id that was not found in the '
'DagBag: {}'.format(dag_id))
continue
elif not dagbag.dags[dag_id].has_task(task_id):
self.logger.error(
'Executor reported a task_id that was not found in the '
'dag: {} in dag {}'.format(task_id, dag_id))
continue
task = dagbag.dags[dag_id].get_task(task_id)
ti = models.TaskInstance(task, execution_date)
ti.refresh_from_db()

if executor_state == State.SUCCESS:
# collect queued tasks for prioritiztion
if ti.state == State.QUEUED:
self.queued_tis.add(ti)
else:
# special instructions for failed executions could go here
pass

@provide_session
def prioritize_queued(self, session, executor, dagbag):
# Prioritizing queued task instances

pools = {p.pool: p for p in session.query(models.Pool).all()}

TI = models.TaskInstance
queued_tis = (
session.query(TI)
.filter(TI.state == State.QUEUED)
.all()
)
self.logger.info(
"Prioritizing {} queued jobs".format(len(self.queued_tis)))
"Prioritizing {} queued jobs".format(len(queued_tis)))
session.expunge_all()
d = defaultdict(list)
for ti in self.queued_tis:
for ti in queued_tis:
if ti.dag_id not in dagbag.dags:
self.logger.info(
"DAG no longer in dagbag, deleting {}".format(ti))
Expand All @@ -621,8 +595,6 @@ def prioritize_queued(self, session, executor, dagbag):
else:
d[ti.pool].append(ti)

self.queued_tis.clear()

dag_blacklist = set(dagbag.paused_dags())
for pool, tis in list(d.items()):
if not pool:
Expand Down Expand Up @@ -676,6 +648,7 @@ def prioritize_queued(self, session, executor, dagbag):
open_slots -= 1
else:
session.delete(ti)
session.commit()
continue
ti.task = task

Expand Down Expand Up @@ -721,7 +694,6 @@ def _execute(self):
try:
loop_start_dttm = datetime.now()
try:
self.process_events(executor=executor, dagbag=dagbag)
self.prioritize_queued(executor=executor, dagbag=dagbag)
except Exception as e:
self.logger.exception(e)
Expand Down
19 changes: 14 additions & 5 deletions airflow/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,16 +817,25 @@ def error(self, session=None):
session.commit()

@provide_session
def refresh_from_db(self, session=None):
def refresh_from_db(self, session=None, lock_for_update=False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

document lock_for_update

"""
Refreshes the task instance from the database based on the primary key

:param lock_for_update: if True, indicates that the database should
lock the TaskInstance (issuing a FOR UPDATE clause) until the session
is committed.
"""
TI = TaskInstance
ti = session.query(TI).filter(

qry = session.query(TI).filter(
TI.dag_id == self.dag_id,
TI.task_id == self.task_id,
TI.execution_date == self.execution_date,
).first()
TI.execution_date == self.execution_date)

if lock_for_update:
ti = qry.with_for_update().first()
else:
ti = qry.first()
if ti:
self.state = ti.state
self.start_date = ti.start_date
Expand Down Expand Up @@ -1159,7 +1168,7 @@ def run(
self.pool = pool or task.pool
self.test_mode = test_mode
self.force = force
self.refresh_from_db()
self.refresh_from_db(session=session, lock_for_update=True)
self.clear_xcom_data()
self.job_id = job_id
iso = datetime.now().isoformat()
Expand Down
16 changes: 15 additions & 1 deletion tests/dags/test_issue_1225.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from airflow.models import DAG
from airflow.operators import DummyOperator, PythonOperator, SubDagOperator
from airflow.utils.trigger_rule import TriggerRule
import time

DEFAULT_DATE = datetime(2016, 1, 1)
default_args = dict(
start_date=DEFAULT_DATE,
Expand All @@ -31,6 +33,16 @@
def fail():
raise ValueError('Expected failure.')

def delayed_fail():
Copy link
Contributor

@aoen aoen May 9, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sleeping in tests is an antipattern, usually you want to want to use concurrency mechanisms instead. Can you put a TODO documenting this so we can fix it later?

"""
Delayed failure to make sure that processes are running before the error
is raised.

TODO handle more directly (without sleeping)
"""
time.sleep(5)
raise ValueError('Expected failure.')

# DAG tests backfill with pooled tasks
# Previously backfill would queue the task but never run it
dag1 = DAG(dag_id='test_backfill_pooled_task_dag', default_args=default_args)
Expand Down Expand Up @@ -123,7 +135,9 @@ def fail():
end_date=DEFAULT_DATE,
default_args=default_args)
dag8_task1 = PythonOperator(
python_callable=fail,
# use delayed_fail because otherwise LocalExecutor will have a chance to
# complete the task
python_callable=delayed_fail,
task_id='test_queued_task',
dag=dag8,
pool='test_queued_pool')
20 changes: 16 additions & 4 deletions tests/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,13 @@

from airflow import AirflowException, settings
from airflow.bin import cli
from airflow.executors import DEFAULT_EXECUTOR
from airflow.jobs import BackfillJob, SchedulerJob
from airflow.models import DagBag, DagRun, Pool, TaskInstance as TI
from airflow.models import DAG, DagBag, DagRun, Pool, TaskInstance as TI
from airflow.operators import DummyOperator
from airflow.utils.db import provide_session
from airflow.utils.state import State
from airflow.utils.timeout import timeout
from airflow.utils.db import provide_session

from airflow import configuration
configuration.test_mode()
Expand Down Expand Up @@ -283,15 +285,25 @@ def test_scheduler_pooled_tasks(self):
dag = self.dagbag.get_dag(dag_id)
dag.clear()

scheduler = SchedulerJob(dag_id, num_runs=10)
scheduler = SchedulerJob(dag_id, num_runs=1)
scheduler.run()

task_1 = dag.tasks[0]
logging.info("Trying to find task {}".format(task_1))
ti = TI(task_1, dag.start_date)
ti.refresh_from_db()
self.assertEqual(ti.state, State.FAILED)
self.assertEqual(ti.state, State.QUEUED)

# now we use a DIFFERENT scheduler and executor
# to simulate the num-runs CLI arg
scheduler2 = SchedulerJob(
dag_id,
num_runs=5,
executor=DEFAULT_EXECUTOR.__class__())
scheduler2.run()

ti.refresh_from_db()
self.assertEqual(ti.state, State.FAILED)
dag.clear()

def test_dagrun_deadlock_ignore_depends_on_past_advance_ex_date(self):
Expand Down