From fc3226957e4f1a1a7d619ac268ed91c2d7af2f4a Mon Sep 17 00:00:00 2001 From: Elvis Pranskevichus Date: Thu, 6 Apr 2017 09:03:58 -0400 Subject: [PATCH] Fix SSL support in connection pool. Closes: #119. --- asyncpg/pool.py | 11 +++++++---- tests/test_connect.py | 25 +++++++++++++++++++++++++ 2 files changed, 32 insertions(+), 4 deletions(-) diff --git a/asyncpg/pool.py b/asyncpg/pool.py index d311fcbf..a0399f48 100644 --- a/asyncpg/pool.py +++ b/asyncpg/pool.py @@ -115,6 +115,7 @@ async def connect(self): **self._connect_kwargs) self._pool._working_addr = con._addr self._pool._working_opts = con._opts + self._pool._working_ssl_context = con._ssl_context else: # We've connected before and have a resolved address @@ -126,9 +127,10 @@ async def connect(self): else: host, port = self._pool._working_addr - con = await self._pool._connect(host=host, port=port, - loop=self._pool._loop, - **self._pool._working_opts) + con = await self._pool._connect( + host=host, port=port, loop=self._pool._loop, + ssl=self._pool._working_ssl_context, + **self._pool._working_opts) if self._init is not None: await self._init(con) @@ -248,7 +250,7 @@ class Pool: """ __slots__ = ('_queue', '_loop', '_minsize', '_maxsize', - '_working_addr', '_working_opts', + '_working_addr', '_working_opts', '_working_ssl_context', '_holders', '_initialized', '_closed') def __init__(self, *connect_args, @@ -292,6 +294,7 @@ def __init__(self, *connect_args, self._working_addr = None self._working_opts = None + self._working_ssl_context = None self._closed = False diff --git a/tests/test_connect.py b/tests/test_connect.py index 61c20db5..740b904f 100644 --- a/tests/test_connect.py +++ b/tests/test_connect.py @@ -579,3 +579,28 @@ async def test_ssl_connection_default_context(self): database='postgres', loop=self.loop, ssl=True) + + async def test_ssl_connection_pool(self): + ssl_context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + ssl_context.load_verify_locations(SSL_CA_CERT_FILE) + + pool = await self.create_pool( + host='localhost', + user='ssl_user', + database='postgres', + min_size=5, + max_size=10, + ssl=ssl_context) + + async def worker(): + async with pool.acquire() as con: + self.assertEqual(await con.fetchval('SELECT 42'), 42) + + with self.assertRaises(asyncio.TimeoutError): + await con.execute('SELECT pg_sleep(5)', timeout=0.5) + + self.assertEqual(await con.fetchval('SELECT 43'), 43) + + tasks = [worker() for _ in range(100)] + await asyncio.gather(*tasks, loop=self.loop) + await pool.close()