diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 81e24f2e8a3b3..959ace70bc402 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -413,6 +413,7 @@ repos: ^airflow-ctl.*\.py$| ^airflow-core/src/airflow/models/.*\.py$| ^airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_assets.py$| + ^airflow-core/tests/unit/assets/test_manager.py$| ^task_sdk.*\.py$ pass_filenames: true - id: update-supported-versions diff --git a/airflow-core/tests/unit/assets/test_manager.py b/airflow-core/tests/unit/assets/test_manager.py index 46a59198d8018..3d83ef34db9cb 100644 --- a/airflow-core/tests/unit/assets/test_manager.py +++ b/airflow-core/tests/unit/assets/test_manager.py @@ -21,7 +21,7 @@ from unittest import mock import pytest -from sqlalchemy import delete +from sqlalchemy import delete, func, select from sqlalchemy.orm import Session from airflow.assets.manager import AssetManager @@ -105,8 +105,11 @@ def test_register_asset_change(self, session, dag_maker, mock_task_instance, tes session.flush() # Ensure we've created an asset - assert session.query(AssetEvent).filter_by(asset_id=asm.id).count() == 1 - assert session.query(AssetDagRunQueue).count() == 2 + assert ( + session.scalar(select(func.count()).select_from(AssetEvent).where(AssetEvent.asset_id == asm.id)) + == 1 + ) + assert session.scalar(select(func.count()).select_from(AssetDagRunQueue)) == 2 @pytest.mark.usefixtures("clear_assets") def test_register_asset_change_with_alias( @@ -145,8 +148,11 @@ def test_register_asset_change_with_alias( session.flush() # Ensure we've created an asset - assert session.query(AssetEvent).filter_by(asset_id=asm.id).count() == 1 - assert session.query(AssetDagRunQueue).count() == 2 + assert ( + session.scalar(select(func.count()).select_from(AssetEvent).where(AssetEvent.asset_id == asm.id)) + == 1 + ) + assert session.scalar(select(func.count()).select_from(AssetDagRunQueue)) == 2 def test_register_asset_change_no_downstreams(self, session, mock_task_instance): asset_manager = AssetManager() @@ -161,8 +167,11 @@ def test_register_asset_change_no_downstreams(self, session, mock_task_instance) session.flush() # Ensure we've created an asset - assert session.query(AssetEvent).filter_by(asset_id=asm.id).count() == 1 - assert session.query(AssetDagRunQueue).count() == 0 + assert ( + session.scalar(select(func.count()).select_from(AssetEvent).where(AssetEvent.asset_id == asm.id)) + == 1 + ) + assert session.scalar(select(func.count()).select_from(AssetDagRunQueue)) == 0 def test_register_asset_change_notifies_asset_listener( self, session, mock_task_instance, testing_dag_bundle