Skip to content

Commit b9ad3c3

Browse files
committed
Add rudimentary server capability detection.
Add basic server capability detection mechanism based on server version and parameters reported by the server through ParameterStatus messages. This allows altering certain asyncpg behaviour based on the connected server. Specifically, this allows asyncpg to connect to CochroachDB servers. Fixes #87.
1 parent 8d17ecc commit b9ad3c3

File tree

3 files changed

+88
-19
lines changed

3 files changed

+88
-19
lines changed

asyncpg/connection.py

+77-18
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from . import protocol
2020
from . import serverversion
2121
from . import transaction
22+
from . import types
2223

2324

2425
class Connection:
@@ -31,7 +32,8 @@ class Connection:
3132
'_type_by_name_stmt', '_top_xact', '_uid', '_aborted',
3233
'_stmt_cache_max_size', '_stmt_cache', '_stmts_to_close',
3334
'_addr', '_opts', '_command_timeout', '_listeners',
34-
'_server_version', '_intro_query')
35+
'_server_version', '_server_caps', '_intro_query',
36+
'_reset_query')
3537

3638
def __init__(self, protocol, transport, loop, addr, opts, *,
3739
statement_cache_size, command_timeout):
@@ -55,15 +57,21 @@ def __init__(self, protocol, transport, loop, addr, opts, *,
5557

5658
self._listeners = {}
5759

58-
ver_string = self._protocol.get_settings().server_version
60+
settings = self._protocol.get_settings()
61+
ver_string = settings.server_version
5962
self._server_version = \
6063
serverversion.split_server_version_string(ver_string)
6164

65+
self._server_caps = _detect_server_capabilities(
66+
self._server_version, settings)
67+
6268
if self._server_version < (9, 2):
6369
self._intro_query = introspection.INTRO_LOOKUP_TYPES_91
6470
else:
6571
self._intro_query = introspection.INTRO_LOOKUP_TYPES
6672

73+
self._reset_query = None
74+
6775
async def add_listener(self, channel, callback):
6876
"""Add a listener for Postgres notifications.
6977
@@ -107,9 +115,26 @@ def get_server_version(self):
107115
ServerVersion(major=9, minor=6, micro=1,
108116
releaselevel='final', serial=0)
109117
118+
.. versionadded:: 0.8.0
110119
"""
111120
return self._server_version
112121

122+
def get_server_capabilities(self):
123+
"""Return the capabilities supported by the server as detected.
124+
125+
The returned value is a named tuple:
126+
127+
.. code-block:: pycon
128+
129+
>>> con.get_server_capabilities()
130+
ServerCapabilities(advisory_locks=True, cursors=True,
131+
notifications=True, plpgsql=True,
132+
sql_reset=True)
133+
134+
.. versionadded:: 0.10.0
135+
"""
136+
return self._server_caps
137+
113138
def get_settings(self):
114139
"""Return connection settings.
115140
@@ -394,22 +419,10 @@ def terminate(self):
394419
self._protocol.abort()
395420

396421
async def reset(self):
397-
self._listeners = {}
398-
399-
await self.execute('''
400-
DO $$
401-
BEGIN
402-
PERFORM * FROM pg_listening_channels() LIMIT 1;
403-
IF FOUND THEN
404-
UNLISTEN *;
405-
END IF;
406-
END;
407-
$$;
408-
SET SESSION AUTHORIZATION DEFAULT;
409-
RESET ALL;
410-
CLOSE ALL;
411-
SELECT pg_advisory_unlock_all();
412-
''')
422+
self._listeners.clear()
423+
reset_query = self._get_reset_query()
424+
if reset_query:
425+
await self.execute(reset_query)
413426

414427
def _get_unique_id(self):
415428
self._uid += 1
@@ -492,6 +505,35 @@ def _notify(self, pid, channel, payload):
492505
'exception': ex
493506
})
494507

508+
def _get_reset_query(self):
509+
if self._reset_query is not None:
510+
return self._reset_query
511+
512+
caps = self.get_server_capabilities()
513+
514+
_reset_query = ''
515+
if caps.advisory_locks:
516+
_reset_query += 'SELECT pg_advisory_unlock_all();\n'
517+
if caps.cursors:
518+
_reset_query += 'CLOSE ALL;\n'
519+
if caps.notifications and caps.plpgsql:
520+
_reset_query += '''
521+
DO $$
522+
BEGIN
523+
PERFORM * FROM pg_listening_channels() LIMIT 1;
524+
IF FOUND THEN
525+
UNLISTEN *;
526+
END IF;
527+
END;
528+
$$;
529+
'''
530+
if caps.sql_reset:
531+
_reset_query += 'RESET ALL;\n'
532+
533+
self._reset_query = _reset_query
534+
535+
return _reset_query
536+
495537

496538
async def connect(dsn=None, *,
497539
host=None, port=None,
@@ -730,3 +772,20 @@ def _create_future(loop):
730772
return asyncio.Future(loop=loop)
731773
else:
732774
return create_future()
775+
776+
777+
def _detect_server_capabilities(server_version, connection_settings):
778+
if hasattr(connection_settings, 'crdb_version'):
779+
# CocroachDB detected.
780+
advisory_locks = cursors = notifications = plpgsql = sql_reset = False
781+
else:
782+
# Standard PostgreSQL server assumed.
783+
advisory_locks = cursors = notifications = plpgsql = sql_reset = True
784+
785+
return types.ServerCapabilities(
786+
advisory_locks=advisory_locks,
787+
cursors=cursors,
788+
notifications=notifications,
789+
plpgsql=plpgsql,
790+
sql_reset=sql_reset
791+
)

asyncpg/protocol/settings.pyx

+3
Original file line numberDiff line numberDiff line change
@@ -60,3 +60,6 @@ cdef class ConnectionSettings:
6060
raise AttributeError(name) from None
6161

6262
return object.__getattr__(self, name)
63+
64+
def __repr__(self):
65+
return '<ConnectionSettings {!r}>'.format(self._settings)

asyncpg/types.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010

1111
__all__ = (
1212
'Type', 'Attribute', 'Range', 'BitString', 'Point', 'Path', 'Polygon',
13-
'Box', 'Line', 'LineSegment', 'Circle', 'ServerVersion'
13+
'Box', 'Line', 'LineSegment', 'Circle', 'ServerVersion',
14+
'ServerCapabilities'
1415
)
1516

1617

@@ -34,6 +35,12 @@
3435
ServerVersion.__doc__ = 'PostgreSQL server version tuple.'
3536

3637

38+
ServerCapabilities = collections.namedtuple(
39+
'ServerCapabilities',
40+
['advisory_locks', 'cursors', 'notifications', 'plpgsql', 'sql_reset'])
41+
ServerCapabilities.__doc__ = 'PostgreSQL server capabilities.'
42+
43+
3744
class Range:
3845
"""Immutable representation of PostgreSQL `range` type."""
3946

0 commit comments

Comments
 (0)