Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,7 @@ repos:
^airflow-ctl.*\.py$|
^airflow-core/src/airflow/models/.*\.py$|
^airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_assets.py$|
^airflow-core/tests/unit/cli/commands/test_pool_command\.py$|
^task_sdk.*\.py$
pass_filenames: true
- id: update-supported-versions
Expand Down
29 changes: 22 additions & 7 deletions airflow-core/tests/unit/cli/commands/test_pool_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import json

import pytest
from sqlalchemy import func, select

from airflow import models, settings
from airflow.cli import cli_parser
Expand Down Expand Up @@ -47,7 +48,9 @@ def tearDown(self):
def _cleanup(session=None):
if session is None:
session = Session()
session.query(Pool).filter(Pool.pool != Pool.DEFAULT_POOL_NAME).delete()
session.execute(
Pool.__table__.delete().where(Pool.pool != Pool.DEFAULT_POOL_NAME),
)
session.commit()
add_default_pool_if_not_exists()
session.close()
Expand All @@ -64,19 +67,28 @@ def test_pool_list_with_args(self):

def test_pool_create(self):
pool_command.pool_set(self.parser.parse_args(["pools", "set", "foo", "1", "test"]))
assert self.session.query(Pool).count() == 2
assert self.session.execute(select(func.count()).select_from(Pool)).scalar_one() == 2

def test_pool_update_deferred(self):
pool_command.pool_set(self.parser.parse_args(["pools", "set", "foo", "1", "test"]))
assert self.session.query(Pool).filter(Pool.pool == "foo").first().include_deferred is False
assert (
self.session.execute(select(Pool).where(Pool.pool == "foo")).scalars().first().include_deferred
is False
)

pool_command.pool_set(
self.parser.parse_args(["pools", "set", "foo", "1", "test", "--include-deferred"])
)
assert self.session.query(Pool).filter(Pool.pool == "foo").first().include_deferred is True
assert (
self.session.execute(select(Pool).where(Pool.pool == "foo")).scalars().first().include_deferred
is True
)

pool_command.pool_set(self.parser.parse_args(["pools", "set", "foo", "1", "test"]))
assert self.session.query(Pool).filter(Pool.pool == "foo").first().include_deferred is False
assert (
self.session.execute(select(Pool).where(Pool.pool == "foo")).scalars().first().include_deferred
is False
)

def test_pool_get(self):
pool_command.pool_set(self.parser.parse_args(["pools", "set", "foo", "1", "test"]))
Expand All @@ -85,7 +97,7 @@ def test_pool_get(self):
def test_pool_delete(self):
pool_command.pool_set(self.parser.parse_args(["pools", "set", "foo", "1", "test"]))
pool_command.pool_delete(self.parser.parse_args(["pools", "delete", "foo"]))
assert self.session.query(Pool).count() == 1
assert self.session.execute(select(func.count()).select_from(Pool)).scalar_one() == 1

def test_pool_import_nonexistent(self):
with pytest.raises(SystemExit):
Expand Down Expand Up @@ -123,7 +135,10 @@ def test_pool_import_backwards_compatibility(self, tmp_path):

pool_command.pool_import(self.parser.parse_args(["pools", "import", str(pool_import_file_path)]))

assert self.session.query(Pool).filter(Pool.pool == "foo").first().include_deferred is False
assert (
self.session.execute(select(Pool).where(Pool.pool == "foo")).scalars().first().include_deferred
is False
)

def test_pool_import_export(self, tmp_path):
pool_import_file_path = tmp_path / "pools_import.json"
Expand Down