From 435b6746edcbd58ef487dae437d3ef4f235c3083 Mon Sep 17 00:00:00 2001 From: Travis Gockel Date: Mon, 31 May 2021 00:43:29 -0600 Subject: [PATCH] Use Python's array module instead of creating elements individually Updates for PR 111 --- cbor2/decoder.py | 43 +++++++++--- source/decoder.c | 151 ++++++++++++++++++++++++++++-------------- source/module.c | 30 +++++++++ source/module.h | 6 ++ tests/test_decoder.py | 6 +- 5 files changed, 179 insertions(+), 57 deletions(-) diff --git a/cbor2/decoder.py b/cbor2/decoder.py index 6d920cd9..53b51006 100644 --- a/cbor2/decoder.py +++ b/cbor2/decoder.py @@ -444,8 +444,8 @@ def decode_uuid(self): from uuid import UUID return self.set_shareable(UUID(bytes=self._decode())) - def _decode_typed_array_impl(self, name, tag, element_size, format): - """Helper function for decoding typed arrays described by RFC 8746""" + def _decode_typed_array_half_float_impl(self, name, tag, element_size, format): + """Helper function for decoding typed arrays of half-precision floats""" buf = self.decode() if not isinstance(buf, bytes): raise CBORDecodeValueError("invalid %s typed array %r" % (name, buf)) @@ -461,15 +461,42 @@ def _decode_typed_array_impl(self, name, tag, element_size, format): else: return self.set_shareable(list(out)) + def _decode_typed_array_half_fload_func(*args): + return lambda self: self._decode_typed_array_half_float_impl(*args) + + decode_array_float16_be = _decode_typed_array_half_fload_func('float16', 80, 2, '>%ie') + decode_array_float16_le = _decode_typed_array_half_fload_func('float16', 84, 2, '<%ie') + + def _decode_typed_array_impl(self, name, tag, element_size, typecode, endianness): + """Helper function for decoding typed arrays described by RFC 8746""" + import array + import sys + + buf = self.decode() + if not isinstance(buf, bytes): + raise CBORDecodeValueError("invalid %s typed array %r" % (name, buf)) + elif len(buf) % element_size != 0: + raise CBORDecodeValueError( + "invalid length for %s typed array -- must be multiple of %i, but is %i" + % (name, element_size, len(buf))) + + out = array.array(typecode, buf) + if sys.byteorder != endianness: + out.byteswap() + + if self._immutable: + # TODO(tgockel/111): The returned array is not immutable + return self.set_shareable(out) + else: + return self.set_shareable(out) + def _decode_typed_array_func(*args): return lambda self: self._decode_typed_array_impl(*args) - decode_array_float16_be = _decode_typed_array_func('float16', 80, 2, '>%ie') - decode_array_float32_be = _decode_typed_array_func('float32', 81, 4, '>%if') - decode_array_float64_be = _decode_typed_array_func('float64', 82, 8, '>%id') - decode_array_float16_le = _decode_typed_array_func('float16', 84, 2, '<%ie') - decode_array_float32_le = _decode_typed_array_func('float32', 85, 4, '<%if') - decode_array_float64_le = _decode_typed_array_func('float64', 86, 8, '<%id') + decode_array_float32_be = _decode_typed_array_func('float32', 81, 4, 'f', 'big') + decode_array_float64_be = _decode_typed_array_func('float64', 82, 8, 'd', 'big') + decode_array_float32_le = _decode_typed_array_func('float32', 85, 4, 'f', 'little') + decode_array_float64_le = _decode_typed_array_func('float64', 86, 8, 'd', 'little') def decode_set(self): # Semantic tag 258 diff --git a/source/decoder.c b/source/decoder.c index 07b7725b..4b1d4a10 100644 --- a/source/decoder.c +++ b/source/decoder.c @@ -23,17 +23,11 @@ #define be16toh(x) OSSwapBigToHostInt16(x) #define be32toh(x) OSSwapBigToHostInt32(x) #define be64toh(x) OSSwapBigToHostInt64(x) -#define le16toh(x) OSSwapLittleToHostInt16(x) -#define le32toh(x) OSSwapLittleToHostInt32(x) -#define le64toh(x) OSSwapLittleToHostInt64(x) #elif _WIN32 // All windows platforms are (currently) little-endian so byteswap is required #define be16toh(x) _byteswap_ushort(x) #define be32toh(x) _byteswap_ulong(x) #define be64toh(x) _byteswap_uint64(x) -#define le16toh(x) (x) -#define le32toh(x) (x) -#define le64toh(x) (x) #endif enum DecodeOption { @@ -1343,7 +1337,7 @@ CBORDecoder_decode_uuid(CBORDecoderObject *self) } /** - * Implementation of the decoder for all typed arrays. + * Implementation of the decoder for typed arrays of half-precision floats. * * \param[in] self The decoder object * \param[in] type_name The name of the type to use in error messages @@ -1351,10 +1345,10 @@ CBORDecoder_decode_uuid(CBORDecoderObject *self) * \param[in] create_value Create a Python object from its byte representation */ static PyObject * -CBORDecoder_decode_array_typed_impl(CBORDecoderObject *self, - const char *type_name, - size_t element_size, - PyObject * (*create_value)(const char *byte_repr)) +CBORDecoder_decode_array_typed_half_float_impl(CBORDecoderObject *self, + const char *type_name, + size_t element_size, + PyObject * (*create_value)(const char *byte_repr)) { PyObject *bytes, *list, *ret = NULL; Py_ssize_t bytes_size, element_count, element_idx; @@ -1415,6 +1409,15 @@ create_float16_be_from_buffer(const char *src) } +// CBORDecoder.decode_array_float16_be +static PyObject * +CBORDecoder_decode_array_float16_be(CBORDecoderObject *self) +{ + // semantic type 80 + return CBORDecoder_decode_array_typed_half_float_impl(self, "float16", 2, create_float16_be_from_buffer); +} + + static PyObject * create_float16_le_from_buffer(const char *src) { @@ -1423,34 +1426,95 @@ create_float16_le_from_buffer(const char *src) } -#define CBOR_DECODER_GEN_FLOAT_EXTRACT_FN(fn_name_, float_type_, irepr_type_, endian_flip_) \ - static PyObject * \ - fn_name_(const char *src) \ - { \ - irepr_type_ i_repr; \ - float_type_ value; \ - \ - memcpy(&i_repr, src, sizeof i_repr); \ - i_repr = endian_flip_(i_repr); \ - memcpy(&value, &i_repr, sizeof value); \ - \ - return PyFloat_FromDouble(value); \ - } - -CBOR_DECODER_GEN_FLOAT_EXTRACT_FN(create_float32_be_from_buffer, float, uint32_t, be32toh) -CBOR_DECODER_GEN_FLOAT_EXTRACT_FN(create_float64_be_from_buffer, double, uint64_t, be64toh) -CBOR_DECODER_GEN_FLOAT_EXTRACT_FN(create_float32_le_from_buffer, float, uint32_t, le32toh) -CBOR_DECODER_GEN_FLOAT_EXTRACT_FN(create_float64_le_from_buffer, double, uint64_t, le64toh) - -#undef CBOR_DECODER_GEN_FLOAT_EXTRACT_FN +// CBORDecoder.decode_array_float16_le +static PyObject * +CBORDecoder_decode_array_float16_le(CBORDecoderObject *self) +{ + // semantic type 84 + return CBORDecoder_decode_array_typed_half_float_impl(self, "float16", 2, create_float16_le_from_buffer); +} + + +static bool +plaform_is_big_endian(void) +{ + #if defined __BIG_ENDIAN__ \ + || defined __ARMEB__ \ + || defined __THUMBEB__ \ + || defined __AARCH64EB__ \ + || defined _MIBSEB \ + || defined __MIBSEB \ + || defined __MIBSEB__ + // Big endian known at compile-time + return true; + #elif defined __LITTLE_ENDIAN__ \ + || defined __ARMEL__ \ + || defined __THUMBEL__ \ + || defined __AARCH64EL__ \ + || defined _MIPSEL \ + || defined __MIPSEL \ + || defined __MIPSEL__ + // Little endian known at compile-time + return false; + #else + // Fall back to checking at run-time + char c; + size_t num = 1; + memcpy(&c, &num, 1); + return c == '\0'; + #endif +} -// CBORDecoder.decode_array_float16_be static PyObject * -CBORDecoder_decode_array_float16_be(CBORDecoderObject *self) +CBORDecoder_decode_array_typed_impl(CBORDecoderObject *self, + const char *type_name, + size_t element_size, + PyObject *array_typecode, + bool big_endian) { - // semantic type 80 - return CBORDecoder_decode_array_typed_impl(self, "float16", 2, create_float16_be_from_buffer); + PyObject *bytes, *array, *byteswap_result, *ret = NULL; + + if (!_CBOR2_array && _CBOR2_init_array() == -1) + return NULL; + + bytes = decode(self, DECODE_UNSHARED); + if (bytes) { + if (PyBytes_CheckExact(bytes)) { + array = PyObject_CallFunctionObjArgs(_CBOR2_array, array_typecode, bytes, NULL); + if (array) { + set_shareable(self, array); + + if (plaform_is_big_endian() == big_endian) { + ret = array; + } else { + byteswap_result = PyObject_CallMethodObjArgs(array, _CBOR2_str_byteswap, NULL); + if (byteswap_result) { + Py_DECREF(byteswap_result); + ret = array; + } else { + // byteswap failed -- error is set + Py_DECREF(array); + array = NULL; + } + } + + if (array && self->immutable) { + // TODO(tgockel/111): something here to return an immutable object + } + } else { + // array creation failed -- error is set + } + } else { + PyErr_Format( + _CBOR2_CBORDecodeValueError, + "invalid %s typed array %R", type_name, bytes); + } + + Py_DECREF(bytes); + } + + return ret; } @@ -1459,7 +1523,7 @@ static PyObject * CBORDecoder_decode_array_float32_be(CBORDecoderObject *self) { // semantic type 81 - return CBORDecoder_decode_array_typed_impl(self, "float32", 4, create_float32_be_from_buffer); + return CBORDecoder_decode_array_typed_impl(self, "float32", 4, _CBOR2_str_f, true); } @@ -1468,16 +1532,7 @@ static PyObject * CBORDecoder_decode_array_float64_be(CBORDecoderObject *self) { // semantic type 82 - return CBORDecoder_decode_array_typed_impl(self, "float64", 8, create_float64_be_from_buffer); -} - - -// CBORDecoder.decode_array_float16_le -static PyObject * -CBORDecoder_decode_array_float16_le(CBORDecoderObject *self) -{ - // semantic type 84 - return CBORDecoder_decode_array_typed_impl(self, "float16", 2, create_float16_le_from_buffer); + return CBORDecoder_decode_array_typed_impl(self, "float64", 8, _CBOR2_str_d, true); } @@ -1486,7 +1541,7 @@ static PyObject * CBORDecoder_decode_array_float32_le(CBORDecoderObject *self) { // semantic type 85 - return CBORDecoder_decode_array_typed_impl(self, "float32", 4, create_float32_le_from_buffer); + return CBORDecoder_decode_array_typed_impl(self, "float32", 4, _CBOR2_str_f, false); } @@ -1495,7 +1550,7 @@ static PyObject * CBORDecoder_decode_array_float64_le(CBORDecoderObject *self) { // semantic type 86 - return CBORDecoder_decode_array_typed_impl(self, "float64", 8, create_float64_le_from_buffer); + return CBORDecoder_decode_array_typed_impl(self, "float64", 8, _CBOR2_str_d, false); } diff --git a/source/module.c b/source/module.c index 99ee3af2..965e5bcb 100644 --- a/source/module.c +++ b/source/module.c @@ -376,6 +376,26 @@ CBOR2_loads(PyObject *module, PyObject *args, PyObject *kwargs) // Cache-init functions ////////////////////////////////////////////////////// +int +_CBOR2_init_array(void) +{ + PyObject *array; + + array = PyImport_ImportModule("array"); + if (!array) + goto error; + _CBOR2_array = PyObject_GetAttr(array, _CBOR2_str_array); + Py_DECREF(array); + if (!_CBOR2_array) + goto error; + return 0; + +error: + PyErr_SetString(PyExc_ImportError, + "unable to import array from array"); + return -1; +} + int _CBOR2_init_BytesIO(void) { @@ -580,19 +600,23 @@ _CBOR2_init_ip_address(void) PyObject *_CBOR2_empty_bytes = NULL; PyObject *_CBOR2_empty_str = NULL; +PyObject *_CBOR2_str_array = NULL; PyObject *_CBOR2_str_as_string = NULL; PyObject *_CBOR2_str_as_tuple = NULL; PyObject *_CBOR2_str_bit_length = NULL; PyObject *_CBOR2_str_bytes = NULL; +PyObject *_CBOR2_str_byteswap = NULL; PyObject *_CBOR2_str_BytesIO = NULL; PyObject *_CBOR2_str_canonical_encoders = NULL; PyObject *_CBOR2_str_compile = NULL; PyObject *_CBOR2_str_copy = NULL; +PyObject *_CBOR2_str_d = NULL; PyObject *_CBOR2_str_datestr_re = NULL; PyObject *_CBOR2_str_Decimal = NULL; PyObject *_CBOR2_str_default_encoders = NULL; PyObject *_CBOR2_str_denominator = NULL; PyObject *_CBOR2_str_encode_date = NULL; +PyObject *_CBOR2_str_f = NULL; PyObject *_CBOR2_str_Fraction = NULL; PyObject *_CBOR2_str_fromtimestamp = NULL; PyObject *_CBOR2_str_FrozenDict = NULL; @@ -633,6 +657,7 @@ PyObject *_CBOR2_CBORDecodeEOF = NULL; PyObject *_CBOR2_timezone = NULL; PyObject *_CBOR2_timezone_utc = NULL; +PyObject *_CBOR2_array = NULL; PyObject *_CBOR2_BytesIO = NULL; PyObject *_CBOR2_Decimal = NULL; PyObject *_CBOR2_Fraction = NULL; @@ -652,6 +677,7 @@ cbor2_free(PyObject *m) { Py_CLEAR(_CBOR2_timezone_utc); Py_CLEAR(_CBOR2_timezone); + Py_CLEAR(_CBOR2_array); Py_CLEAR(_CBOR2_BytesIO); Py_CLEAR(_CBOR2_Decimal); Py_CLEAR(_CBOR2_Fraction); @@ -908,18 +934,22 @@ PyInit__cbor2(void) !(_CBOR2_str_##name = PyUnicode_InternFromString(#name))) \ goto error; + INTERN_STRING(array); INTERN_STRING(as_string); INTERN_STRING(as_tuple); INTERN_STRING(bit_length); INTERN_STRING(bytes); + INTERN_STRING(byteswap); INTERN_STRING(BytesIO); INTERN_STRING(canonical_encoders); INTERN_STRING(compile); INTERN_STRING(copy); + INTERN_STRING(d); INTERN_STRING(Decimal); INTERN_STRING(default_encoders); INTERN_STRING(denominator); INTERN_STRING(encode_date); + INTERN_STRING(f); INTERN_STRING(Fraction); INTERN_STRING(fromtimestamp); INTERN_STRING(FrozenDict); diff --git a/source/module.h b/source/module.h index 3335a2e2..4e3c8b24 100644 --- a/source/module.h +++ b/source/module.h @@ -32,19 +32,23 @@ extern PyTypeObject CBORSimpleValueType; // Various interned strings extern PyObject *_CBOR2_empty_bytes; extern PyObject *_CBOR2_empty_str; +extern PyObject *_CBOR2_str_array; extern PyObject *_CBOR2_str_as_string; extern PyObject *_CBOR2_str_as_tuple; extern PyObject *_CBOR2_str_bit_length; extern PyObject *_CBOR2_str_bytes; +extern PyObject *_CBOR2_str_byteswap; extern PyObject *_CBOR2_str_BytesIO; extern PyObject *_CBOR2_str_canonical_encoders; extern PyObject *_CBOR2_str_compile; extern PyObject *_CBOR2_str_copy; +extern PyObject *_CBOR2_str_d; extern PyObject *_CBOR2_str_datestr_re; extern PyObject *_CBOR2_str_Decimal; extern PyObject *_CBOR2_str_default_encoders; extern PyObject *_CBOR2_str_denominator; extern PyObject *_CBOR2_str_encode_date; +extern PyObject *_CBOR2_str_f; extern PyObject *_CBOR2_str_Fraction; extern PyObject *_CBOR2_str_fromtimestamp; extern PyObject *_CBOR2_str_FrozenDict; @@ -87,6 +91,7 @@ extern PyObject *_CBOR2_CBORDecodeEOF; // Global references (initialized by functions declared below) extern PyObject *_CBOR2_timezone; extern PyObject *_CBOR2_timezone_utc; +extern PyObject *_CBOR2_array; extern PyObject *_CBOR2_BytesIO; extern PyObject *_CBOR2_Decimal; extern PyObject *_CBOR2_Fraction; @@ -100,6 +105,7 @@ extern PyObject *_CBOR2_ip_network; // Initializers for the cached references above int _CBOR2_init_timezone_utc(void); // also handles timezone +int _CBOR2_init_array(void); int _CBOR2_init_BytesIO(void); int _CBOR2_init_Decimal(void); int _CBOR2_init_Fraction(void); diff --git a/tests/test_decoder.py b/tests/test_decoder.py index 5598fd35..2f8fdd98 100644 --- a/tests/test_decoder.py +++ b/tests/test_decoder.py @@ -645,4 +645,8 @@ def test_huge_truncated_string(impl): def test_typed_array_floats(impl, payload, expected): src_bytes = unhexlify(payload) decoded = impl.loads(src_bytes) - assert decoded == expected + + # TODO(tgockel/111): There is probably a cleaner way to write this check + assert len(decoded) == len(expected) + for decoded_elem, expected_elem in zip(decoded, expected): + assert decoded_elem == expected_elem