Skip to content

Commit

Permalink
Use Python's array module instead of creating elements individually
Browse files Browse the repository at this point in the history
Updates for PR 111
  • Loading branch information
tgockel committed May 31, 2021
1 parent cad2aaf commit 435b674
Show file tree
Hide file tree
Showing 5 changed files with 179 additions and 57 deletions.
43 changes: 35 additions & 8 deletions cbor2/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,8 +444,8 @@ def decode_uuid(self):
from uuid import UUID
return self.set_shareable(UUID(bytes=self._decode()))

def _decode_typed_array_impl(self, name, tag, element_size, format):
"""Helper function for decoding typed arrays described by RFC 8746"""
def _decode_typed_array_half_float_impl(self, name, tag, element_size, format):
"""Helper function for decoding typed arrays of half-precision floats"""
buf = self.decode()
if not isinstance(buf, bytes):
raise CBORDecodeValueError("invalid %s typed array %r" % (name, buf))
Expand All @@ -461,15 +461,42 @@ def _decode_typed_array_impl(self, name, tag, element_size, format):
else:
return self.set_shareable(list(out))

def _decode_typed_array_half_fload_func(*args):
return lambda self: self._decode_typed_array_half_float_impl(*args)

decode_array_float16_be = _decode_typed_array_half_fload_func('float16', 80, 2, '>%ie')
decode_array_float16_le = _decode_typed_array_half_fload_func('float16', 84, 2, '<%ie')

def _decode_typed_array_impl(self, name, tag, element_size, typecode, endianness):
"""Helper function for decoding typed arrays described by RFC 8746"""
import array
import sys

buf = self.decode()
if not isinstance(buf, bytes):
raise CBORDecodeValueError("invalid %s typed array %r" % (name, buf))
elif len(buf) % element_size != 0:
raise CBORDecodeValueError(
"invalid length for %s typed array -- must be multiple of %i, but is %i"
% (name, element_size, len(buf)))

out = array.array(typecode, buf)
if sys.byteorder != endianness:
out.byteswap()

if self._immutable:
# TODO(tgockel/111): The returned array is not immutable
return self.set_shareable(out)
else:
return self.set_shareable(out)

def _decode_typed_array_func(*args):
return lambda self: self._decode_typed_array_impl(*args)

decode_array_float16_be = _decode_typed_array_func('float16', 80, 2, '>%ie')
decode_array_float32_be = _decode_typed_array_func('float32', 81, 4, '>%if')
decode_array_float64_be = _decode_typed_array_func('float64', 82, 8, '>%id')
decode_array_float16_le = _decode_typed_array_func('float16', 84, 2, '<%ie')
decode_array_float32_le = _decode_typed_array_func('float32', 85, 4, '<%if')
decode_array_float64_le = _decode_typed_array_func('float64', 86, 8, '<%id')
decode_array_float32_be = _decode_typed_array_func('float32', 81, 4, 'f', 'big')
decode_array_float64_be = _decode_typed_array_func('float64', 82, 8, 'd', 'big')
decode_array_float32_le = _decode_typed_array_func('float32', 85, 4, 'f', 'little')
decode_array_float64_le = _decode_typed_array_func('float64', 86, 8, 'd', 'little')

def decode_set(self):
# Semantic tag 258
Expand Down
151 changes: 103 additions & 48 deletions source/decoder.c
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,11 @@
#define be16toh(x) OSSwapBigToHostInt16(x)
#define be32toh(x) OSSwapBigToHostInt32(x)
#define be64toh(x) OSSwapBigToHostInt64(x)
#define le16toh(x) OSSwapLittleToHostInt16(x)
#define le32toh(x) OSSwapLittleToHostInt32(x)
#define le64toh(x) OSSwapLittleToHostInt64(x)
#elif _WIN32
// All windows platforms are (currently) little-endian so byteswap is required
#define be16toh(x) _byteswap_ushort(x)
#define be32toh(x) _byteswap_ulong(x)
#define be64toh(x) _byteswap_uint64(x)
#define le16toh(x) (x)
#define le32toh(x) (x)
#define le64toh(x) (x)
#endif

enum DecodeOption {
Expand Down Expand Up @@ -1343,18 +1337,18 @@ CBORDecoder_decode_uuid(CBORDecoderObject *self)
}

/**
* Implementation of the decoder for all typed arrays.
* Implementation of the decoder for typed arrays of half-precision floats.
*
* \param[in] self The decoder object
* \param[in] type_name The name of the type to use in error messages
* \param[in] element_size The size of the individual elements (e.g.: sizeof(uint64_t))
* \param[in] create_value Create a Python object from its byte representation
*/
static PyObject *
CBORDecoder_decode_array_typed_impl(CBORDecoderObject *self,
const char *type_name,
size_t element_size,
PyObject * (*create_value)(const char *byte_repr))
CBORDecoder_decode_array_typed_half_float_impl(CBORDecoderObject *self,
const char *type_name,
size_t element_size,
PyObject * (*create_value)(const char *byte_repr))
{
PyObject *bytes, *list, *ret = NULL;
Py_ssize_t bytes_size, element_count, element_idx;
Expand Down Expand Up @@ -1415,6 +1409,15 @@ create_float16_be_from_buffer(const char *src)
}


// CBORDecoder.decode_array_float16_be
static PyObject *
CBORDecoder_decode_array_float16_be(CBORDecoderObject *self)
{
// semantic type 80
return CBORDecoder_decode_array_typed_half_float_impl(self, "float16", 2, create_float16_be_from_buffer);
}


static PyObject *
create_float16_le_from_buffer(const char *src)
{
Expand All @@ -1423,34 +1426,95 @@ create_float16_le_from_buffer(const char *src)
}


#define CBOR_DECODER_GEN_FLOAT_EXTRACT_FN(fn_name_, float_type_, irepr_type_, endian_flip_) \
static PyObject * \
fn_name_(const char *src) \
{ \
irepr_type_ i_repr; \
float_type_ value; \
\
memcpy(&i_repr, src, sizeof i_repr); \
i_repr = endian_flip_(i_repr); \
memcpy(&value, &i_repr, sizeof value); \
\
return PyFloat_FromDouble(value); \
}

CBOR_DECODER_GEN_FLOAT_EXTRACT_FN(create_float32_be_from_buffer, float, uint32_t, be32toh)
CBOR_DECODER_GEN_FLOAT_EXTRACT_FN(create_float64_be_from_buffer, double, uint64_t, be64toh)
CBOR_DECODER_GEN_FLOAT_EXTRACT_FN(create_float32_le_from_buffer, float, uint32_t, le32toh)
CBOR_DECODER_GEN_FLOAT_EXTRACT_FN(create_float64_le_from_buffer, double, uint64_t, le64toh)

#undef CBOR_DECODER_GEN_FLOAT_EXTRACT_FN
// CBORDecoder.decode_array_float16_le
static PyObject *
CBORDecoder_decode_array_float16_le(CBORDecoderObject *self)
{
// semantic type 84
return CBORDecoder_decode_array_typed_half_float_impl(self, "float16", 2, create_float16_le_from_buffer);
}


static bool
plaform_is_big_endian(void)
{
#if defined __BIG_ENDIAN__ \
|| defined __ARMEB__ \
|| defined __THUMBEB__ \
|| defined __AARCH64EB__ \
|| defined _MIBSEB \
|| defined __MIBSEB \
|| defined __MIBSEB__
// Big endian known at compile-time
return true;
#elif defined __LITTLE_ENDIAN__ \
|| defined __ARMEL__ \
|| defined __THUMBEL__ \
|| defined __AARCH64EL__ \
|| defined _MIPSEL \
|| defined __MIPSEL \
|| defined __MIPSEL__
// Little endian known at compile-time
return false;
#else
// Fall back to checking at run-time
char c;
size_t num = 1;
memcpy(&c, &num, 1);
return c == '\0';
#endif
}


// CBORDecoder.decode_array_float16_be
static PyObject *
CBORDecoder_decode_array_float16_be(CBORDecoderObject *self)
CBORDecoder_decode_array_typed_impl(CBORDecoderObject *self,
const char *type_name,
size_t element_size,
PyObject *array_typecode,
bool big_endian)
{
// semantic type 80
return CBORDecoder_decode_array_typed_impl(self, "float16", 2, create_float16_be_from_buffer);
PyObject *bytes, *array, *byteswap_result, *ret = NULL;

if (!_CBOR2_array && _CBOR2_init_array() == -1)
return NULL;

bytes = decode(self, DECODE_UNSHARED);
if (bytes) {
if (PyBytes_CheckExact(bytes)) {
array = PyObject_CallFunctionObjArgs(_CBOR2_array, array_typecode, bytes, NULL);
if (array) {
set_shareable(self, array);

if (plaform_is_big_endian() == big_endian) {
ret = array;
} else {
byteswap_result = PyObject_CallMethodObjArgs(array, _CBOR2_str_byteswap, NULL);
if (byteswap_result) {
Py_DECREF(byteswap_result);
ret = array;
} else {
// byteswap failed -- error is set
Py_DECREF(array);
array = NULL;
}
}

if (array && self->immutable) {
// TODO(tgockel/111): something here to return an immutable object
}
} else {
// array creation failed -- error is set
}
} else {
PyErr_Format(
_CBOR2_CBORDecodeValueError,
"invalid %s typed array %R", type_name, bytes);
}

Py_DECREF(bytes);
}

return ret;
}


Expand All @@ -1459,7 +1523,7 @@ static PyObject *
CBORDecoder_decode_array_float32_be(CBORDecoderObject *self)
{
// semantic type 81
return CBORDecoder_decode_array_typed_impl(self, "float32", 4, create_float32_be_from_buffer);
return CBORDecoder_decode_array_typed_impl(self, "float32", 4, _CBOR2_str_f, true);
}


Expand All @@ -1468,16 +1532,7 @@ static PyObject *
CBORDecoder_decode_array_float64_be(CBORDecoderObject *self)
{
// semantic type 82
return CBORDecoder_decode_array_typed_impl(self, "float64", 8, create_float64_be_from_buffer);
}


// CBORDecoder.decode_array_float16_le
static PyObject *
CBORDecoder_decode_array_float16_le(CBORDecoderObject *self)
{
// semantic type 84
return CBORDecoder_decode_array_typed_impl(self, "float16", 2, create_float16_le_from_buffer);
return CBORDecoder_decode_array_typed_impl(self, "float64", 8, _CBOR2_str_d, true);
}


Expand All @@ -1486,7 +1541,7 @@ static PyObject *
CBORDecoder_decode_array_float32_le(CBORDecoderObject *self)
{
// semantic type 85
return CBORDecoder_decode_array_typed_impl(self, "float32", 4, create_float32_le_from_buffer);
return CBORDecoder_decode_array_typed_impl(self, "float32", 4, _CBOR2_str_f, false);
}


Expand All @@ -1495,7 +1550,7 @@ static PyObject *
CBORDecoder_decode_array_float64_le(CBORDecoderObject *self)
{
// semantic type 86
return CBORDecoder_decode_array_typed_impl(self, "float64", 8, create_float64_le_from_buffer);
return CBORDecoder_decode_array_typed_impl(self, "float64", 8, _CBOR2_str_d, false);
}


Expand Down
30 changes: 30 additions & 0 deletions source/module.c
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,26 @@ CBOR2_loads(PyObject *module, PyObject *args, PyObject *kwargs)

// Cache-init functions //////////////////////////////////////////////////////

int
_CBOR2_init_array(void)
{
PyObject *array;

array = PyImport_ImportModule("array");
if (!array)
goto error;
_CBOR2_array = PyObject_GetAttr(array, _CBOR2_str_array);
Py_DECREF(array);
if (!_CBOR2_array)
goto error;
return 0;

error:
PyErr_SetString(PyExc_ImportError,
"unable to import array from array");
return -1;
}

int
_CBOR2_init_BytesIO(void)
{
Expand Down Expand Up @@ -580,19 +600,23 @@ _CBOR2_init_ip_address(void)

PyObject *_CBOR2_empty_bytes = NULL;
PyObject *_CBOR2_empty_str = NULL;
PyObject *_CBOR2_str_array = NULL;
PyObject *_CBOR2_str_as_string = NULL;
PyObject *_CBOR2_str_as_tuple = NULL;
PyObject *_CBOR2_str_bit_length = NULL;
PyObject *_CBOR2_str_bytes = NULL;
PyObject *_CBOR2_str_byteswap = NULL;
PyObject *_CBOR2_str_BytesIO = NULL;
PyObject *_CBOR2_str_canonical_encoders = NULL;
PyObject *_CBOR2_str_compile = NULL;
PyObject *_CBOR2_str_copy = NULL;
PyObject *_CBOR2_str_d = NULL;
PyObject *_CBOR2_str_datestr_re = NULL;
PyObject *_CBOR2_str_Decimal = NULL;
PyObject *_CBOR2_str_default_encoders = NULL;
PyObject *_CBOR2_str_denominator = NULL;
PyObject *_CBOR2_str_encode_date = NULL;
PyObject *_CBOR2_str_f = NULL;
PyObject *_CBOR2_str_Fraction = NULL;
PyObject *_CBOR2_str_fromtimestamp = NULL;
PyObject *_CBOR2_str_FrozenDict = NULL;
Expand Down Expand Up @@ -633,6 +657,7 @@ PyObject *_CBOR2_CBORDecodeEOF = NULL;

PyObject *_CBOR2_timezone = NULL;
PyObject *_CBOR2_timezone_utc = NULL;
PyObject *_CBOR2_array = NULL;
PyObject *_CBOR2_BytesIO = NULL;
PyObject *_CBOR2_Decimal = NULL;
PyObject *_CBOR2_Fraction = NULL;
Expand All @@ -652,6 +677,7 @@ cbor2_free(PyObject *m)
{
Py_CLEAR(_CBOR2_timezone_utc);
Py_CLEAR(_CBOR2_timezone);
Py_CLEAR(_CBOR2_array);
Py_CLEAR(_CBOR2_BytesIO);
Py_CLEAR(_CBOR2_Decimal);
Py_CLEAR(_CBOR2_Fraction);
Expand Down Expand Up @@ -908,18 +934,22 @@ PyInit__cbor2(void)
!(_CBOR2_str_##name = PyUnicode_InternFromString(#name))) \
goto error;

INTERN_STRING(array);
INTERN_STRING(as_string);
INTERN_STRING(as_tuple);
INTERN_STRING(bit_length);
INTERN_STRING(bytes);
INTERN_STRING(byteswap);
INTERN_STRING(BytesIO);
INTERN_STRING(canonical_encoders);
INTERN_STRING(compile);
INTERN_STRING(copy);
INTERN_STRING(d);
INTERN_STRING(Decimal);
INTERN_STRING(default_encoders);
INTERN_STRING(denominator);
INTERN_STRING(encode_date);
INTERN_STRING(f);
INTERN_STRING(Fraction);
INTERN_STRING(fromtimestamp);
INTERN_STRING(FrozenDict);
Expand Down
Loading

0 comments on commit 435b674

Please sign in to comment.