Skip to content

Commit 289f1f8

Browse files
pierreglaserpitrou
authored andcommitted
bpo-35900: Enable custom reduction callback registration in _pickle (GH-12499)
Enable custom reduction callback registration for functions and classes in _pickle.c, using the new Pickler's attribute ``reducer_override``.
1 parent 9a4135e commit 289f1f8

File tree

6 files changed

+227
-24
lines changed

6 files changed

+227
-24
lines changed

Doc/library/pickle.rst

+71
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,18 @@ The :mod:`pickle` module exports two classes, :class:`Pickler` and
356356

357357
.. versionadded:: 3.3
358358

359+
.. method:: reducer_override(self, obj)
360+
361+
Special reducer that can be defined in :class:`Pickler` subclasses. This
362+
method has priority over any reducer in the :attr:`dispatch_table`. It
363+
should conform to the same interface as a :meth:`__reduce__` method, and
364+
can optionally return ``NotImplemented`` to fallback on
365+
:attr:`dispatch_table`-registered reducers to pickle ``obj``.
366+
367+
For a detailed example, see :ref:`reducer_override`.
368+
369+
.. versionadded:: 3.8
370+
359371
.. attribute:: fast
360372

361373
Deprecated. Enable fast mode if set to a true value. The fast mode
@@ -791,6 +803,65 @@ A sample usage might be something like this::
791803
>>> new_reader.readline()
792804
'3: Goodbye!'
793805

806+
.. _reducer_override:
807+
808+
Custom Reduction for Types, Functions, and Other Objects
809+
--------------------------------------------------------
810+
811+
.. versionadded:: 3.8
812+
813+
Sometimes, :attr:`~Pickler.dispatch_table` may not be flexible enough.
814+
In particular we may want to customize pickling based on another criterion
815+
than the object's type, or we may want to customize the pickling of
816+
functions and classes.
817+
818+
For those cases, it is possible to subclass from the :class:`Pickler` class and
819+
implement a :meth:`~Pickler.reducer_override` method. This method can return an
820+
arbitrary reduction tuple (see :meth:`__reduce__`). It can alternatively return
821+
``NotImplemented`` to fallback to the traditional behavior.
822+
823+
If both the :attr:`~Pickler.dispatch_table` and
824+
:meth:`~Pickler.reducer_override` are defined, then
825+
:meth:`~Pickler.reducer_override` method takes priority.
826+
827+
.. Note::
828+
For performance reasons, :meth:`~Pickler.reducer_override` may not be
829+
called for the following objects: ``None``, ``True``, ``False``, and
830+
exact instances of :class:`int`, :class:`float`, :class:`bytes`,
831+
:class:`str`, :class:`dict`, :class:`set`, :class:`frozenset`, :class:`list`
832+
and :class:`tuple`.
833+
834+
Here is a simple example where we allow pickling and reconstructing
835+
a given class::
836+
837+
import io
838+
import pickle
839+
840+
class MyClass:
841+
my_attribute = 1
842+
843+
class MyPickler(pickle.Pickler):
844+
def reducer_override(self, obj):
845+
"""Custom reducer for MyClass."""
846+
if getattr(obj, "__name__", None) == "MyClass":
847+
return type, (obj.__name__, obj.__bases__,
848+
{'my_attribute': obj.my_attribute})
849+
else:
850+
# For any other object, fallback to usual reduction
851+
return NotImplemented
852+
853+
f = io.BytesIO()
854+
p = MyPickler(f)
855+
p.dump(MyClass)
856+
857+
del MyClass
858+
859+
unpickled_class = pickle.loads(f.getvalue())
860+
861+
assert isinstance(unpickled_class, type)
862+
assert unpickled_class.__name__ == "MyClass"
863+
assert unpickled_class.my_attribute == 1
864+
794865

795866
.. _pickle-restrict:
796867

Lib/pickle.py

+28-20
Original file line numberDiff line numberDiff line change
@@ -497,34 +497,42 @@ def save(self, obj, save_persistent_id=True):
497497
self.write(self.get(x[0]))
498498
return
499499

500-
# Check the type dispatch table
501-
t = type(obj)
502-
f = self.dispatch.get(t)
503-
if f is not None:
504-
f(self, obj) # Call unbound method with explicit self
505-
return
506-
507-
# Check private dispatch table if any, or else copyreg.dispatch_table
508-
reduce = getattr(self, 'dispatch_table', dispatch_table).get(t)
500+
rv = NotImplemented
501+
reduce = getattr(self, "reducer_override", None)
509502
if reduce is not None:
510503
rv = reduce(obj)
511-
else:
512-
# Check for a class with a custom metaclass; treat as regular class
513-
if issubclass(t, type):
514-
self.save_global(obj)
504+
505+
if rv is NotImplemented:
506+
# Check the type dispatch table
507+
t = type(obj)
508+
f = self.dispatch.get(t)
509+
if f is not None:
510+
f(self, obj) # Call unbound method with explicit self
515511
return
516512

517-
# Check for a __reduce_ex__ method, fall back to __reduce__
518-
reduce = getattr(obj, "__reduce_ex__", None)
513+
# Check private dispatch table if any, or else
514+
# copyreg.dispatch_table
515+
reduce = getattr(self, 'dispatch_table', dispatch_table).get(t)
519516
if reduce is not None:
520-
rv = reduce(self.proto)
517+
rv = reduce(obj)
521518
else:
522-
reduce = getattr(obj, "__reduce__", None)
519+
# Check for a class with a custom metaclass; treat as regular
520+
# class
521+
if issubclass(t, type):
522+
self.save_global(obj)
523+
return
524+
525+
# Check for a __reduce_ex__ method, fall back to __reduce__
526+
reduce = getattr(obj, "__reduce_ex__", None)
523527
if reduce is not None:
524-
rv = reduce()
528+
rv = reduce(self.proto)
525529
else:
526-
raise PicklingError("Can't pickle %r object: %r" %
527-
(t.__name__, obj))
530+
reduce = getattr(obj, "__reduce__", None)
531+
if reduce is not None:
532+
rv = reduce()
533+
else:
534+
raise PicklingError("Can't pickle %r object: %r" %
535+
(t.__name__, obj))
528536

529537
# Check for string returned by reduce(), meaning "save as global"
530538
if isinstance(rv, str):

Lib/test/pickletester.py

+68
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import io
55
import functools
66
import os
7+
import math
78
import pickle
89
import pickletools
910
import shutil
@@ -3013,6 +3014,73 @@ def setstate_bbb(obj, state):
30133014
obj.a = "custom state_setter"
30143015

30153016

3017+
3018+
class AbstractCustomPicklerClass:
3019+
"""Pickler implementing a reducing hook using reducer_override."""
3020+
def reducer_override(self, obj):
3021+
obj_name = getattr(obj, "__name__", None)
3022+
3023+
if obj_name == 'f':
3024+
# asking the pickler to save f as 5
3025+
return int, (5, )
3026+
3027+
if obj_name == 'MyClass':
3028+
return str, ('some str',)
3029+
3030+
elif obj_name == 'g':
3031+
# in this case, the callback returns an invalid result (not a 2-5
3032+
# tuple or a string), the pickler should raise a proper error.
3033+
return False
3034+
3035+
elif obj_name == 'h':
3036+
# Simulate a case when the reducer fails. The error should
3037+
# be propagated to the original ``dump`` call.
3038+
raise ValueError('The reducer just failed')
3039+
3040+
return NotImplemented
3041+
3042+
class AbstractHookTests(unittest.TestCase):
3043+
def test_pickler_hook(self):
3044+
# test the ability of a custom, user-defined CPickler subclass to
3045+
# override the default reducing routines of any type using the method
3046+
# reducer_override
3047+
3048+
def f():
3049+
pass
3050+
3051+
def g():
3052+
pass
3053+
3054+
def h():
3055+
pass
3056+
3057+
class MyClass:
3058+
pass
3059+
3060+
for proto in range(0, pickle.HIGHEST_PROTOCOL + 1):
3061+
with self.subTest(proto=proto):
3062+
bio = io.BytesIO()
3063+
p = self.pickler_class(bio, proto)
3064+
3065+
p.dump([f, MyClass, math.log])
3066+
new_f, some_str, math_log = pickle.loads(bio.getvalue())
3067+
3068+
self.assertEqual(new_f, 5)
3069+
self.assertEqual(some_str, 'some str')
3070+
# math.log does not have its usual reducer overriden, so the
3071+
# custom reduction callback should silently direct the pickler
3072+
# to the default pickling by attribute, by returning
3073+
# NotImplemented
3074+
self.assertIs(math_log, math.log)
3075+
3076+
with self.assertRaises(pickle.PicklingError):
3077+
p.dump(g)
3078+
3079+
with self.assertRaisesRegex(
3080+
ValueError, 'The reducer just failed'):
3081+
p.dump(h)
3082+
3083+
30163084
class AbstractDispatchTableTests(unittest.TestCase):
30173085

30183086
def test_default_dispatch_table(self):

Lib/test/test_pickle.py

+16-2
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,15 @@
1111
import unittest
1212
from test import support
1313

14+
from test.pickletester import AbstractHookTests
1415
from test.pickletester import AbstractUnpickleTests
1516
from test.pickletester import AbstractPickleTests
1617
from test.pickletester import AbstractPickleModuleTests
1718
from test.pickletester import AbstractPersistentPicklerTests
1819
from test.pickletester import AbstractIdentityPersistentPicklerTests
1920
from test.pickletester import AbstractPicklerUnpicklerObjectTests
2021
from test.pickletester import AbstractDispatchTableTests
22+
from test.pickletester import AbstractCustomPicklerClass
2123
from test.pickletester import BigmemPickleTests
2224

2325
try:
@@ -253,12 +255,23 @@ class CChainDispatchTableTests(AbstractDispatchTableTests):
253255
def get_dispatch_table(self):
254256
return collections.ChainMap({}, pickle.dispatch_table)
255257

258+
class PyPicklerHookTests(AbstractHookTests):
259+
class CustomPyPicklerClass(pickle._Pickler,
260+
AbstractCustomPicklerClass):
261+
pass
262+
pickler_class = CustomPyPicklerClass
263+
264+
class CPicklerHookTests(AbstractHookTests):
265+
class CustomCPicklerClass(_pickle.Pickler, AbstractCustomPicklerClass):
266+
pass
267+
pickler_class = CustomCPicklerClass
268+
256269
@support.cpython_only
257270
class SizeofTests(unittest.TestCase):
258271
check_sizeof = support.check_sizeof
259272

260273
def test_pickler(self):
261-
basesize = support.calcobjsize('6P2n3i2n3iP')
274+
basesize = support.calcobjsize('6P2n3i2n3i2P')
262275
p = _pickle.Pickler(io.BytesIO())
263276
self.assertEqual(object.__sizeof__(p), basesize)
264277
MT_size = struct.calcsize('3nP0n')
@@ -498,14 +511,15 @@ def test_main():
498511
tests = [PyPickleTests, PyUnpicklerTests, PyPicklerTests,
499512
PyPersPicklerTests, PyIdPersPicklerTests,
500513
PyDispatchTableTests, PyChainDispatchTableTests,
501-
CompatPickleTests]
514+
CompatPickleTests, PyPicklerHookTests]
502515
if has_c_implementation:
503516
tests.extend([CPickleTests, CUnpicklerTests, CPicklerTests,
504517
CPersPicklerTests, CIdPersPicklerTests,
505518
CDumpPickle_LoadPickle, DumpPickle_CLoadPickle,
506519
PyPicklerUnpicklerObjectTests,
507520
CPicklerUnpicklerObjectTests,
508521
CDispatchTableTests, CChainDispatchTableTests,
522+
CPicklerHookTests,
509523
InMemoryPickleTests, SizeofTests])
510524
support.run_unittest(*tests)
511525
support.run_doctest(pickle)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
enable custom reduction callback registration for functions and classes in
2+
_pickle.c, using the new Pickler's attribute ``reducer_override``

Modules/_pickle.c

+42-2
Original file line numberDiff line numberDiff line change
@@ -616,6 +616,9 @@ typedef struct PicklerObject {
616616
PyObject *pers_func_self; /* borrowed reference to self if pers_func
617617
is an unbound method, NULL otherwise */
618618
PyObject *dispatch_table; /* private dispatch_table, can be NULL */
619+
PyObject *reducer_override; /* hook for invoking user-defined callbacks
620+
instead of save_global when pickling
621+
functions and classes*/
619622

620623
PyObject *write; /* write() method of the output stream. */
621624
PyObject *output_buffer; /* Write into a local bytearray buffer before
@@ -1110,6 +1113,7 @@ _Pickler_New(void)
11101113
self->fast_memo = NULL;
11111114
self->max_output_len = WRITE_BUF_SIZE;
11121115
self->output_len = 0;
1116+
self->reducer_override = NULL;
11131117

11141118
self->memo = PyMemoTable_New();
11151119
self->output_buffer = PyBytes_FromStringAndSize(NULL,
@@ -2220,7 +2224,7 @@ save_bytes(PicklerObject *self, PyObject *obj)
22202224
Python 2 *and* the appropriate 'bytes' object when unpickled
22212225
using Python 3. Again this is a hack and we don't need to do this
22222226
with newer protocols. */
2223-
PyObject *reduce_value = NULL;
2227+
PyObject *reduce_value;
22242228
int status;
22252229

22262230
if (PyBytes_GET_SIZE(obj) == 0) {
@@ -4058,7 +4062,25 @@ save(PicklerObject *self, PyObject *obj, int pers_save)
40584062
status = save_tuple(self, obj);
40594063
goto done;
40604064
}
4061-
else if (type == &PyType_Type) {
4065+
4066+
/* Now, check reducer_override. If it returns NotImplemented,
4067+
* fallback to save_type or save_global, and then perhaps to the
4068+
* regular reduction mechanism.
4069+
*/
4070+
if (self->reducer_override != NULL) {
4071+
reduce_value = PyObject_CallFunctionObjArgs(self->reducer_override,
4072+
obj, NULL);
4073+
if (reduce_value == NULL) {
4074+
goto error;
4075+
}
4076+
if (reduce_value != Py_NotImplemented) {
4077+
goto reduce;
4078+
}
4079+
Py_DECREF(reduce_value);
4080+
reduce_value = NULL;
4081+
}
4082+
4083+
if (type == &PyType_Type) {
40624084
status = save_type(self, obj);
40634085
goto done;
40644086
}
@@ -4149,6 +4171,7 @@ save(PicklerObject *self, PyObject *obj, int pers_save)
41494171
if (reduce_value == NULL)
41504172
goto error;
41514173

4174+
reduce:
41524175
if (PyUnicode_Check(reduce_value)) {
41534176
status = save_global(self, obj, reduce_value);
41544177
goto done;
@@ -4180,6 +4203,20 @@ static int
41804203
dump(PicklerObject *self, PyObject *obj)
41814204
{
41824205
const char stop_op = STOP;
4206+
PyObject *tmp;
4207+
_Py_IDENTIFIER(reducer_override);
4208+
4209+
if (_PyObject_LookupAttrId((PyObject *)self, &PyId_reducer_override,
4210+
&tmp) < 0) {
4211+
return -1;
4212+
}
4213+
/* Cache the reducer_override method, if it exists. */
4214+
if (tmp != NULL) {
4215+
Py_XSETREF(self->reducer_override, tmp);
4216+
}
4217+
else {
4218+
Py_CLEAR(self->reducer_override);
4219+
}
41834220

41844221
if (self->proto >= 2) {
41854222
char header[2];
@@ -4304,6 +4341,7 @@ Pickler_dealloc(PicklerObject *self)
43044341
Py_XDECREF(self->pers_func);
43054342
Py_XDECREF(self->dispatch_table);
43064343
Py_XDECREF(self->fast_memo);
4344+
Py_XDECREF(self->reducer_override);
43074345

43084346
PyMemoTable_Del(self->memo);
43094347

@@ -4317,6 +4355,7 @@ Pickler_traverse(PicklerObject *self, visitproc visit, void *arg)
43174355
Py_VISIT(self->pers_func);
43184356
Py_VISIT(self->dispatch_table);
43194357
Py_VISIT(self->fast_memo);
4358+
Py_VISIT(self->reducer_override);
43204359
return 0;
43214360
}
43224361

@@ -4328,6 +4367,7 @@ Pickler_clear(PicklerObject *self)
43284367
Py_CLEAR(self->pers_func);
43294368
Py_CLEAR(self->dispatch_table);
43304369
Py_CLEAR(self->fast_memo);
4370+
Py_CLEAR(self->reducer_override);
43314371

43324372
if (self->memo != NULL) {
43334373
PyMemoTable *memo = self->memo;

0 commit comments

Comments
 (0)