Skip to content

Commit

Permalink
fix pylance warnings in tests
Browse files Browse the repository at this point in the history
  • Loading branch information
dantownsend committed Mar 11, 2024
1 parent 0e2ec8a commit 2def0ed
Show file tree
Hide file tree
Showing 51 changed files with 328 additions and 233 deletions.
4 changes: 3 additions & 1 deletion tests/apps/app/commands/test_new.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,5 +39,7 @@ def test_new_with_clashing_name(self):

exception = context.exception
self.assertTrue(
exception.code.startswith("A module called sys already exists")
str(exception.code).startswith(
"A module called sys already exists"
)
)
2 changes: 0 additions & 2 deletions tests/apps/fixtures/commands/test_dump_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,5 +276,3 @@ def test_on_conflict(self):

run_sync(load(path=json_file_path, on_conflict="DO NOTHING"))
run_sync(load(path=json_file_path, on_conflict="DO UPDATE"))
run_sync(load(path=json_file_path, on_conflict="do nothing"))
run_sync(load(path=json_file_path, on_conflict="do update"))
4 changes: 2 additions & 2 deletions tests/apps/fixtures/commands/test_shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,5 +56,5 @@ def test_shared(self):
}

model = pydantic_model(**data)
self.assertEqual(model.mega.SmallTable[0].id, 1)
self.assertEqual(model.mega.MegaTable[0].id, 1)
self.assertEqual(model.mega.SmallTable[0].id, 1) # type: ignore
self.assertEqual(model.mega.MegaTable[0].id, 1) # type: ignore
15 changes: 8 additions & 7 deletions tests/apps/migrations/auto/test_migration_manager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import random
import typing as t
from io import StringIO
from unittest import TestCase
from unittest.mock import MagicMock, patch
Expand Down Expand Up @@ -267,7 +268,7 @@ def test_add_table(self, get_app_config: MagicMock):
self.assertEqual(self.table_exists("musician"), False)

@engines_only("postgres", "cockroach")
def test_add_column(self):
def test_add_column(self) -> None:
"""
Test adding a column to a MigrationManager.
"""
Expand Down Expand Up @@ -304,21 +305,21 @@ def test_add_column(self):
response = self.run_sync("SELECT * FROM manager;")
self.assertEqual(response, [{"id": 1, "name": "Dave"}])

id = 0
row_id: t.Optional[int] = None
if engine_is("cockroach"):
id = self.run_sync(
row_id = self.run_sync(
"INSERT INTO manager VALUES (default, 'Dave', 'dave@me.com') RETURNING id;" # noqa: E501
)
)[0]["id"]
response = self.run_sync("SELECT * FROM manager;")
self.assertEqual(
response,
[{"id": id[0]["id"], "name": "Dave", "email": "dave@me.com"}],
[{"id": row_id, "name": "Dave", "email": "dave@me.com"}],
)

# Reverse
asyncio.run(manager.run(backwards=True))
response = self.run_sync("SELECT * FROM manager;")
self.assertEqual(response, [{"id": id[0]["id"], "name": "Dave"}])
self.assertEqual(response, [{"id": row_id, "name": "Dave"}])

# Preview
manager.preview = True
Expand All @@ -333,7 +334,7 @@ def test_add_column(self):
if engine_is("postgres"):
self.assertEqual(response, [{"id": 1, "name": "Dave"}])
if engine_is("cockroach"):
self.assertEqual(response, [{"id": id[0]["id"], "name": "Dave"}])
self.assertEqual(response, [{"id": row_id, "name": "Dave"}])

@engines_only("postgres", "cockroach")
def test_add_column_with_index(self):
Expand Down
16 changes: 12 additions & 4 deletions tests/apps/user/test_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,15 +221,21 @@ def test_long_password_error(self):

def test_no_username_error(self):
with self.assertRaises(ValueError) as manager:
BaseUser.create_user_sync(username=None, password="abc123")
BaseUser.create_user_sync(
username=None, # type: ignore
password="abc123",
)

self.assertEqual(
manager.exception.__str__(), "A username must be provided."
)

def test_no_password_error(self):
with self.assertRaises(ValueError) as manager:
BaseUser.create_user_sync(username="bob", password=None)
BaseUser.create_user_sync(
username="bob",
password=None, # type: ignore
)

self.assertEqual(
manager.exception.__str__(), "A password must be provided."
Expand Down Expand Up @@ -272,12 +278,14 @@ def test_hash_update(self):
BaseUser.login_sync(username=username, password=password)
)

hashed_password = (
user_data = (
BaseUser.select(BaseUser.password)
.where(BaseUser.id == user.id)
.first()
.run_sync()["password"]
.run_sync()
)
assert user_data is not None
hashed_password = user_data["password"]

algorithm, iterations_, salt, hashed = BaseUser.split_stored_password(
hashed_password
Expand Down
18 changes: 13 additions & 5 deletions tests/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,19 @@
ENGINE = engine_finder()


def engine_version_lt(version: float):
return ENGINE and run_sync(ENGINE.get_version()) < version
def engine_version_lt(version: float) -> bool:
return ENGINE is not None and run_sync(ENGINE.get_version()) < version


def is_running_postgres():
def is_running_postgres() -> bool:
return type(ENGINE) is PostgresEngine


def is_running_sqlite():
def is_running_sqlite() -> bool:
return type(ENGINE) is SQLiteEngine


def is_running_cockroach():
def is_running_cockroach() -> bool:
return type(ENGINE) is CockroachEngine


Expand Down Expand Up @@ -228,6 +228,8 @@ def get_postgres_varchar_length(
###########################################################################

def create_tables(self):
assert ENGINE is not None

if ENGINE.engine_type in ("postgres", "cockroach"):
self.run_sync(
"""
Expand Down Expand Up @@ -308,6 +310,8 @@ def create_tables(self):
raise Exception("Unrecognised engine")

def insert_row(self):
assert ENGINE is not None

if ENGINE.engine_type == "cockroach":
id = self.run_sync(
"""
Expand Down Expand Up @@ -352,6 +356,8 @@ def insert_row(self):
)

def insert_rows(self):
assert ENGINE is not None

if ENGINE.engine_type == "cockroach":
id = self.run_sync(
"""
Expand Down Expand Up @@ -428,6 +434,8 @@ def insert_many_rows(self, row_count=10000):
self.run_sync(f"INSERT INTO manager (name) VALUES {values_string};")

def drop_tables(self):
assert ENGINE is not None

if ENGINE.engine_type in ("postgres", "cockroach"):
self.run_sync("DROP TABLE IF EXISTS band CASCADE;")
self.run_sync("DROP TABLE IF EXISTS manager CASCADE;")
Expand Down
8 changes: 4 additions & 4 deletions tests/columns/foreign_key/test_all_columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,17 @@ def test_all_columns_deep(self):
"""
Make sure ``all_columns`` works when the joins are several layers deep.
"""
all_columns = Concert.band_1.manager.all_columns()
self.assertEqual(all_columns, [Band.manager.id, Band.manager.name])
all_columns = Concert.band_1._.manager.all_columns()
self.assertEqual(all_columns, [Band.manager._.id, Band.manager._.name])

# Make sure the call chains are also correct.
self.assertEqual(
all_columns[0]._meta.call_chain,
Concert.band_1.manager.id._meta.call_chain,
Concert.band_1._.manager._.id._meta.call_chain,
)
self.assertEqual(
all_columns[1]._meta.call_chain,
Concert.band_1.manager.name._meta.call_chain,
Concert.band_1._.manager._.name._meta.call_chain,
)

def test_all_columns_exclude(self):
Expand Down
8 changes: 4 additions & 4 deletions tests/columns/foreign_key/test_all_related.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,13 @@ def test_all_related_deep(self):
"""
Make sure ``all_related`` works when the joins are several layers deep.
"""
all_related = Ticket.concert.band_1.all_related()
self.assertEqual(all_related, [Ticket.concert.band_1.manager])
all_related = Ticket.concert._.band_1.all_related()
self.assertEqual(all_related, [Ticket.concert._.band_1._.manager])

# Make sure the call chains are also correct.
self.assertEqual(
all_related[0]._meta.call_chain,
Ticket.concert.band_1.manager._meta.call_chain,
Ticket.concert._.band_1._.manager._meta.call_chain,
)

def test_all_related_exclude(self):
Expand All @@ -57,6 +57,6 @@ def test_all_related_exclude(self):
)

self.assertEqual(
Ticket.concert.all_related(exclude=[Ticket.concert.venue]),
Ticket.concert.all_related(exclude=[Ticket.concert._.venue]),
[Ticket.concert.band_1, Ticket.concert.band_2],
)
2 changes: 1 addition & 1 deletion tests/columns/foreign_key/test_attribute_access.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,6 @@ def test_recursion_time(self):
Make sure that a really large call chain doesn't take too long.
"""
start = time.time()
Manager.manager.manager.manager.manager.manager.manager.name
Manager.manager._.manager._.manager._.manager._.manager._.manager._.name # noqa: E501
end = time.time()
self.assertLess(end - start, 1.0)
3 changes: 2 additions & 1 deletion tests/columns/foreign_key/test_foreign_key_self.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from unittest import TestCase

from piccolo.columns import ForeignKey, Varchar
from piccolo.columns import ForeignKey, Serial, Varchar
from piccolo.table import Table


class Manager(Table, tablename="manager"):
id: Serial
name = Varchar()
manager: ForeignKey["Manager"] = ForeignKey("self", null=True)

Expand Down
2 changes: 1 addition & 1 deletion tests/columns/foreign_key/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def test_with_schema(self):
query = Concert.select(
Concert.start_date,
Concert.band.name.as_alias("band_name"),
Concert.band.manager.name.as_alias("manager_name"),
Concert.band._.manager._.name.as_alias("manager_name"),
)
self.assertIn('"schema_1"."concert"', query.__str__())
self.assertIn('"schema_1"."band"', query.__str__())
Expand Down
Loading

0 comments on commit 2def0ed

Please sign in to comment.