Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: should call connect manually if close database #246

Merged
merged 2 commits into from
May 22, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions load-testing/app.py
Original file line number Diff line number Diff line change
@@ -123,8 +123,8 @@ async def atomic():

@app.get("/recreate_pool")
async def atomic():
await manager.database.close_async()
await manager.database.connect_async()
await manager.database.aio_close()
await manager.database.aio_connect()


if __name__ == "__main__":
3 changes: 3 additions & 0 deletions load-testing/load.yaml
Original file line number Diff line number Diff line change
@@ -2,6 +2,9 @@ phantom:
address: 127.0.0.1:8000 # [Target's address]:[target's port]
uris:
- /select
- /atomic
- /transaction
- /recreate_pool
load_profile:
load_type: rps # schedule load by defining requests per second
schedule: const(100, 10m) # starting from 1rps growing linearly to 10rps during 10 minutes
130 changes: 65 additions & 65 deletions peewee_async.py
Original file line number Diff line number Diff line change
@@ -21,7 +21,7 @@
import warnings
from contextvars import ContextVar
from importlib.metadata import version
from typing import Optional
from typing import Optional, Type

import peewee
from playhouse.db_url import register_database
@@ -216,37 +216,84 @@ async def rollback(self):


class ConnectionContextManager:
def __init__(self, aio_pool):
self.aio_pool = aio_pool
def __init__(self, pool_backend):
self.pool_backend = pool_backend
self.connection_context = connection_context.get()
self.resuing_connection = self.connection_context is not None

async def __aenter__(self):
if self.connection_context is not None:
connection = self.connection_context.connection
else:
connection = await self.aio_pool.acquire()
connection = await self.pool_backend.acquire()
self.connection_context = ConnectionContext(connection)
connection_context.set(self.connection_context)
return connection

async def __aexit__(self, *args):
if self.resuing_connection is False:
self.aio_pool.release(self.connection_context.connection)
self.pool_backend.release(self.connection_context.connection)
connection_context.set(None)


class PoolBackend(metaclass=abc.ABCMeta):
"""Asynchronous database connection pool.
"""
def __init__(self, *, database=None, **kwargs):
self.pool = None
self.database = database
self.connect_params = kwargs
self._connection_lock = asyncio.Lock()

@property
def is_connected(self):
return self.pool is not None and self.pool.closed is False

def has_acquired_connections(self):
return self.pool is not None and len(self.pool._used) > 0

async def connect(self):
async with self._connection_lock:
if self.is_connected is False:
await self.create()

async def acquire(self):
"""Acquire connection from pool.
"""
if self.pool is None:
await self.connect()
return await self.pool.acquire()

def release(self, conn):
"""Release connection to pool.
"""
self.pool.release(conn)

@abc.abstractmethod
async def create(self):
"""Create connection pool asynchronously.
"""
raise NotImplementedError

async def terminate(self):
"""Terminate all pool connections.
"""
self.pool.terminate()
await self.pool.wait_closed()


############
# Database #
############


class AioDatabase:
_allow_sync = True # whether sync queries are allowed

pool_backend_cls: Type[PoolBackend]

def __init__(self, database, **kwargs):
super().__init__(database, **kwargs)
self.aio_pool = self.aio_pool_cls(
self.pool_backend = self.pool_backend_cls(
database=self.database,
**self.connect_params_async
)
@@ -257,12 +304,16 @@ async def aio_connect(self):
if self.deferred:
raise Exception("Error, database not properly initialized "
"before opening connection")
await self.aio_pool.connect()
await self.pool_backend.connect()

@property
def is_connected(self):
return self.pool_backend.is_connected

async def aio_close(self):
"""Close async connection.
"""
await self.aio_pool.terminate()
await self.pool_backend.terminate()

@contextlib.asynccontextmanager
async def aio_atomic(self):
@@ -323,7 +374,7 @@ def execute_sql(self, *args, **kwargs):
return super().execute_sql(*args, **kwargs)

def aio_connection(self) -> ConnectionContextManager:
return ConnectionContextManager(self.aio_pool)
return ConnectionContextManager(self.pool_backend)

async def aio_execute_sql(self, sql: str, params=None, fetch_results=None):
__log__.debug(sql, params)
@@ -405,63 +456,12 @@ def transaction_async(self):
)
return self.aio_atomic()


class AioPool(metaclass=abc.ABCMeta):
"""Asynchronous database connection pool.
"""
def __init__(self, *, database=None, **kwargs):
self.pool = None
self.database = database
self.connect_params = kwargs
self._connection_lock = asyncio.Lock()

def has_acquired_connections(self):
return self.pool is not None and len(self.pool._used) > 0

async def connect(self):
async with self._connection_lock:
if self.pool is not None:
return
await self.create()

async def acquire(self):
"""Acquire connection from pool.
"""
if self.pool is None:
await self.connect()
return await self.pool.acquire()

def release(self, conn):
"""Release connection to pool.
"""
if self.pool is not None:
self.pool.release(conn)

@abc.abstractmethod
async def create(self):
"""Create connection pool asynchronously.
"""
raise NotImplementedError

async def terminate(self):
"""Terminate all pool connections.
"""
async with self._connection_lock:
if self.pool is not None:
pool = self.pool
self.pool = None
pool.terminate()
await pool.wait_closed()




##############
# PostgreSQL #
##############


class AioPostgresqlPool(AioPool):
class PostgresqlPoolBackend(PoolBackend):
"""Asynchronous database connection pool.
"""

@@ -481,7 +481,7 @@ class AioPostgresqlMixin(AioDatabase):
for managing async connection.
"""

aio_pool_cls = AioPostgresqlPool
pool_backend_cls = PostgresqlPoolBackend

if psycopg2:
Error = psycopg2.Error
@@ -573,7 +573,7 @@ def init(self, database, **kwargs):
#########


class AioMysqlPool(AioPool):
class MysqlPoolBackend(PoolBackend):
"""Asynchronous database connection pool.
"""

@@ -596,7 +596,7 @@ class MySQLDatabase(AioDatabase, peewee.MySQLDatabase):
See also:
http://peewee.readthedocs.io/en/latest/peewee/api.html#MySQLDatabase
"""
aio_pool_cls = AioMysqlPool
pool_backend_cls = MysqlPoolBackend

if pymysql:
Error = pymysql.Error
9 changes: 4 additions & 5 deletions peewee_async_compat.py
Original file line number Diff line number Diff line change
@@ -214,7 +214,7 @@ def __init__(self, database=None):
def is_connected(self):
"""Check if database is connected.
"""
return self.database.aio_pool.pool is not None
return self.database.is_connected

async def get(self, source_, *args, **kwargs):
"""Get the model instance.
@@ -472,8 +472,7 @@ def transaction(db):
"`transaction` is deprecated, use `database.aio_atomic` or `Transaction` class instead.",
DeprecationWarning
)
from peewee_async import TransactionContextManager
return TransactionContextManager(db.aio_pool)
return db.aio_atomic()


def atomic(db):
@@ -484,8 +483,8 @@ def atomic(db):
"`atomic` is deprecated, use `database.aio_atomic` or `Transaction` class instead.",
DeprecationWarning
)
from peewee_async import TransactionContextManager
return TransactionContextManager(db.aio_pool)
return db.aio_atomic()


class _savepoint:
"""Asynchronous context manager (`async with`), similar to
29 changes: 28 additions & 1 deletion tests/compat/test_transaction.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio

from peewee_async import transaction, atomic
from tests.conftest import manager_for_all_dbs
from tests.models import TestModel

@@ -11,7 +12,7 @@ class FakeUpdateError(Exception):


@manager_for_all_dbs
async def test_atomic_success(manager):
async def test_manager_atomic_success(manager):
obj = await manager.create(TestModel, text='FOO')
obj_id = obj.id

@@ -23,6 +24,32 @@ async def test_atomic_success(manager):
assert res.text == 'BAR'


@manager_for_all_dbs
async def test_transaction_success(manager):
obj = await manager.create(TestModel, text='FOO')
obj_id = obj.id

async with transaction(manager.database):
obj.text = 'BAR'
await manager.update(obj)

res = await manager.get(TestModel, id=obj_id)
assert res.text == 'BAR'


@manager_for_all_dbs
async def test_atomic_success(manager):
obj = await manager.create(TestModel, text='FOO')
obj_id = obj.id

async with atomic(manager.database):
obj.text = 'BAR'
await manager.update(obj)

res = await manager.get(TestModel, id=obj_id)
assert res.text == 'BAR'


@manager_for_all_dbs
async def test_atomic_failed(manager):
"""Failed update in transaction.
2 changes: 1 addition & 1 deletion tests/test_common.py
Original file line number Diff line number Diff line change
@@ -146,7 +146,7 @@ async def get_conn(manager):
await manager.connect()
# await asyncio.sleep(0.05, loop=self.loop)
# NOTE: "private" member access
return manager.database.aio_pool
return manager.database.pool_backend


c1 = await get_conn(manager)
27 changes: 26 additions & 1 deletion tests/test_database.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import pytest

from peewee_async import connection_context
from tests.conftest import dbs_all
from tests.models import TestModel
@@ -16,4 +18,27 @@ async def test_nested_connection(db):
async with connection_2.cursor() as cursor:
await cursor.execute("SELECT 1")
assert connection_context.get() is None
assert db.aio_pool.has_acquired_connections() is False
assert db.pool_backend.has_acquired_connections() is False


@dbs_all
async def test_db_should_connect_manually_after_close(db):
await TestModel.aio_create(text='test')

await db.aio_close()
with pytest.raises(RuntimeError):
await TestModel.aio_get_or_none(text='test')
await db.aio_connect()

assert await TestModel.aio_get_or_none(text='test') is not None


@dbs_all
async def test_is_connected(db):
assert db.is_connected is False

await db.aio_connect()
assert db.is_connected is True

await db.aio_close()
assert db.is_connected is False
Loading