From 87ba91c3caf275301ff82ffa25ba1f7a0e379628 Mon Sep 17 00:00:00 2001 From: igorcoding Date: Sun, 31 Dec 2023 12:25:39 +0300 Subject: [PATCH 1/2] intervals support (closes #30) --- CHANGELOG.md | 5 + README.md | 3 +- asynctnt/__init__.py | 4 +- asynctnt/iproto/buffer.pxd | 1 + asynctnt/iproto/buffer.pyx | 17 ++ asynctnt/iproto/cmsgpuck.pxd | 4 + asynctnt/iproto/ext.pxd | 32 --- asynctnt/iproto/ext/datetime.pxd | 18 ++ asynctnt/iproto/ext/datetime.pyx | 87 ++++++ asynctnt/iproto/ext/decimal.pxd | 14 + asynctnt/iproto/{ext.pyx => ext/decimal.pyx} | 91 ------ asynctnt/iproto/ext/error.pxd | 15 + asynctnt/iproto/ext/error.pyx | 106 +++++++ asynctnt/iproto/ext/interval.pxd | 29 ++ asynctnt/iproto/ext/interval.pyx | 189 +++++++++++++ asynctnt/iproto/ext/uuid.pxd | 4 + asynctnt/iproto/ext/uuid.pyx | 9 + asynctnt/iproto/protocol.pxd | 6 +- asynctnt/iproto/protocol.pyi | 27 ++ asynctnt/iproto/protocol.pyx | 6 +- asynctnt/iproto/response.pxd | 15 - asynctnt/iproto/response.pyx | 115 +------- asynctnt/iproto/tarantool.pxd | 1 + docs/mpext.md | 38 ++- setup.py | 2 +- tests/test_mp_ext.py | 274 ++++++++++++++++++- 26 files changed, 862 insertions(+), 250 deletions(-) delete mode 100644 asynctnt/iproto/ext.pxd create mode 100644 asynctnt/iproto/ext/datetime.pxd create mode 100644 asynctnt/iproto/ext/datetime.pyx create mode 100644 asynctnt/iproto/ext/decimal.pxd rename asynctnt/iproto/{ext.pyx => ext/decimal.pyx} (55%) create mode 100644 asynctnt/iproto/ext/error.pxd create mode 100644 asynctnt/iproto/ext/error.pyx create mode 100644 asynctnt/iproto/ext/interval.pxd create mode 100644 asynctnt/iproto/ext/interval.pyx create mode 100644 asynctnt/iproto/ext/uuid.pxd create mode 100644 asynctnt/iproto/ext/uuid.pyx diff --git a/CHANGELOG.md b/CHANGELOG.md index daef2f8..c463f0a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,10 @@ # Changelog +## v2.3.0 +**New features:** +* Added support for [interval types](https://www.tarantool.io/en/doc/latest/reference/reference_lua/datetime/interval_object/) [#30](https://github.com/igorcoding/asynctnt/issues/30) + + ## v2.2.0 **New features:** * Implemented ability to send update/upsert requests with field names when schema is disabled (`fetch_schema=False`) and when fields are not found in the schema (good example of this case is using json path like `data.inner1.inner2.key1` as a key) diff --git a/README.md b/README.md index 97d5b78..1f4a24b 100644 --- a/README.md +++ b/README.md @@ -34,7 +34,8 @@ Documentation is available [here](https://igorcoding.github.io/asynctnt). * Full support for [SQL](https://www.tarantool.io/en/doc/latest/tutorials/sql_tutorial/), including [prepared statements](https://www.tarantool.io/en/doc/latest/reference/reference_lua/box_sql/prepare/). * Support for [interactive transaction](https://www.tarantool.io/en/doc/latest/book/box/atomic/txn_mode_mvcc/) via Tarantool streams. -* Support of `Decimal`, `UUID` and `datetime` types natively. +* Support of `Decimal`, `UUID`,`datetime` types natively. +* Support for [interval types](https://www.tarantool.io/en/doc/latest/reference/reference_lua/datetime/interval_object/). * Support for parsing [custom errors](https://www.tarantool.io/en/doc/latest/reference/reference_lua/box_error/new/). * **Schema fetching** on connection establishment, so you can use spaces and indexes names rather than their ids, and **auto refetching** if schema in diff --git a/asynctnt/__init__.py b/asynctnt/__init__.py index 70a4944..4c8be61 100644 --- a/asynctnt/__init__.py +++ b/asynctnt/__init__.py @@ -2,12 +2,14 @@ from .connection import Connection, connect from .iproto.protocol import ( + Adjust, Db, Field, IProtoError, IProtoErrorStackFrame, Iterator, Metadata, + MPInterval, PushIterator, Response, Schema, @@ -16,4 +18,4 @@ TarantoolTuple, ) -__version__ = "2.2.0" +__version__ = "2.3.0" diff --git a/asynctnt/iproto/buffer.pxd b/asynctnt/iproto/buffer.pxd index 967be01..4723084 100644 --- a/asynctnt/iproto/buffer.pxd +++ b/asynctnt/iproto/buffer.pxd @@ -44,6 +44,7 @@ cdef class WriteBuffer: cdef char *mp_encode_decimal(self, char *p, object value) except NULL cdef char *mp_encode_uuid(self, char *p, object value) except NULL cdef char *mp_encode_datetime(self, char *p, object value) except NULL + cdef char *mp_encode_interval(self, char *p, MPInterval value) except NULL cdef char *mp_encode_array(self, char *p, uint32_t len) except NULL cdef char *mp_encode_map(self, char *p, uint32_t len) except NULL cdef char *mp_encode_list(self, char *p, list arr) except NULL diff --git a/asynctnt/iproto/buffer.pyx b/asynctnt/iproto/buffer.pyx index a88957b..d47f100 100644 --- a/asynctnt/iproto/buffer.pyx +++ b/asynctnt/iproto/buffer.pyx @@ -279,6 +279,20 @@ cdef class WriteBuffer: self._length += (p - begin) return p + cdef char *mp_encode_interval(self, char *p, MPInterval value) except NULL: + cdef: + char *begin + char *data_p + uint32_t length + + length = interval_len(value) + p = begin = self._ensure_allocated(p, mp_sizeof_ext(length)) + p = mp_encode_extl(p, tarantool.MP_INTERVAL, length) + p = interval_encode(p, value) + + self._length += (p - begin) + return p + cdef char *mp_encode_array(self, char *p, uint32_t len) except NULL: cdef char *begin p = begin = self._ensure_allocated(p, mp_sizeof_array(len)) @@ -406,6 +420,9 @@ cdef class WriteBuffer: elif isinstance(o, datetime): return self.mp_encode_datetime(p, o) + elif isinstance(o, MPInterval): + return self.mp_encode_interval(p, o) + elif isinstance(o, Decimal): return self.mp_encode_decimal(p, o) diff --git a/asynctnt/iproto/cmsgpuck.pxd b/asynctnt/iproto/cmsgpuck.pxd index 7aad993..f54c7a3 100644 --- a/asynctnt/iproto/cmsgpuck.pxd +++ b/asynctnt/iproto/cmsgpuck.pxd @@ -26,6 +26,9 @@ cdef extern from "../../third_party/msgpuck/msgpuck.h": cdef char *mp_store_u32(char *data, uint32_t val) cdef char *mp_store_u64(char *data, uint64_t val) + cdef ptrdiff_t mp_check_uint(const char *cur, const char *end) + cdef ptrdiff_t mp_check_int(const char *cur, const char *end) + cdef mp_type mp_typeof(const char c) cdef uint32_t mp_sizeof_array(uint32_t size) @@ -43,6 +46,7 @@ cdef extern from "../../third_party/msgpuck/msgpuck.h": cdef uint32_t mp_sizeof_int(int64_t num) cdef char *mp_encode_int(char *data, int64_t num) cdef int64_t mp_decode_int(const char **data) + cdef int mp_read_int64(const char **data, int64_t *ret) cdef uint32_t mp_sizeof_float(float num) cdef char *mp_encode_float(char *data, float num) diff --git a/asynctnt/iproto/ext.pxd b/asynctnt/iproto/ext.pxd deleted file mode 100644 index 1bdab68..0000000 --- a/asynctnt/iproto/ext.pxd +++ /dev/null @@ -1,32 +0,0 @@ -from cpython.datetime cimport datetime -from libc cimport math -from libc.stdint cimport int16_t, int32_t, int64_t, uint8_t, uint32_t - - -cdef inline uint32_t bcd_len(uint32_t digits_len): - return math.floor(digits_len / 2) + 1 - -cdef uint32_t decimal_len(int exponent, uint32_t digits_count) -cdef char *decimal_encode(char *p, - uint32_t digits_count, - uint8_t sign, - tuple digits, - int exponent) except NULL -cdef object decimal_decode(const char ** p, uint32_t length) - -cdef object uuid_decode(const char ** p, uint32_t length) - -cdef struct IProtoDateTime: - int64_t seconds - int32_t nsec - int16_t tzoffset - int16_t tzindex - -cdef void datetime_zero(IProtoDateTime *dt) -cdef uint32_t datetime_len(IProtoDateTime *dt) -cdef char *datetime_encode(char *p, IProtoDateTime *dt) except NULL -cdef int datetime_decode(const char ** p, - uint32_t length, - IProtoDateTime *dt) except -1 -cdef void datetime_from_py(datetime ob, IProtoDateTime *dt) -cdef object datetime_to_py(IProtoDateTime *dt) diff --git a/asynctnt/iproto/ext/datetime.pxd b/asynctnt/iproto/ext/datetime.pxd new file mode 100644 index 0000000..1189f4e --- /dev/null +++ b/asynctnt/iproto/ext/datetime.pxd @@ -0,0 +1,18 @@ +from cpython.datetime cimport datetime +from libc.stdint cimport int16_t, int32_t, int64_t, uint32_t + + +cdef struct IProtoDateTime: + int64_t seconds + int32_t nsec + int16_t tzoffset + int16_t tzindex + +cdef void datetime_zero(IProtoDateTime *dt) +cdef uint32_t datetime_len(IProtoDateTime *dt) +cdef char *datetime_encode(char *p, IProtoDateTime *dt) except NULL +cdef int datetime_decode(const char ** p, + uint32_t length, + IProtoDateTime *dt) except -1 +cdef void datetime_from_py(datetime ob, IProtoDateTime *dt) +cdef object datetime_to_py(IProtoDateTime *dt) diff --git a/asynctnt/iproto/ext/datetime.pyx b/asynctnt/iproto/ext/datetime.pyx new file mode 100644 index 0000000..0694490 --- /dev/null +++ b/asynctnt/iproto/ext/datetime.pyx @@ -0,0 +1,87 @@ +cimport cpython.datetime +from cpython.datetime cimport PyDateTimeAPI, datetime, datetime_tzinfo, timedelta_new +from libc.stdint cimport uint32_t +from libc.string cimport memcpy + + +cdef inline void datetime_zero(IProtoDateTime *dt): + dt.seconds = 0 + dt.nsec = 0 + dt.tzoffset = 0 + dt.tzindex = 0 + +cdef inline uint32_t datetime_len(IProtoDateTime *dt): + cdef uint32_t sz + sz = sizeof(int64_t) + if dt.nsec != 0 or dt.tzoffset != 0 or dt.tzindex != 0: + return sz + DATETIME_TAIL_SZ + return sz + +cdef char *datetime_encode(char *p, IProtoDateTime *dt) except NULL: + store_u64(p, dt.seconds) + p += sizeof(dt.seconds) + if dt.nsec != 0 or dt.tzoffset != 0 or dt.tzindex != 0: + memcpy(p, &dt.nsec, DATETIME_TAIL_SZ) + p += DATETIME_TAIL_SZ + return p + +cdef int datetime_decode( + const char ** p, + uint32_t length, + IProtoDateTime *dt +) except -1: + delta = None + tz = None + + dt.seconds = load_u64(p[0]) + p[0] += sizeof(dt.seconds) + length -= sizeof(dt.seconds) + + if length == 0: + return 0 + + if length != DATETIME_TAIL_SZ: + raise ValueError("invalid datetime size. got {} extra bytes".format( + length + )) + + dt.nsec = load_u32(p[0]) + p[0] += 4 + dt.tzoffset = load_u16(p[0]) + p[0] += 2 + dt.tzindex = load_u16(p[0]) + p[0] += 2 + +cdef void datetime_from_py(datetime ob, IProtoDateTime *dt): + cdef: + double ts + int offset + ts = ob.timestamp() + dt.seconds = ts + dt.nsec = ((ts - dt.seconds) * 1000000) * 1000 + if dt.nsec < 0: + # correction for negative dates + dt.seconds -= 1 + dt.nsec += 1000000000 + + if datetime_tzinfo(ob) is not None: + offset = ob.utcoffset().total_seconds() + dt.tzoffset = (offset / 60) + +cdef object datetime_to_py(IProtoDateTime *dt): + cdef: + double timestamp + object tz + + tz = None + + if dt.tzoffset != 0: + delta = timedelta_new(0, dt.tzoffset * 60, 0) + tz = timezone_new(delta) + + timestamp = dt.seconds + ( dt.nsec) / 1e9 + return PyDateTimeAPI.DateTime_FromTimestamp( + PyDateTimeAPI.DateTimeType, + (timestamp,) if tz is None else (timestamp, tz), + NULL, + ) diff --git a/asynctnt/iproto/ext/decimal.pxd b/asynctnt/iproto/ext/decimal.pxd new file mode 100644 index 0000000..66cccd9 --- /dev/null +++ b/asynctnt/iproto/ext/decimal.pxd @@ -0,0 +1,14 @@ +from libc cimport math +from libc.stdint cimport uint8_t, uint32_t + + +cdef inline uint32_t bcd_len(uint32_t digits_len): + return math.floor(digits_len / 2) + 1 + +cdef uint32_t decimal_len(int exponent, uint32_t digits_count) +cdef char *decimal_encode(char *p, + uint32_t digits_count, + uint8_t sign, + tuple digits, + int exponent) except NULL +cdef object decimal_decode(const char ** p, uint32_t length) diff --git a/asynctnt/iproto/ext.pyx b/asynctnt/iproto/ext/decimal.pyx similarity index 55% rename from asynctnt/iproto/ext.pyx rename to asynctnt/iproto/ext/decimal.pyx index de8e179..a1214f6 100644 --- a/asynctnt/iproto/ext.pyx +++ b/asynctnt/iproto/ext/decimal.pyx @@ -1,10 +1,6 @@ -cimport cpython.datetime -from cpython.datetime cimport PyDateTimeAPI, datetime, datetime_tzinfo, timedelta_new from libc.stdint cimport uint32_t -from libc.string cimport memcpy from decimal import Decimal -from uuid import UUID cdef uint32_t decimal_len(int exponent, uint32_t digits_count): @@ -127,90 +123,3 @@ cdef object decimal_decode(const char ** p, uint32_t length): p[0] += length return Decimal(( sign, digits, exponent)) - -cdef object uuid_decode(const char ** p, uint32_t length): - data = cpython.bytes.PyBytes_FromStringAndSize(p[0], length) - p[0] += length - return UUID(bytes=data) - -cdef inline void datetime_zero(IProtoDateTime *dt): - dt.seconds = 0 - dt.nsec = 0 - dt.tzoffset = 0 - dt.tzindex = 0 - -cdef inline uint32_t datetime_len(IProtoDateTime *dt): - cdef uint32_t sz - sz = sizeof(int64_t) - if dt.nsec != 0 or dt.tzoffset != 0 or dt.tzindex != 0: - return sz + DATETIME_TAIL_SZ - return sz - -cdef char *datetime_encode(char *p, IProtoDateTime *dt) except NULL: - store_u64(p, dt.seconds) - p += sizeof(dt.seconds) - if dt.nsec != 0 or dt.tzoffset != 0 or dt.tzindex != 0: - memcpy(p, &dt.nsec, DATETIME_TAIL_SZ) - p += DATETIME_TAIL_SZ - return p - -cdef int datetime_decode( - const char ** p, - uint32_t length, - IProtoDateTime *dt -) except -1: - delta = None - tz = None - - dt.seconds = load_u64(p[0]) - p[0] += sizeof(dt.seconds) - length -= sizeof(dt.seconds) - - if length == 0: - return 0 - - if length != DATETIME_TAIL_SZ: - raise ValueError("invalid datetime size. got {} extra bytes".format( - length - )) - - dt.nsec = load_u32(p[0]) - p[0] += 4 - dt.tzoffset = load_u16(p[0]) - p[0] += 2 - dt.tzindex = load_u16(p[0]) - p[0] += 2 - -cdef void datetime_from_py(datetime ob, IProtoDateTime *dt): - cdef: - double ts - int offset - ts = ob.timestamp() - dt.seconds = ts - dt.nsec = ((ts - dt.seconds) * 1000000) * 1000 - if dt.nsec < 0: - # correction for negative dates - dt.seconds -= 1 - dt.nsec += 1000000000 - - if datetime_tzinfo(ob) is not None: - offset = ob.utcoffset().total_seconds() - dt.tzoffset = (offset / 60) - -cdef object datetime_to_py(IProtoDateTime *dt): - cdef: - double timestamp - object tz - - tz = None - - if dt.tzoffset != 0: - delta = timedelta_new(0, dt.tzoffset * 60, 0) - tz = timezone_new(delta) - - timestamp = dt.seconds + ( dt.nsec) / 1e9 - return PyDateTimeAPI.DateTime_FromTimestamp( - PyDateTimeAPI.DateTimeType, - (timestamp,) if tz is None else (timestamp, tz), - NULL, - ) diff --git a/asynctnt/iproto/ext/error.pxd b/asynctnt/iproto/ext/error.pxd new file mode 100644 index 0000000..1bb06f8 --- /dev/null +++ b/asynctnt/iproto/ext/error.pxd @@ -0,0 +1,15 @@ +cdef class IProtoErrorStackFrame: + cdef: + readonly str error_type + readonly str file + readonly int line + readonly str message + readonly int err_no + readonly int code + readonly dict fields + +cdef class IProtoError: + cdef: + readonly list trace + +cdef IProtoError iproto_error_decode(const char ** b, bytes encoding) diff --git a/asynctnt/iproto/ext/error.pyx b/asynctnt/iproto/ext/error.pyx new file mode 100644 index 0000000..9ab3e43 --- /dev/null +++ b/asynctnt/iproto/ext/error.pyx @@ -0,0 +1,106 @@ +cimport cpython.list +cimport cython +from libc.stdint cimport uint32_t + + +@cython.final +cdef class IProtoErrorStackFrame: + def __repr__(self): + return "".format( + self.error_type, + self.code, + self.message, + ) + +@cython.final +cdef class IProtoError: + pass + +cdef inline IProtoErrorStackFrame parse_iproto_error_stack_frame(const char ** b, bytes encoding): + cdef: + uint32_t size + uint32_t key + const char * s + uint32_t s_len + IProtoErrorStackFrame frame + uint32_t unum + + size = 0 + key = 0 + + frame = IProtoErrorStackFrame.__new__(IProtoErrorStackFrame) + + size = mp_decode_map(b) + for _ in range(size): + key = mp_decode_uint(b) + + if key == tarantool.MP_ERROR_TYPE: + s = NULL + s_len = 0 + s = mp_decode_str(b, &s_len) + frame.error_type = decode_string(s[:s_len], encoding) + + elif key == tarantool.MP_ERROR_FILE: + s = NULL + s_len = 0 + s = mp_decode_str(b, &s_len) + frame.file = decode_string(s[:s_len], encoding) + + elif key == tarantool.MP_ERROR_LINE: + frame.line = mp_decode_uint(b) + + elif key == tarantool.MP_ERROR_MESSAGE: + s = NULL + s_len = 0 + s = mp_decode_str(b, &s_len) + frame.message = decode_string(s[:s_len], encoding) + + elif key == tarantool.MP_ERROR_ERRNO: + frame.err_no = mp_decode_uint(b) + + elif key == tarantool.MP_ERROR_ERRCODE: + frame.code = mp_decode_uint(b) + + elif key == tarantool.MP_ERROR_FIELDS: + if mp_typeof(b[0][0]) != MP_MAP: # pragma: nocover + raise TypeError(f'iproto_error stack frame fields must be a ' + f'map, but got {mp_typeof(b[0][0])}') + + frame.fields = _decode_obj(b, encoding) + + else: # pragma: nocover + logger.debug(f"unknown iproto_error stack element with key {key}") + mp_next(b) + + return frame + +cdef inline IProtoError iproto_error_decode(const char ** b, bytes encoding): + cdef: + uint32_t size + uint32_t arr_size + uint32_t key + uint32_t i + IProtoError error + + size = 0 + arr_size = 0 + key = 0 + + error = IProtoError.__new__(IProtoError) + + size = mp_decode_map(b) + for _ in range(size): + key = mp_decode_uint(b) + + if key == tarantool.MP_ERROR_STACK: + arr_size = mp_decode_array(b) + error.trace = cpython.list.PyList_New(arr_size) + for i in range(arr_size): + el = parse_iproto_error_stack_frame(b, encoding) + cpython.Py_INCREF(el) + cpython.list.PyList_SET_ITEM(error.trace, i, el) + else: # pragma: nocover + logger.debug(f"unknown iproto_error map field with key {key}") + mp_next(b) + + return error diff --git a/asynctnt/iproto/ext/interval.pxd b/asynctnt/iproto/ext/interval.pxd new file mode 100644 index 0000000..72384be --- /dev/null +++ b/asynctnt/iproto/ext/interval.pxd @@ -0,0 +1,29 @@ +from libc.stdint cimport uint32_t + + +cdef class MPInterval: + cdef: + public int year + public int month + public int week + public int day + public int hour + public int min + public int sec + public int nsec + public object adjust + +cdef enum mp_interval_fields: + MP_INTERVAL_FIELD_YEAR = 0 + MP_INTERVAL_FIELD_MONTH = 1 + MP_INTERVAL_FIELD_WEEK = 2 + MP_INTERVAL_FIELD_DAY = 3 + MP_INTERVAL_FIELD_HOUR = 4 + MP_INTERVAL_FIELD_MINUTE = 5 + MP_INTERVAL_FIELD_SECOND = 6 + MP_INTERVAL_FIELD_NANOSECOND = 7 + MP_INTERVAL_FIELD_ADJUST = 8 + +cdef uint32_t interval_len(MPInterval interval) +cdef char *interval_encode(char *p, MPInterval interval) except NULL +cdef MPInterval interval_decode(const char ** p, uint32_t length) except * diff --git a/asynctnt/iproto/ext/interval.pyx b/asynctnt/iproto/ext/interval.pyx new file mode 100644 index 0000000..0741114 --- /dev/null +++ b/asynctnt/iproto/ext/interval.pyx @@ -0,0 +1,189 @@ +import enum + +from libc.stdint cimport int64_t, uint8_t, uint32_t, uint64_t + + +class Adjust(enum.IntEnum): + """ + Interval adjustment mode for year and month arithmetic. + """ + EXCESS = 0 + NONE = 1 + LAST = 2 + + +cdef class MPInterval: + def __cinit__(self, + int year=0, + int month=0, + int week=0, + int day=0, + int hour=0, + int min=0, + int sec=0, + int nsec=0, + object adjust=Adjust.NONE): + self.year = year + self.month = month + self.week = week + self.day = day + self.hour = hour + self.min = min + self.sec = sec + self.nsec = nsec + self.adjust = adjust + + def __repr__(self): + return (f"asynctnt.Interval(" + f"year={self.year}, " + f"month={self.month}, " + f"week={self.week}, " + f"day={self.day}, " + f"hour={self.hour}, " + f"min={self.min}, " + f"sec={self.sec}, " + f"nsec={self.nsec}, " + f"adjust={self.adjust!r}" + f")") + + def __eq__(self, other): + cdef: + MPInterval other_interval + + if not isinstance(other, MPInterval): + return False + + other_interval = other + + return (self.year == other_interval.year + and self.month == other_interval.month + and self.week == other_interval.week + and self.day == other_interval.day + and self.hour == other_interval.hour + and self.min == other_interval.min + and self.sec == other_interval.sec + and self.nsec == other_interval.nsec + and self.adjust == other_interval.adjust + ) + +cdef uint32_t interval_value_len(int64_t value): + if value == 0: + return 0 + + if value > 0: + return 1 + mp_sizeof_uint( value) + + return 1 + mp_sizeof_int(value) + +cdef char *interval_value_pack(char *data, mp_interval_fields field, int64_t value): + if value == 0: + return data + + data = mp_encode_uint(data, field) + + if value > 0: + return mp_encode_uint(data, value) + + return mp_encode_int(data, value) + +cdef uint32_t interval_len(MPInterval interval): + return (1 + + interval_value_len(interval.year) + + interval_value_len(interval.month) + + interval_value_len(interval.week) + + interval_value_len(interval.day) + + interval_value_len(interval.hour) + + interval_value_len(interval.min) + + interval_value_len(interval.sec) + + interval_value_len(interval.nsec) + + interval_value_len( interval.adjust.value) + ) + +cdef char *interval_encode(char *data, MPInterval interval) except NULL: + cdef: + uint8_t fields_count + + fields_count = ((interval.year != 0) + + (interval.month != 0) + + (interval.week != 0) + + (interval.day != 0) + + (interval.hour != 0) + + (interval.min != 0) + + (interval.sec != 0) + + (interval.nsec != 0) + + (interval.adjust != 0) + ) + data = mp_store_u8(data, fields_count) + data = interval_value_pack(data, MP_INTERVAL_FIELD_YEAR, interval.year) + data = interval_value_pack(data, MP_INTERVAL_FIELD_MONTH, interval.month) + data = interval_value_pack(data, MP_INTERVAL_FIELD_WEEK, interval.week) + data = interval_value_pack(data, MP_INTERVAL_FIELD_DAY, interval.day) + data = interval_value_pack(data, MP_INTERVAL_FIELD_HOUR, interval.hour) + data = interval_value_pack(data, MP_INTERVAL_FIELD_MINUTE, interval.min) + data = interval_value_pack(data, MP_INTERVAL_FIELD_SECOND, interval.sec) + data = interval_value_pack(data, MP_INTERVAL_FIELD_NANOSECOND, interval.nsec) + data = interval_value_pack(data, MP_INTERVAL_FIELD_ADJUST, interval.adjust.value) + return data + +cdef MPInterval interval_decode(const char ** p, + uint32_t length) except*: + cdef: + char *end + MPInterval interval + uint8_t fields_count + int64_t value + uint8_t field_type + mp_type field_value_type + + end = p[0] + length + fields_count = mp_load_u8(p) + length -= sizeof(uint8_t) + if fields_count > 0 and length < 2: + raise ValueError("Invalid MPInterval length") + + interval = MPInterval.__new__(MPInterval) + + # NONE is default but it will be encoded, + # and because zeros are not encoded then we must set a zero value + interval.adjust = Adjust.EXCESS + + for i in range(fields_count): + field_type = mp_load_u8(p) + value = 0 + field_value_type = mp_typeof(p[0][0]) + if field_value_type == MP_UINT: + if mp_check_uint(p[0], end) > 0: + raise ValueError(f"invalid uint. field_type: {field_type}") + + elif field_value_type == MP_INT: + if mp_check_int(p[0], end) > 0: + raise ValueError(f"invalid int. field_type: {field_type}") + + else: + raise ValueError("Invalid MPInterval field value type") + + if mp_read_int64(p, &value) != 0: + raise ValueError("Invalid MPInterval value") + + if field_type == MP_INTERVAL_FIELD_YEAR: + interval.year = value + elif field_type == MP_INTERVAL_FIELD_MONTH: + interval.month = value + elif field_type == MP_INTERVAL_FIELD_WEEK: + interval.week = value + elif field_type == MP_INTERVAL_FIELD_DAY: + interval.day = value + elif field_type == MP_INTERVAL_FIELD_HOUR: + interval.hour = value + elif field_type == MP_INTERVAL_FIELD_MINUTE: + interval.min = value + elif field_type == MP_INTERVAL_FIELD_SECOND: + interval.sec = value + elif field_type == MP_INTERVAL_FIELD_NANOSECOND: + interval.nsec = value + elif field_type == MP_INTERVAL_FIELD_ADJUST: + interval.adjust = Adjust( value) + else: + raise ValueError(f"Invalid MPInterval field type {field_type}") + + return interval diff --git a/asynctnt/iproto/ext/uuid.pxd b/asynctnt/iproto/ext/uuid.pxd new file mode 100644 index 0000000..ffebf13 --- /dev/null +++ b/asynctnt/iproto/ext/uuid.pxd @@ -0,0 +1,4 @@ +from libc.stdint cimport uint32_t + + +cdef object uuid_decode(const char ** p, uint32_t length) diff --git a/asynctnt/iproto/ext/uuid.pyx b/asynctnt/iproto/ext/uuid.pyx new file mode 100644 index 0000000..e76ab86 --- /dev/null +++ b/asynctnt/iproto/ext/uuid.pyx @@ -0,0 +1,9 @@ +from libc.stdint cimport uint32_t + +from uuid import UUID + + +cdef object uuid_decode(const char ** p, uint32_t length): + data = cpython.bytes.PyBytes_FromStringAndSize(p[0], length) + p[0] += length + return UUID(bytes=data) diff --git a/asynctnt/iproto/protocol.pxd b/asynctnt/iproto/protocol.pxd index ce42e94..4fe93b1 100644 --- a/asynctnt/iproto/protocol.pxd +++ b/asynctnt/iproto/protocol.pxd @@ -10,7 +10,11 @@ include "bit.pxd" include "unicodeutil.pxd" include "schema.pxd" -include "ext.pxd" +include "ext/decimal.pxd" +include "ext/uuid.pxd" +include "ext/error.pxd" +include "ext/datetime.pxd" +include "ext/interval.pxd" include "buffer.pxd" include "rbuffer.pxd" diff --git a/asynctnt/iproto/protocol.pyi b/asynctnt/iproto/protocol.pyi index 10ee16a..462a589 100644 --- a/asynctnt/iproto/protocol.pyi +++ b/asynctnt/iproto/protocol.pyi @@ -1,6 +1,8 @@ import asyncio from typing import Any, Dict, Iterator, List, Optional, Tuple, Union +from asynctnt.iproto.protocol import Adjust + class Field: name: Optional[str] """ Field name """ @@ -176,3 +178,28 @@ class Protocol: def is_connected(self) -> bool: ... def is_fully_connected(self) -> bool: ... def get_version(self) -> tuple: ... + +class MPInterval: + year: int + month: int + week: int + day: int + hour: int + min: int + sec: int + nsec: int + adjust: Adjust + + def __init__( + self, + year: int = 0, + month: int = 0, + week: int = 0, + day: int = 0, + hour: int = 0, + min: int = 0, + sec: int = 0, + nsec: int = 0, + adjust: Adjust = Adjust.NONE, + ): ... + def __eq__(self, other) -> bool: ... diff --git a/asynctnt/iproto/protocol.pyx b/asynctnt/iproto/protocol.pyx index 8f479dd..ea6116a 100644 --- a/asynctnt/iproto/protocol.pyx +++ b/asynctnt/iproto/protocol.pyx @@ -14,7 +14,11 @@ include "const.pxi" include "unicodeutil.pyx" include "schema.pyx" -include "ext.pyx" +include "ext/decimal.pyx" +include "ext/uuid.pyx" +include "ext/error.pyx" +include "ext/datetime.pyx" +include "ext/interval.pyx" include "buffer.pyx" include "rbuffer.pyx" diff --git a/asynctnt/iproto/response.pxd b/asynctnt/iproto/response.pxd index 8bb59eb..afac6ae 100644 --- a/asynctnt/iproto/response.pxd +++ b/asynctnt/iproto/response.pxd @@ -8,20 +8,6 @@ cdef struct Header: uint64_t sync int64_t schema_id -cdef class IProtoErrorStackFrame: - cdef: - readonly str error_type - readonly str file - readonly int line - readonly str message - readonly int err_no - readonly int code - readonly dict fields - -cdef class IProtoError: - cdef: - readonly list trace - cdef class Response: cdef: int32_t code_ @@ -65,4 +51,3 @@ cdef ssize_t response_parse_header(const char *buf, uint32_t buf_len, cdef ssize_t response_parse_body(const char *buf, uint32_t buf_len, Response resp, BaseRequest req, bint is_chunk) except -1 -cdef IProtoError parse_iproto_error(const char ** b, bytes encoding) diff --git a/asynctnt/iproto/response.pyx b/asynctnt/iproto/response.pyx index 57ac06b..3a13bf3 100644 --- a/asynctnt/iproto/response.pyx +++ b/asynctnt/iproto/response.pyx @@ -2,8 +2,6 @@ import asyncio import collections from typing import Optional -cimport cpython -cimport cpython.dict cimport cpython.list cimport cython from libc cimport stdio @@ -12,19 +10,6 @@ from libc.stdint cimport uint32_t from asynctnt.log import logger -@cython.final -cdef class IProtoErrorStackFrame: - def __repr__(self): - return "".format( - self.error_type, - self.code, - self.message, - ) - -@cython.final -cdef class IProtoError: - pass - @cython.final @cython.freelist(REQUEST_FREELIST) cdef class Response: @@ -262,16 +247,24 @@ cdef object _decode_obj(const char ** p, bytes encoding): elif obj_type == MP_EXT: ext_type = 0 s_len = mp_decode_extl(p, &ext_type) + if ext_type == tarantool.MP_DECIMAL: return decimal_decode(p, s_len) + elif ext_type == tarantool.MP_UUID: return uuid_decode(p, s_len) + elif ext_type == tarantool.MP_ERROR: - return parse_iproto_error(p, encoding) + return iproto_error_decode(p, encoding) + elif ext_type == tarantool.MP_DATETIME: datetime_zero(&dt) datetime_decode(p, s_len, &dt) return datetime_to_py(&dt) + + elif ext_type == tarantool.MP_INTERVAL: + return interval_decode(p, s_len) + else: # pragma: nocover logger.warning('Unexpected ext type: %d', ext_type) p += s_len # skip unknown ext @@ -443,94 +436,6 @@ cdef Metadata response_parse_metadata(const char ** b, bytes encoding): metadata.add( field_id, field) return metadata -cdef inline IProtoErrorStackFrame parse_iproto_error_stack_frame(const char ** b, bytes encoding): - cdef: - uint32_t size - uint32_t key - const char * s - uint32_t s_len - IProtoErrorStackFrame frame - uint32_t unum - - size = 0 - key = 0 - - frame = IProtoErrorStackFrame.__new__(IProtoErrorStackFrame) - - size = mp_decode_map(b) - for _ in range(size): - key = mp_decode_uint(b) - - if key == tarantool.MP_ERROR_TYPE: - s = NULL - s_len = 0 - s = mp_decode_str(b, &s_len) - frame.error_type = decode_string(s[:s_len], encoding) - - elif key == tarantool.MP_ERROR_FILE: - s = NULL - s_len = 0 - s = mp_decode_str(b, &s_len) - frame.file = decode_string(s[:s_len], encoding) - - elif key == tarantool.MP_ERROR_LINE: - frame.line = mp_decode_uint(b) - - elif key == tarantool.MP_ERROR_MESSAGE: - s = NULL - s_len = 0 - s = mp_decode_str(b, &s_len) - frame.message = decode_string(s[:s_len], encoding) - - elif key == tarantool.MP_ERROR_ERRNO: - frame.err_no = mp_decode_uint(b) - - elif key == tarantool.MP_ERROR_ERRCODE: - frame.code = mp_decode_uint(b) - - elif key == tarantool.MP_ERROR_FIELDS: - if mp_typeof(b[0][0]) != MP_MAP: # pragma: nocover - raise TypeError(f'iproto_error stack frame fields must be a ' - f'map, but got {mp_typeof(b[0][0])}') - - frame.fields = _decode_obj(b, encoding) - - else: # pragma: nocover - logger.debug(f"unknown iproto_error stack element with key {key}") - mp_next(b) - - return frame - -cdef inline IProtoError parse_iproto_error(const char ** b, bytes encoding): - cdef: - uint32_t size - uint32_t arr_size - uint32_t key - uint32_t i - IProtoError error - - size = 0 - arr_size = 0 - key = 0 - - error = IProtoError.__new__(IProtoError) - - size = mp_decode_map(b) - for _ in range(size): - key = mp_decode_uint(b) - - if key == tarantool.MP_ERROR_STACK: - arr_size = mp_decode_array(b) - error.trace = cpython.list.PyList_New(arr_size) - for i in range(arr_size): - el = parse_iproto_error_stack_frame(b, encoding) - cpython.Py_INCREF(el) - cpython.list.PyList_SET_ITEM(error.trace, i, el) - else: # pragma: nocover - logger.debug(f"unknown iproto_error map field with key {key}") - mp_next(b) - - return error cdef ssize_t response_parse_body(const char *buf, uint32_t buf_len, Response resp, BaseRequest req, @@ -575,7 +480,7 @@ cdef ssize_t response_parse_body(const char *buf, uint32_t buf_len, if mp_typeof(b[0]) != MP_MAP: # pragma: nocover raise TypeError('IPROTO_ERROR type must be a MP_MAP') - resp.error = parse_iproto_error(&b, resp.encoding) + resp.error = iproto_error_decode(&b, resp.encoding) elif key == tarantool.IPROTO_STMT_ID: if mp_typeof(b[0]) != MP_UINT: # pragma: nocover diff --git a/asynctnt/iproto/tarantool.pxd b/asynctnt/iproto/tarantool.pxd index 3f22b2b..04e63ad 100644 --- a/asynctnt/iproto/tarantool.pxd +++ b/asynctnt/iproto/tarantool.pxd @@ -104,6 +104,7 @@ cdef enum mp_extension_type: MP_UUID = 2 MP_ERROR = 3 MP_DATETIME = 4 + MP_INTERVAL = 6 cdef enum iproto_features: IPROTO_FEATURE_STREAMS = 0 diff --git a/docs/mpext.md b/docs/mpext.md index 358fbf9..7a77bfe 100644 --- a/docs/mpext.md +++ b/docs/mpext.md @@ -1,6 +1,6 @@ # Type Extensions -Tarantool supports natively Decimal, uuid and Datetime types. `asynctnt` also supports +Tarantool supports natively Decimal, uuid, Datetime and Interval types. `asynctnt` also supports encoding/decoding of such types to Python native `Decimal`, `UUID` and `datetime` types respectively. Some examples: @@ -37,3 +37,39 @@ await conn.insert('wallets', { 'created_at': datetime.datetime.now(tz=Moscow) }) ``` + +## Interval types + +Tarantool has support for an interval type. `asynctnt` also has a support for this type which can be used as follows: + +```python +import asynctnt + +async with asynctnt.Connection() as conn: + resp = await conn.eval(""" + local datetime = require('datetime') + return datetime.interval.new({ + year=1, + month=2, + week=3, + day=4, + hour=5, + min=6, + sec=7, + nsec=8, + }) + """) + + assert resp[0] == asynctnt.MPInterval( + year=1, + month=2, + week=3, + day=4, + hour=5, + min=6, + sec=7, + nsec=8, + ) +``` + +You may use `asynctnt.MPInterval` type also as parameters to Tarantool methods (like call, insert, and others). diff --git a/setup.py b/setup.py index 8337d1c..1189531 100644 --- a/setup.py +++ b/setup.py @@ -42,7 +42,7 @@ def initialize_options(self): self.debug = True self.gdb_debug = True else: - self.cython_always = False + self.cython_always = True self.cython_annotate = None self.cython_directives = None self.gdb_debug = False diff --git a/tests/test_mp_ext.py b/tests/test_mp_ext.py index fad9ed6..bb28552 100644 --- a/tests/test_mp_ext.py +++ b/tests/test_mp_ext.py @@ -7,6 +7,7 @@ import dateutil.parser import pytz +import asynctnt from asynctnt import IProtoError from asynctnt.exceptions import ErrorCode, TarantoolDatabaseError from tests import BaseTarantoolTestCase @@ -19,7 +20,7 @@ class DecimalTestCase: tarantool: str -class MpExtTestCase(BaseTarantoolTestCase): +class MpExtDecimalTestCase(BaseTarantoolTestCase): @ensure_version(min=(2, 2)) async def test__decimal(self): space = "tester_ext_dec" @@ -81,6 +82,8 @@ async def test__decimal(self): ) self.assertEqual(res[0], dec, "matches tarantool decimal") + +class MpExtUUIDTestCase(BaseTarantoolTestCase): @ensure_version(min=(2, 4, 1)) async def test__uuid(self): space = "tester_ext_uuid" @@ -109,6 +112,8 @@ async def test__uuid(self): res = await self.conn.replace(space, [1, val]) self.assertEqual(res[0][1], val) + +class MpExtErrorTestCase(BaseTarantoolTestCase): @ensure_version(min=(2, 4, 1)) async def test__ext_error(self): try: @@ -195,6 +200,8 @@ async def test__ext_error_custom_return_with_disabled_exterror(self): """ ) + +class MpExtDatetimeTestCase(BaseTarantoolTestCase): @ensure_version(min=(2, 10)) async def test__ext_datetime_read(self): resp = await self.conn.eval( @@ -294,6 +301,271 @@ async def test__ext_datetime_write_pytz_america(self): self.assertEqual(dt, res["dt"]) +class MpExtIntervalTestCase(BaseTarantoolTestCase): + @ensure_version(min=(2, 10)) + async def test__ext_interval_read(self): + resp = await self.conn.eval( + """ + local datetime = require('datetime') + return datetime.interval.new({ + year=1, + month=2, + week=3, + day=4, + hour=5, + min=6, + sec=7, + nsec=8, + }) + """ + ) + self.assertEqual( + asynctnt.MPInterval( + year=1, + month=2, + week=3, + day=4, + hour=5, + min=6, + sec=7, + nsec=8, + ), + resp[0], + ) + + @ensure_version(min=(2, 10)) + async def test__ext_interval_read_adjust_last(self): + resp = await self.conn.eval( + """ + local datetime = require('datetime') + return datetime.interval.new({ + year=1, + month=2, + week=3, + day=4, + hour=5, + min=6, + sec=7, + nsec=8, + adjust='last' + }) + """ + ) + self.assertEqual( + asynctnt.MPInterval( + year=1, + month=2, + week=3, + day=4, + hour=5, + min=6, + sec=7, + nsec=8, + adjust=asynctnt.Adjust.LAST, + ), + resp[0], + ) + + @ensure_version(min=(2, 10)) + async def test__ext_interval_read_adjust_excess(self): + resp = await self.conn.eval( + """ + local datetime = require('datetime') + return datetime.interval.new({ + year=1, + month=2, + week=3, + day=4, + hour=5, + min=6, + sec=7, + nsec=8, + adjust='excess' + }) + """ + ) + self.assertEqual( + asynctnt.MPInterval( + year=1, + month=2, + week=3, + day=4, + hour=5, + min=6, + sec=7, + nsec=8, + adjust=asynctnt.Adjust.EXCESS, + ), + resp[0], + ) + + @ensure_version(min=(2, 10)) + async def test__ext_interval_read_all_negative(self): + resp = await self.conn.eval( + """ + local datetime = require('datetime') + return datetime.interval.new({ + year=-1, + month=-2, + week=-3, + day=-4, + hour=-5, + min=-6, + sec=-7, + nsec=-8, + adjust='excess' + }) + """ + ) + self.assertEqual( + asynctnt.MPInterval( + year=-1, + month=-2, + week=-3, + day=-4, + hour=-5, + min=-6, + sec=-7, + nsec=-8, + adjust=asynctnt.Adjust.EXCESS, + ), + resp[0], + ) + + @ensure_version(min=(2, 10)) + async def test__ext_interval_read_all_mixed(self): + resp = await self.conn.eval( + """ + local datetime = require('datetime') + return datetime.interval.new({ + year=1, + month=-2, + week=3, + day=-4, + hour=5, + min=-6, + sec=7, + nsec=-8, + adjust='excess' + }) + """ + ) + self.assertEqual( + asynctnt.MPInterval( + year=1, + month=-2, + week=3, + day=-4, + hour=5, + min=-6, + sec=7, + nsec=-8, + adjust=asynctnt.Adjust.EXCESS, + ), + resp[0], + ) + + @ensure_version(min=(2, 10)) + async def test__ext_interval_read_zeros(self): + resp = await self.conn.eval( + """ + local datetime = require('datetime') + return datetime.interval.new({}) + """ + ) + self.assertEqual( + asynctnt.MPInterval(), + resp[0], + ) + + @ensure_version(min=(2, 10)) + async def test__ext_interval_send(self): + resp = await self.conn.eval( + """ + local args = {...} + local val = args[1] + local datetime = require('datetime') + return val == datetime.interval.new({ + year=1, + month=-2, + week=3, + day=-4, + hour=5, + min=-6, + sec=7, + nsec=-8, + }) + """, + [ + asynctnt.MPInterval( + year=1, + month=-2, + week=3, + day=-4, + hour=5, + min=-6, + sec=7, + nsec=-8, + ) + ], + ) + self.assertTrue(resp[0]) + + @ensure_version(min=(2, 10)) + async def test__ext_interval_send_excess(self): + resp = await self.conn.eval( + """ + local args = {...} + local val = args[1] + local datetime = require('datetime') + return val == datetime.interval.new({ + year=1, + month=-2, + week=3, + day=-4, + hour=5, + min=-6, + sec=7, + nsec=-8, + adjust='excess' + }) + """, + [ + asynctnt.MPInterval( + year=1, + month=-2, + week=3, + day=-4, + hour=5, + min=-6, + sec=7, + nsec=-8, + adjust=asynctnt.Adjust.EXCESS, + ) + ], + ) + self.assertTrue(resp[0]) + + @ensure_version(min=(2, 10)) + async def test__ext_interval_send_with_zeros(self): + resp = await self.conn.eval( + """ + local args = {...} + local val = args[1] + local datetime = require('datetime') + return val == datetime.interval.new({ + year=100, + }) + """, + [ + asynctnt.MPInterval( + year=100, + ) + ], + ) + self.assertTrue(resp[0]) + + def datetime_fromisoformat(s): if sys.version_info < (3, 7, 0): return dateutil.parser.isoparse(s) From 8994fccf6705204a2a82c7f7aeb74e50a8539c36 Mon Sep 17 00:00:00 2001 From: igorcoding Date: Sat, 30 Dec 2023 21:37:38 +0300 Subject: [PATCH 2/2] Exporting features to Connection class --- .gitignore | 2 ++ CHANGELOG.md | 1 + asynctnt/connection.py | 12 +++++++-- asynctnt/iproto/protocol.pxd | 1 + asynctnt/iproto/protocol.pyi | 14 +++++++++++ asynctnt/iproto/protocol.pyx | 9 ++++--- asynctnt/iproto/response.pxd | 14 +++++++++++ asynctnt/iproto/response.pyx | 48 ++++++++++++++++++++++++++++++++++- tests/test_connect.py | 49 ++++++++++++++++++++++++++++++++++++ 9 files changed, 144 insertions(+), 6 deletions(-) diff --git a/.gitignore b/.gitignore index 0f272e6..0033cf3 100644 --- a/.gitignore +++ b/.gitignore @@ -111,3 +111,5 @@ deploy_key* !.ci/deploy_key.enc /core cython_debug + +temp diff --git a/CHANGELOG.md b/CHANGELOG.md index c463f0a..a29ce89 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,7 @@ ## v2.3.0 **New features:** * Added support for [interval types](https://www.tarantool.io/en/doc/latest/reference/reference_lua/datetime/interval_object/) [#30](https://github.com/igorcoding/asynctnt/issues/30) +* Added ability to retrieve IProto features [available](https://www.tarantool.io/en/doc/latest/reference/reference_lua/box_iproto/feature/) in Tarantool using `conn.features` property ## v2.2.0 diff --git a/asynctnt/connection.py b/asynctnt/connection.py index 12e4571..4fd8e3c 100644 --- a/asynctnt/connection.py +++ b/asynctnt/connection.py @@ -444,7 +444,7 @@ async def reconnect(self): await self.disconnect() await self.connect() - async def __aenter__(self): + async def __aenter__(self) -> "Connection": """ Executed on entering the async with section. Connects to Tarantool instance. @@ -606,7 +606,7 @@ def _normalize_api(self): Api.call = Api.call16 Connection.call = Connection.call16 - if self.version < (2, 10): # pragma: nocover + if not self.features.streams: # pragma: nocover def stream_stub(_): raise TarantoolError("streams are available only in Tarantool 2.10+") @@ -627,6 +627,14 @@ def stream(self) -> Stream: stream._set_db(db) return stream + @property + def features(self) -> protocol.IProtoFeatures: + """ + Lookup available Tarantool features - https://www.tarantool.io/en/doc/latest/reference/reference_lua/box_iproto/feature/ + :return: + """ + return self._protocol.features + async def connect(**kwargs) -> Connection: """ diff --git a/asynctnt/iproto/protocol.pxd b/asynctnt/iproto/protocol.pxd index 4fe93b1..f7e5045 100644 --- a/asynctnt/iproto/protocol.pxd +++ b/asynctnt/iproto/protocol.pxd @@ -73,6 +73,7 @@ cdef class BaseProtocol(CoreProtocol): bint _schema_fetch_in_progress object _refetch_schema_future Db _db + IProtoFeatures _features req_execute_func execute object create_future diff --git a/asynctnt/iproto/protocol.pyi b/asynctnt/iproto/protocol.pyi index 462a589..8b3eb2f 100644 --- a/asynctnt/iproto/protocol.pyi +++ b/asynctnt/iproto/protocol.pyi @@ -172,6 +172,8 @@ class Protocol: def schema_id(self) -> int: ... @property def schema(self) -> Schema: ... + @property + def features(self) -> IProtoFeatures: ... def create_db(self, gen_stream_id: bool = False) -> Db: ... def get_common_db(self) -> Db: ... def refetch_schema(self) -> asyncio.Future: ... @@ -203,3 +205,15 @@ class MPInterval: adjust: Adjust = Adjust.NONE, ): ... def __eq__(self, other) -> bool: ... + +class IProtoFeatures: + streams: bool + transactions: bool + error_extension: bool + watchers: bool + pagination: bool + space_and_index_names: bool + watch_once: bool + dml_tuple_extension: bool + call_ret_tuple_extension: bool + call_arg_tuple_extension: bool diff --git a/asynctnt/iproto/protocol.pyx b/asynctnt/iproto/protocol.pyx index ea6116a..0e8601c 100644 --- a/asynctnt/iproto/protocol.pyx +++ b/asynctnt/iproto/protocol.pyx @@ -102,6 +102,7 @@ cdef class BaseProtocol(CoreProtocol): self._schema_fetch_in_progress = False self._refetch_schema_future = None self._db = self._create_db( False) + self._features = IProtoFeatures.__new__(IProtoFeatures) self.execute = self._execute_bad try: @@ -257,9 +258,7 @@ cdef class BaseProtocol(CoreProtocol): return e = f.exception() if not e: - logger.debug('Tarantool[%s:%s] identified successfully', - self.host, self.port) - + self._features = ( f.result()).result_ self.post_con_state = POST_CONNECTION_AUTH self._post_con_state_machine() else: @@ -519,6 +518,10 @@ cdef class BaseProtocol(CoreProtocol): def refetch_schema(self): return self._refetch_schema() + @property + def features(self) -> IProtoFeatures: + return self._features + class Protocol(BaseProtocol, asyncio.Protocol): pass diff --git a/asynctnt/iproto/response.pxd b/asynctnt/iproto/response.pxd index afac6ae..ff01067 100644 --- a/asynctnt/iproto/response.pxd +++ b/asynctnt/iproto/response.pxd @@ -27,6 +27,7 @@ cdef class Response: bint _push_subscribe BaseRequest request_ object _exception + object result_ readonly object _q readonly object _push_event @@ -51,3 +52,16 @@ cdef ssize_t response_parse_header(const char *buf, uint32_t buf_len, cdef ssize_t response_parse_body(const char *buf, uint32_t buf_len, Response resp, BaseRequest req, bint is_chunk) except -1 + +cdef class IProtoFeatures: + cdef: + readonly bint streams + readonly bint transactions + readonly bint error_extension + readonly bint watchers + readonly bint pagination + readonly bint space_and_index_names + readonly bint watch_once + readonly bint dml_tuple_extension + readonly bint call_ret_tuple_extension + readonly bint call_arg_tuple_extension diff --git a/asynctnt/iproto/response.pyx b/asynctnt/iproto/response.pyx index 3a13bf3..a915716 100644 --- a/asynctnt/iproto/response.pyx +++ b/asynctnt/iproto/response.pyx @@ -10,6 +10,24 @@ from libc.stdint cimport uint32_t from asynctnt.log import logger +@cython.final +cdef class IProtoFeatures: + def __repr__(self): + return (f"" + ) + + @cython.final @cython.freelist(REQUEST_FREELIST) cdef class Response: @@ -26,6 +44,7 @@ cdef class Response: self.errmsg = None self.error = None self._rowcount = 0 + self.result_ = None self.body = None self.encoding = None self.metadata = None @@ -451,6 +470,7 @@ cdef ssize_t response_parse_body(const char *buf, uint32_t buf_len, const char *s list data Field field + IProtoFeatures features b = buf # mp_fprint(stdio.stdout, b) @@ -540,7 +560,33 @@ cdef ssize_t response_parse_body(const char *buf, uint32_t buf_len, logger.debug("IProto version: %s", _decode_obj(&b, resp.encoding)) elif key == tarantool.IPROTO_FEATURES: - logger.debug("IProto features available: %s", _decode_obj(&b, resp.encoding)) + features = IProtoFeatures.__new__(IProtoFeatures) + + for item in _decode_obj(&b, resp.encoding): + if item == 0: + features.streams = 1 + elif item == 1: + features.transactions = 1 + elif item == 2: + features.error_extension = 1 + elif item == 3: + features.watchers = 1 + elif item == 4: + features.pagination = 1 + elif item == 5: + features.space_and_index_names = 1 + elif item == 6: + features.watch_once = 1 + elif item == 7: + features.dml_tuple_extension = 1 + elif item == 8: + features.call_ret_tuple_extension = 1 + elif item == 9: + features.call_arg_tuple_extension = 1 + else: + logger.debug("unknown iproto feature available: %d", item) + + resp.result_ = features elif key == tarantool.IPROTO_AUTH_TYPE: logger.debug("IProto auth type: %s", _decode_obj(&b, resp.encoding)) diff --git a/tests/test_connect.py b/tests/test_connect.py index 6bf62b4..97c21e3 100644 --- a/tests/test_connect.py +++ b/tests/test_connect.py @@ -802,3 +802,52 @@ async def state_checker(): await conn.call("box.info") finally: await conn.disconnect() + + async def test__features(self): + async with asynctnt.Connection(host=self.tnt.host, port=self.tnt.port) as conn: + if not check_version( + self, + conn.version, + min=(2, 10), + max=(3, 0), + min_included=True, + max_included=False, + ): + return + + self.assertIsNotNone(conn.features) + self.assertTrue(conn.features.streams) + self.assertTrue(conn.features.watchers) + self.assertTrue(conn.features.error_extension) + self.assertTrue(conn.features.transactions) + self.assertTrue(conn.features.pagination) + + self.assertFalse(conn.features.space_and_index_names) + self.assertFalse(conn.features.watch_once) + self.assertFalse(conn.features.dml_tuple_extension) + self.assertFalse(conn.features.call_ret_tuple_extension) + self.assertFalse(conn.features.call_arg_tuple_extension) + + async def test__features_3_0(self): + async with asynctnt.Connection(host=self.tnt.host, port=self.tnt.port) as conn: + if not check_version( + self, + conn.version, + min=(3, 0), + min_included=True, + max_included=False, + ): + return + + self.assertIsNotNone(conn.features) + self.assertTrue(conn.features.streams) + self.assertTrue(conn.features.watchers) + self.assertTrue(conn.features.error_extension) + self.assertTrue(conn.features.transactions) + self.assertTrue(conn.features.pagination) + + self.assertTrue(conn.features.space_and_index_names) + self.assertTrue(conn.features.watch_once) + self.assertTrue(conn.features.dml_tuple_extension) + self.assertTrue(conn.features.call_ret_tuple_extension) + self.assertTrue(conn.features.call_arg_tuple_extension)