Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AIRFLOW-3591] Fix start date, end date, duration for rescheduled tasks #4502

Merged
merged 1 commit into from
Jan 18, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions airflow/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1268,7 +1268,13 @@ def _check_and_change_state_before_execution(
msg = "Starting attempt {attempt} of {total}".format(
attempt=self.try_number,
total=self.max_tries + 1)

# Set the task start date. In case it was re-scheduled use the initial
# start date that is recorded in task_reschedule table
self.start_date = timezone.utcnow()
task_reschedules = TaskReschedule.find_for_task_instance(self, session)
if task_reschedules:
self.start_date = task_reschedules[0].start_date

dep_context = DepContext(
deps=RUN_DEPS - QUEUE_DEPS,
Expand Down Expand Up @@ -1362,6 +1368,7 @@ def _run_raw_task(
self.operator = task.__class__.__name__

context = {}
actual_start_date = timezone.utcnow()
try:
if not mark_success:
context = self.get_template_context()
Expand Down Expand Up @@ -1411,7 +1418,7 @@ def signal_handler(signum, frame):
self.state = State.SKIPPED
except AirflowRescheduleException as reschedule_exception:
self.refresh_from_db()
self._handle_reschedule(reschedule_exception, test_mode, context)
self._handle_reschedule(actual_start_date, reschedule_exception, test_mode, context)
return
except AirflowException as e:
self.refresh_from_db()
Expand Down Expand Up @@ -1483,7 +1490,7 @@ def dry_run(self):
task_copy.dry_run()

@provide_session
def _handle_reschedule(self, reschedule_exception, test_mode=False, context=None,
def _handle_reschedule(self, actual_start_date, reschedule_exception, test_mode=False, context=None,
session=None):
# Don't record reschedule request in test mode
if test_mode:
Expand All @@ -1494,7 +1501,7 @@ def _handle_reschedule(self, reschedule_exception, test_mode=False, context=None

# Log reschedule request
session.add(TaskReschedule(self.task, self.execution_date, self._try_number,
self.start_date, self.end_date,
actual_start_date, self.end_date,
reschedule_exception.reschedule_date))

# set state
Expand Down
30 changes: 6 additions & 24 deletions airflow/www/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -1728,33 +1728,15 @@ def gantt(self, session=None):
TF.execution_date == ti.execution_date)
.all()
) for ti in tis]))
TR = models.TaskReschedule
ti_reschedules = list(itertools.chain(*[(
session
.query(TR)
.filter(TR.dag_id == ti.dag_id,
TR.task_id == ti.task_id,
TR.execution_date == ti.execution_date)
.all()
) for ti in tis]))

# determine bars to show in the gantt chart
# all reschedules of one attempt are combinded into one bar
gantt_bar_items = []
for task_id, items in itertools.groupby(
sorted(tis + ti_fails + ti_reschedules, key=lambda ti: ti.task_id),
key=lambda ti: ti.task_id):
start_date = None
for i in sorted(items, key=lambda ti: ti.start_date):
start_date = start_date or i.start_date
end_date = i.end_date or timezone.utcnow()
if type(i) == models.TaskInstance:
gantt_bar_items.append((task_id, start_date, end_date, i.state))
start_date = None
elif type(i) == TF and (len(gantt_bar_items) == 0 or
end_date != gantt_bar_items[-1][2]):
gantt_bar_items.append((task_id, start_date, end_date, State.FAILED))
start_date = None
for ti in tis:
end_date = ti.end_date or timezone.utcnow()
gantt_bar_items.append((ti.task_id, ti.start_date, end_date, ti.state))
for tf in ti_fails:
end_date = tf.end_date or timezone.utcnow()
gantt_bar_items.append((tf.task_id, tf.start_date, end_date, State.FAILED))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lol so much easier :D


tasks = []
for gantt_bar_item in gantt_bar_items:
Expand Down
105 changes: 105 additions & 0 deletions tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,18 @@
import six
from mock import ANY, Mock, mock_open, patch
from parameterized import parameterized
from freezegun import freeze_time

from airflow import AirflowException, configuration, models, settings
from airflow.contrib.sensors.python_sensor import PythonSensor
from airflow.exceptions import AirflowDagCycleException, AirflowSkipException
from airflow.jobs import BackfillJob
from airflow.models import DAG, TaskInstance as TI
from airflow.models import DagModel, DagRun
from airflow.models import KubeResourceVersion, KubeWorkerIdentifier
from airflow.models import SkipMixin
from airflow.models import State as ST
from airflow.models import TaskReschedule as TR
from airflow.models import XCom
from airflow.models import clear_task_instances
from airflow.models.connection import Connection
Expand Down Expand Up @@ -1821,6 +1824,12 @@ def test_deactivate_unknown_dags(self):

class TaskInstanceTest(unittest.TestCase):

def tearDown(self):
with create_session() as session:
session.query(models.TaskFail).delete()
session.query(models.TaskReschedule).delete()
session.query(models.TaskInstance).delete()

def test_set_task_dates(self):
"""
Test that tasks properly take start/end dates from DAGs
Expand Down Expand Up @@ -2191,6 +2200,102 @@ def test_next_retry_datetime(self):
dt = ti.next_retry_datetime()
self.assertEqual(dt, ti.end_date + max_delay)

@patch.object(TI, 'pool_full')
def test_reschedule_handling(self, mock_pool_full):
"""
Test that task reschedules are handled properly
"""
# Mock the pool with a pool with slots open since the pool doesn't actually exist
mock_pool_full.return_value = False

# Return values of the python sensor callable, modified during tests
done = False
fail = False

def callable():
if fail:
raise AirflowException()
return done

dag = models.DAG(dag_id='test_reschedule_handling')
task = PythonSensor(
task_id='test_reschedule_handling_sensor',
poke_interval=0,
mode='reschedule',
python_callable=callable,
retries=1,
retry_delay=datetime.timedelta(seconds=0),
dag=dag,
owner='airflow',
start_date=timezone.datetime(2016, 2, 1, 0, 0, 0))

ti = TI(task=task, execution_date=timezone.utcnow())
self.assertEqual(ti._try_number, 0)
self.assertEqual(ti.try_number, 1)

def run_ti_and_assert(run_date, expected_start_date, expected_end_date, expected_duration,
expected_state, expected_try_number, expected_task_reschedule_count):
with freeze_time(run_date):
try:
ti.run()
except AirflowException:
if not fail:
raise
ti.refresh_from_db()
self.assertEqual(ti.state, expected_state)
self.assertEqual(ti._try_number, expected_try_number)
self.assertEqual(ti.try_number, expected_try_number + 1)
self.assertEqual(ti.start_date, expected_start_date)
self.assertEqual(ti.end_date, expected_end_date)
self.assertEqual(ti.duration, expected_duration)
trs = TR.find_for_task_instance(ti)
self.assertEqual(len(trs), expected_task_reschedule_count)

date1 = timezone.utcnow()
date2 = date1 + datetime.timedelta(minutes=1)
date3 = date2 + datetime.timedelta(minutes=1)
date4 = date3 + datetime.timedelta(minutes=1)

# Run with multiple reschedules.
# During reschedule the try number remains the same, but each reschedule is recorded.
# The start date is expected to remain the inital date, hence the duration increases.
# When finished the try number is incremented and there is no reschedule expected
# for this try.

done, fail = False, False
run_ti_and_assert(date1, date1, date1, 0, State.UP_FOR_RESCHEDULE, 0, 1)

done, fail = False, False
run_ti_and_assert(date2, date1, date2, 60, State.UP_FOR_RESCHEDULE, 0, 2)

done, fail = False, False
run_ti_and_assert(date3, date1, date3, 120, State.UP_FOR_RESCHEDULE, 0, 3)

done, fail = True, False
run_ti_and_assert(date4, date1, date4, 180, State.SUCCESS, 1, 0)

# Clear the task instance.
dag.clear()
ti.refresh_from_db()
self.assertEqual(ti.state, State.NONE)
self.assertEqual(ti._try_number, 1)

# Run again after clearing with reschedules and a retry.
# The retry increments the try number, and for that try no reschedule is expected.
# After the retry the start date is reset, hence the duration is also reset.

done, fail = False, False
run_ti_and_assert(date1, date1, date1, 0, State.UP_FOR_RESCHEDULE, 1, 1)

done, fail = False, True
run_ti_and_assert(date2, date1, date2, 60, State.UP_FOR_RETRY, 2, 0)

done, fail = False, False
run_ti_and_assert(date3, date3, date3, 0, State.UP_FOR_RESCHEDULE, 2, 1)

done, fail = True, False
run_ti_and_assert(date4, date3, date4, 60, State.SUCCESS, 3, 0)

def test_depends_on_past(self):
dagbag = models.DagBag()
dag = dagbag.get_dag('test_depends_on_past')
Expand Down
6 changes: 6 additions & 0 deletions tests/sensors/test_base_sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,8 @@ def test_ok_with_reschedule(self):
if ti.task_id == SENSOR_OP:
# verify task is re-scheduled, i.e. state set to NONE
self.assertEquals(ti.state, State.UP_FOR_RESCHEDULE)
# verify task start date is the initial one
self.assertEquals(ti.start_date, date1)
# verify one row in task_reschedule table
task_reschedules = TaskReschedule.find_for_task_instance(ti)
self.assertEquals(len(task_reschedules), 1)
Expand All @@ -202,6 +204,8 @@ def test_ok_with_reschedule(self):
if ti.task_id == SENSOR_OP:
# verify task is re-scheduled, i.e. state set to NONE
self.assertEquals(ti.state, State.UP_FOR_RESCHEDULE)
# verify task start date is the initial one
self.assertEquals(ti.start_date, date1)
# verify two rows in task_reschedule table
task_reschedules = TaskReschedule.find_for_task_instance(ti)
self.assertEquals(len(task_reschedules), 2)
Expand All @@ -220,6 +224,8 @@ def test_ok_with_reschedule(self):
for ti in tis:
if ti.task_id == SENSOR_OP:
self.assertEquals(ti.state, State.SUCCESS)
# verify task start date is the initial one
self.assertEquals(ti.start_date, date1)
if ti.task_id == DUMMY_OP:
self.assertEquals(ti.state, State.NONE)

Expand Down