diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4d675a90c8b20..81d7c6e3f173c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -412,6 +412,7 @@ repos: (?x) ^airflow-ctl.*\.py$| ^airflow-core/src/airflow/models/.*\.py$| + ^airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_assets.py$| ^task_sdk.*\.py$ pass_filenames: true - id: update-supported-versions diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_assets.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_assets.py index 618b5b00960ce..036d88263d642 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_assets.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_assets.py @@ -22,6 +22,7 @@ import pytest import time_machine +from sqlalchemy import delete, func, select from airflow._shared.timezones import timezone from airflow.models import DagModel @@ -224,8 +225,8 @@ def _create_dag_run(session, num: int = 2): def _create_asset_dag_run(session, num: int = 2): for i in range(1, 1 + num): - dag_run = session.query(DagRun).filter_by(run_id=f"source_run_id_{i}").first() - asset_event = session.query(AssetEvent).filter_by(id=i).first() + dag_run = session.scalar(select(DagRun).where(DagRun.run_id == f"source_run_id_{i}")) + asset_event = session.scalar(select(AssetEvent).where(AssetEvent.id == i)) if dag_run and asset_event: dag_run.consumed_asset_events.append(asset_event) session.commit() @@ -300,8 +301,8 @@ def test_should_respond_200(self, test_client, session): session.add(AssetModel("inactive", "inactive")) session.commit() - assert len(session.query(AssetModel).all()) == 3 - assert len(session.query(AssetActive).all()) == 2 + assert len(session.scalars(select(AssetModel)).all()) == 3 + assert len(session.scalars(select(AssetActive)).all()) == 2 with assert_queries_count(7): response = test_client.get("/assets") @@ -417,8 +418,8 @@ def test_should_show_inactive(self, test_client, session): ) session.commit() - assert len(session.query(AssetModel).all()) == 3 - assert len(session.query(AssetActive).all()) == 2 + assert len(session.scalars(select(AssetModel)).all()) == 3 + assert len(session.scalars(select(AssetActive)).all()) == 2 response = test_client.get("/assets?only_active=0") assert response.status_code == 200 @@ -567,7 +568,7 @@ def test_filter_assets_by_uri_pattern_works(self, test_client, params, expected_ def test_filter_assets_by_dag_ids_works( self, test_client, dag_ids, expected_num, testing_dag_bundle, session ): - session.query(DagModel).delete() + session.execute(delete(DagModel)) session.commit() bundle_name = "testing" @@ -606,7 +607,7 @@ def test_filter_assets_by_dag_ids_works( def test_filter_assets_by_dag_ids_and_uri_pattern_works( self, test_client, dag_ids, uri_pattern, expected_num, testing_dag_bundle, session ): - session.query(DagModel).delete() + session.execute(delete(DagModel)) session.commit() bundle_name = "testing" @@ -697,7 +698,7 @@ def create_provided_asset_alias(self, asset_alias: AssetAliasModel, session): class TestGetAssetAliases(TestAssetAliases): def test_should_respond_200(self, test_client, session): self.create_asset_aliases() - asset_aliases = session.query(AssetAliasModel).all() + asset_aliases = session.scalars(select(AssetAliasModel)).all() assert len(asset_aliases) == 2 with assert_queries_count(2): @@ -782,7 +783,7 @@ def test_should_respond_200(self, test_client, session): self.create_assets_events(session) self.create_dag_run(session) self.create_asset_dag_run(session) - assets = session.query(AssetEvent).all() + assets = session.scalars(select(AssetEvent)).all() session.commit() assert len(assets) == 2 @@ -1039,7 +1040,7 @@ class TestGetAssetEndpoint(TestAssets): @provide_session def test_should_respond_200(self, test_client, session): self.create_assets(num=1) - assert session.query(AssetModel).count() == 1 + assert session.scalars(select(func.count(AssetModel.id))).one() == 1 tz_datetime_format = from_datetime_to_zulu_without_ms(DEFAULT_DATE) with assert_queries_count(6): response = test_client.get("/assets/1") @@ -1134,7 +1135,7 @@ class TestGetAssetAliasEndpoint(TestAssetAliases): @provide_session def test_should_respond_200(self, test_client, session): self.create_asset_aliases(num=1) - assert session.query(AssetAliasModel).count() == 1 + assert session.scalars(select(func.count(AssetAliasModel.id))).one() == 1 with assert_queries_count(6): response = test_client.get("/assets/aliases/1") assert response.status_code == 200 @@ -1148,7 +1149,7 @@ def test_should_respond_404(self, test_client): class TestQueuedEventEndpoint(TestAssets): def _create_asset_dag_run_queues(self, dag_id, asset_id, session): - session.query(AssetDagRunQueue).delete() + session.execute(delete(AssetDagRunQueue)) session.flush() adrq = AssetDagRunQueue(target_dag_id=dag_id, asset_id=asset_id) session.add(adrq) @@ -1209,7 +1210,7 @@ def test_should_respond_204(self, test_client, session, create_dummy_dag): self.create_assets(session=session, num=1) asset_id = 1 self._create_asset_dag_run_queues(dag_id, asset_id, session) - adrqs = session.query(AssetDagRunQueue).all() + adrqs = session.scalars(select(AssetDagRunQueue)).all() assert len(adrqs) == 1 response = test_client.delete( @@ -1217,7 +1218,7 @@ def test_should_respond_204(self, test_client, session, create_dummy_dag): ) assert response.status_code == 204 - adrqs = session.query(AssetDagRunQueue).all() + adrqs = session.scalars(select(AssetDagRunQueue)).all() assert len(adrqs) == 0 check_last_log(session, dag_id=dag_id, event="delete_dag_asset_queued_events", logical_date=None) @@ -1243,7 +1244,7 @@ def test_should_respond_404_valid_dag_no_adrq(self, test_client, session, create dag, _ = create_dummy_dag() dag_id = dag.dag_id self.create_assets(session=session, num=1) - adrqs = session.query(AssetDagRunQueue).all() + adrqs = session.scalars(select(AssetDagRunQueue)).all() assert len(adrqs) == 0 response = test_client.delete( @@ -1493,7 +1494,7 @@ def test_delete_should_respond_204(self, test_client, session, create_dummy_dag) (asset,) = self.create_assets(session=session, num=1) self._create_asset_dag_run_queues(dag_id, asset.id, session) - adrq = session.query(AssetDagRunQueue).all() + adrq = session.scalars(select(AssetDagRunQueue)).all() assert len(adrq) == 1 response = test_client.delete( @@ -1501,7 +1502,7 @@ def test_delete_should_respond_204(self, test_client, session, create_dummy_dag) ) assert response.status_code == 204 - adrq = session.query(AssetDagRunQueue).all() + adrq = session.scalars(select(AssetDagRunQueue)).all() assert len(adrq) == 0 check_last_log(session, dag_id=dag_id, event="delete_dag_asset_queued_event", logical_date=None)