Skip to content

Add rudimentary server capability detection. #90

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 17, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 74 additions & 18 deletions asyncpg/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ class Connection:
'_type_by_name_stmt', '_top_xact', '_uid', '_aborted',
'_stmt_cache_max_size', '_stmt_cache', '_stmts_to_close',
'_addr', '_opts', '_command_timeout', '_listeners',
'_server_version', '_intro_query')
'_server_version', '_server_caps', '_intro_query',
'_reset_query')

def __init__(self, protocol, transport, loop, addr, opts, *,
statement_cache_size, command_timeout):
Expand All @@ -55,15 +56,21 @@ def __init__(self, protocol, transport, loop, addr, opts, *,

self._listeners = {}

ver_string = self._protocol.get_settings().server_version
settings = self._protocol.get_settings()
ver_string = settings.server_version
self._server_version = \
serverversion.split_server_version_string(ver_string)

self._server_caps = _detect_server_capabilities(
self._server_version, settings)

if self._server_version < (9, 2):
self._intro_query = introspection.INTRO_LOOKUP_TYPES_91
else:
self._intro_query = introspection.INTRO_LOOKUP_TYPES

self._reset_query = None

async def add_listener(self, channel, callback):
"""Add a listener for Postgres notifications.

Expand Down Expand Up @@ -107,6 +114,7 @@ def get_server_version(self):
ServerVersion(major=9, minor=6, micro=1,
releaselevel='final', serial=0)

.. versionadded:: 0.8.0
"""
return self._server_version

Expand Down Expand Up @@ -394,22 +402,10 @@ def terminate(self):
self._protocol.abort()

async def reset(self):
self._listeners = {}

await self.execute('''
DO $$
BEGIN
PERFORM * FROM pg_listening_channels() LIMIT 1;
IF FOUND THEN
UNLISTEN *;
END IF;
END;
$$;
SET SESSION AUTHORIZATION DEFAULT;
RESET ALL;
CLOSE ALL;
SELECT pg_advisory_unlock_all();
''')
self._listeners.clear()
reset_query = self._get_reset_query()
if reset_query:
await self.execute(reset_query)

def _get_unique_id(self):
self._uid += 1
Expand Down Expand Up @@ -492,6 +488,35 @@ def _notify(self, pid, channel, payload):
'exception': ex
})

def _get_reset_query(self):
if self._reset_query is not None:
return self._reset_query

caps = self._server_caps

_reset_query = ''
if caps.advisory_locks:
_reset_query += 'SELECT pg_advisory_unlock_all();\n'
if caps.cursors:
_reset_query += 'CLOSE ALL;\n'
if caps.notifications and caps.plpgsql:
_reset_query += '''
DO $$
BEGIN
PERFORM * FROM pg_listening_channels() LIMIT 1;
IF FOUND THEN
UNLISTEN *;
END IF;
END;
$$;
'''
if caps.sql_reset:
_reset_query += 'RESET ALL;\n'

self._reset_query = _reset_query

return _reset_query


async def connect(dsn=None, *,
host=None, port=None,
Expand Down Expand Up @@ -730,3 +755,34 @@ def _create_future(loop):
return asyncio.Future(loop=loop)
else:
return create_future()


ServerCapabilities = collections.namedtuple(
'ServerCapabilities',
['advisory_locks', 'cursors', 'notifications', 'plpgsql', 'sql_reset'])
ServerCapabilities.__doc__ = 'PostgreSQL server capabilities.'


def _detect_server_capabilities(server_version, connection_settings):
if hasattr(connection_settings, 'crdb_version'):
# CocroachDB detected.
advisory_locks = False
cursors = False
notifications = False
plpgsql = False
sql_reset = False
else:
# Standard PostgreSQL server assumed.
advisory_locks = True
cursors = True
notifications = True
plpgsql = True
sql_reset = True

return ServerCapabilities(
advisory_locks=advisory_locks,
cursors=cursors,
notifications=notifications,
plpgsql=plpgsql,
sql_reset=sql_reset
)
3 changes: 3 additions & 0 deletions asyncpg/protocol/settings.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,6 @@ cdef class ConnectionSettings:
raise AttributeError(name) from None

return object.__getattr__(self, name)

def __repr__(self):
return '<ConnectionSettings {!r}>'.format(self._settings)
2 changes: 1 addition & 1 deletion asyncpg/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

__all__ = (
'Type', 'Attribute', 'Range', 'BitString', 'Point', 'Path', 'Polygon',
'Box', 'Line', 'LineSegment', 'Circle', 'ServerVersion'
'Box', 'Line', 'LineSegment', 'Circle', 'ServerVersion',
)


Expand Down