Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use reconstructor function for unpickling #207

Merged
merged 8 commits into from
Aug 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGE_LOG
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
2023-XX-XX 2.8.1:
-------------------
* use reconstructor function for pickling, see #207


2023-07-22 2.8.0:
Expand Down
1 change: 1 addition & 0 deletions bitarray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from __future__ import absolute_import

from bitarray._bitarray import (bitarray, decodetree, _sysinfo,
_bitarray_reconstructor,
get_default_endian, _set_default_endian,
__version__)

Expand Down
86 changes: 68 additions & 18 deletions bitarray/_bitarray.c
Original file line number Diff line number Diff line change
Expand Up @@ -1162,9 +1162,20 @@ When the optional `index` is given, only invert the single bit at index.");
static PyObject *
bitarray_reduce(bitarrayobject *self)
{
const Py_ssize_t nbytes = Py_SIZE(self);
PyObject *dict, *repr = NULL, *result = NULL;
char *str;
static PyObject *reconstructor = NULL;
PyObject *dict, *bytes, *result;

if (reconstructor == NULL) {
PyObject *bitarray_module;

if ((bitarray_module = PyImport_ImportModule("bitarray")) == NULL)
return NULL;
reconstructor = PyObject_GetAttrString(bitarray_module,
"_bitarray_reconstructor");
Py_DECREF(bitarray_module);
if (reconstructor == NULL)
return NULL;
}

dict = PyObject_GetAttrString((PyObject *) self, "__dict__");
if (dict == NULL) {
Expand All @@ -1173,25 +1184,22 @@ bitarray_reduce(bitarrayobject *self)
Py_INCREF(dict);
}

repr = PyBytes_FromStringAndSize(NULL, nbytes + 1);
if (repr == NULL)
goto error;

str = PyBytes_AsString(repr);
/* first byte contains the number of pad bits */
*str = (char) set_padbits(self);
/* remaining bytes contain buffer */
memcpy(str + 1, self->ob_item, (size_t) nbytes);
set_padbits(self);
bytes = PyBytes_FromStringAndSize(self->ob_item, Py_SIZE(self));
if (bytes == NULL) {
Py_DECREF(dict);
return NULL;
}

result = Py_BuildValue("O(Os)O", Py_TYPE(self),
repr, ENDIAN_STR(self->endian), dict);
error:
result = Py_BuildValue("O(OOsii)O", reconstructor, Py_TYPE(self), bytes,
ENDIAN_STR(self->endian), PADBITS(self),
self->readonly, dict);
Py_DECREF(dict);
Py_XDECREF(repr);
Py_DECREF(bytes);
return result;
}

PyDoc_STRVAR(reduce_doc, "state information for pickling");
PyDoc_STRVAR(reduce_doc, "Internal. Used for pickling support.");


static PyObject *
Expand Down Expand Up @@ -3638,7 +3646,7 @@ bitarray_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
if (PyIndex_Check(initial))
return newbitarray_from_index(type, initial, endian);

/* bytes (for pickling) - must have head byte (0x00 .. 0x07) */
/* bytes (for pickling) - to be removed, see #206 */
if (PyBytes_Check(initial) && PyBytes_GET_SIZE(initial) > 0) {
char head = *PyBytes_AS_STRING(initial);
if ((head & 0xf8) == 0)
Expand Down Expand Up @@ -4001,6 +4009,45 @@ static PyTypeObject Bitarray_Type = {

/***************************** Module functions ***************************/

static PyObject *
reconstructor(PyObject *module, PyObject *args)
{
PyTypeObject *type;
Py_ssize_t nbytes;
PyObject *res, *bytes;
char *endian_str;
int endian, padbits, readonly;

if (!PyArg_ParseTuple(args, "OOsii:_bitarray_reconstructor",
&type, &bytes, &endian_str, &padbits, &readonly))
return NULL;

if (!PyBytes_Check(bytes))
return PyErr_Format(PyExc_TypeError, "bytes expected, got '%s'",
Py_TYPE(bytes)->tp_name);

if ((endian = endian_from_string(endian_str)) < 0)
return NULL;

if (padbits < 0 || padbits >= 8)
return PyErr_Format(PyExc_ValueError,
"padbits not in range(0, 8), got %d", padbits);

nbytes = PyBytes_GET_SIZE(bytes);
res = newbitarrayobject(type, 8 * nbytes - padbits, endian);
if (res == NULL)
return NULL;
#define rr ((bitarrayobject *) res)
memcpy(rr->ob_item, PyBytes_AS_STRING(bytes), (size_t) nbytes);
if (readonly) {
set_padbits(rr);
rr->readonly = 1;
}
#undef rr
return res;
}


static PyObject *
get_default_endian(PyObject *module)
{
Expand Down Expand Up @@ -4079,6 +4126,9 @@ Return tuple containing:\n\


static PyMethodDef module_functions[] = {
{"_bitarray_reconstructor",
(PyCFunction) reconstructor, METH_VARARGS,
reduce_doc},
{"get_default_endian", (PyCFunction) get_default_endian, METH_NOARGS,
get_default_endian_doc},
{"_set_default_endian", (PyCFunction) set_default_endian, METH_VARARGS,
Expand Down
File renamed without changes.
Binary file added bitarray/test_281.pickle
Binary file not shown.
13 changes: 9 additions & 4 deletions bitarray/test_bitarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -3709,10 +3709,8 @@ def test_pickle(self):
for key in d1.keys():
self.assertEQUAL(d1[key], d2[key])

@skipIf(sys.version_info[0] == 2)
def test_pickle_load(self):
# the test data file was created using bitarray 1.5.0 / Python 3.5.5
path = os.path.join(os.path.dirname(__file__), 'test_data.pickle')
def check_file(self, fn):
path = os.path.join(os.path.dirname(__file__), fn)
with open(path, 'rb') as fi:
d = pickle.load(fi)

Expand Down Expand Up @@ -3740,6 +3738,13 @@ def test_pickle_load(self):
self.assertTrue(f.readonly)
self.check_obj(f)

@skipIf(sys.version_info[0] == 2)
def test_pickle_load(self):
# test data file was created using bitarray 1.5.0 / Python 3.5.5
self.check_file('test_150.pickle')
# using bitarray 2.8.1 / Python 3.5.5 (_bitarray_reconstructor)
self.check_file('test_281.pickle')

@skipIf(pyodide) # pyodide has no dbm module
def test_shelve(self):
if hasattr(sys, 'gettotalrefcount'):
Expand Down