Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Return deferred db init (#278) #280

Merged
merged 2 commits into from
Aug 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 94 additions & 70 deletions peewee_async/databases.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import contextlib
import logging
import warnings
from typing import Type, Optional, Any, AsyncIterator, Iterator
from typing import Dict, Type, Optional, Any, AsyncIterator, Iterator

import peewee
from playhouse import postgres_ext as ext
Expand All @@ -13,15 +13,25 @@
from .utils import psycopg2, aiopg, pymysql, aiomysql, __log__


class AioDatabase:
class AioDatabase(peewee.Database):
_allow_sync = True # whether sync queries are allowed

pool_backend_cls: Type[PoolBackend]
pool_backend: PoolBackend

def __init__(self, database: Optional[str], **kwargs: Any) -> None:
super().__init__(database, **kwargs)
if not database:
raise Exception("Deferred initialization is not supported")
@property
def connect_params_async(self) -> Dict[str, Any]:
...

def init(self, database: Optional[str], **kwargs: Any) -> None:
connection_timeout = kwargs.pop('connection_timeout', None)
if connection_timeout is not None:
warnings.warn(
"`connection_timeout` is deprecated, use `connect_timeout` instead.",
DeprecationWarning
)
kwargs['connect_timeout'] = connection_timeout
super().init(database, **kwargs)
self.pool_backend = self.pool_backend_cls(
database=self.database,
**self.connect_params_async
Expand All @@ -30,6 +40,8 @@ def __init__(self, database: Optional[str], **kwargs: Any) -> None:
async def aio_connect(self) -> None:
"""Creates a connection pool
"""
if self.deferred:
raise Exception('Error, database must be initialized before creating a connection pool')
await self.pool_backend.connect()

@property
Expand All @@ -41,6 +53,9 @@ def is_connected(self) -> bool:
async def aio_close(self) -> None:
"""Terminate pool backend. The pool is closed until you run aio_connect manually
"""
if self.deferred:
raise Exception('Error, database must be initialized before creating a connection pool')

await self.pool_backend.terminate()

@contextlib.asynccontextmanager
Expand Down Expand Up @@ -93,12 +108,17 @@ def execute_sql(self, *args, **kwargs):
"Error, sync query is not allowed! Call the `.set_allow_sync()` "
"or use the `.allow_sync()` context manager.")
if self._allow_sync in (logging.ERROR, logging.WARNING):
logging.log(self._allow_sync,
"Error, sync query is not allowed: %s %s" %
(str(args), str(kwargs)))
logging.log(
self._allow_sync,
"Error, sync query is not allowed: %s %s" %
(str(args), str(kwargs))
)
return super().execute_sql(*args, **kwargs)

def aio_connection(self) -> ConnectionContextManager:
if self.deferred:
raise Exception('Error, database must be initialized before creating a connection pool')

return ConnectionContextManager(self.pool_backend)

async def aio_execute_sql(self, sql: str, params=None, fetch_results=None):
Expand Down Expand Up @@ -183,38 +203,28 @@ def transaction_async(self) -> Any:
return self.aio_atomic()


class AioPostgresqlMixin(AioDatabase):
"""Mixin for `peewee.PostgresqlDatabase` providing extra methods
class AioPostgresqlMixin(AioDatabase, peewee.PostgresqlDatabase):
"""Extension for `peewee.PostgresqlDatabase` providing extra methods
for managing async connection.
"""

_enable_json: bool
_enable_hstore: bool

pool_backend_cls = PostgresqlPoolBackend

if psycopg2:
Error = psycopg2.Error

def init_async(self, enable_json: bool = False, enable_hstore: bool =False) -> None:
def init_async(self, enable_json: bool = False, enable_hstore: bool = False) -> None:
if not aiopg:
raise Exception("Error, aiopg is not installed!")
self._enable_json = enable_json
self._enable_hstore = enable_hstore

@property
def connect_params_async(self):
"""Connection parameters for `aiopg.Connection`
"""
kwargs = self.connect_params.copy()
kwargs.update({
'minsize': self.min_connections,
'maxsize': self.max_connections,
'enable_json': self._enable_json,
'enable_hstore': self._enable_hstore,
})
return kwargs


class PooledPostgresqlDatabase(AioPostgresqlMixin, peewee.PostgresqlDatabase):
"""PosgreSQL database driver providing **single drop-in sync**
"""PostgreSQL database driver providing **single drop-in sync**
connection and **async connections pool** interface.

:param max_connections: connections pool size
Expand All @@ -226,22 +236,37 @@ class PooledPostgresqlDatabase(AioPostgresqlMixin, peewee.PostgresqlDatabase):
See also:
http://peewee.readthedocs.io/en/latest/peewee/api.html#PostgresqlDatabase
"""
min_connections: int = 1
max_connections: int = 20

def init(self, database: Optional[str], **kwargs: Any) -> None:
self.min_connections = kwargs.pop('min_connections', 1)
self.max_connections = kwargs.pop('max_connections', 20)
connection_timeout = kwargs.pop('connection_timeout', None)
if connection_timeout is not None:
warnings.warn(
"`connection_timeout` is deprecated, use `connect_timeout` instead.",
DeprecationWarning
)
kwargs['connect_timeout'] = connection_timeout
super().init(database, **kwargs)
if min_connections := kwargs.pop('min_connections', False):
self.min_connections = min_connections

if max_connections := kwargs.pop('max_connections', False):
self.max_connections = max_connections

self.init_async()
super().init(database, **kwargs)

@property
def connect_params_async(self):
"""Connection parameters for `aiopg.Connection`
"""
kwargs = self.connect_params.copy()
kwargs.update(
{
'minsize': self.min_connections,
'maxsize': self.max_connections,
'enable_json': self._enable_json,
'enable_hstore': self._enable_hstore,
}
)
return kwargs


class PooledPostgresqlExtDatabase(
AioPostgresqlMixin,
PooledPostgresqlDatabase,
ext.PostgresqlExtDatabase
):
"""PosgreSQL database extended driver providing **single drop-in sync**
Expand All @@ -250,8 +275,6 @@ class PooledPostgresqlExtDatabase(
JSON fields support is always enabled, HStore supports is enabled by
default, but can be disabled with ``register_hstore=False`` argument.

:param max_connections: connections pool size

Example::

database = PooledPostgresqlExtDatabase('test', register_hstore=False,
Expand All @@ -262,20 +285,11 @@ class PooledPostgresqlExtDatabase(
"""

def init(self, database: Optional[str], **kwargs: Any) -> None:
self.min_connections = kwargs.pop('min_connections', 1)
self.max_connections = kwargs.pop('max_connections', 20)
connection_timeout = kwargs.pop('connection_timeout', None)
if connection_timeout is not None:
warnings.warn(
"`connection_timeout` is deprecated, use `connect_timeout` instead.",
DeprecationWarning
)
kwargs['connect_timeout'] = connection_timeout
super().init(database, **kwargs)
self.init_async(
enable_json=True,
enable_hstore=self._register_hstore
)
super().init(database, **kwargs)


class PooledMySQLDatabase(AioDatabase, peewee.MySQLDatabase):
Expand All @@ -291,6 +305,9 @@ class PooledMySQLDatabase(AioDatabase, peewee.MySQLDatabase):
See also:
http://peewee.readthedocs.io/en/latest/peewee/api.html#MySQLDatabase
"""
min_connections: int = 1
max_connections: int = 20

pool_backend_cls = MysqlPoolBackend

if pymysql:
Expand All @@ -299,27 +316,34 @@ class PooledMySQLDatabase(AioDatabase, peewee.MySQLDatabase):
def init(self, database: Optional[str], **kwargs: Any) -> None:
if not aiomysql:
raise Exception("Error, aiomysql is not installed!")
self.min_connections = kwargs.pop('min_connections', 1)
self.max_connections = kwargs.pop('max_connections', 20)

if min_connections := kwargs.pop('min_connections', False):
self.min_connections = min_connections

if max_connections := kwargs.pop('max_connections', False):
self.max_connections = max_connections

super().init(database, **kwargs)

@property
def connect_params_async(self):
def connect_params_async(self) -> Dict[str, Any]:
"""Connection parameters for `aiomysql.Connection`
"""
kwargs = self.connect_params.copy()
kwargs.update({
'minsize': self.min_connections,
'maxsize': self.max_connections,
'autocommit': True,
})
kwargs.update(
{
'minsize': self.min_connections,
'maxsize': self.max_connections,
'autocommit': True,
}
)
return kwargs


# DEPRECATED Databases


class PostgresqlDatabase(AioPostgresqlMixin, peewee.PostgresqlDatabase):
class PostgresqlDatabase(PooledPostgresqlDatabase):
"""PosgreSQL database driver providing **single drop-in sync** connection
and **single async connection** interface.

Expand All @@ -330,15 +354,16 @@ class PostgresqlDatabase(AioPostgresqlMixin, peewee.PostgresqlDatabase):
See also:
http://peewee.readthedocs.io/en/latest/peewee/api.html#PostgresqlDatabase
"""

min_connections: int = 1
max_connections: int = 1

def init(self, database: Optional[str], **kwargs: Any) -> None:
warnings.warn(
"`PostgresqlDatabase` is deprecated, use `PooledPostgresqlDatabase` instead.",
DeprecationWarning
)
self.min_connections = 1
self.max_connections = 1
super().init(database, **kwargs)
self.init_async()


class MySQLDatabase(PooledMySQLDatabase):
Expand All @@ -352,17 +377,19 @@ class MySQLDatabase(PooledMySQLDatabase):
See also:
http://peewee.readthedocs.io/en/latest/peewee/api.html#MySQLDatabase
"""

min_connections: int = 1
max_connections: int = 1

def init(self, database: Optional[str], **kwargs: Any) -> None:
warnings.warn(
"`MySQLDatabase` is deprecated, use `PooledMySQLDatabase` instead.",
DeprecationWarning
)
super().init(database, **kwargs)
self.min_connections = 1
self.max_connections = 1


class PostgresqlExtDatabase(AioPostgresqlMixin, ext.PostgresqlExtDatabase):
class PostgresqlExtDatabase(PooledPostgresqlExtDatabase):
"""PosgreSQL database extended driver providing **single drop-in sync**
connection and **single async connection** interface.

Expand All @@ -377,15 +404,12 @@ class PostgresqlExtDatabase(AioPostgresqlMixin, ext.PostgresqlExtDatabase):
https://peewee.readthedocs.io/en/latest/peewee/playhouse.html#PostgresqlExtDatabase
"""

min_connections: int = 1
max_connections: int = 1

def init(self, database: Optional[str], **kwargs: Any) -> None:
warnings.warn(
"`PostgresqlExtDatabase` is deprecated, use `PooledPostgresqlExtDatabase` instead.",
DeprecationWarning
)
self.min_connections = 1
self.max_connections = 1
super().init(database, **kwargs)
self.init_async(
enable_json=True,
enable_hstore=self._register_hstore
)
1 change: 1 addition & 0 deletions peewee_async/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
class PoolBackend(metaclass=abc.ABCMeta):
"""Asynchronous database connection pool.
"""

def __init__(self, *, database: str, **kwargs: Any) -> None:
self.pool: Optional[PoolProtocol] = None
self.database = database
Expand Down
1 change: 1 addition & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ async def db(request):

params = DB_DEFAULTS[db]
database = DB_CLASSES[db](**params)

database._allow_sync = False
with database.allow_sync():
for model in ALL_MODELS:
Expand Down
4 changes: 2 additions & 2 deletions tests/db_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
'port': int(os.environ.get('POSTGRES_PORT', 5432)),
'password': 'postgres',
'user': 'postgres',
'connect_timeout': 30
'connection_timeout': 30
}

MYSQL_DEFAULTS = {
Expand All @@ -18,7 +18,7 @@
'port': int(os.environ.get('MYSQL_PORT', 3306)),
'user': 'root',
'password': 'mysql',
'connect_timeout': 30
'connection_timeout': 30
}

DB_DEFAULTS = {
Expand Down
16 changes: 15 additions & 1 deletion tests/test_database.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import pytest

from peewee_async import connection_context
from tests.conftest import dbs_all
from peewee_async.databases import AioDatabase
from tests.conftest import dbs_all, MYSQL_DBS, PG_DBS
from tests.db_config import DB_DEFAULTS, DB_CLASSES
from tests.models import TestModel


Expand Down Expand Up @@ -53,3 +55,15 @@ async def test_aio_close_idempotent(db):

await db.aio_close()
assert db.is_connected is False


@pytest.mark.parametrize('db_name', PG_DBS + MYSQL_DBS)
async def test_deferred_init(db_name):
database: AioDatabase = DB_CLASSES[db_name](None)

with pytest.raises(Exception, match='Error, database must be initialized before creating a connection pool'):
await database.aio_execute_sql(sql='SELECT 1;')

database.init(**DB_DEFAULTS[db_name])

await database.aio_execute_sql(sql='SELECT 1;')
Loading