Skip to content

Commit

Permalink
fixed tests for AsyncAbstractRepository
Browse files Browse the repository at this point in the history
  • Loading branch information
KozyrevIvan committed Oct 4, 2024
1 parent e34fa63 commit c4b62f0
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions test/test_async_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def database():
return client.db


class AsyncTestRepository:
class TestAsyncRepository:
@pytest.mark.asyncio
async def test_save(self, database):
spam_repository = SpamRepository(database=database)
Expand All @@ -50,20 +50,22 @@ async def test_save(self, database):
spam = Spam(foo=foo, bars=[bar])
await spam_repository.save(spam)

result = await database["spams"].find().to_list(length=None)
assert {
"_id": ObjectId(spam.id),
"foo": {"count": 1, "size": 1.0},
"bars": [{"apple": "x", "banana": "y"}],
} == database["spams"].find()[0]
} == result[0]

cast(Foo, spam.foo).count = 2
await spam_repository.save(spam)

result = await database["spams"].find().to_list(length=None)
assert {
"_id": ObjectId(spam.id),
"foo": {"count": 2, "size": 1.0},
"bars": [{"apple": "x", "banana": "y"}],
} == database["spams"].find()[0]
} == result[0]

@pytest.mark.asyncio
async def test_save_upsert(self, database):
Expand All @@ -73,11 +75,12 @@ async def test_save_upsert(self, database):
)
await spam_repository.save(spam)

result = await database["spams"].find().to_list(length=None)
assert {
"_id": ObjectId(spam.id),
"foo": {"count": 1, "size": 1.0},
"bars": [],
} == database["spams"].find()[0]
} == result[0]

@pytest.mark.asyncio
async def test_save_many(self, database):
Expand All @@ -101,11 +104,13 @@ async def test_save_many(self, database):
},
]

assert initial_data == list(database["spams"].find())
result = await database["spams"].find().to_list(length=None)
assert initial_data == result

# Calling save_many again will only update
await spam_repository.save_many(spams)
assert initial_data == list(database["spams"].find())
result = await database["spams"].find().to_list(length=None)
assert initial_data == result

# Calling save_many with only a new model will only insert
new_span = Spam()
Expand Down

0 comments on commit c4b62f0

Please sign in to comment.