From a3bacdbb4cc67456d509bbdbd562f1cf890f41fc Mon Sep 17 00:00:00 2001 From: Hussein Awala Date: Sun, 2 Jul 2023 15:35:22 +0200 Subject: [PATCH 1/4] Order triggers by - TI priority_weight when assign unassigned triggers Signed-off-by: Hussein Awala --- airflow/models/trigger.py | 3 +- tests/models/test_trigger.py | 88 ++++++++++++++++++++++++++++++++---- 2 files changed, 81 insertions(+), 10 deletions(-) diff --git a/airflow/models/trigger.py b/airflow/models/trigger.py index c0d749eb5905e..f2a653075a4dd 100644 --- a/airflow/models/trigger.py +++ b/airflow/models/trigger.py @@ -241,8 +241,9 @@ def assign_unassigned(cls, triggerer_id, capacity, heartrate, session: Session = def get_sorted_triggers(cls, capacity, alive_triggerer_ids, session): return with_row_locks( session.query(cls.id) + .join(TaskInstance, cls.id == TaskInstance.trigger_id, isouter=True) .filter(or_(cls.triggerer_id.is_(None), cls.triggerer_id.notin_(alive_triggerer_ids))) - .order_by(cls.created_date) + .order_by(-TaskInstance.priority_weight, cls.created_date) .limit(capacity), session, skip_locked=True, diff --git a/tests/models/test_trigger.py b/tests/models/test_trigger.py index 0c5e5eeadcb8f..3e65aeb5af22b 100644 --- a/tests/models/test_trigger.py +++ b/tests/models/test_trigger.py @@ -243,31 +243,101 @@ def test_assign_unassigned_missing_heartbeat(session, create_task_instance, chec second_triggerer.latest_heartbeat += datetime.timedelta(seconds=check_triggerer_heartrate) -def test_get_sorted_triggers(session, create_task_instance): +def test_get_sorted_triggers_same_priority_weight(session, create_task_instance): """ - Tests that triggers are sorted by the creation_date. + Tests that triggers are sorted by the creation_date if they have the same priority. """ + old_execution_date = datetime.datetime( + 2023, 5, 9, 12, 16, 14, 474415, tzinfo=pytz.timezone("Africa/Abidjan") + ) trigger_old = Trigger( classpath="airflow.triggers.testing.SuccessTrigger", kwargs={}, - created_date=datetime.datetime( - 2023, 5, 9, 12, 16, 14, 474415, tzinfo=pytz.timezone("Africa/Abidjan") - ), + created_date=old_execution_date + datetime.timedelta(seconds=30), ) trigger_old.id = 1 + session.add(trigger_old) + TI_old = create_task_instance( + task_id="old", + execution_date=old_execution_date, + run_id="old_run_id", + ) + TI_old.priority_weight = 1 + TI_old.trigger_id = trigger_old.id + session.add(TI_old) + + new_execution_date = datetime.datetime( + 2023, 5, 9, 12, 17, 14, 474415, tzinfo=pytz.timezone("Africa/Abidjan") + ) trigger_new = Trigger( classpath="airflow.triggers.testing.SuccessTrigger", kwargs={}, - created_date=datetime.datetime( - 2023, 5, 9, 12, 17, 14, 474415, tzinfo=pytz.timezone("Africa/Abidjan") - ), + created_date=new_execution_date + datetime.timedelta(seconds=30), ) trigger_new.id = 2 - session.add(trigger_old) session.add(trigger_new) + TI_new = create_task_instance( + task_id="new", + execution_date=new_execution_date, + run_id="new_run_id", + ) + TI_new.priority_weight = 1 + TI_new.trigger_id = trigger_new.id + session.add(TI_new) + session.commit() assert session.query(Trigger).count() == 2 trigger_ids_query = Trigger.get_sorted_triggers(capacity=100, alive_triggerer_ids=[], session=session) assert trigger_ids_query == [(1,), (2,)] + + +def test_get_sorted_triggers_different_priority_weights(session, create_task_instance): + """ + Tests that triggers are sorted by the priority_weight. + """ + old_execution_date = datetime.datetime( + 2023, 5, 9, 12, 16, 14, 474415, tzinfo=pytz.timezone("Africa/Abidjan") + ) + trigger_old = Trigger( + classpath="airflow.triggers.testing.SuccessTrigger", + kwargs={}, + created_date=old_execution_date + datetime.timedelta(seconds=30), + ) + trigger_old.id = 1 + session.add(trigger_old) + TI_old = create_task_instance( + task_id="old", + execution_date=old_execution_date, + run_id="old_run_id", + ) + TI_old.priority_weight = 1 + TI_old.trigger_id = trigger_old.id + session.add(TI_old) + + new_execution_date = datetime.datetime( + 2023, 5, 9, 12, 17, 14, 474415, tzinfo=pytz.timezone("Africa/Abidjan") + ) + trigger_new = Trigger( + classpath="airflow.triggers.testing.SuccessTrigger", + kwargs={}, + created_date=new_execution_date + datetime.timedelta(seconds=30), + ) + trigger_new.id = 2 + session.add(trigger_new) + TI_new = create_task_instance( + task_id="new", + execution_date=new_execution_date, + run_id="new_run_id", + ) + TI_new.priority_weight = 2 + TI_new.trigger_id = trigger_new.id + session.add(TI_new) + + session.commit() + assert session.query(Trigger).count() == 2 + + trigger_ids_query = Trigger.get_sorted_triggers(capacity=100, alive_triggerer_ids=[], session=session) + + assert trigger_ids_query == [(2,), (1,)] From c266d1cf239937200db9da3b6971603468fdd308 Mon Sep 17 00:00:00 2001 From: Hussein Awala Date: Mon, 10 Jul 2023 21:24:48 +0200 Subject: [PATCH 2/4] Update airflow/models/trigger.py Co-authored-by: Tzu-ping Chung --- airflow/models/trigger.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/models/trigger.py b/airflow/models/trigger.py index f2a653075a4dd..ad505ec80f43d 100644 --- a/airflow/models/trigger.py +++ b/airflow/models/trigger.py @@ -243,7 +243,7 @@ def get_sorted_triggers(cls, capacity, alive_triggerer_ids, session): session.query(cls.id) .join(TaskInstance, cls.id == TaskInstance.trigger_id, isouter=True) .filter(or_(cls.triggerer_id.is_(None), cls.triggerer_id.notin_(alive_triggerer_ids))) - .order_by(-TaskInstance.priority_weight, cls.created_date) + .order_by(TaskInstance.priority_weight.desc(), cls.created_date) .limit(capacity), session, skip_locked=True, From b0fe3b82bf7388418532238a3a7d1619ca4fcd64 Mon Sep 17 00:00:00 2001 From: Hussein Awala Date: Fri, 11 Aug 2023 00:34:06 +0200 Subject: [PATCH 3/4] Replace outer join by inner join and use coalesce to handle None values --- airflow/models/trigger.py | 5 +++-- tests/models/test_trigger.py | 36 ++++++++++++++++++++++++++++++++---- 2 files changed, 35 insertions(+), 6 deletions(-) diff --git a/airflow/models/trigger.py b/airflow/models/trigger.py index 9969145da05f2..79bd373aad14b 100644 --- a/airflow/models/trigger.py +++ b/airflow/models/trigger.py @@ -22,6 +22,7 @@ from sqlalchemy import Column, Integer, String, delete, func, or_, select, update from sqlalchemy.orm import Session, joinedload, relationship +from sqlalchemy.sql.functions import coalesce from airflow.api_internal.internal_api_call import internal_api_call from airflow.models.base import Base @@ -244,9 +245,9 @@ def assign_unassigned( def get_sorted_triggers(cls, capacity, alive_triggerer_ids, session): query = with_row_locks( select(cls.id) - .join(TaskInstance, cls.id == TaskInstance.trigger_id, isouter=True) + .join(TaskInstance, cls.id == TaskInstance.trigger_id, isouter=False) .where(or_(cls.triggerer_id.is_(None), cls.triggerer_id.not_in(alive_triggerer_ids))) - .order_by(TaskInstance.priority_weight.desc(), cls.created_date) + .order_by(coalesce(TaskInstance.priority_weight, 0).desc(), cls.created_date) .limit(capacity), session, skip_locked=True, diff --git a/tests/models/test_trigger.py b/tests/models/test_trigger.py index 5fc5b8190f282..3626c946707c9 100644 --- a/tests/models/test_trigger.py +++ b/tests/models/test_trigger.py @@ -171,19 +171,47 @@ def test_assign_unassigned(session, create_task_instance): trigger_on_healthy_triggerer = Trigger(classpath="airflow.triggers.testing.SuccessTrigger", kwargs={}) trigger_on_healthy_triggerer.id = 1 trigger_on_healthy_triggerer.triggerer_id = healthy_triggerer.id + session.add(trigger_on_healthy_triggerer) + ti_trigger_on_healthy_triggerer = create_task_instance( + task_id="ti_trigger_on_healthy_triggerer", + execution_date=time_now, + run_id="trigger_on_healthy_triggerer_run_id", + ) + ti_trigger_on_healthy_triggerer.trigger_id = trigger_on_healthy_triggerer.id + session.add(ti_trigger_on_healthy_triggerer) trigger_on_unhealthy_triggerer = Trigger(classpath="airflow.triggers.testing.SuccessTrigger", kwargs={}) trigger_on_unhealthy_triggerer.id = 2 trigger_on_unhealthy_triggerer.triggerer_id = unhealthy_triggerer.id + session.add(trigger_on_unhealthy_triggerer) + ti_trigger_on_unhealthy_triggerer = create_task_instance( + task_id="ti_trigger_on_unhealthy_triggerer", + execution_date=time_now + datetime.timedelta(hours=1), + run_id="trigger_on_unhealthy_triggerer_run_id", + ) + ti_trigger_on_unhealthy_triggerer.trigger_id = trigger_on_unhealthy_triggerer.id + session.add(ti_trigger_on_unhealthy_triggerer) trigger_on_killed_triggerer = Trigger(classpath="airflow.triggers.testing.SuccessTrigger", kwargs={}) trigger_on_killed_triggerer.id = 3 trigger_on_killed_triggerer.triggerer_id = finished_triggerer.id + session.add(trigger_on_killed_triggerer) + ti_trigger_on_killed_triggerer = create_task_instance( + task_id="ti_trigger_on_killed_triggerer", + execution_date=time_now + datetime.timedelta(hours=2), + run_id="trigger_on_killed_triggerer_run_id", + ) + ti_trigger_on_killed_triggerer.trigger_id = trigger_on_killed_triggerer.id + session.add(ti_trigger_on_killed_triggerer) trigger_unassigned_to_triggerer = Trigger(classpath="airflow.triggers.testing.SuccessTrigger", kwargs={}) trigger_unassigned_to_triggerer.id = 4 - assert trigger_unassigned_to_triggerer.triggerer_id is None - session.add(trigger_on_healthy_triggerer) - session.add(trigger_on_unhealthy_triggerer) - session.add(trigger_on_killed_triggerer) session.add(trigger_unassigned_to_triggerer) + ti_trigger_unassigned_to_triggerer = create_task_instance( + task_id="ti_trigger_unassigned_to_triggerer", + execution_date=time_now + datetime.timedelta(hours=3), + run_id="trigger_unassigned_to_triggerer_run_id", + ) + ti_trigger_unassigned_to_triggerer.trigger_id = trigger_unassigned_to_triggerer.id + session.add(ti_trigger_unassigned_to_triggerer) + assert trigger_unassigned_to_triggerer.triggerer_id is None session.commit() assert session.query(Trigger).count() == 4 Trigger.assign_unassigned(new_triggerer.id, 100, health_check_threshold=30) From 9fa2ed4e600faf1b12690eae1c06f0b3f0c9a648 Mon Sep 17 00:00:00 2001 From: Hussein Awala Date: Fri, 11 Aug 2023 18:48:47 +0200 Subject: [PATCH 4/4] fix unit tests --- tests/jobs/test_triggerer_job.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/tests/jobs/test_triggerer_job.py b/tests/jobs/test_triggerer_job.py index 35ebe99b31ad1..bd71c3c9ed6fa 100644 --- a/tests/jobs/test_triggerer_job.py +++ b/tests/jobs/test_triggerer_job.py @@ -414,7 +414,7 @@ def handle_events(self): assert len(instances) == 1 -def test_trigger_from_dead_triggerer(session): +def test_trigger_from_dead_triggerer(session, create_task_instance): """ Checks that the triggerer will correctly claim a Trigger that is assigned to a triggerer that does not exist. @@ -425,6 +425,13 @@ def test_trigger_from_dead_triggerer(session): trigger_orm.id = 1 trigger_orm.triggerer_id = 999 # Non-existent triggerer session.add(trigger_orm) + ti_orm = create_task_instance( + task_id="ti_orm", + execution_date=datetime.datetime.utcnow(), + run_id="orm_run_id", + ) + ti_orm.trigger_id = trigger_orm.id + session.add(trigger_orm) session.commit() # Make a TriggererJobRunner and have it retrieve DB tasks job = Job() @@ -434,7 +441,7 @@ def test_trigger_from_dead_triggerer(session): assert [x for x, y in job_runner.trigger_runner.to_create] == [1] -def test_trigger_from_expired_triggerer(session): +def test_trigger_from_expired_triggerer(session, create_task_instance): """ Checks that the triggerer will correctly claim a Trigger that is assigned to a triggerer that has an expired heartbeat. @@ -445,6 +452,13 @@ def test_trigger_from_expired_triggerer(session): trigger_orm.id = 1 trigger_orm.triggerer_id = 42 session.add(trigger_orm) + ti_orm = create_task_instance( + task_id="ti_orm", + execution_date=datetime.datetime.utcnow(), + run_id="orm_run_id", + ) + ti_orm.trigger_id = trigger_orm.id + session.add(trigger_orm) # Use a TriggererJobRunner with an expired heartbeat triggerer_job_orm = Job(TriggererJobRunner.job_type) triggerer_job_orm.id = 42