diff --git a/asyncpg/protocol/codecs/numeric.pyx b/asyncpg/protocol/codecs/numeric.pyx index f7490951..71a97917 100644 --- a/asyncpg/protocol/codecs/numeric.pyx +++ b/asyncpg/protocol/codecs/numeric.pyx @@ -5,24 +5,319 @@ # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 +from libc.math cimport abs, log10 +from libc.stdio cimport snprintf + import decimal +from asyncpg.protocol cimport python + +# defined in postgresql/src/backend/utils/adt/numeric.c +DEF DEC_DIGITS = 4 +DEF MAX_DSCALE = 0x3FFF +DEF NUMERIC_POS = 0x0000 +DEF NUMERIC_NEG = 0x4000 +DEF NUMERIC_NAN = 0xC000 _Dec = decimal.Decimal -cdef numeric_encode(ConnectionSettings settings, WriteBuffer buf, obj): +cdef numeric_encode_text(ConnectionSettings settings, WriteBuffer buf, obj): text_encode(settings, buf, str(obj)) -cdef numeric_decode(ConnectionSettings settings, FastReadBuffer buf): +cdef numeric_decode_text(ConnectionSettings settings, FastReadBuffer buf): return _Dec(text_decode(settings, buf)) +cdef numeric_encode_binary(ConnectionSettings settings, WriteBuffer buf, obj): + cdef: + object dec + object dt + int64_t exponent + int64_t i + int64_t j + tuple pydigits + int64_t num_pydigits + int16_t pgdigit + int64_t num_pgdigits + int16_t dscale + int64_t dweight + int64_t weight + uint16_t sign + int64_t padding_size = 0 + + if isinstance(obj, _Dec): + dec = obj + else: + dec = _Dec(obj) + + dt = dec.as_tuple() + if dt.exponent == 'F': + raise ValueError('numeric type does not support infinite values') + + if dt.exponent == 'n' or dt.exponent == 'N': + # NaN + sign = NUMERIC_NAN + num_pgdigits = 0 + weight = 0 + dscale = 0 + else: + exponent = dt.exponent + if exponent < 0 and -exponent > MAX_DSCALE: + raise ValueError( + 'cannot encode Decimal value into numeric: ' + 'exponent is too small') + + if dt.sign: + sign = NUMERIC_NEG + else: + sign = NUMERIC_POS + + pydigits = dt.digits + num_pydigits = len(pydigits) + + dweight = num_pydigits + exponent - 1 + if dweight >= 0: + weight = (dweight + DEC_DIGITS) // DEC_DIGITS - 1 + else: + weight = -((-dweight - 1) // DEC_DIGITS + 1) + + if weight > 2 ** 16 - 1: + raise ValueError( + 'cannot encode Decimal value into numeric: ' + 'exponent is too large') + + padding_size = \ + (weight + 1) * DEC_DIGITS - (dweight + 1) + num_pgdigits = \ + (num_pydigits + padding_size + DEC_DIGITS - 1) // DEC_DIGITS + + if num_pgdigits > 2 ** 16 - 1: + raise ValueError( + 'cannot encode Decimal value into numeric: ' + 'number of digits is too large') + + # Pad decimal digits to provide room for correct Postgres + # digit alignment in the digit computation loop. + pydigits = (0,) * DEC_DIGITS + pydigits + (0,) * DEC_DIGITS + + if exponent < 0: + if -exponent > MAX_DSCALE: + raise ValueError( + 'cannot encode Decimal value into numeric: ' + 'exponent is too small') + dscale = -exponent + else: + dscale = 0 + + buf.write_int32(2 + 2 + 2 + 2 + 2 * num_pgdigits) + buf.write_int16(num_pgdigits) + buf.write_int16(weight) + buf.write_int16(sign) + buf.write_int16(dscale) + + j = DEC_DIGITS - padding_size + + for i in range(num_pgdigits): + pgdigit = (pydigits[j] * 1000 + pydigits[j + 1] * 100 + + pydigits[j + 2] * 10 + pydigits[j + 3]) + j += DEC_DIGITS + buf.write_int16(pgdigit) + + +# The decoding strategy here is to form a string representation of +# the numeric var, as it is faster than passing an iterable of digits. +# For this reason the below code is pure overhead and is ~25% slower +# than the simple text decoder above. That said, we need the binary +# decoder to support binary COPY with numeric values. +cdef numeric_decode_binary(ConnectionSettings settings, FastReadBuffer buf): + cdef: + uint16_t num_pgdigits = hton.unpack_int16(buf.read(2)) + int16_t weight = hton.unpack_int16(buf.read(2)) + uint16_t sign = hton.unpack_int16(buf.read(2)) + uint16_t dscale = hton.unpack_int16(buf.read(2)) + int16_t pgdigit0 + ssize_t i + int16_t pgdigit + object pydigits + ssize_t num_pydigits + ssize_t buf_size + int64_t exponent + int64_t abs_exponent + ssize_t exponent_chars + ssize_t front_padding = 0 + ssize_t trailing_padding = 0 + ssize_t num_fract_digits + ssize_t dscale_left + char smallbuf[_NUMERIC_DECODER_SMALLBUF_SIZE] + char *charbuf + char *bufptr + bint buf_allocated = False + + if sign == NUMERIC_NAN: + # Not-a-number + return _Dec('NaN') + + if num_pgdigits == 0: + # Zero + return _Dec() + + pgdigit0 = hton.unpack_int16(buf.read(2)) + if weight >= 0: + if pgdigit0 < 10: + front_padding = 3 + elif pgdigit0 < 100: + front_padding = 2 + elif pgdigit0 < 1000: + front_padding = 1 + + # Maximum possible number of decimal digits in base 10. + num_pydigits = num_pgdigits * DEC_DIGITS + dscale + # Exponent. + exponent = (weight + 1) * DEC_DIGITS - front_padding + abs_exponent = abs(exponent) + # Number of characters required to render absolute exponent value. + exponent_chars = log10(abs_exponent) + 1 + + buf_size = ( + 1 + # sign + 1 + # leading zero + 1 + # decimal dot + num_pydigits + # digits + 2 + # exponent indicator (E-,E+) + exponent_chars + # exponent + 1 # null terminator char + ) + + if buf_size > _NUMERIC_DECODER_SMALLBUF_SIZE: + charbuf = PyMem_Malloc(buf_size) + buf_allocated = True + else: + charbuf = smallbuf + + try: + bufptr = charbuf + + if sign == NUMERIC_NEG: + bufptr[0] = b'-' + bufptr += 1 + + bufptr[0] = b'0' + bufptr[1] = b'.' + bufptr += 2 + + if weight >= 0: + bufptr = _unpack_digit_stripping_lzeros(bufptr, pgdigit0) + else: + bufptr = _unpack_digit(bufptr, pgdigit0) + + for i in range(1, num_pgdigits): + pgdigit = hton.unpack_int16(buf.read(2)) + bufptr = _unpack_digit(bufptr, pgdigit) + + if dscale: + if weight >= 0: + num_fract_digits = num_pgdigits - weight - 1 + else: + num_fract_digits = num_pgdigits + + # Check how much dscale is left to render (trailing zeros). + dscale_left = dscale - num_fract_digits * DEC_DIGITS + if dscale_left > 0: + for i in range(dscale_left): + bufptr[i] = b'0' + + # If display scale is _less_ than the number of rendered digits, + # dscale_left will be negative and this will strip the excess + # trailing zeros. + bufptr += dscale_left + + if exponent != 0: + bufptr[0] = b'E' + if exponent < 0: + bufptr[1] = b'-' + else: + bufptr[1] = b'+' + bufptr += 2 + snprintf(bufptr, exponent_chars + 1, '%d', + abs_exponent) + bufptr += exponent_chars + + bufptr[0] = 0 + + pydigits = python.PyUnicode_FromString(charbuf) + + return _Dec(pydigits) + + finally: + if buf_allocated: + PyMem_Free(charbuf) + + +cdef inline char *_unpack_digit_stripping_lzeros(char *buf, int64_t pgdigit): + cdef: + int64_t d + bint significant + + d = pgdigit // 1000 + significant = (d > 0) + if significant: + pgdigit -= d * 1000 + buf[0] = (d + b'0') + buf += 1 + + d = pgdigit // 100 + significant |= (d > 0) + if significant: + pgdigit -= d * 100 + buf[0] = (d + b'0') + buf += 1 + + d = pgdigit // 10 + significant |= (d > 0) + if significant: + pgdigit -= d * 10 + buf[0] = (d + b'0') + buf += 1 + + buf[0] = (pgdigit + b'0') + buf += 1 + + return buf + + +cdef inline char *_unpack_digit(char *buf, int64_t pgdigit): + cdef: + int64_t d + + d = pgdigit // 1000 + pgdigit -= d * 1000 + buf[0] = (d + b'0') + + d = pgdigit // 100 + pgdigit -= d * 100 + buf[1] = (d + b'0') + + d = pgdigit // 10 + pgdigit -= d * 10 + buf[2] = (d + b'0') + + buf[3] = (pgdigit + b'0') + buf += 4 + + return buf + + cdef init_numeric_codecs(): register_core_codec(NUMERICOID, - &numeric_encode, - &numeric_decode, + &numeric_encode_text, + &numeric_decode_text, PG_FORMAT_TEXT) + register_core_codec(NUMERICOID, + &numeric_encode_binary, + &numeric_decode_binary, + PG_FORMAT_BINARY) + init_numeric_codecs() diff --git a/asyncpg/protocol/consts.pxi b/asyncpg/protocol/consts.pxi index f880940a..66958227 100644 --- a/asyncpg/protocol/consts.pxi +++ b/asyncpg/protocol/consts.pxi @@ -13,3 +13,4 @@ DEF _MEMORY_FREELIST_SIZE = 1024 DEF _MAXINT32 = 2**31 - 1 DEF _COPY_BUFFER_SIZE = 524288 DEF _COPY_SIGNATURE = b"PGCOPY\n\377\r\n\0" +DEF _NUMERIC_DECODER_SMALLBUF_SIZE = 256 diff --git a/asyncpg/protocol/python.pxd b/asyncpg/protocol/python.pxd index adce0073..3ae2d14b 100644 --- a/asyncpg/protocol/python.pxd +++ b/asyncpg/protocol/python.pxd @@ -23,6 +23,7 @@ cdef extern from "Python.h": object PyMemoryView_FromMemory(char *mem, ssize_t size, int flags) object PyMemoryView_GetContiguous(object, int buffertype, char order) + object PyUnicode_FromString(const char *u) char* PyUnicode_AsUTF8AndSize(object unicode, ssize_t *size) except NULL char* PyByteArray_AsString(object) Py_UCS4* PyUnicode_AsUCS4Copy(object) diff --git a/tests/test_codecs.py b/tests/test_codecs.py index 72f49b38..fc7001ef 100644 --- a/tests/test_codecs.py +++ b/tests/test_codecs.py @@ -452,6 +452,43 @@ async def test_interval(self): res = await self.con.fetchval("SELECT '-5 years -1 month'::interval") self.assertEqual(res, datetime.timedelta(days=-1855)) + async def test_numeric(self): + # Test that we handle dscale correctly. + cases = [ + '0.001', + '0.001000', + '1', + '1.00000' + ] + + for case in cases: + res = await self.con.fetchval( + "SELECT $1::numeric", case) + + self.assertEqual(str(res), case) + + res = await self.con.fetchval( + "SELECT $1::numeric", decimal.Decimal('NaN')) + self.assertTrue(res.is_nan()) + + res = await self.con.fetchval( + "SELECT $1::numeric", decimal.Decimal('sNaN')) + self.assertTrue(res.is_nan()) + + with self.assertRaisesRegex(ValueError, 'numeric type does not ' + 'support infinite values'): + await self.con.fetchval( + "SELECT $1::numeric", decimal.Decimal('-Inf')) + + with self.assertRaisesRegex(ValueError, 'numeric type does not ' + 'support infinite values'): + await self.con.fetchval( + "SELECT $1::numeric", decimal.Decimal('+Inf')) + + with self.assertRaises(decimal.InvalidOperation): + await self.con.fetchval( + "SELECT $1::numeric", 'invalid') + async def test_unhandled_type_fallback(self): await self.con.execute(''' CREATE EXTENSION IF NOT EXISTS isn diff --git a/tests/test_copy.py b/tests/test_copy.py index 45270b35..fb866014 100644 --- a/tests/test_copy.py +++ b/tests/test_copy.py @@ -600,11 +600,22 @@ async def test_copy_records_to_table_1(self): async def test_copy_records_to_table_no_binary_codec(self): await self.con.execute(''' - CREATE TABLE copytab(a numeric); + CREATE TABLE copytab(a uuid); ''') try: - records = [(123.123,)] + def _encoder(value): + return value + + def _decoder(value): + return value + + await self.con.set_type_codec( + 'uuid', encoder=_encoder, decoder=_decoder, + schema='pg_catalog', format='text' + ) + + records = [('2975ab9a-f79c-4ab4-9be5-7bc134d952f0',)] with self.assertRaisesRegex( RuntimeError, 'no binary format encoder'): @@ -612,4 +623,7 @@ async def test_copy_records_to_table_no_binary_codec(self): 'copytab', records=records) finally: + await self.con.reset_type_codec( + 'uuid', schema='pg_catalog' + ) await self.con.execute('DROP TABLE copytab')