diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f2221ed27e15e..1d57e98ba8d37 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -413,6 +413,7 @@ repos: ^airflow-ctl.*\.py$| ^airflow-core/src/airflow/models/.*\.py$| ^airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_assets.py$| + ^airflow-core/tests/unit/models/test_trigger.py$| ^dev/airflow_perf/scheduler_dag_execution_timing.py$| ^providers/openlineage/.*\.py$| ^task_sdk.*\.py$ diff --git a/airflow-core/tests/unit/models/test_trigger.py b/airflow-core/tests/unit/models/test_trigger.py index fe2fbeb6b98f2..df169852cb52c 100644 --- a/airflow-core/tests/unit/models/test_trigger.py +++ b/airflow-core/tests/unit/models/test_trigger.py @@ -26,6 +26,7 @@ import pytest import pytz from cryptography.fernet import Fernet +from sqlalchemy import delete, func, select from airflow._shared.timezones import timezone from airflow.jobs.job import Job @@ -61,21 +62,21 @@ def session(): @pytest.fixture(autouse=True) def clear_db(session): - session.query(TaskInstance).delete() - session.query(AssetWatcherModel).delete() - session.query(Callback).delete() - session.query(Trigger).delete() - session.query(AssetModel).delete() - session.query(AssetEvent).delete() - session.query(Job).delete() + session.execute(delete(TaskInstance)) + session.execute(delete(AssetWatcherModel)) + session.execute(delete(Callback)) + session.execute(delete(Trigger)) + session.execute(delete(AssetModel)) + session.execute(delete(AssetEvent)) + session.execute(delete(Job)) yield session - session.query(TaskInstance).delete() - session.query(AssetWatcherModel).delete() - session.query(Callback).delete() - session.query(Trigger).delete() - session.query(AssetModel).delete() - session.query(AssetEvent).delete() - session.query(Job).delete() + session.execute(delete(TaskInstance)) + session.execute(delete(AssetWatcherModel)) + session.execute(delete(Callback)) + session.execute(delete(Trigger)) + session.execute(delete(AssetModel)) + session.execute(delete(AssetEvent)) + session.execute(delete(Job)) session.commit() @@ -121,7 +122,7 @@ def test_clean_unused(session, create_task_instance): session.add(trigger5) session.add(trigger6) session.commit() - assert session.query(Trigger).count() == 6 + assert session.scalar(select(func.count()).select_from(Trigger)) == 6 # Tie one to a fake TaskInstance that is not deferred, and one to one that is task_instance = create_task_instance( session=session, task_id="fake", state=State.DEFERRED, logical_date=timezone.utcnow() @@ -150,7 +151,7 @@ def test_clean_unused(session, create_task_instance): asset.add_trigger(trigger5, "test_asset_watcher2") session.add(asset) session.commit() - assert session.query(AssetModel).count() == 1 + assert session.scalar(select(func.count()).select_from(AssetModel)) == 1 # Create callback with trigger callback = TriggererCallback( @@ -162,7 +163,7 @@ def test_clean_unused(session, create_task_instance): # Run clear operation Trigger.clean_unused() - results = session.query(Trigger).all() + results = session.scalars(select(Trigger)).all() assert len(results) == 4 assert {result.id for result in results} == {trigger1.id, trigger4.id, trigger5.id, trigger6.id} @@ -196,7 +197,10 @@ def test_submit_event(mock_callback_handle_event, session, create_task_instance) session.commit() # Check that the asset has 0 event prior to sending an event to the trigger - assert session.query(AssetEvent).filter_by(asset_id=asset.id).count() == 0 + assert ( + session.scalar(select(func.count()).select_from(AssetEvent).where(AssetEvent.asset_id == asset.id)) + == 0 + ) # Create event payload = "payload" @@ -210,8 +214,11 @@ def test_submit_event(mock_callback_handle_event, session, create_task_instance) assert task_instance.state == State.SCHEDULED assert task_instance.next_kwargs == {"event": payload, "cheesecake": True} # Check that the asset has received an event - assert session.query(AssetEvent).filter_by(asset_id=asset.id).count() == 1 - asset_event = session.query(AssetEvent).filter_by(asset_id=asset.id).first() + assert ( + session.scalar(select(func.count()).select_from(AssetEvent).where(AssetEvent.asset_id == asset.id)) + == 1 + ) + asset_event = session.scalar(select(AssetEvent).where(AssetEvent.asset_id == asset.id)) assert asset_event.extra == {"from_trigger": True, "payload": payload} # Check that the callback's handle_event was called @@ -233,7 +240,7 @@ def test_submit_failure(session, create_task_instance): # Call submit_event Trigger.submit_failure(trigger.id, session=session) # Check that the task instance is now scheduled to fail - updated_task_instance = session.query(TaskInstance).one() + updated_task_instance = session.scalar(select(TaskInstance)) assert updated_task_instance.state == State.SCHEDULED assert updated_task_instance.next_method == "__fail__" @@ -272,7 +279,7 @@ def get_xcoms(ti): # now for the real test # first check initial state - ti: TaskInstance = session.query(TaskInstance).one() + ti: TaskInstance = session.scalar(select(TaskInstance)) assert ti.state == "deferred" assert get_xcoms(ti) == [] @@ -285,7 +292,7 @@ def get_xcoms(ti): # commit changes made by submit event and expire all cache to read from db. session.flush() # Check that the task instance is now correct - ti = session.query(TaskInstance).one() + ti = session.scalar(select(TaskInstance)) assert ti.state == expected assert ti.next_kwargs is None assert ti.end_date == now @@ -370,26 +377,26 @@ def test_assign_unassigned(session, create_task_instance): session.add(ti_trigger_unassigned_to_triggerer) assert trigger_unassigned_to_triggerer.triggerer_id is None session.commit() - assert session.query(Trigger).count() == 4 + assert session.scalar(select(func.count()).select_from(Trigger)) == 4 Trigger.assign_unassigned(new_triggerer.id, 100, health_check_threshold=30) session.expire_all() # Check that trigger on killed triggerer and unassigned trigger are assigned to new triggerer assert ( - session.query(Trigger).filter(Trigger.id == trigger_on_killed_triggerer.id).one().triggerer_id + session.scalar(select(Trigger).where(Trigger.id == trigger_on_killed_triggerer.id)).triggerer_id == new_triggerer.id ) assert ( - session.query(Trigger).filter(Trigger.id == trigger_unassigned_to_triggerer.id).one().triggerer_id + session.scalar(select(Trigger).where(Trigger.id == trigger_unassigned_to_triggerer.id)).triggerer_id == new_triggerer.id ) # Check that trigger on healthy triggerer still assigned to existing triggerer assert ( - session.query(Trigger).filter(Trigger.id == trigger_on_healthy_triggerer.id).one().triggerer_id + session.scalar(select(Trigger).where(Trigger.id == trigger_on_healthy_triggerer.id)).triggerer_id == healthy_triggerer.id ) # Check that trigger on unhealthy triggerer is assigned to new triggerer assert ( - session.query(Trigger).filter(Trigger.id == trigger_on_unhealthy_triggerer.id).one().triggerer_id + session.scalar(select(Trigger).where(Trigger.id == trigger_on_unhealthy_triggerer.id)).triggerer_id == new_triggerer.id ) @@ -453,7 +460,7 @@ def test_get_sorted_triggers_same_priority_weight(session, create_task_instance) ) session.add(trigger_callback) session.commit() - assert session.query(Trigger).count() == 5 + assert session.scalar(select(func.count()).select_from(Trigger)) == 5 # Create assets asset = AssetModel("test") asset.add_trigger(trigger_asset, "test_asset_watcher") @@ -534,7 +541,7 @@ def test_get_sorted_triggers_different_priority_weights(session, create_task_ins session.add(TI_new) session.commit() - assert session.query(Trigger).count() == 5 + assert session.scalar(select(func.count()).select_from(Trigger)) == 5 trigger_ids_query = Trigger.get_sorted_triggers(capacity=100, alive_triggerer_ids=[], session=session) @@ -605,7 +612,7 @@ def test_get_sorted_triggers_dont_starve_for_ha(session, create_task_instance): asset_triggers.append(trigger) session.commit() - assert session.query(Trigger).count() == 60 + assert session.scalar(select(func.count()).select_from(Trigger)) == 60 # Mock max_trigger_to_select_per_loop to 5 for testing with patch.object(Trigger, "max_trigger_to_select_per_loop", 5):