diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 939039b0827fb..e24116494303c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -425,6 +425,7 @@ repos: ^airflow-ctl.*\.py$| ^airflow-core/src/airflow/models/.*\.py$| ^airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_assets.py$| + ^airflow-core/tests/unit/models/test_serialized_dag.py$| ^airflow-core/tests/unit/utils/test_db_cleanup.py$| ^dev/airflow_perf/scheduler_dag_execution_timing.py$| ^providers/openlineage/.*\.py$| diff --git a/airflow-core/tests/unit/models/test_serialized_dag.py b/airflow-core/tests/unit/models/test_serialized_dag.py index 2472c0720461e..ddc98f7206492 100644 --- a/airflow-core/tests/unit/models/test_serialized_dag.py +++ b/airflow-core/tests/unit/models/test_serialized_dag.py @@ -24,7 +24,7 @@ import pendulum import pytest -from sqlalchemy import func, select, update +from sqlalchemy import delete, func, select, update import airflow.example_dags as example_dags_module from airflow.dag_processing.dagbag import DagBag @@ -59,7 +59,12 @@ def make_example_dags(module): from airflow.utils.session import create_session with create_session() as session: - if session.query(DagBundleModel).filter(DagBundleModel.name == "testing").count() == 0: + if ( + session.scalar( + select(func.count()).select_from(DagBundleModel).where(DagBundleModel.name == "testing") + ) + == 0 + ): testing = DagBundleModel(name="testing") session.add(testing) @@ -101,7 +106,7 @@ def test_write_dag(self, testing_dag_bundle): with create_session() as session: for dag in example_dags.values(): assert SDM.has_dag(dag.dag_id) - result = session.query(SDM).filter(SDM.dag_id == dag.dag_id).one() + result = session.scalar(select(SDM).where(SDM.dag_id == dag.dag_id)) assert result.dag_version.dag_code.fileloc == dag.fileloc # Verifies JSON schema. @@ -118,7 +123,7 @@ def my_callable(): with dag_maker("dag1"): PythonOperator(task_id="task1", python_callable=lambda x: None) dag_maker.create_dagrun(run_id="test2", logical_date=pendulum.datetime(2025, 1, 1)) - assert len(session.query(DagVersion).all()) == 2 + assert len(session.scalars(select(DagVersion)).all()) == 2 with dag_maker("dag2"): @@ -136,7 +141,7 @@ def my_callable2(): pass my_callable2() - assert len(session.query(DagVersion).all()) == 4 + assert len(session.scalars(select(DagVersion)).all()) == 4 def test_serialized_dag_is_updated_if_dag_is_changed(self, testing_dag_bundle): """Test Serialized DAG is updated if DAG is changed""" @@ -212,7 +217,7 @@ def test_read_all_dags_only_picks_the_latest_serdags(self, session): dag.doc_md = "new doc string" SDM.write_dag(LazyDeserializedDAG.from_dag(dag), bundle_name="testing") serialized_dags2 = SDM.read_all_dags() - sdags = session.query(SDM).all() + sdags = session.scalars(select(SDM)).all() # assert only the latest SDM is returned assert len(sdags) != len(serialized_dags2) @@ -334,7 +339,7 @@ def test_get_latest_serdag_versions(self, dag_maker, session): def test_new_dag_versions_are_not_created_if_no_dagruns(self, dag_maker, session): with dag_maker("dag1") as dag: PythonOperator(task_id="task1", python_callable=lambda: None) - assert session.query(SDM).count() == 1 + assert session.scalar(select(func.count()).select_from(SDM)) == 1 sdm1 = SDM.get(dag.dag_id, session=session) dag_hash = sdm1.dag_hash created_at = sdm1.created_at @@ -347,21 +352,21 @@ def test_new_dag_versions_are_not_created_if_no_dagruns(self, dag_maker, session assert sdm2.dag_hash != dag_hash # first recorded serdag assert sdm2.created_at == created_at assert sdm2.last_updated != last_updated - assert session.query(DagVersion).count() == 1 - assert session.query(SDM).count() == 1 + assert session.scalar(select(func.count()).select_from(DagVersion)) == 1 + assert session.scalar(select(func.count()).select_from(SDM)) == 1 def test_new_dag_versions_are_created_if_there_is_a_dagrun(self, dag_maker, session): with dag_maker("dag1") as dag: PythonOperator(task_id="task1", python_callable=lambda: None) dag_maker.create_dagrun(run_id="test3", logical_date=pendulum.datetime(2025, 1, 2)) - assert session.query(SDM).count() == 1 - assert session.query(DagVersion).count() == 1 + assert session.scalar(select(func.count()).select_from(SDM)) == 1 + assert session.scalar(select(func.count()).select_from(DagVersion)) == 1 # new task PythonOperator(task_id="task2", python_callable=lambda: None, dag=dag) SDM.write_dag(LazyDeserializedDAG.from_dag(dag), bundle_name="dag_maker") - assert session.query(DagVersion).count() == 2 - assert session.query(SDM).count() == 2 + assert session.scalar(select(func.count()).select_from(DagVersion)) == 2 + assert session.scalar(select(func.count()).select_from(SDM)) == 2 def test_example_dag_sorting_serialised_dag(self, session): """ @@ -517,14 +522,14 @@ def test_new_dag_version_created_when_bundle_name_changes_and_hash_unchanged(sel # Create TIs dag_maker.create_dagrun(run_id="test_run") - assert session.query(DagVersion).count() == 1 + assert session.scalar(select(func.count()).select_from(DagVersion)) == 1 # Write the same DAG (no changes, so hash is the same) with a new bundle_name new_bundle = "bundleB" SDM.write_dag(LazyDeserializedDAG.from_dag(dag), bundle_name=new_bundle) # There should now be two versions of the DAG - assert session.query(DagVersion).count() == 2 + assert session.scalar(select(func.count()).select_from(DagVersion)) == 2 def test_hash_method_removes_fileloc_and_remains_consistent(self): """Test that the hash method removes fileloc before hashing.""" @@ -632,7 +637,7 @@ def test_dynamic_dag_update_preserves_null_check(self, dag_maker, session): assert dag_version is not None # Manually delete SerializedDagModel (simulates edge case) - session.query(SDM).filter(SDM.dag_id == "test_missing_serdag").delete() + session.execute(delete(SDM).where(SDM.dag_id == "test_missing_serdag")) session.commit() # Verify no SerializedDagModel exists @@ -709,7 +714,9 @@ def test_write_dag_atomicity_on_dagcode_failure(self, dag_maker, session): EmptyOperator(task_id="task1") dag = dag_maker.dag - initial_version_count = session.query(DagVersion).filter(DagVersion.dag_id == dag.dag_id).count() + initial_version_count = session.scalar( + select(func.count()).select_from(DagVersion).where(DagVersion.dag_id == dag.dag_id) + ) assert initial_version_count == 1, "Should have one DagVersion after initial write" dag_maker.create_dagrun() # ensure the second dag version is created @@ -732,8 +739,8 @@ def test_write_dag_atomicity_on_dagcode_failure(self, dag_maker, session): # Verify that no new DagVersion was committed # Use a fresh session to ensure we're reading from committed data with create_session() as fresh_session: - final_version_count = ( - fresh_session.query(DagVersion).filter(DagVersion.dag_id == dag.dag_id).count() + final_version_count = fresh_session.scalar( + select(func.count()).select_from(DagVersion).where(DagVersion.dag_id == dag.dag_id) ) assert final_version_count == initial_version_count, ( "DagVersion should not be committed when DagCode.write_code fails"