Skip to content

api: custom packer and unpacker factories #268

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## Unreleased

### Added
- Support custom packer and unpacker factories (#191).

### Changed

Expand Down
52 changes: 50 additions & 2 deletions tarantool/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,12 @@

import msgpack

from tarantool.response import Response
from tarantool.response import (
unpacker_factory as default_unpacker_factory,
Response,
)
from tarantool.request import (
packer_factory as default_packer_factory,
Request,
# RequestOK,
RequestCall,
Expand Down Expand Up @@ -357,7 +361,9 @@ def __init__(self, host, port,
ssl_key_file=DEFAULT_SSL_KEY_FILE,
ssl_cert_file=DEFAULT_SSL_CERT_FILE,
ssl_ca_file=DEFAULT_SSL_CA_FILE,
ssl_ciphers=DEFAULT_SSL_CIPHERS):
ssl_ciphers=DEFAULT_SSL_CIPHERS,
packer_factory=default_packer_factory,
unpacker_factory=default_unpacker_factory):
"""
:param host: Server hostname or IP address. Use ``None`` for
Unix sockets.
Expand Down Expand Up @@ -395,6 +401,16 @@ def __init__(self, host, port,
:param encoding: ``'utf-8'`` or ``None``. Use ``None`` to work
with non-UTF8 strings.

If non-default
:paramref:`~tarantool.Connection.packer_factory` option is
used, :paramref:`~tarantool.Connection.encoding` option
value is ignored on encode until the factory explicitly uses
its value. If non-default
:paramref:`~tarantool.Connection.unpacker_factory` option is
used, :paramref:`~tarantool.Connection.encoding` option
value is ignored on decode until the factory explicitly uses
its value.

If ``'utf-8'``, pack Unicode string (:obj:`str`) to
MessagePack string (`mp_str`_) and unpack MessagePack string
(`mp_str`_) Unicode string (:obj:`str`), pack :obj:`bytes`
Expand Down Expand Up @@ -429,6 +445,13 @@ def __init__(self, host, port,
:param use_list:
If ``True``, unpack MessagePack array (`mp_array`_) to
:obj:`list`. Otherwise, unpack to :obj:`tuple`.

If non-default
:paramref:`~tarantool.Connection.unpacker_factory` option is
used,
:paramref:`~tarantool.Connection.use_list` option value is
ignored on decode until the factory explicitly uses its
value.
:type use_list: :obj:`bool`, optional

:param call_16:
Expand Down Expand Up @@ -463,6 +486,23 @@ def __init__(self, host, port,
suites the connection can use.
:type ssl_ciphers: :obj:`str` or :obj:`None`, optional

:param packer_factory: Request MessagePack packer factory.
Supersedes :paramref:`~tarantool.Connection.encoding`. See
:func:`~tarantool.request.packer_factory` for example of
a packer factory.
:type packer_factory:
callable[[:obj:`~tarantool.Connection`], :obj:`~msgpack.Packer`],
optional

:param unpacker_factory: Response MessagePack unpacker factory.
Supersedes :paramref:`~tarantool.Connection.encoding` and
:paramref:`~tarantool.Connection.use_list`. See
:func:`~tarantool.response.unpacker_factory` for example of
an unpacker factory.
:type unpacker_factory:
callable[[:obj:`~tarantool.Connection`], :obj:`~msgpack.Unpacker`],
optional

:raise: :exc:`~tarantool.error.ConfigurationError`,
:meth:`~tarantool.Connection.connect` exceptions

Expand Down Expand Up @@ -514,6 +554,8 @@ def __init__(self, host, port,
IPROTO_FEATURE_ERROR_EXTENSION: False,
IPROTO_FEATURE_WATCHERS: False,
}
self._packer_factory_impl = packer_factory
self._unpacker_factory_impl = unpacker_factory

if connect_now:
self.connect()
Expand Down Expand Up @@ -1749,3 +1791,9 @@ def _check_features(self):
features_list = [val for val in CONNECTOR_FEATURES if val in server_features]
for val in features_list:
self._features[val] = True

def _packer_factory(self):
return self._packer_factory_impl(self)

def _unpacker_factory(self):
return self._unpacker_factory_impl(self)
4 changes: 2 additions & 2 deletions tarantool/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@

from tarantool.msgpack_ext.packer import default as packer_default

def build_packer(conn):
def packer_factory(conn):
"""
Build packer to pack request.

Expand Down Expand Up @@ -148,7 +148,7 @@ def __init__(self, conn):
self._body = ''
self.response_class = Response

self.packer = build_packer(conn)
self.packer = conn._packer_factory()

def _dumps(self, src):
"""
Expand Down
4 changes: 2 additions & 2 deletions tarantool/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

from tarantool.msgpack_ext.unpacker import ext_hook as unpacker_ext_hook

def build_unpacker(conn):
def unpacker_factory(conn):
"""
Build unpacker to unpack request response.

Expand Down Expand Up @@ -108,7 +108,7 @@ def __init__(self, conn, response):
# created in the __new__().
# super(Response, self).__init__()

unpacker = build_unpacker(conn)
unpacker = conn._unpacker_factory()

unpacker.feed(response)
header = unpacker.unpack()
Expand Down
4 changes: 3 additions & 1 deletion test/suites/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,16 @@
from .test_package import TestSuite_Package
from .test_error_ext import TestSuite_ErrorExt
from .test_push import TestSuite_Push
from .test_connection import TestSuite_Connection

test_cases = (TestSuite_Schema_UnicodeConnection,
TestSuite_Schema_BinaryConnection,
TestSuite_Request, TestSuite_Protocol, TestSuite_Reconnect,
TestSuite_Mesh, TestSuite_Execute, TestSuite_DBAPI,
TestSuite_Encoding, TestSuite_Pool, TestSuite_Ssl,
TestSuite_Decimal, TestSuite_UUID, TestSuite_Datetime,
TestSuite_Interval, TestSuite_ErrorExt, TestSuite_Push,)
TestSuite_Interval, TestSuite_ErrorExt, TestSuite_Push,
TestSuite_Connection,)

def load_tests(loader, tests, pattern):
suite = unittest.TestSuite()
Expand Down
161 changes: 161 additions & 0 deletions test/suites/test_connection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
import sys
import unittest

import decimal
import msgpack

import tarantool
import tarantool.msgpack_ext.decimal as ext_decimal

from .lib.skip import skip_or_run_decimal_test, skip_or_run_varbinary_test
from .lib.tarantool_server import TarantoolServer

class TestSuite_Connection(unittest.TestCase):
@classmethod
def setUpClass(self):
print(' CONNECTION '.center(70, '='), file=sys.stderr)
print('-' * 70, file=sys.stderr)
self.srv = TarantoolServer()
self.srv.script = 'test/suites/box.lua'
self.srv.start()

self.adm = self.srv.admin
self.adm(r"""
box.schema.user.create('test', {password = 'test', if_not_exists = true})
box.schema.user.grant('test', 'read,write,execute', 'universe')

box.schema.create_space('space_varbin')

box.space['space_varbin']:format({
{
'id',
type = 'number',
is_nullable = false
},
{
'varbin',
type = 'varbinary',
is_nullable = false,
}
})

box.space['space_varbin']:create_index('id', {
type = 'tree',
parts = {1, 'number'},
unique = true})

box.space['space_varbin']:create_index('varbin', {
type = 'tree',
parts = {2, 'varbinary'},
unique = true})
""")

def setUp(self):
# prevent a remote tarantool from clean our session
if self.srv.is_started():
self.srv.touch_lock()

@skip_or_run_decimal_test
def test_custom_packer(self):
def my_ext_type_encoder(obj):
if isinstance(obj, decimal.Decimal):
obj = obj + 1
return msgpack.ExtType(ext_decimal.EXT_ID, ext_decimal.encode(obj, None))
raise TypeError("Unknown type: %r" % (obj,))

def my_packer_factory(_):
return msgpack.Packer(default=my_ext_type_encoder)

self.con = tarantool.Connection(self.srv.host, self.srv.args['primary'],
user='test', password='test',
packer_factory=my_packer_factory)

resp = self.con.eval("return ...", (decimal.Decimal('27756'),))
self.assertSequenceEqual(resp, [decimal.Decimal('27757')])

def test_custom_packer_supersedes_encoding(self):
def my_packer_factory(_):
return msgpack.Packer(use_bin_type=False)

self.con = tarantool.Connection(self.srv.host, self.srv.args['primary'],
user='test', password='test',
encoding='utf-8',
packer_factory=my_packer_factory)

# bytes -> mp_str (string) for encoding=None
# bytes -> mp_bin (varbinary) for encoding='utf-8'
resp = self.con.eval("return type(...)", (bytes(bytearray.fromhex('DEADBEAF0103')),))
self.assertSequenceEqual(resp, ['string'])

@skip_or_run_decimal_test
def test_custom_unpacker(self):
def my_ext_type_decoder(code, data):
if code == ext_decimal.EXT_ID:
return ext_decimal.decode(data, None) - 1
raise NotImplementedError("Unknown msgpack extension type code %d" % (code,))

def my_unpacker_factory(_):
if msgpack.version >= (1, 0, 0):
return msgpack.Unpacker(ext_hook=my_ext_type_decoder, strict_map_key=False)
return msgpack.Unpacker(ext_hook=my_ext_type_decoder)


self.con = tarantool.Connection(self.srv.host, self.srv.args['primary'],
user='test', password='test',
unpacker_factory=my_unpacker_factory)

resp = self.con.eval("return require('decimal').new('27756')")
self.assertSequenceEqual(resp, [decimal.Decimal('27755')])

@skip_or_run_varbinary_test
def test_custom_unpacker_supersedes_encoding(self):
def my_unpacker_factory(_):
if msgpack.version >= (0, 5, 2):
if msgpack.version >= (1, 0, 0):
return msgpack.Unpacker(raw=True, strict_map_key=False)

return msgpack.Unpacker(raw=True)
return msgpack.Unpacker(encoding=None)

self.con = tarantool.Connection(self.srv.host, self.srv.args['primary'],
user='test', password='test',
encoding='utf-8',
unpacker_factory=my_unpacker_factory)

data_id = 1
data_hex = 'DEADBEAF'
data = bytes(bytearray.fromhex(data_hex))
space = 'space_varbin'

self.con.execute("""
INSERT INTO "%s" VALUES (%d, x'%s');
""" % (space, data_id, data_hex))

resp = self.con.execute("""
SELECT * FROM "%s" WHERE "varbin" == x'%s';
""" % (space, data_hex))
self.assertSequenceEqual(resp, [[data_id, data]])

def test_custom_unpacker_supersedes_use_list(self):
def my_unpacker_factory(_):
if msgpack.version >= (1, 0, 0):
return msgpack.Unpacker(use_list=False, strict_map_key=False)
return msgpack.Unpacker(use_list=False)

self.con = tarantool.Connection(self.srv.host, self.srv.args['primary'],
user='test', password='test',
use_list=True,
unpacker_factory=my_unpacker_factory)

resp = self.con.eval("return {1, 2, 3}")
self.assertIsInstance(resp[0], tuple)

@classmethod
def tearDown(self):
if hasattr(self, 'con'):
self.con.close()

@classmethod
def tearDownClass(self):
self.srv.stop()
self.srv.clean()
6 changes: 2 additions & 4 deletions test/suites/test_error_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@

from tarantool.msgpack_ext.packer import default as packer_default
from tarantool.msgpack_ext.unpacker import ext_hook as unpacker_ext_hook
from tarantool.request import build_packer
from tarantool.response import build_unpacker

from .lib.tarantool_server import TarantoolServer
from .lib.skip import skip_or_run_error_ext_type_test
Expand Down Expand Up @@ -273,7 +271,7 @@ def test_msgpack_decode(self):
unpacker_ext_hook(
3,
case['msgpack'],
build_unpacker(conn)
conn._unpacker_factory(),
),
case['python'])

Expand Down Expand Up @@ -330,7 +328,7 @@ def test_msgpack_encode(self):
case = self.cases[name]
conn = getattr(self, case['conn'])

self.assertEqual(packer_default(case['python'], build_packer(conn)),
self.assertEqual(packer_default(case['python'], conn._packer_factory()),
msgpack.ExtType(code=3, data=case['msgpack']))

@skip_or_run_error_ext_type_test
Expand Down
7 changes: 3 additions & 4 deletions test/suites/test_interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

from tarantool.msgpack_ext.packer import default as packer_default
from tarantool.msgpack_ext.unpacker import ext_hook as unpacker_ext_hook
from tarantool.response import build_unpacker

from .lib.tarantool_server import TarantoolServer
from .lib.skip import skip_or_run_datetime_test
Expand Down Expand Up @@ -154,7 +153,7 @@ def test_msgpack_decode(self):
self.assertEqual(unpacker_ext_hook(
6,
case['msgpack'],
build_unpacker(self.con),
self.con._unpacker_factory(),
),
case['python'])

Expand Down Expand Up @@ -206,13 +205,13 @@ def test_unknown_field_decode(self):
case = b'\x01\x09\xce\x00\x98\x96\x80'
self.assertRaisesRegex(
MsgpackError, 'Unknown interval field id 9',
lambda: unpacker_ext_hook(6, case, build_unpacker(self.con)))
lambda: unpacker_ext_hook(6, case, self.con._unpacker_factory()))

def test_unknown_adjust_decode(self):
case = b'\x02\x07\xce\x00\x98\x96\x80\x08\x03'
self.assertRaisesRegex(
MsgpackError, '3 is not a valid Adjust',
lambda: unpacker_ext_hook(6, case, build_unpacker(self.con)))
lambda: unpacker_ext_hook(6, case, self.con._unpacker_factory()))


arithmetic_cases = {
Expand Down