diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 81e24f2e8a3b3..f7dbab3ec4e99 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -413,6 +413,7 @@ repos: ^airflow-ctl.*\.py$| ^airflow-core/src/airflow/models/.*\.py$| ^airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_assets.py$| + ^airflow-core/tests/unit/cli/commands/test_pool_command\.py$| ^task_sdk.*\.py$ pass_filenames: true - id: update-supported-versions diff --git a/airflow-core/tests/unit/cli/commands/test_pool_command.py b/airflow-core/tests/unit/cli/commands/test_pool_command.py index 98248abd5a669..dc87cf8cc34d4 100644 --- a/airflow-core/tests/unit/cli/commands/test_pool_command.py +++ b/airflow-core/tests/unit/cli/commands/test_pool_command.py @@ -20,6 +20,7 @@ import json import pytest +from sqlalchemy import func, select from airflow import models, settings from airflow.cli import cli_parser @@ -47,7 +48,9 @@ def tearDown(self): def _cleanup(session=None): if session is None: session = Session() - session.query(Pool).filter(Pool.pool != Pool.DEFAULT_POOL_NAME).delete() + session.execute( + Pool.__table__.delete().where(Pool.pool != Pool.DEFAULT_POOL_NAME), + ) session.commit() add_default_pool_if_not_exists() session.close() @@ -64,19 +67,28 @@ def test_pool_list_with_args(self): def test_pool_create(self): pool_command.pool_set(self.parser.parse_args(["pools", "set", "foo", "1", "test"])) - assert self.session.query(Pool).count() == 2 + assert self.session.execute(select(func.count()).select_from(Pool)).scalar_one() == 2 def test_pool_update_deferred(self): pool_command.pool_set(self.parser.parse_args(["pools", "set", "foo", "1", "test"])) - assert self.session.query(Pool).filter(Pool.pool == "foo").first().include_deferred is False + assert ( + self.session.execute(select(Pool).where(Pool.pool == "foo")).scalars().first().include_deferred + is False + ) pool_command.pool_set( self.parser.parse_args(["pools", "set", "foo", "1", "test", "--include-deferred"]) ) - assert self.session.query(Pool).filter(Pool.pool == "foo").first().include_deferred is True + assert ( + self.session.execute(select(Pool).where(Pool.pool == "foo")).scalars().first().include_deferred + is True + ) pool_command.pool_set(self.parser.parse_args(["pools", "set", "foo", "1", "test"])) - assert self.session.query(Pool).filter(Pool.pool == "foo").first().include_deferred is False + assert ( + self.session.execute(select(Pool).where(Pool.pool == "foo")).scalars().first().include_deferred + is False + ) def test_pool_get(self): pool_command.pool_set(self.parser.parse_args(["pools", "set", "foo", "1", "test"])) @@ -85,7 +97,7 @@ def test_pool_get(self): def test_pool_delete(self): pool_command.pool_set(self.parser.parse_args(["pools", "set", "foo", "1", "test"])) pool_command.pool_delete(self.parser.parse_args(["pools", "delete", "foo"])) - assert self.session.query(Pool).count() == 1 + assert self.session.execute(select(func.count()).select_from(Pool)).scalar_one() == 1 def test_pool_import_nonexistent(self): with pytest.raises(SystemExit): @@ -123,7 +135,10 @@ def test_pool_import_backwards_compatibility(self, tmp_path): pool_command.pool_import(self.parser.parse_args(["pools", "import", str(pool_import_file_path)])) - assert self.session.query(Pool).filter(Pool.pool == "foo").first().include_deferred is False + assert ( + self.session.execute(select(Pool).where(Pool.pool == "foo")).scalars().first().include_deferred + is False + ) def test_pool_import_export(self, tmp_path): pool_import_file_path = tmp_path / "pools_import.json"