Skip to content

Commit

Permalink
Allow aliasing builtin types by name in set_builtin_type_codec()
Browse files Browse the repository at this point in the history
Currently, `Connection.set_builtin_type_codec()` accepts either a
"contrib" codec name, such as "pg_contrib.hstore", or an OID of
a core type.  The latter is undocumented and not very useful.

Make `set_builtin_type_codec()` accept any core type name as the
"codec_name" argument.  Generally, the name of the core type can be
found in the pg_types PostgreSQL system catalog.  SQL standard names
for certain types are also accepted (such as "smallint" or
"timestamp with timezone").  This may be useful for extension types
or user-defined types which are wire-compatible with the target
builtin type.
  • Loading branch information
elprans committed Sep 18, 2018
1 parent cc053fe commit 687127e
Show file tree
Hide file tree
Showing 9 changed files with 242 additions and 53 deletions.
50 changes: 39 additions & 11 deletions asyncpg/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1005,29 +1005,57 @@ async def reset_type_codec(self, typename, *, schema='public'):
self._drop_local_statement_cache()

async def set_builtin_type_codec(self, typename, *,
schema='public', codec_name):
"""Set a builtin codec for the specified data type.
schema='public', codec_name,
format=None):
"""Set a builtin codec for the specified scalar data type.
:param typename: Name of the data type the codec is for.
:param schema: Schema name of the data type the codec is for
(defaults to 'public')
:param codec_name: The name of the builtin codec.
This method has two uses. The first is to register a builtin
codec for an extension type without a stable OID, such as 'hstore'.
The second use is to declare that an extension type or a
user-defined type is wire-compatible with a certain builtin
data type and should be exchanged as such.
:param typename:
Name of the data type the codec is for.
:param schema:
Schema name of the data type the codec is for
(defaults to ``'public'``).
:param codec_name:
The name of the builtin codec to use for the type.
This should be either the name of a known core type
(such as ``"int"``), or the name of a supported extension
type. Currently, the only supported extension type is
``"pg_contrib.hstore"``.
:param format:
If *format* is ``None`` (the default), all formats supported
by the target codec are declared to be supported for *typename*.
If *format* is ``'text'`` or ``'binary'``, then only the
specified format is declared to be supported for *typename*.
.. versionchanged:: 0.18.0
The *codec_name* argument can be the name of any known
core data type. Added the *format* keyword argument.
"""
self._check_open()

typeinfo = await self.fetchrow(
introspection.TYPE_BY_NAME, typename, schema)
if not typeinfo:
raise ValueError('unknown type: {}.{}'.format(schema, typename))
raise exceptions.InterfaceError(
'unknown type: {}.{}'.format(schema, typename))

oid = typeinfo['oid']
if typeinfo['kind'] != b'b' or typeinfo['elemtype']:
raise ValueError(
if not introspection.is_scalar_type(typeinfo):
raise exceptions.InterfaceError(
'cannot alias non-scalar type {}.{}'.format(
schema, typename))

oid = typeinfo['oid']

self._protocol.get_settings().set_builtin_type_codec(
oid, typename, schema, 'scalar', codec_name)
oid, typename, schema, 'scalar', codec_name, format)

# Statement cache is no longer valid due to codec changes.
self._drop_local_statement_cache()
Expand Down
2 changes: 1 addition & 1 deletion asyncpg/protocol/codecs/base.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -172,4 +172,4 @@ cdef class DataCodecConfig:
dict _custom_type_codecs

cdef inline Codec get_codec(self, uint32_t oid, ServerDataFormat format)
cdef inline Codec get_local_codec(self, uint32_t oid)
cdef inline Codec get_any_local_codec(self, uint32_t oid)
93 changes: 63 additions & 30 deletions asyncpg/protocol/codecs/base.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,14 @@ cdef class Codec:
elif type == CODEC_RANGE:
if format != PG_FORMAT_BINARY:
raise NotImplementedError(
'cannot encode type "{}"."{}": text encoding of '
'cannot decode type "{}"."{}": text encoding of '
'range types is not supported'.format(schema, name))
self.encoder = <codec_encode_func>&self.encode_range
self.decoder = <codec_decode_func>&self.decode_range
elif type == CODEC_COMPOSITE:
if format != PG_FORMAT_BINARY:
raise NotImplementedError(
'cannot encode type "{}"."{}": text encoding of '
'cannot decode type "{}"."{}": text encoding of '
'composite types is not supported'.format(schema, name))
self.encoder = <codec_encode_func>&self.encode_composite
self.decoder = <codec_decode_func>&self.decode_composite
Expand Down Expand Up @@ -243,8 +243,10 @@ cdef class Codec:
'{!r}, expected {!r}'
.format(
i,
TYPEMAP.get(received_elem_typ, received_elem_typ),
TYPEMAP.get(elem_typ, elem_typ)
BUILTIN_TYPE_OID_MAP.get(
received_elem_typ, received_elem_typ),
BUILTIN_TYPE_OID_MAP.get(
elem_typ, elem_typ)
),
schema=self.schema,
data_type=self.name,
Expand Down Expand Up @@ -567,27 +569,38 @@ cdef class DataCodecConfig:
encode_func c_encoder = NULL
decode_func c_decoder = NULL
uint32_t oid = pylong_as_oid(typeoid)

if xformat == PG_XFORMAT_TUPLE:
core_codec = get_any_core_codec(oid, format, xformat)
if core_codec is None:
raise exceptions.InterfaceError(
"{} type does not support 'tuple' exchange format".format(
typename))
c_encoder = core_codec.c_encoder
c_decoder = core_codec.c_decoder
format = core_codec.format
bint codec_set = False

# Clear all previous overrides (this also clears type cache).
self.remove_python_codec(typeoid, typename, typeschema)

self._custom_type_codecs[typeoid] = \
Codec.new_python_codec(oid, typename, typeschema, typekind,
encoder, decoder, c_encoder, c_decoder,
format, xformat)
if format == PG_FORMAT_ANY:
formats = (PG_FORMAT_TEXT, PG_FORMAT_BINARY)
else:
formats = (format,)

for fmt in formats:
if xformat == PG_XFORMAT_TUPLE:
core_codec = get_core_codec(oid, fmt, xformat)
if core_codec is None:
continue
c_encoder = core_codec.c_encoder
c_decoder = core_codec.c_decoder

self._custom_type_codecs[typeoid, fmt] = \
Codec.new_python_codec(oid, typename, typeschema, typekind,
encoder, decoder, c_encoder, c_decoder,
fmt, xformat)
codec_set = True

if not codec_set:
raise exceptions.InterfaceError(
"{} type does not support the 'tuple' exchange format".format(
typename))

def remove_python_codec(self, typeoid, typename, typeschema):
self._custom_type_codecs.pop(typeoid, None)
for fmt in (PG_FORMAT_BINARY, PG_FORMAT_TEXT):
self._custom_type_codecs.pop((typeoid, fmt), None)
self.clear_type_cache()

def _set_builtin_type_codec(self, typeoid, typename, typeschema, typekind,
Expand All @@ -596,16 +609,21 @@ cdef class DataCodecConfig:
Codec codec
Codec target_codec
uint32_t oid = pylong_as_oid(typeoid)
uint32_t alias_pid
uint32_t alias_oid = 0
bint codec_set = False

if format == PG_FORMAT_ANY:
formats = (PG_FORMAT_BINARY, PG_FORMAT_TEXT)
else:
formats = (format,)

if isinstance(alias_to, int):
alias_oid = pylong_as_oid(alias_to)
else:
alias_oid = BUILTIN_TYPE_NAME_MAP.get(alias_to, 0)

for format in formats:
if isinstance(alias_to, int):
alias_oid = pylong_as_oid(alias_to)
if alias_oid != 0:
target_codec = self.get_codec(alias_oid, format)
else:
target_codec = get_extra_codec(alias_to, format)
Expand All @@ -619,11 +637,20 @@ cdef class DataCodecConfig:
codec.schema = typeschema
codec.kind = typekind

self._custom_type_codecs[typeoid] = codec
break
else:
self._custom_type_codecs[typeoid, format] = codec
codec_set = True

if not codec_set:
if format == PG_FORMAT_BINARY:
codec_str = 'binary'
elif format == PG_FORMAT_TEXT:
codec_str = 'text'
else:
codec_str = 'text or binary'

raise exceptions.InterfaceError(
'invalid builtin codec reference: {}'.format(alias_to))
f'cannot alias {typename} to {alias_to}: '
f'there is no {codec_str} codec for {alias_to}')

def set_builtin_type_codec(self, typeoid, typename, typeschema, typekind,
alias_to, format=PG_FORMAT_ANY):
Expand Down Expand Up @@ -667,7 +694,7 @@ cdef class DataCodecConfig:
cdef inline Codec get_codec(self, uint32_t oid, ServerDataFormat format):
cdef Codec codec

codec = self.get_local_codec(oid)
codec = self.get_any_local_codec(oid)
if codec is not None:
if codec.format != format:
# The codec for this OID has been overridden by
Expand All @@ -686,8 +713,14 @@ cdef class DataCodecConfig:
except KeyError:
return None

cdef inline Codec get_local_codec(self, uint32_t oid):
return self._custom_type_codecs.get(oid)
cdef inline Codec get_any_local_codec(self, uint32_t oid):
cdef Codec codec

codec = self._custom_type_codecs.get((oid, PG_FORMAT_BINARY))
if codec is None:
return self._custom_type_codecs.get((oid, PG_FORMAT_TEXT))
else:
return codec


cdef inline Codec get_core_codec(
Expand Down Expand Up @@ -746,7 +779,7 @@ cdef register_core_codec(uint32_t oid,
str name
str kind

name = TYPEMAP[oid]
name = BUILTIN_TYPE_OID_MAP[oid]
kind = 'array' if oid in ARRAY_TYPES else 'scalar'

codec = Codec(oid)
Expand Down
34 changes: 32 additions & 2 deletions asyncpg/protocol/pgtypes.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ DEF REGROLEOID = 4096

cdef ARRAY_TYPES = (_TEXTOID, _OIDOID,)

TYPEMAP = {
BUILTIN_TYPE_OID_MAP = {
ABSTIMEOID: 'abstime',
ACLITEMOID: 'aclitem',
ANYARRAYOID: 'anyarray',
Expand Down Expand Up @@ -187,4 +187,34 @@ TYPEMAP = {
XIDOID: 'xid',
XMLOID: 'xml',
_OIDOID: 'oid[]',
_TEXTOID: 'text[]'}
_TEXTOID: 'text[]'
}

BUILTIN_TYPE_NAME_MAP = {v: k for k, v in BUILTIN_TYPE_OID_MAP.items()}

BUILTIN_TYPE_NAME_MAP['smallint'] = \
BUILTIN_TYPE_NAME_MAP['int2']

BUILTIN_TYPE_NAME_MAP['int'] = \
BUILTIN_TYPE_NAME_MAP['int4']

BUILTIN_TYPE_NAME_MAP['integer'] = \
BUILTIN_TYPE_NAME_MAP['int4']

BUILTIN_TYPE_NAME_MAP['bigint'] = \
BUILTIN_TYPE_NAME_MAP['int8']

BUILTIN_TYPE_NAME_MAP['decimal'] = \
BUILTIN_TYPE_NAME_MAP['numeric']

BUILTIN_TYPE_NAME_MAP['real'] = \
BUILTIN_TYPE_NAME_MAP['float4']

BUILTIN_TYPE_NAME_MAP['double precision'] = \
BUILTIN_TYPE_NAME_MAP['float8']

BUILTIN_TYPE_NAME_MAP['timestamp with timezone'] = \
BUILTIN_TYPE_NAME_MAP['timestamptz']

BUILTIN_TYPE_NAME_MAP['time with timezone'] = \
BUILTIN_TYPE_NAME_MAP['timetz']
2 changes: 1 addition & 1 deletion asyncpg/protocol/settings.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,6 @@ cdef class ConnectionSettings:
self, typeoid, typename, typeschema)
cpdef inline clear_type_cache(self)
cpdef inline set_builtin_type_codec(
self, typeoid, typename, typeschema, typekind, alias_to)
self, typeoid, typename, typeschema, typekind, alias_to, format)
cpdef inline Codec get_data_codec(
self, uint32_t oid, ServerDataFormat format=*)
24 changes: 21 additions & 3 deletions asyncpg/protocol/settings.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0


from asyncpg import exceptions


@cython.final
cdef class ConnectionSettings:

Expand Down Expand Up @@ -48,7 +51,7 @@ cdef class ConnectionSettings:
_format = PG_FORMAT_ANY
xformat = PG_XFORMAT_TUPLE
else:
raise ValueError(
raise exceptions.InterfaceError(
'invalid `format` argument, expected {}, got {!r}'.format(
"'text', 'binary' or 'tuple'", format
))
Expand All @@ -64,9 +67,24 @@ cdef class ConnectionSettings:
self._data_codecs.clear_type_cache()

cpdef inline set_builtin_type_codec(self, typeoid, typename, typeschema,
typekind, alias_to):
typekind, alias_to, format):
cdef:
ServerDataFormat _format

if format is None:
_format = PG_FORMAT_ANY
elif format == 'binary':
_format = PG_FORMAT_BINARY
elif format == 'text':
_format = PG_FORMAT_TEXT
else:
raise exceptions.InterfaceError(
'invalid `format` argument, expected {}, got {!r}'.format(
"'text' or 'binary'", format
))

self._data_codecs.set_builtin_type_codec(typeoid, typename, typeschema,
typekind, alias_to)
typekind, alias_to, _format)

cpdef inline Codec get_data_codec(self, uint32_t oid,
ServerDataFormat format=PG_FORMAT_ANY):
Expand Down
27 changes: 27 additions & 0 deletions docs/usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,33 @@ shows how to instruct asyncpg to use floats instead.
asyncio.get_event_loop().run_until_complete(main())
Example: decoding hstore values
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

hstore_ is an extension data type used for storing key/value pairs.
asyncpg includes a codec to decode and encode hstore values as ``dict``
objects. Because ``hstore`` is not a builtin type, the codec must
be registered on a connection using :meth:`Connection.set_builtin_type_codec()
<asyncpg.connection.Connection.set_builtin_type_codec>`:

.. code-block:: python
import asyncpg
import asyncio
async def run():
conn = await asyncpg.connect()
# Assuming the hstore extension exists in the public schema.
await con.set_builtin_type_codec(
'hstore', codec_name='pg_contrib.hstore')
result = await con.fetchval("SELECT 'a=>1,b=>2'::hstore")
assert result == {'a': 1, 'b': 2}
asyncio.get_event_loop().run_until_complete(run())
.. _hstore: https://www.postgresql.org/docs/current/static/hstore.html


Transactions
------------

Expand Down
Loading

0 comments on commit 687127e

Please sign in to comment.