Skip to content

Commit

Permalink
Merge pull request #22 from aioworkers/#12_extra_parameters_pool
Browse files Browse the repository at this point in the history
#12 extra parameters pool
  • Loading branch information
abogushov authored Nov 8, 2021
2 parents 34b7779 + fdec988 commit 0b46495
Show file tree
Hide file tree
Showing 7 changed files with 147 additions and 16 deletions.
15 changes: 14 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ Asyncpg plugin for `aioworkers`.

## Usage

### Connection

Add this to aioworkers config.yaml:

```yaml
Expand All @@ -25,8 +27,19 @@ await context.db.execute('CREATE TABLE users(id serial PRIMARY KEY, name text)')
await context.db.execute(users.insert().values(name='Bob'))
```

### Connection additional

```yaml
db:
cls: aioworkers_pg.base.Connector
dsn: postgresql:///test
pool:
min_size: 1
max_size: 100
```
## Storage
### Storage
```yaml
storage:
Expand Down
59 changes: 48 additions & 11 deletions aioworkers_pg/base.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,60 @@
from typing import Awaitable, Callable, Optional
from typing import Awaitable, Callable, Optional, Type

import asyncpg
import warnings
from aioworkers.core.base import AbstractConnector
from aioworkers.core.config import ValueExtractor


class Connector(AbstractConnector):
_connection_init: Optional[Callable[[asyncpg.Connection], Awaitable]]
_pool_init: Optional[Callable[[asyncpg.Connection], Awaitable]] = None
_pool_setup: Optional[Callable[[asyncpg.Connection], Awaitable]] = None
_connection_class: Optional[
Type[asyncpg.connection.Connection]
] = asyncpg.connection.Connection
_record_class: type = asyncpg.protocol.Record
_connect_kwargs: dict = {}

def __init__(self, *args, **kwargs):
self._pool = None
self._pool_init = self._default_pool_init
super().__init__(*args, **kwargs)

def set_config(self, config: ValueExtractor) -> None:
cfg = config.new_parent(logger=__package__)
super().set_config(cfg)
connection_init: Optional[str] = self.config.get(
"connection.init",
)
if connection_init:
self._connection_init = self.context.get_object(connection_init)
else:
self._connection_init = self._default_connection_init

# TODO: Remove deprecated code.
if self.config.get("connection.init"):
warnings.warn(
"Do not use connection.init config. Use pool.init", DeprecationWarning
)
self._pool_init = self.context.get_object(
self.config.get("connection.init")
)

pool_config = dict(self.config.get("pool", {}))
if not pool_config:
# Do not process pool config if there is no any parameter
return

pool_init: Optional[str] = pool_config.pop("init", None)
if pool_init:
self._pool_init = self.context.get_object(pool_init)

pool_setup: Optional[str] = pool_config.pop("setup", None)
if pool_setup:
self._pool_setup = self.context.get_object(pool_setup)

connection_class: Optional[str] = pool_config.pop("connection_class", None)
if connection_class:
self._connection_class = self.context.get_object(connection_class)

record_class: Optional[str] = pool_config.pop("record_class", None)
if record_class:
self._record_class = self.context.get_object(record_class)

self._connect_kwargs = pool_config

@property
def pool(self) -> asyncpg.pool.Pool:
Expand All @@ -39,7 +72,11 @@ async def connect(self):
async def pool_factory(self, config) -> asyncpg.pool.Pool:
pool = await asyncpg.create_pool(
config.dsn,
init=self._connection_init,
init=self._pool_init,
setup=self._pool_setup,
connection_class=self._connection_class,
record_class=self._record_class,
**self._connect_kwargs,
)
self.logger.debug("Create pool with address %s", config.dsn)
return pool
Expand All @@ -50,7 +87,7 @@ async def disconnect(self):
await self._pool.close()
self._pool = None

async def _default_connection_init(self, connection: asyncpg.Connection):
async def _default_pool_init(self, connection: asyncpg.Connection):
import json

# TODO: Need general solution to add codecs
Expand Down
4 changes: 2 additions & 2 deletions aioworkers_pg/sa.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@ async def pool_factory(self, config: ValueExtractor) -> asyncpg.pool.Pool:

pool = await asyncpgsa.create_pool(
config.dsn,
init=self._connection_init,
init=self._pool_init,
)
self.logger.debug("Create pool with address %s", config.dsn)
return pool

async def _default_connection_init(self, connection: asyncpg.Connection):
async def _default_pool_init(self, connection: asyncpg.Connection):
import json

# TODO: Need general solution to add codecs
Expand Down
34 changes: 34 additions & 0 deletions tests/test_connector_pool_custom_classes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import pytest
from aioworkers.utils import import_uri
from asyncpg.connection import Connection
from asyncpg.protocol.protocol import Record


class CustomConnection(Connection):
async def execute(self, query: str, *args, timeout: float = None) -> str:
query += "SELECT 1"
return await super().execute(query, *args, timeout=timeout)


class CustomRecord(Record):
pass


@pytest.fixture
def config(config, dsn):
config.update(
db={
"cls": "aioworkers_pg.base.Connector",
"dsn": dsn,
"pool": {
"connection_class": import_uri(CustomConnection),
"record_class": import_uri(CustomRecord),
},
},
)
return config


async def test_custom_classes(context):
# Run an empty query. The final query will be updated in the CustomConnection.
await context.db.execute("")
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,19 @@ def config(config, dsn):
db={
"cls": "aioworkers_pg.base.Connector",
"dsn": dsn,
"connection": {
"pool": {
"init": import_uri(custom_init),
},
},
)
return config


async def test_custom_connection_init(context):
async def test_custom_pool_init(context):
await context.db.execute("SELECT 1")

# The custom_init function is empty so there is no codec
# to translate dict into str.
with pytest.raises(asyncpg.exceptions.DataError):
await context.db.execute(
"SELECT $1",
Expand Down
21 changes: 21 additions & 0 deletions tests/test_connector_pool_kwargs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import pytest


@pytest.fixture
def config(config, dsn):
config.update(
db={
"cls": "aioworkers_pg.base.Connector",
"dsn": dsn,
"pool": {
"min_size": 1,
"max_size": 1,
},
},
)
return config


async def test_custom_pool_setup(context):
assert context.db._pool._minsize == 1
assert context.db._pool._maxsize == 1
24 changes: 24 additions & 0 deletions tests/test_connector_pool_setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import pytest
from aioworkers.utils import import_uri


async def custom_setup(connection):
pass


@pytest.fixture
def config(config, dsn):
config.update(
db={
"cls": "aioworkers_pg.base.Connector",
"dsn": dsn,
"pool": {
"setup": import_uri(custom_setup),
},
},
)
return config


async def test_custom_pool_setup(context):
await context.db.execute("SELECT 1")

0 comments on commit 0b46495

Please sign in to comment.