Skip to content

Commit

Permalink
BUG: Fix crash when pickling dynamic class cycles.
Browse files Browse the repository at this point in the history
Fixes a bug where we would fail to pickle a class created inside a
function if that class participated in a cycle with its own __dict__.

Such cycles occur, for example, when a class defines a method that makes
a Python 2-style super call, because we have a cycle from class ->
__dict__ -> function -> __closure__ -> class.

The fix for this is to use the same technique we use to
dynamically-created functions: we first pickle an empty "skeleton
class", which we memoize before pickling the rest of the class'
__dict__. We then invoke a reduce function that re-attaches the class'
attributes from the __dict__.
  • Loading branch information
Scott Sanderson committed Jun 17, 2017
1 parent c89dc9d commit 9fa3692
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 12 deletions.
94 changes: 82 additions & 12 deletions cloudpickle/cloudpickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,71 @@ def _save_subimports(self, code, top_level_dependencies):
# then discards the reference to it
self.write(pickle.POP)

def save_dynamic_class(self, obj):
"""
Save a class that can't be stored as module global.
This method is used to serialize classes that are defined inside
functions, or that otherwise can't be serialized as attribute lookups
from global modules.
"""
clsdict = dict(obj.__dict__) # copy dict proxy to a dict
if not isinstance(clsdict.get('__dict__', None), property):
# don't extract dict that are properties
clsdict.pop('__dict__', None)
clsdict.pop('__weakref__', None)

# hack as __new__ is stored differently in the __dict__
new_override = clsdict.get('__new__', None)
if new_override:
clsdict['__new__'] = obj.__new__

save = self.save
write = self.write

# We write pickle instructions explicitly here to handle the
# possibility that the type object participates in a cycle with its own
# __dict__. We first write an empty "skeleton" version of the class and
# memoize it before writing the class' __dict__ itself. We then write
# instructions to "rehydrate" the skeleton class by restoring the
# attributes from the __dict__.
#
# A type can appear in a cycle with it's __dict__ if an instance of the
# type appears in the type's __dict__ (which happens for the stdlib
# Enum class), or if the type defines methods that close over the name
# of the type, (which is common for Python 2-style super() calls).

# Push the rehydration function.
save(_rehydrate_skeleton_class)

# Mark the start of the args for the rehydration function.
write(pickle.MARK)

# On PyPy, __doc__ is a readonly attribute, so we need to include it in
# the initial skeleton class. This is safe because we know that the
# doc can't participate in a cycle with the original class.
doc_dict = {'__doc__': clsdict.pop('__doc__', None)}

# Create and memoize an empty class with obj's name and bases.
save(type(obj))
save((
obj.__name__,
obj.__bases__,
doc_dict,
))
write(pickle.REDUCE)
self.memoize(obj)

# Now save the rest of obj's __dict__. Any references to obj
# encountered while saving will point to the skeleton class.
save(clsdict)

# Write a tuple of (skeleton_class, clsdict).
write(pickle.TUPLE)

# Call _rehydrate_skeleton_class(skeleton_class, clsdict)
write(pickle.REDUCE)

def save_function_tuple(self, func):
""" Pickles an actual func object.
Expand Down Expand Up @@ -513,6 +578,12 @@ def save_builtin_function(self, obj):
dispatch[types.BuiltinFunctionType] = save_builtin_function

def save_global(self, obj, name=None, pack=struct.pack):
"""
Save a "global".
The name of this method is somewhat misleading: all types get
dispatched here.
"""
if obj.__module__ == "__builtin__" or obj.__module__ == "builtins":
if obj in _BUILTIN_TYPE_NAMES:
return self.save_reduce(_builtin_type, (_BUILTIN_TYPE_NAMES[obj],), obj=obj)
Expand All @@ -536,18 +607,7 @@ def save_global(self, obj, name=None, pack=struct.pack):

typ = type(obj)
if typ is not obj and isinstance(obj, (type, types.ClassType)):
d = dict(obj.__dict__) # copy dict proxy to a dict
if not isinstance(d.get('__dict__', None), property):
# don't extract dict that are properties
d.pop('__dict__', None)
d.pop('__weakref__', None)

# hack as __new__ is stored differently in the __dict__
new_override = d.get('__new__', None)
if new_override:
d['__new__'] = obj.__new__

self.save_reduce(typ, (obj.__name__, obj.__bases__, d), obj=obj)
self.save_dynamic_class(obj)
else:
raise pickle.PicklingError("Can't pickle %r" % obj)

Expand Down Expand Up @@ -986,6 +1046,16 @@ def _make_skel_func(code, cell_count, base_globals=None):
return types.FunctionType(code, base_globals, None, None, closure)


def _rehydrate_skeleton_class(skeleton_class, class_dict):
"""Put attributes from `class_dict` back on `skeleton_class`.
See CloudPickler.save_dynamic_class for more info.
"""
for attrname, attr in class_dict.items():
setattr(skeleton_class, attrname, attr)
return skeleton_class


def _find_module(mod_name):
"""
Iterate over each part instead of calling imp.find_module directly.
Expand Down
42 changes: 42 additions & 0 deletions tests/cloudpickle_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,48 @@ def g():
g = pickle_depickle(f())
self.assertEqual(g(), 2)

def test_dynamically_generated_class_that_uses_super(self):

class Base(object):
def method(self):
return 1

class Derived(Base):
"Derived Docstring"
def method(self):
return super(Derived, self).method() + 1

self.assertEqual(Derived().method(), 2)

# Pickle and unpickle the class.
UnpickledDerived = pickle_depickle(Derived)
self.assertEqual(UnpickledDerived().method(), 2)
self.assertEqual(UnpickledDerived.__doc__, "Derived Docstring")

# Pickle and unpickle an instance.
orig_d = Derived()
d = pickle_depickle(orig_d)
self.assertEqual(d.method(), 2)

def test_cycle_in_classdict_globals(self):

class C(object):

def it_works(self):
return "woohoo!"

C.C_again = C
C.instance_of_C = C()

depickled_C = pickle_depickle(C)
depickled_instance = pickle_depickle(C())

# Test instance of depickled class.
self.assertEqual(depickled_C().it_works(), "woohoo!")
self.assertEqual(depickled_C.C_again().it_works(), "woohoo!")
self.assertEqual(depickled_C.instance_of_C.it_works(), "woohoo!")
self.assertEqual(depickled_instance.it_works(), "woohoo!")

@pytest.mark.skipif(sys.version_info >= (3, 4)
and sys.version_info < (3, 4, 3),
reason="subprocess has a bug in 3.4.0 to 3.4.2")
Expand Down

0 comments on commit 9fa3692

Please sign in to comment.