diff --git a/asyncpg/__init__.py b/asyncpg/__init__.py index 96b41c89..b0cef053 100644 --- a/asyncpg/__init__.py +++ b/asyncpg/__init__.py @@ -31,4 +31,4 @@ # snapshots will automatically include the git revision # in __version__, for example: '0.16.0.dev0+ge06ad03' -__version__ = '0.20.1' +__version__ = '0.21.0.dev0' diff --git a/asyncpg/connect_utils.py b/asyncpg/connect_utils.py index ec3d1090..2678b358 100644 --- a/asyncpg/connect_utils.py +++ b/asyncpg/connect_utils.py @@ -21,6 +21,7 @@ import typing import urllib.parse import warnings +import inspect from . import compat from . import exceptions @@ -601,6 +602,16 @@ async def _connect_addr(*, addr, loop, timeout, params, config, raise asyncio.TimeoutError connected = _create_future(loop) + + params_input = params + if callable(params.password): + if inspect.iscoroutinefunction(params.password): + password = await params.password() + else: + password = params.password() + + params = params._replace(password=password) + proto_factory = lambda: protocol.Protocol( addr, connected, params, loop) @@ -633,7 +644,7 @@ async def _connect_addr(*, addr, loop, timeout, params, config, tr.close() raise - con = connection_class(pr, tr, loop, addr, config, params) + con = connection_class(pr, tr, loop, addr, config, params_input) pr.set_connection(con) return con diff --git a/asyncpg/connection.py b/asyncpg/connection.py index ef1b595d..76a5a9cf 100644 --- a/asyncpg/connection.py +++ b/asyncpg/connection.py @@ -1566,6 +1566,10 @@ async def connect(dsn=None, *, other users and applications may be able to read it without needing specific privileges. It is recommended to use *passfile* instead. + Password may be either a string, or a callable that returns a string. + If a callable is provided, it will be called each time a new connection + is established. + :param passfile: The name of the file used to store passwords (defaults to ``~/.pgpass``, or ``%APPDATA%\postgresql\pgpass.conf`` diff --git a/tests/test_connect.py b/tests/test_connect.py index f0767827..116b8ad9 100644 --- a/tests/test_connect.py +++ b/tests/test_connect.py @@ -204,6 +204,44 @@ async def test_auth_password_cleartext(self): user='password_user', password='wrongpassword') + async def test_auth_password_cleartext_callable(self): + def get_correctpassword(): + return 'correctpassword' + + def get_wrongpassword(): + return 'wrongpassword' + + conn = await self.connect( + user='password_user', + password=get_correctpassword) + await conn.close() + + with self.assertRaisesRegex( + asyncpg.InvalidPasswordError, + 'password authentication failed for user "password_user"'): + await self._try_connect( + user='password_user', + password=get_wrongpassword) + + async def test_auth_password_cleartext_callable_coroutine(self): + async def get_correctpassword(): + return 'correctpassword' + + async def get_wrongpassword(): + return 'wrongpassword' + + conn = await self.connect( + user='password_user', + password=get_correctpassword) + await conn.close() + + with self.assertRaisesRegex( + asyncpg.InvalidPasswordError, + 'password authentication failed for user "password_user"'): + await self._try_connect( + user='password_user', + password=get_wrongpassword) + async def test_auth_password_md5(self): conn = await self.connect( user='md5_user', password='correctpassword')