Skip to content

Commit

Permalink
Fix race conditions in task callback invocations (#10917)
Browse files Browse the repository at this point in the history
This race condition resulted in task success and failure callbacks being
called more than once. Here is the order of events that could lead to
this issue:

* task started running within process 2
* (process 1) local_task_job checked for task return code, returns None
* (process 2) task exited with failure state, task state updated as failed in DB
* (process 2) task failure callback invoked through taskinstance.handle_failure method
* (process 1) local_task_job heartbeat noticed task state set to
  failure, mistoken it as state bing updated externally, also invoked task
  failure callback

To avoid this race condition, we need to make sure task callbacks are
only invoked within a single process.

(cherry picked from commit f1d4f54)
  • Loading branch information
QP Hou authored and kaxil committed Jan 21, 2021
1 parent 1166937 commit efe163a
Show file tree
Hide file tree
Showing 16 changed files with 343 additions and 136 deletions.
2 changes: 2 additions & 0 deletions airflow/cli/cli_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,7 @@ def positive_int(value):
("--ship-dag",), help="Pickles (serializes) the DAG and ships it to the worker", action="store_true"
)
ARG_PICKLE = Arg(("-p", "--pickle"), help="Serialized pickle object of the entire dag (used internally)")
ARG_ERROR_FILE = Arg(("--error-file",), help="File to store task failure error")
ARG_JOB_ID = Arg(("-j", "--job-id"), help=argparse.SUPPRESS)
ARG_CFG_PATH = Arg(("--cfg-path",), help="Path to config file to use instead of airflow.cfg")
ARG_MIGRATION_TIMEOUT = Arg(
Expand Down Expand Up @@ -954,6 +955,7 @@ class GroupCommand(NamedTuple):
ARG_PICKLE,
ARG_JOB_ID,
ARG_INTERACTIVE,
ARG_ERROR_FILE,
ARG_SHUT_DOWN_LOGGING,
),
),
Expand Down
5 changes: 3 additions & 2 deletions airflow/cli/commands/task_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
from airflow.utils.session import create_session


def _run_task_by_selected_method(args, dag, ti):
def _run_task_by_selected_method(args, dag: DAG, ti: TaskInstance) -> None:
"""
Runs the task in one of 3 modes
Expand Down Expand Up @@ -132,7 +132,7 @@ def _run_task_by_local_task_job(args, ti):
]


def _run_raw_task(args, ti):
def _run_raw_task(args, ti: TaskInstance) -> None:
"""Runs the main task handling code"""
unsupported_options = [o for o in RAW_TASK_UNSUPPORTED_OPTION if getattr(args, o)]

Expand All @@ -149,6 +149,7 @@ def _run_raw_task(args, ti):
mark_success=args.mark_success,
job_id=args.job_id,
pool=args.pool,
error_file=args.error_file,
)


Expand Down
4 changes: 4 additions & 0 deletions airflow/executors/debug_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def sync(self) -> None:
self.log.info("Executor is terminated! Stopping %s to %s", ti.key, State.FAILED)
ti.set_state(State.FAILED)
self.change_state(ti.key, State.FAILED)
ti._run_finished_callback() # pylint: disable=protected-access
continue

task_succeeded = self._run_task(ti)
Expand All @@ -77,9 +78,12 @@ def _run_task(self, ti: TaskInstance) -> bool:
params = self.tasks_params.pop(ti.key, {})
ti._run_raw_task(job_id=ti.job_id, **params) # pylint: disable=protected-access
self.change_state(key, State.SUCCESS)
ti._run_finished_callback() # pylint: disable=protected-access
return True
except Exception as e: # pylint: disable=broad-except
ti.set_state(State.FAILED)
self.change_state(key, State.FAILED)
ti._run_finished_callback() # pylint: disable=protected-access
self.log.exception("Failed to execute task: %s.", str(e))
return False

Expand Down
2 changes: 1 addition & 1 deletion airflow/jobs/backfill_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def _manage_executor_state(self, running):
"killed externally? Info: {}".format(ti, state, ti.state, info)
)
self.log.error(msg)
ti.handle_failure(msg)
ti.handle_failure_with_callback(error=msg)

@provide_session
def _get_dag_run(self, run_date: datetime, dag: DAG, session: Session = None):
Expand Down
36 changes: 28 additions & 8 deletions airflow/jobs/local_task_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,14 @@ def signal_handler(signum, frame):

heartbeat_time_limit = conf.getint('scheduler', 'scheduler_zombie_task_threshold')

while True:
# task callback invocation happens either here or in
# self.heartbeat() instead of taskinstance._run_raw_task to
# avoid race conditions
#
# When self.terminating is set to True by heartbeat_callback, this
# loop should not be restarted. Otherwise self.handle_task_exit
# will be invoked and we will end up with duplicated callbacks
while not self.terminating:
# Monitor the task to see if it's done. Wait in a syscall
# (`os.wait`) for as long as possible so we notice the
# subprocess finishing as quick as we can
Expand All @@ -115,7 +122,7 @@ def signal_handler(signum, frame):

return_code = self.task_runner.return_code(timeout=max_wait_time)
if return_code is not None:
self.log.info("Task exited with return code %s", return_code)
self.handle_task_exit(return_code)
return

self.heartbeat()
Expand All @@ -134,6 +141,17 @@ def signal_handler(signum, frame):
finally:
self.on_kill()

def handle_task_exit(self, return_code: int) -> None:
"""Handle case where self.task_runner exits by itself"""
self.log.info("Task exited with return code %s", return_code)
self.task_instance.refresh_from_db()
# task exited by itself, so we need to check for error file
# incase it failed due to runtime exception/error
error = None
if self.task_instance.state != State.SUCCESS:
error = self.task_runner.deserialize_run_error()
self.task_instance._run_finished_callback(error=error) # pylint: disable=protected-access

def on_kill(self):
self.task_runner.terminate()
self.task_runner.on_finish()
Expand Down Expand Up @@ -169,11 +187,13 @@ def heartbeat_callback(self, session=None):
self.log.warning(
"State of this instance has been externally set to %s. " "Terminating instance.", ti.state
)
if ti.state == State.FAILED and ti.task.on_failure_callback:
context = ti.get_template_context()
ti.task.on_failure_callback(context)
if ti.state == State.SUCCESS and ti.task.on_success_callback:
context = ti.get_template_context()
ti.task.on_success_callback(context)
self.task_runner.terminate()
if ti.state == State.SUCCESS:
error = None
else:
# if ti.state is not set by taskinstance.handle_failure, then
# error file will not be populated and it must be updated by
# external source suck as web UI
error = self.task_runner.deserialize_run_error() or "task marked as failed externally"
ti._run_finished_callback(error=error) # pylint: disable=protected-access
self.terminating = True
6 changes: 3 additions & 3 deletions airflow/jobs/scheduler_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,7 @@ def _execute_task_callbacks(self, dagbag: DagBag, request: TaskCallbackRequest):
ti.state = simple_ti.state
ti.test_mode = self.UNIT_TEST_MODE
if request.is_failure_callback:
ti.handle_failure(request.msg, ti.test_mode, ti.get_template_context())
ti.handle_failure_with_callback(error=request.msg, test_mode=ti.test_mode)
self.log.info('Executed failure callback for %s in state %s', ti, ti.state)

@provide_session
Expand Down Expand Up @@ -1731,8 +1731,8 @@ def _emit_pool_metrics(self, session: Session = None) -> None:
pools = models.Pool.slots_stats(session=session)
for pool_name, slot_stats in pools.items():
Stats.gauge(f'pool.open_slots.{pool_name}', slot_stats["open"])
Stats.gauge(f'pool.queued_slots.{pool_name}', slot_stats[State.QUEUED])
Stats.gauge(f'pool.running_slots.{pool_name}', slot_stats[State.RUNNING])
Stats.gauge(f'pool.queued_slots.{pool_name}', slot_stats[State.QUEUED]) # type: ignore
Stats.gauge(f'pool.running_slots.{pool_name}', slot_stats[State.RUNNING]) # type: ignore

@provide_session
def heartbeat_callback(self, session: Session = None) -> None:
Expand Down
2 changes: 1 addition & 1 deletion airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,7 +584,7 @@ def next_dagrun_after_date(self, date_last_automated_dagrun: Optional[pendulum.D
next_run_date = None
if not date_last_automated_dagrun:
# First run
task_start_dates = [t.start_date for t in self.tasks]
task_start_dates = [t.start_date for t in self.tasks if t.start_date]
if task_start_dates:
next_run_date = self.normalize_schedule(min(task_start_dates))
self.log.debug("Next run date based on tasks %s", next_run_date)
Expand Down
Loading

0 comments on commit efe163a

Please sign in to comment.