Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: aio prefix for database, rename manager fixture #238

Merged
merged 1 commit into from
May 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 21 additions & 7 deletions peewee_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ async def __aexit__(self, exc_type, exc_value, exc_tb):
############


class AsyncDatabase:
class AioDatabase:
_allow_sync = True # whether sync queries are allowed

def __init__(self, database, **kwargs):
Expand All @@ -283,19 +283,33 @@ def __setattr__(self, name, value):
else:
super().__setattr__(name, value)

async def connect_async(self):
async def aio_connect(self):
"""Set up async connection on default event loop.
"""
if self.deferred:
raise Exception("Error, database not properly initialized "
"before opening connection")
await self.aio_pool.connect()

async def close_async(self):
async def aio_close(self):
"""Close async connection.
"""
await self.aio_pool.terminate()

async def connect_async(self):
warnings.warn(
"`connect_async` is deprecated, use `aio_connect` instead.",
DeprecationWarning
)
await self.aio_connect()

async def close_async(self):
warnings.warn(
"`close_async` is deprecated, use `aio_close` instead.",
DeprecationWarning
)
await self.aio_close()

def transaction_async(self):
"""Similar to peewee `Database.transaction()` method, but returns
asynchronous context manager.
Expand Down Expand Up @@ -472,7 +486,7 @@ async def create(self):
)


class AsyncPostgresqlMixin(AsyncDatabase):
class AioPostgresqlMixin(AioDatabase):
"""Mixin for `peewee.PostgresqlDatabase` providing extra methods
for managing async connection.
"""
Expand Down Expand Up @@ -513,7 +527,7 @@ async def last_insert_id_async(self, cursor):
return cursor.lastrowid


class PostgresqlDatabase(AsyncPostgresqlMixin, peewee.PostgresqlDatabase):
class PostgresqlDatabase(AioPostgresqlMixin, peewee.PostgresqlDatabase):
"""PosgreSQL database driver providing **single drop-in sync** connection
and **single async connection** interface.

Expand All @@ -534,7 +548,7 @@ def init(self, database, **kwargs):
register_database(PostgresqlDatabase, 'postgres+async', 'postgresql+async')


class PooledPostgresqlDatabase(AsyncPostgresqlMixin, peewee.PostgresqlDatabase):
class PooledPostgresqlDatabase(AioPostgresqlMixin, peewee.PostgresqlDatabase):
"""PosgreSQL database driver providing **single drop-in sync**
connection and **async connections pool** interface.

Expand Down Expand Up @@ -581,7 +595,7 @@ async def create(self):
)


class MySQLDatabase(AsyncDatabase, peewee.MySQLDatabase):
class MySQLDatabase(AioDatabase, peewee.MySQLDatabase):
"""MySQL database driver providing **single drop-in sync** connection
and **single async connection** interface.

Expand Down
6 changes: 3 additions & 3 deletions peewee_asyncext.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
from playhouse import postgres_ext as ext
from playhouse.db_url import register_database

from peewee_async import AsyncPostgresqlMixin
from peewee_async import AioPostgresqlMixin


class PostgresqlExtDatabase(AsyncPostgresqlMixin, ext.PostgresqlExtDatabase):
class PostgresqlExtDatabase(AioPostgresqlMixin, ext.PostgresqlExtDatabase):
"""PosgreSQL database extended driver providing **single drop-in sync**
connection and **single async connection** interface.

Expand Down Expand Up @@ -50,7 +50,7 @@ def init(self, database, **kwargs):


class PooledPostgresqlExtDatabase(
AsyncPostgresqlMixin,
AioPostgresqlMixin,
ext.PostgresqlExtDatabase
):
"""PosgreSQL database extended driver providing **single drop-in sync**
Expand Down
6 changes: 3 additions & 3 deletions tests/aio_model/test_deleting.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import uuid

from tests.conftest import postgres_only, all_dbs
from tests.conftest import postgres_only, manager_for_all_dbs
from tests.models import TestModel
from tests.utils import model_has_fields


@all_dbs
@manager_for_all_dbs
async def test_delete__count(manager):
query = TestModel.insert_many([
{'text': "Test %s" % uuid.uuid4()},
Expand All @@ -18,7 +18,7 @@ async def test_delete__count(manager):
assert count == 2


@all_dbs
@manager_for_all_dbs
async def test_delete__by_condition(manager):
expected_text = "text1"
deleted_text = "text2"
Expand Down
6 changes: 3 additions & 3 deletions tests/aio_model/test_inserting.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import uuid

from tests.conftest import postgres_only, all_dbs
from tests.conftest import postgres_only, manager_for_all_dbs
from tests.models import TestModel, UUIDTestModel
from tests.utils import model_has_fields


@all_dbs
@manager_for_all_dbs
async def test_insert_many(manager):
last_id = await TestModel.insert_many([
{'text': "Test %s" % uuid.uuid4()},
Expand All @@ -18,7 +18,7 @@ async def test_insert_many(manager):
assert last_id in [m.id for m in res]


@all_dbs
@manager_for_all_dbs
async def test_insert__return_id(manager):
last_id = await TestModel.insert(text="Test %s" % uuid.uuid4()).aio_execute()

Expand Down
6 changes: 3 additions & 3 deletions tests/aio_model/test_selecting.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import peewee
from tests.conftest import all_dbs
from tests.conftest import manager_for_all_dbs
from tests.models import TestModel, TestModelAlpha, TestModelBeta


@all_dbs
@manager_for_all_dbs
async def test_select_w_join(manager):
alpha = await TestModelAlpha.aio_create(text="Test 1")
beta = await TestModelBeta.aio_create(alpha_id=alpha.id, text="text")
Expand All @@ -17,7 +17,7 @@ async def test_select_w_join(manager):
assert result.joined_alpha.id == alpha.id


@all_dbs
@manager_for_all_dbs
async def test_select_compound(manager):
obj1 = await manager.create(TestModel, text="Test 1")
obj2 = await manager.create(TestModel, text="Test 2")
Expand Down
6 changes: 3 additions & 3 deletions tests/aio_model/test_shortcuts.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@

import pytest

from tests.conftest import all_dbs, postgres_only
from tests.conftest import manager_for_all_dbs, postgres_only
from tests.models import TestModel, TestModelAlpha, TestModelBeta



@all_dbs
@manager_for_all_dbs
async def test_aio_get(manager):
obj1 = await TestModel.aio_create(text="Test 1")
obj2 = await TestModel.aio_create(text="Test 2")
Expand All @@ -22,7 +22,7 @@ async def test_aio_get(manager):
await TestModel.aio_get(TestModel.text == "unknown")


@all_dbs
@manager_for_all_dbs
async def test_aio_get_or_none(manager):
obj1 = await TestModel.aio_create(text="Test 1")

Expand Down
6 changes: 3 additions & 3 deletions tests/aio_model/test_updating.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import uuid

from tests.conftest import postgres_only, all_dbs
from tests.conftest import postgres_only, manager_for_all_dbs
from tests.models import TestModel


@all_dbs
@manager_for_all_dbs
async def test_update__count(manager):
for n in range(3):
await TestModel.aio_create(text=f"{n}")
Expand All @@ -13,7 +13,7 @@ async def test_update__count(manager):
assert count == 3


@all_dbs
@manager_for_all_dbs
async def test_update__field_updated(manager):
text = "Test %s" % uuid.uuid4()
obj1 = await TestModel.aio_create(text=text)
Expand Down
14 changes: 7 additions & 7 deletions tests/compat/test_compat.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import uuid

import peewee
from tests.conftest import all_dbs
from tests.conftest import manager_for_all_dbs
from tests.models import CompatTestModel


@all_dbs
@manager_for_all_dbs
async def test_create_select_compat_mode(manager):
obj1 = await manager.create(CompatTestModel, text="Test 1")
obj2 = await manager.create(CompatTestModel, text="Test 2")
Expand All @@ -15,7 +15,7 @@ async def test_create_select_compat_mode(manager):
assert list(result) == [obj1, obj2]


@all_dbs
@manager_for_all_dbs
async def test_compound_select_compat_mode(manager):
obj1 = await manager.create(CompatTestModel, text="Test 1")
obj2 = await manager.create(CompatTestModel, text="Test 2")
Expand All @@ -30,7 +30,7 @@ async def test_compound_select_compat_mode(manager):
assert obj2 in list(result)


@all_dbs
@manager_for_all_dbs
async def test_raw_select_compat_mode(manager):
obj1 = await manager.create(CompatTestModel, text="Test 1")
obj2 = await manager.create(CompatTestModel, text="Test 2")
Expand All @@ -42,7 +42,7 @@ async def test_raw_select_compat_mode(manager):
assert list(result) == [obj1, obj2]


@all_dbs
@manager_for_all_dbs
async def test_update_compat_mode(manager):
obj_draft = await manager.create(CompatTestModel, text="Draft 1")
obj_draft.text = "Final result"
Expand All @@ -51,15 +51,15 @@ async def test_update_compat_mode(manager):
assert obj.text == "Final result"


@all_dbs
@manager_for_all_dbs
async def test_count_compat_mode(manager):
obj = await manager.create(CompatTestModel, text="Unique title %s" % uuid.uuid4())
search = CompatTestModel.select().where(CompatTestModel.text == obj.text)
count = await manager.count(search)
assert count == 1


@all_dbs
@manager_for_all_dbs
async def test_delete_compat_mode(manager):
obj = await manager.create(CompatTestModel, text="Expired item %s" % uuid.uuid4())
search = CompatTestModel.select().where(CompatTestModel.id == obj.id)
Expand Down
8 changes: 4 additions & 4 deletions tests/compat/test_transaction.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio

from tests.conftest import all_dbs
from tests.conftest import manager_for_all_dbs
from tests.models import TestModel


Expand All @@ -10,7 +10,7 @@ class FakeUpdateError(Exception):
pass


@all_dbs
@manager_for_all_dbs
async def test_atomic_success(manager):
obj = await manager.create(TestModel, text='FOO')
obj_id = obj.id
Expand All @@ -23,7 +23,7 @@ async def test_atomic_success(manager):
assert res.text == 'BAR'


@all_dbs
@manager_for_all_dbs
async def test_atomic_failed(manager):
"""Failed update in transaction.
"""
Expand All @@ -44,7 +44,7 @@ async def test_atomic_failed(manager):
assert res.text == 'FOO'


@all_dbs
@manager_for_all_dbs
async def test_acid_when_connetion_has_been_brooken(manager):
# TODO REWRITE FOR NEW STYLE TRANSACTIONS
async def restart_connections(event_for_lock: asyncio.Event) -> None:
Expand Down
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ async def db(request):
for model in reversed(sort_models(ALL_MODELS)):
model.delete().execute()
model._meta.database = None
await database.close_async()
await database.aio_close()


PG_DBS = [
Expand All @@ -74,7 +74,7 @@ async def db(request):
"manager", MYSQL_DBS, indirect=["manager"]
)

all_dbs = pytest.mark.parametrize(
manager_for_all_dbs = pytest.mark.parametrize(
"manager", PG_DBS + MYSQL_DBS, indirect=["manager"]
)

Expand Down
Loading
Loading