diff --git a/airflow-core/tests/unit/dag_processing/test_manager.py b/airflow-core/tests/unit/dag_processing/test_manager.py index 91692e6d0fa44..456d925935bd2 100644 --- a/airflow-core/tests/unit/dag_processing/test_manager.py +++ b/airflow-core/tests/unit/dag_processing/test_manager.py @@ -172,7 +172,7 @@ def test_remove_file_clears_import_error(self, tmp_path, configure_testing_dag_b manager.run() with create_session() as session: - import_errors = session.query(ParseImportError).all() + import_errors = session.scalars(select(ParseImportError)).all() assert len(import_errors) == 1 path_to_parse.unlink() @@ -181,7 +181,7 @@ def test_remove_file_clears_import_error(self, tmp_path, configure_testing_dag_b manager.run() with create_session() as session: - import_errors = session.query(ParseImportError).all() + import_errors = session.scalars(select(ParseImportError)).all() assert len(import_errors) == 0 session.rollback() @@ -435,7 +435,7 @@ def test_parsing_requests_only_bundles_being_parsed(self, testing_dag_bundle): manager._queue_requested_files_for_parsing() assert manager._file_queue == deque([file1]) with create_session() as session2: - parsing_request_after = session2.query(DagPriorityParsingRequest).all() + parsing_request_after = session2.scalars(select(DagPriorityParsingRequest)).all() assert len(parsing_request_after) == 1 assert parsing_request_after[0].relative_fileloc == "file_x.py" @@ -480,34 +480,28 @@ def test_scan_stale_dags(self, session): manager._files = [test_dag_path] manager._file_stats[test_dag_path] = stat - active_dag_count = ( - session.query(func.count(DagModel.dag_id)) - .filter( + active_dag_count = session.scalar( + select(func.count(DagModel.dag_id)).where( ~DagModel.is_stale, DagModel.relative_fileloc == str(test_dag_path.rel_path), DagModel.bundle_name == test_dag_path.bundle_name, ) - .scalar() ) assert active_dag_count == 1 manager._scan_stale_dags() - active_dag_count = ( - session.query(func.count(DagModel.dag_id)) - .filter( + active_dag_count = session.scalar( + select(func.count(DagModel.dag_id)).where( ~DagModel.is_stale, DagModel.relative_fileloc == str(test_dag_path.rel_path), DagModel.bundle_name == test_dag_path.bundle_name, ) - .scalar() ) assert active_dag_count == 0 - serialized_dag_count = ( - session.query(func.count(SerializedDagModel.dag_id)) - .filter(SerializedDagModel.dag_id == dag.dag_id) - .scalar() + serialized_dag_count = session.scalar( + select(func.count(SerializedDagModel.dag_id)).where(SerializedDagModel.dag_id == dag.dag_id) ) # Deactivating the DagModel should not delete the SerializedDagModel # SerializedDagModel gives history about Dags @@ -776,7 +770,7 @@ def test_fetch_callbacks_from_database(self, configure_testing_dag_bundle): assert callbacks[0].run_id == "123" assert callbacks[1].run_id == "456" - assert session.query(DbCallbackRequest).count() == 0 + assert len(session.scalars(select(DbCallbackRequest)).all()) == 0 @conf_vars( { @@ -805,11 +799,11 @@ def test_fetch_callbacks_from_database_max_per_loop(self, tmp_path, configure_te with create_session() as session: manager.run() - assert session.query(DbCallbackRequest).count() == 3 + assert len(session.scalars(select(DbCallbackRequest)).all()) == 3 with create_session() as session: manager.run() - assert session.query(DbCallbackRequest).count() == 1 + assert len(session.scalars(select(DbCallbackRequest)).all()) == 1 @conf_vars({("core", "load_examples"): "False"}) def test_fetch_callbacks_ignores_other_bundles(self, configure_testing_dag_bundle): @@ -850,7 +844,7 @@ def test_fetch_callbacks_ignores_other_bundles(self, configure_testing_dag_bundl assert [c.run_id for c in callbacks] == ["match"] # The non-matching callback should remain in the DB - remaining = session.query(DbCallbackRequest).all() + remaining = session.scalars(select(DbCallbackRequest)).all() assert len(remaining) == 1 # Decode remaining request and verify it's for the other bundle remaining_req = remaining[0].get_callback_request() diff --git a/airflow-core/tests/unit/dag_processing/test_processor.py b/airflow-core/tests/unit/dag_processing/test_processor.py index fc95923cf02ae..63e63a57643d1 100644 --- a/airflow-core/tests/unit/dag_processing/test_processor.py +++ b/airflow-core/tests/unit/dag_processing/test_processor.py @@ -32,6 +32,7 @@ import pytest import structlog from pydantic import TypeAdapter +from sqlalchemy import select from structlog.typing import FilteringBoundLogger from airflow._shared.timezones import timezone @@ -242,7 +243,7 @@ def dag_in_a_fn(): assert result.import_errors == {} assert result.serialized_dags[0].dag_id == "test_myvalue" - all_vars = session.query(VariableORM).all() + all_vars = session.scalars(select(VariableORM)).all() assert len(all_vars) == 1 assert all_vars[0].key == "mykey" @@ -285,7 +286,7 @@ def dag_in_a_fn(): assert result.import_errors == {} assert result.serialized_dags[0].dag_id == "not-found" - all_vars = session.query(VariableORM).all() + all_vars = session.scalars(select(VariableORM)).all() assert len(all_vars) == 0 def test_top_level_connection_access( diff --git a/airflow-core/tests/unit/utils/test_types.py b/airflow-core/tests/unit/utils/test_types.py index 4a6831f40354d..277a0d8460723 100644 --- a/airflow-core/tests/unit/utils/test_types.py +++ b/airflow-core/tests/unit/utils/test_types.py @@ -19,6 +19,7 @@ from datetime import timedelta import pytest +from sqlalchemy import select from airflow.models.dagrun import DagRun from airflow.utils.state import State @@ -36,22 +37,22 @@ def test_runtype_enum_escape(dag_maker, session): pass dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED) - query = session.query( + query = select( DagRun.dag_id, DagRun.state, DagRun.run_type, - ).filter( + ).where( DagRun.dag_id == "test_enum_dags", # make sure enum value can be used in filter queries DagRun.run_type == DagRunType.SCHEDULED, ) - assert str(query.statement.compile(compile_kwargs={"literal_binds": True})) == ( + rows = session.execute(query).all() + assert str(query.compile(compile_kwargs={"literal_binds": True})) == ( "SELECT dag_run.dag_id, dag_run.state, dag_run.run_type \n" "FROM dag_run \n" "WHERE dag_run.dag_id = 'test_enum_dags' AND dag_run.run_type = 'scheduled'" ) - rows = query.all() assert len(rows) == 1 assert rows[0].dag_id == "test_enum_dags" assert rows[0].state == State.RUNNING