diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ed5ecaa9ed1f3..27bb45db68859 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -413,6 +413,16 @@ repos: ^airflow-ctl.*\.py$| ^airflow-core/src/airflow/models/.*\.py$| ^airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_assets.py$| + ^airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py$| + ^airflow-core/tests/unit/models/test_renderedtifields.py$| + ^airflow-core/tests/unit/models/test_timestamp.py$| + ^airflow-core/tests/unit/utils/test_cli_util.py$| + ^airflow-core/tests/unit/timetables/test_assets_timetable.py$| + ^airflow-core/tests/unit/assets/test_manager.py$| + ^airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dag_run.py$| + ^airflow-core/tests/unit/ti_deps/deps/test_runnable_exec_date_dep.py$| + ^airflow-core/tests/unit/models/test_dagwarning.py$| + ^airflow-core/tests/integration/otel/test_otel.py$| ^airflow-core/tests/unit/utils/test_db_cleanup.py$| ^dev/airflow_perf/scheduler_dag_execution_timing.py$| ^providers/openlineage/.*\.py$| diff --git a/airflow-core/tests/integration/otel/test_otel.py b/airflow-core/tests/integration/otel/test_otel.py index 81ab540c501dc..8d6f665811570 100644 --- a/airflow-core/tests/integration/otel/test_otel.py +++ b/airflow-core/tests/integration/otel/test_otel.py @@ -24,7 +24,7 @@ import time import pytest -from sqlalchemy import select +from sqlalchemy import func, select from airflow._shared.timezones import timezone from airflow.dag_processing.bundles.manager import DagBundlesManager @@ -87,13 +87,11 @@ def wait_for_dag_run_and_check_span_status( while timezone.utcnow().timestamp() - start_time < max_wait_time: with create_session() as session: - dag_run = ( - session.query(DagRun) - .filter( + dag_run = session.scalar( + select(DagRun).where( DagRun.dag_id == dag_id, DagRun.run_id == run_id, ) - .first() ) if dag_run is None: @@ -121,13 +119,11 @@ def wait_for_dag_run_and_check_span_status( def check_dag_run_state_and_span_status(dag_id: str, run_id: str, state: str, span_status: str): with create_session() as session: - dag_run = ( - session.query(DagRun) - .filter( + dag_run = session.scalar( + select(DagRun).where( DagRun.dag_id == dag_id, DagRun.run_id == run_id, ) - .first() ) assert dag_run is not None @@ -139,13 +135,11 @@ def check_dag_run_state_and_span_status(dag_id: str, run_id: str, state: str, sp def check_ti_state_and_span_status(task_id: str, run_id: str, state: str, span_status: str | None): with create_session() as session: - ti = ( - session.query(TaskInstance) - .filter( + ti = session.scalar( + select(TaskInstance).where( TaskInstance.task_id == task_id, TaskInstance.run_id == run_id, ) - .first() ) assert ti is not None @@ -668,7 +662,12 @@ def serialize_and_get_dags(cls) -> dict[str, SerializedDAG]: if AIRFLOW_V_3_0_PLUS: from airflow.models.dagbundle import DagBundleModel - if session.query(DagBundleModel).filter(DagBundleModel.name == "testing").count() == 0: + count = session.scalar( + select(func.count()) + .select_from(DagBundleModel) + .where(DagBundleModel.name == "testing") + ) + if count == 0: session.add(DagBundleModel(name="testing")) session.commit() SerializedDAG.bulk_write_to_db( diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dag_run.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dag_run.py index 785904d3d8100..4abfd6ee8085c 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dag_run.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dag_run.py @@ -23,7 +23,7 @@ import pytest import time_machine -from sqlalchemy import select +from sqlalchemy import func, select from airflow._shared.timezones import timezone from airflow.api_fastapi.core_api.datamodels.dag_versions import DagVersionResponse @@ -340,11 +340,9 @@ def test_get_dag_runs(self, test_client, session, dag_id, total_entries): body = response.json() assert body["total_entries"] == total_entries for each in body["dag_runs"]: - run = ( - session.query(DagRun) - .where(DagRun.dag_id == each["dag_id"], DagRun.run_id == each["dag_run_id"]) - .one() - ) + run = session.scalars( + select(DagRun).where(DagRun.dag_id == each["dag_id"], DagRun.run_id == each["dag_run_id"]) + ).one() assert each == get_dag_run_dict(run) @pytest.mark.usefixtures("configure_git_connection_for_dag_bundle") @@ -821,7 +819,7 @@ def test_list_dag_runs_return_200(self, test_client, session): body = response.json() assert body["total_entries"] == 4 for each in body["dag_runs"]: - run = session.query(DagRun).where(DagRun.run_id == each["dag_run_id"]).one() + run = session.scalars(select(DagRun).where(DagRun.run_id == each["dag_run_id"])).one() expected = get_dag_run_dict(run) assert each == expected @@ -1344,7 +1342,7 @@ def test_should_respond_200(self, test_client, dag_maker, session): dr = dag_maker.create_dagrun() ti = dr.task_instances[0] - asset1_id = session.query(AssetModel.id).filter_by(uri=asset1.uri).scalar() + asset1_id = session.scalar(select(AssetModel.id).where(AssetModel.uri == asset1.uri)) event = AssetEvent( asset_id=asset1_id, source_task_id=ti.task_id, @@ -1473,13 +1471,15 @@ def test_clear_dag_run_dry_run(self, test_client, session, body, dag_run_id, exp assert body["total_entries"] == len(expected_state) for index, each in enumerate(sorted(body["task_instances"], key=lambda x: x["task_id"])): assert each["state"] == expected_state[index] - dag_run = session.scalar(select(DagRun).filter_by(dag_id=DAG1_ID, run_id=DAG1_RUN1_ID)) + dag_run = session.scalar( + select(DagRun).where(DagRun.dag_id == DAG1_ID, DagRun.run_id == DAG1_RUN1_ID) + ) assert dag_run.state == DAG1_RUN1_STATE - logs = ( - session.query(Log) - .filter(Log.dag_id == DAG1_ID, Log.run_id == dag_run_id, Log.event == "clear_dag_run") - .count() + logs = session.scalar( + select(func.count()) + .select_from(Log) + .where(Log.dag_id == DAG1_ID, Log.run_id == dag_run_id, Log.event == "clear_dag_run") ) assert logs == 0 @@ -1572,9 +1572,9 @@ def test_should_respond_200( expected_data_interval_end = data_interval_end.replace("+00:00", "Z") expected_logical_date = fixed_now.replace("+00:00", "Z") - run = ( - session.query(DagRun).where(DagRun.dag_id == DAG1_ID, DagRun.run_id == expected_dag_run_id).one() - ) + run = session.scalars( + select(DagRun).where(DagRun.dag_id == DAG1_ID, DagRun.run_id == expected_dag_run_id) + ).one() expected_response_json = { "bundle_version": None, @@ -1907,7 +1907,7 @@ def test_custom_timetable_generate_run_id_for_manual_trigger(self, dag_maker, te run_id_with_logical_date = response.json()["dag_run_id"] assert run_id_with_logical_date.startswith("custom_") - run = session.query(DagRun).filter(DagRun.run_id == run_id_with_logical_date).one() + run = session.scalars(select(DagRun).where(DagRun.run_id == run_id_with_logical_date)).one() assert run.dag_id == custom_dag_id response = test_client.post( @@ -1918,7 +1918,7 @@ def test_custom_timetable_generate_run_id_for_manual_trigger(self, dag_maker, te run_id_without_logical_date = response.json()["dag_run_id"] assert run_id_without_logical_date.startswith("custom_manual_") - run = session.query(DagRun).filter(DagRun.run_id == run_id_without_logical_date).one() + run = session.scalars(select(DagRun).where(DagRun.run_id == run_id_without_logical_date)).one() assert run.dag_id == custom_dag_id diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py index f91ae7f285f17..d9f4e1f42424f 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py @@ -1144,7 +1144,7 @@ def test_ti_update_state_to_deferred(self, client, session, create_task_instance session.expire_all() - tis = session.query(TaskInstance).all() + tis = session.scalars(select(TaskInstance)).all() assert len(tis) == 1 assert tis[0].state == TaskInstanceState.DEFERRED @@ -1155,7 +1155,7 @@ def test_ti_update_state_to_deferred(self, client, session, create_task_instance } assert tis[0].trigger_timeout == timezone.make_aware(datetime(2024, 11, 23), timezone=timezone.utc) - t = session.query(Trigger).all() + t = session.scalars(select(Trigger)).all() assert len(t) == 1 assert t[0].created_date == instant assert t[0].classpath == "my-classpath" @@ -1192,14 +1192,14 @@ def test_ti_update_state_to_reschedule(self, client, session, create_task_instan session.expire_all() - tis = session.query(TaskInstance).all() + tis = session.scalars(select(TaskInstance)).all() assert len(tis) == 1 assert tis[0].state == TaskInstanceState.UP_FOR_RESCHEDULE assert tis[0].next_method is None assert tis[0].next_kwargs is None assert tis[0].duration == 129600 - trs = session.query(TaskReschedule).all() + trs = session.scalars(select(TaskReschedule)).all() assert len(trs) == 1 assert trs[0].task_instance.dag_id == "dag" assert trs[0].task_instance.task_id == "test_ti_update_state_to_reschedule" @@ -1273,11 +1273,11 @@ def test_ti_update_state_handle_retry(self, client, session, create_task_instanc assert ti.next_method is None assert ti.next_kwargs is None - tih = ( - session.query(TaskInstanceHistory) - .where(TaskInstanceHistory.task_id == ti.task_id, TaskInstanceHistory.run_id == ti.run_id) - .one() - ) + tih = session.scalars( + select(TaskInstanceHistory).where( + TaskInstanceHistory.task_id == ti.task_id, TaskInstanceHistory.run_id == ti.run_id + ) + ).one() assert tih.task_instance_id assert tih.task_instance_id != ti.id @@ -1680,7 +1680,7 @@ def test_ti_put_rtif_success(self, client, session, create_task_instance, payloa session.expire_all() - rtifs = session.query(RenderedTaskInstanceFields).all() + rtifs = session.scalars(select(RenderedTaskInstanceFields)).all() assert len(rtifs) == 1 assert rtifs[0].dag_id == "dag" diff --git a/airflow-core/tests/unit/assets/test_manager.py b/airflow-core/tests/unit/assets/test_manager.py index 46a59198d8018..3d83ef34db9cb 100644 --- a/airflow-core/tests/unit/assets/test_manager.py +++ b/airflow-core/tests/unit/assets/test_manager.py @@ -21,7 +21,7 @@ from unittest import mock import pytest -from sqlalchemy import delete +from sqlalchemy import delete, func, select from sqlalchemy.orm import Session from airflow.assets.manager import AssetManager @@ -105,8 +105,11 @@ def test_register_asset_change(self, session, dag_maker, mock_task_instance, tes session.flush() # Ensure we've created an asset - assert session.query(AssetEvent).filter_by(asset_id=asm.id).count() == 1 - assert session.query(AssetDagRunQueue).count() == 2 + assert ( + session.scalar(select(func.count()).select_from(AssetEvent).where(AssetEvent.asset_id == asm.id)) + == 1 + ) + assert session.scalar(select(func.count()).select_from(AssetDagRunQueue)) == 2 @pytest.mark.usefixtures("clear_assets") def test_register_asset_change_with_alias( @@ -145,8 +148,11 @@ def test_register_asset_change_with_alias( session.flush() # Ensure we've created an asset - assert session.query(AssetEvent).filter_by(asset_id=asm.id).count() == 1 - assert session.query(AssetDagRunQueue).count() == 2 + assert ( + session.scalar(select(func.count()).select_from(AssetEvent).where(AssetEvent.asset_id == asm.id)) + == 1 + ) + assert session.scalar(select(func.count()).select_from(AssetDagRunQueue)) == 2 def test_register_asset_change_no_downstreams(self, session, mock_task_instance): asset_manager = AssetManager() @@ -161,8 +167,11 @@ def test_register_asset_change_no_downstreams(self, session, mock_task_instance) session.flush() # Ensure we've created an asset - assert session.query(AssetEvent).filter_by(asset_id=asm.id).count() == 1 - assert session.query(AssetDagRunQueue).count() == 0 + assert ( + session.scalar(select(func.count()).select_from(AssetEvent).where(AssetEvent.asset_id == asm.id)) + == 1 + ) + assert session.scalar(select(func.count()).select_from(AssetDagRunQueue)) == 0 def test_register_asset_change_notifies_asset_listener( self, session, mock_task_instance, testing_dag_bundle diff --git a/airflow-core/tests/unit/models/test_dagwarning.py b/airflow-core/tests/unit/models/test_dagwarning.py index 18a61e15d72eb..167244afbc1f2 100644 --- a/airflow-core/tests/unit/models/test_dagwarning.py +++ b/airflow-core/tests/unit/models/test_dagwarning.py @@ -21,6 +21,7 @@ from unittest.mock import MagicMock import pytest +from sqlalchemy import select from sqlalchemy.exc import OperationalError from airflow.models import DagModel @@ -55,7 +56,7 @@ def test_purge_inactive_dag_warnings(self, session, testing_dag_bundle): DagWarning.purge_inactive_dag_warnings(session) - remaining_dag_warnings = session.query(DagWarning).all() + remaining_dag_warnings = session.scalars(select(DagWarning)).all() assert len(remaining_dag_warnings) == 1 assert remaining_dag_warnings[0].dag_id == "dag_2" diff --git a/airflow-core/tests/unit/models/test_renderedtifields.py b/airflow-core/tests/unit/models/test_renderedtifields.py index f31ed7722c019..02b705f4ce949 100644 --- a/airflow-core/tests/unit/models/test_renderedtifields.py +++ b/airflow-core/tests/unit/models/test_renderedtifields.py @@ -261,7 +261,9 @@ def test_delete_old_records( session.add_all(rtif_list) session.flush() - result = session.query(RTIF).filter(RTIF.dag_id == dag.dag_id, RTIF.task_id == task.task_id).all() + result = session.scalars( + select(RTIF).where(RTIF.dag_id == dag.dag_id, RTIF.task_id == task.task_id) + ).all() for rtif in rtif_list: assert rtif in result @@ -270,7 +272,9 @@ def test_delete_old_records( with assert_queries_count(expected_query_count): RTIF.delete_old_records(task_id=task.task_id, dag_id=task.dag_id, num_to_keep=num_to_keep) - result = session.query(RTIF).filter(RTIF.dag_id == dag.dag_id, RTIF.task_id == task.task_id).all() + result = session.scalars( + select(RTIF).where(RTIF.dag_id == dag.dag_id, RTIF.task_id == task.task_id) + ).all() assert remaining_rtifs == len(result) @pytest.mark.parametrize( @@ -302,14 +306,16 @@ def test_delete_old_records_mapped( session.add(RTIF(ti)) session.flush() - result = session.query(RTIF).filter(RTIF.dag_id == dag.dag_id).all() + result = session.scalars(select(RTIF).where(RTIF.dag_id == dag.dag_id)).all() assert len(result) == num_runs * 2 with assert_queries_count(expected_query_count): RTIF.delete_old_records( task_id=mapped.task_id, dag_id=dr.dag_id, num_to_keep=num_to_keep, session=session ) - result = session.query(RTIF).filter_by(dag_id=dag.dag_id, task_id=mapped.task_id).all() + result = session.scalars( + select(RTIF).where(RTIF.dag_id == dag.dag_id, RTIF.task_id == mapped.task_id) + ).all() rtif_num_runs = Counter(rtif.run_id for rtif in result) assert len(rtif_num_runs) == remaining_rtifs # Check that we have _all_ the data for each row @@ -322,7 +328,7 @@ def test_write(self, dag_maker): Variable.set(key="test_key", value="test_val") session = settings.Session() - result = session.query(RTIF).all() + result = session.scalars(select(RTIF)).all() assert result == [] with dag_maker("test_write"): @@ -334,15 +340,13 @@ def test_write(self, dag_maker): rtif = RTIF(ti) rtif.write() - result = ( - session.query(RTIF.dag_id, RTIF.task_id, RTIF.rendered_fields) - .filter( + result = session.execute( + select(RTIF.dag_id, RTIF.task_id, RTIF.rendered_fields).where( RTIF.dag_id == rtif.dag_id, RTIF.task_id == rtif.task_id, RTIF.run_id == rtif.run_id, ) - .first() - ) + ).first() assert result == ("test_write", "test", {"bash_command": "echo test_val", "env": None, "cwd": None}) # Test that overwrite saves new values to the DB @@ -357,15 +361,13 @@ def test_write(self, dag_maker): rtif_updated = RTIF(ti) rtif_updated.write() - result_updated = ( - session.query(RTIF.dag_id, RTIF.task_id, RTIF.rendered_fields) - .filter( + result_updated = session.execute( + select(RTIF.dag_id, RTIF.task_id, RTIF.rendered_fields).where( RTIF.dag_id == rtif_updated.dag_id, RTIF.task_id == rtif_updated.task_id, RTIF.run_id == rtif_updated.run_id, ) - .first() - ) + ).first() assert result_updated == ( "test_write", "test", diff --git a/airflow-core/tests/unit/models/test_timestamp.py b/airflow-core/tests/unit/models/test_timestamp.py index 529fd32afc504..b200fe3ecdb5a 100644 --- a/airflow-core/tests/unit/models/test_timestamp.py +++ b/airflow-core/tests/unit/models/test_timestamp.py @@ -19,6 +19,7 @@ import pendulum import pytest import time_machine +from sqlalchemy import select from airflow._shared.timezones import timezone from airflow.models import Log @@ -60,7 +61,7 @@ def test_timestamp_behaviour(dag_maker, session): current_time = timezone.utcnow() old_log = add_log(execdate, session, dag_maker) session.expunge(old_log) - log_time = session.query(Log).one().dttm + log_time = session.scalars(select(Log)).one().dttm assert log_time == current_time assert log_time.tzinfo.name == "UTC" @@ -73,7 +74,7 @@ def test_timestamp_behaviour_with_timezone(dag_maker, session): old_log = add_log(execdate, session, dag_maker, timezone_override=pendulum.timezone("Europe/Warsaw")) session.expunge(old_log) # No matter what timezone we set - we should always get back UTC - log_time = session.query(Log).one().dttm + log_time = session.scalars(select(Log)).one().dttm assert log_time == current_time assert old_log.dttm.tzinfo.name != "UTC" assert log_time.tzinfo.name == "UTC" diff --git a/airflow-core/tests/unit/ti_deps/deps/test_runnable_exec_date_dep.py b/airflow-core/tests/unit/ti_deps/deps/test_runnable_exec_date_dep.py index bd0a2b6836a8d..771d55b57a8d5 100644 --- a/airflow-core/tests/unit/ti_deps/deps/test_runnable_exec_date_dep.py +++ b/airflow-core/tests/unit/ti_deps/deps/test_runnable_exec_date_dep.py @@ -21,6 +21,7 @@ import pytest import time_machine +from sqlalchemy import delete from airflow._shared.timezones.timezone import datetime from airflow.models import DagRun, TaskInstance @@ -33,8 +34,8 @@ @pytest.fixture(autouse=True) def clean_db(session): yield - session.query(DagRun).delete() - session.query(TaskInstance).delete() + session.execute(delete(DagRun)) + session.execute(delete(TaskInstance)) @time_machine.travel("2016-11-01") diff --git a/airflow-core/tests/unit/timetables/test_assets_timetable.py b/airflow-core/tests/unit/timetables/test_assets_timetable.py index 3026cace6327b..537fab7ed90eb 100644 --- a/airflow-core/tests/unit/timetables/test_assets_timetable.py +++ b/airflow-core/tests/unit/timetables/test_assets_timetable.py @@ -274,7 +274,7 @@ def test_asset_dag_run_queue_processing(self, session, dag_maker, create_test_as from airflow.assets.evaluation import AssetEvaluator assets = create_test_assets - asset_models = session.query(AssetModel).all() + asset_models = session.scalars(select(AssetModel)).all() evaluator = AssetEvaluator(session) with dag_maker(schedule=AssetAny(*assets)) as dag: diff --git a/airflow-core/tests/unit/utils/test_cli_util.py b/airflow-core/tests/unit/utils/test_cli_util.py index 6f237dca3424b..9cc8a7350c40d 100644 --- a/airflow-core/tests/unit/utils/test_cli_util.py +++ b/airflow-core/tests/unit/utils/test_cli_util.py @@ -27,6 +27,7 @@ from unittest import mock import pytest +from sqlalchemy import select import airflow from airflow import settings @@ -175,7 +176,7 @@ def test_cli_create_user_supplied_password_is_masked( mock_create_session.return_value.bulk_insert_mappings = session.bulk_insert_mappings cli_action_loggers.default_action_log(**metrics) - log = session.query(Log).order_by(Log.dttm.desc()).first() + log = session.scalar(select(Log).order_by(Log.dttm.desc())) assert metrics.get("start_datetime") <= timezone.utcnow() @@ -234,7 +235,7 @@ def test_cli_set_variable_supplied_sensitive_value_is_masked( mock_create_session.return_value.bulk_insert_mappings = session.bulk_insert_mappings cli_action_loggers.default_action_log(**metrics) - log = session.query(Log).order_by(Log.dttm.desc()).first() + log = session.scalar(select(Log).order_by(Log.dttm.desc())) assert metrics.get("start_datetime") <= timezone.utcnow()