Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ close to the data.
Among other things, `cloudpickle` supports pickling for lambda expressions,
functions and classes defined interactively in the `__main__` module.

`cloudpickle` uses `pickle.HIGHEST_PROTOCOL` by default: it is meant to
send objects between processes running the same version of Python. It is
discouraged to use `cloudpickle` for long-term storage.

Installation
------------
Expand Down
41 changes: 35 additions & 6 deletions cloudpickle/cloudpickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,12 @@
import weakref


# cloudpickle is meant for inter process communication: we expect all
# communicating processes to run the same Python version hence we favor
# communication speed over compatibility:
DEFAULT_PROTOCOL = pickle.HIGHEST_PROTOCOL


if sys.version < '3':
from pickle import Pickler
try:
Expand Down Expand Up @@ -248,7 +254,9 @@ class CloudPickler(Pickler):
dispatch = Pickler.dispatch.copy()

def __init__(self, file, protocol=None):
Pickler.__init__(self, file, protocol)
if protocol is None:
protocol = DEFAULT_PROTOCOL
Pickler.__init__(self, file, protocol=protocol)
# set of modules to unpickle
self.modules = set()
# map ids to dictionary. used to ensure that functions can share global env
Expand Down Expand Up @@ -828,32 +836,52 @@ def is_tornado_coroutine(func):
return False
return gen.is_coroutine_function(func)


def _rebuild_tornado_coroutine(func):
from tornado import gen
return gen.coroutine(func)


# Shorthands for legacy support

def dump(obj, file, protocol=2):
CloudPickler(file, protocol).dump(obj)
def dump(obj, file, protocol=None):
"""Serialize obj as bytes streamed into file

protocol defaults to cloudpickle.DEFAULT_PROTOCOL which is an alias to
pickle.HIGHEST_PROTOCOL. This setting favors maximum communication speed
between processes running the same Python version.

Set protocol=pickle.DEFAULT_PROTOCOL instead if you need to ensure
compatibility with older versions of Python.
"""
CloudPickler(file, protocol=protocol).dump(obj)


def dumps(obj, protocol=2):
def dumps(obj, protocol=None):
"""Serialize obj as a string of bytes allocated in memory

protocol defaults to cloudpickle.DEFAULT_PROTOCOL which is an alias to
pickle.HIGHEST_PROTOCOL. This setting favors maximum communication speed
between processes running the same Python version.

Set protocol=pickle.DEFAULT_PROTOCOL instead if you need to ensure
compatibility with older versions of Python.
"""
file = StringIO()
try:
cp = CloudPickler(file,protocol)
cp = CloudPickler(file, protocol=protocol)
cp.dump(obj)
return file.getvalue()
finally:
file.close()


# including pickles unloading functions in this namespace
load = pickle.load
loads = pickle.loads


#hack for __import__ not working as desired
# hack for __import__ not working as desired
def subimport(name):
__import__(name)
return sys.modules[name]
Expand All @@ -865,6 +893,7 @@ def dynamic_subimport(name, vars):
sys.modules[name] = mod
return mod


# restores function attributes
def _restore_attr(obj, attr):
for key, val in attr.items():
Expand Down
95 changes: 56 additions & 39 deletions tests/cloudpickle_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,11 @@
HAVE_WEAKSET = hasattr(weakref, 'WeakSet')


def pickle_depickle(obj):
def pickle_depickle(obj, protocol=cloudpickle.DEFAULT_PROTOCOL):
"""Helper function to test whether object pickled with cloudpickle can be
depickled with pickle
"""
return pickle.loads(cloudpickle.dumps(obj))
return pickle.loads(cloudpickle.dumps(obj, protocol=protocol))


class CloudPicklerTest(unittest.TestCase):
Expand All @@ -64,15 +64,17 @@ def setUp(self):

class CloudPickleTest(unittest.TestCase):

protocol = cloudpickle.DEFAULT_PROTOCOL

def test_itemgetter(self):
d = range(10)
getter = itemgetter(1)

getter2 = pickle_depickle(getter)
getter2 = pickle_depickle(getter, protocol=self.protocol)
self.assertEqual(getter(d), getter2(d))

getter = itemgetter(0, 3)
getter2 = pickle_depickle(getter)
getter2 = pickle_depickle(getter, protocol=self.protocol)
self.assertEqual(getter(d), getter2(d))

def test_attrgetter(self):
Expand All @@ -81,18 +83,18 @@ def __getattr__(self, item):
return item
d = C()
getter = attrgetter("a")
getter2 = pickle_depickle(getter)
getter2 = pickle_depickle(getter, protocol=self.protocol)
self.assertEqual(getter(d), getter2(d))
getter = attrgetter("a", "b")
getter2 = pickle_depickle(getter)
getter2 = pickle_depickle(getter, protocol=self.protocol)
self.assertEqual(getter(d), getter2(d))

d.e = C()
getter = attrgetter("e.a")
getter2 = pickle_depickle(getter)
getter2 = pickle_depickle(getter, protocol=self.protocol)
self.assertEqual(getter(d), getter2(d))
getter = attrgetter("e.a", "e.b")
getter2 = pickle_depickle(getter)
getter2 = pickle_depickle(getter, protocol=self.protocol)
self.assertEqual(getter(d), getter2(d))

# Regression test for SPARK-3415
Expand Down Expand Up @@ -124,15 +126,18 @@ def foo():
def test_buffer(self):
try:
buffer_obj = buffer("Hello")
self.assertEqual(pickle_depickle(buffer_obj), str(buffer_obj))
buffer_clone = pickle_depickle(buffer_obj, protocol=self.protocol)
self.assertEqual(buffer_clone, str(buffer_obj))
buffer_obj = buffer("Hello", 2, 3)
self.assertEqual(pickle_depickle(buffer_obj), str(buffer_obj))
buffer_clone = pickle_depickle(buffer_obj, protocol=self.protocol)
self.assertEqual(buffer_clone, str(buffer_obj))
except NameError: # Python 3 does no longer support buffers
pass

def test_memoryview(self):
buffer_obj = memoryview(b"Hello")
self.assertEqual(pickle_depickle(buffer_obj), buffer_obj.tobytes())
self.assertEqual(pickle_depickle(buffer_obj, protocol=self.protocol),
buffer_obj.tobytes())

def test_lambda(self):
self.assertEqual(pickle_depickle(lambda: 1)(), 1)
Expand All @@ -141,7 +146,7 @@ def test_nested_lambdas(self):
a, b = 1, 2
f1 = lambda x: x + a
f2 = lambda x: f1(x) // b
self.assertEqual(pickle_depickle(f2)(1), 1)
self.assertEqual(pickle_depickle(f2, protocol=self.protocol)(1), 1)

def test_recursive_closure(self):
def f1():
Expand Down Expand Up @@ -170,7 +175,7 @@ def f():
msg='f actually has closure cells!',
)

g = pickle_depickle(f)
g = pickle_depickle(f, protocol=self.protocol)

self.assertTrue(
g.__closure__ is None,
Expand All @@ -191,7 +196,7 @@ def g():
with pytest.raises(NameError):
g1()

g2 = pickle_depickle(g1)
g2 = pickle_depickle(g1, protocol=self.protocol)
with pytest.raises(NameError):
g2()

Expand Down Expand Up @@ -221,7 +226,7 @@ def method(self):
self.assertEqual(Derived().method(), 2)

# Pickle and unpickle the class.
UnpickledDerived = pickle_depickle(Derived)
UnpickledDerived = pickle_depickle(Derived, protocol=self.protocol)
self.assertEqual(UnpickledDerived().method(), 2)

# We have special logic for handling __doc__ because it's a readonly
Expand All @@ -230,7 +235,7 @@ def method(self):

# Pickle and unpickle an instance.
orig_d = Derived()
d = pickle_depickle(orig_d)
d = pickle_depickle(orig_d, protocol=self.protocol)
self.assertEqual(d.method(), 2)

def test_cycle_in_classdict_globals(self):
Expand All @@ -243,7 +248,7 @@ def it_works(self):
C.C_again = C
C.instance_of_C = C()

depickled_C = pickle_depickle(C)
depickled_C = pickle_depickle(C, protocol=self.protocol)
depickled_instance = pickle_depickle(C())

# Test instance of depickled class.
Expand All @@ -262,8 +267,8 @@ def some_function(x, y):
return (x + y) / LOCAL_CONSTANT

# pickle the function definition
self.assertEqual(pickle_depickle(some_function)(41, 1), 1)
self.assertEqual(pickle_depickle(some_function)(81, 3), 2)
self.assertEqual(pickle_depickle(some_function, protocol=self.protocol)(41, 1), 1)
self.assertEqual(pickle_depickle(some_function, protocol=self.protocol)(81, 3), 2)

hidden_constant = lambda: LOCAL_CONSTANT

Expand All @@ -279,10 +284,11 @@ def some_method(self, x):
return self.one() + some_function(x, 1) + self.value

# pickle the class definition
self.assertEqual(pickle_depickle(SomeClass)(1).one(), 1)
self.assertEqual(pickle_depickle(SomeClass)(5).some_method(41), 7)
new_class = subprocess_pickle_echo(SomeClass)
self.assertEqual(new_class(5).some_method(41), 7)
clone_class = pickle_depickle(SomeClass, protocol=self.protocol)
self.assertEqual(clone_class(1).one(), 1)
self.assertEqual(clone_class(5).some_method(41), 7)
clone_class = subprocess_pickle_echo(SomeClass)
self.assertEqual(clone_class(5).some_method(41), 7)

# pickle the class instances
self.assertEqual(pickle_depickle(SomeClass(1)).one(), 1)
Expand All @@ -298,7 +304,8 @@ def some_method(self, x):

def test_partial(self):
partial_obj = functools.partial(min, 1)
self.assertEqual(pickle_depickle(partial_obj)(4), 1)
partial_clone = pickle_depickle(partial_obj, protocol=self.protocol)
self.assertEqual(partial_clone(4), 1)

@pytest.mark.skipif(platform.python_implementation() == 'PyPy',
reason="Skip numpy and scipy tests on PyPy")
Expand Down Expand Up @@ -346,7 +353,7 @@ def some_generator(cnt):
for i in range(cnt):
yield i

gen2 = pickle_depickle(some_generator)
gen2 = pickle_depickle(some_generator, protocol=self.protocol)

assert type(gen2(3)) == type(some_generator(3))
assert list(gen2(3)) == list(range(3))
Expand All @@ -363,8 +370,8 @@ def test_cm(cls):
sm = A.__dict__["test_sm"]
cm = A.__dict__["test_cm"]

A.test_sm = pickle_depickle(sm)
A.test_cm = pickle_depickle(cm)
A.test_sm = pickle_depickle(sm, protocol=self.protocol)
A.test_cm = pickle_depickle(cm, protocol=self.protocol)

self.assertEqual(A.test_sm(), "sm")
self.assertEqual(A.test_cm(), "cm")
Expand All @@ -385,7 +392,8 @@ def f(self, x):
# self.assertEqual(g(F(), 1), 2) # still fails

def test_module(self):
self.assertEqual(pickle, pickle_depickle(pickle))
pickle_clone = pickle_depickle(pickle, protocol=self.protocol)
self.assertEqual(pickle, pickle_clone)

def test_dynamic_module(self):
mod = imp.new_module('mod')
Expand All @@ -395,7 +403,7 @@ def f(y):
return x + y
'''
exec(textwrap.dedent(code), mod.__dict__)
mod2 = pickle_depickle(mod)
mod2 = pickle_depickle(mod, protocol=self.protocol)
self.assertEqual(mod.x, mod2.x)
self.assertEqual(mod.f(5), mod2.f(5))

Expand All @@ -415,18 +423,20 @@ def test_find_module(self):
_find_module('valid_module')

def test_Ellipsis(self):
self.assertEqual(Ellipsis, pickle_depickle(Ellipsis))
self.assertEqual(Ellipsis,
pickle_depickle(Ellipsis, protocol=self.protocol))

def test_NotImplemented(self):
self.assertEqual(NotImplemented, pickle_depickle(NotImplemented))
ExcClone = pickle_depickle(NotImplemented, protocol=self.protocol)
self.assertEqual(NotImplemented, ExcClone)

def test_builtin_function_without_module(self):
on = object.__new__
on_depickled = pickle_depickle(on)
on_depickled = pickle_depickle(on, protocol=self.protocol)
self.assertEqual(type(on_depickled(object)), type(object()))

fi = itertools.chain.from_iterable
fi_depickled = pickle_depickle(fi)
fi_depickled = pickle_depickle(fi, protocol=self.protocol)
self.assertEqual(list(fi([[1, 2], [3, 4]])), [1, 2, 3, 4])

@pytest.mark.skipif(tornado is None,
Expand Down Expand Up @@ -580,7 +590,7 @@ def test_cell_manipulation(self):

def test_logger(self):
logger = logging.getLogger('cloudpickle.dummy_test_logger')
pickled = pickle_depickle(logger)
pickled = pickle_depickle(logger, protocol=self.protocol)
self.assertTrue(pickled is logger, (pickled, logger))

dumped = cloudpickle.dumps(logger)
Expand Down Expand Up @@ -614,8 +624,9 @@ class ConcreteClass(AbstractClass):
def foo(self):
return 'it works!'

depickled_base = pickle_depickle(AbstractClass)
depickled_class = pickle_depickle(ConcreteClass)
depickled_base = pickle_depickle(AbstractClass, protocol=self.protocol)
depickled_class = pickle_depickle(ConcreteClass,
protocol=self.protocol)
depickled_instance = pickle_depickle(ConcreteClass())

self.assertEqual(depickled_class().foo(), 'it works!')
Expand Down Expand Up @@ -643,7 +654,7 @@ def __init__(self, x):
obj1, obj2, obj3 = SomeClass(1), SomeClass(2), SomeClass(3)

things = [weakref.WeakSet([obj1, obj2]), obj1, obj2, obj3]
result = pickle_depickle(things)
result = pickle_depickle(things, protocol=self.protocol)

weakset, depickled1, depickled2, depickled3 = result

Expand Down Expand Up @@ -678,7 +689,8 @@ def foo():
self.assertEqual(pickle_depickle(Foo()).foo(), "it works!")

# Test whichmodule in save_function.
self.assertEqual(pickle_depickle(foo)(), "it works!")
self.assertEqual(pickle_depickle(foo, protocol=self.protocol)(),
"it works!")
finally:
sys.modules.pop("_fake_module", None)

Expand All @@ -701,7 +713,7 @@ def f():

def test_function_module_name(self):
func = lambda x: x
self.assertEqual(pickle_depickle(func).__module__, func.__module__)
self.assertEqual(pickle_depickle(func, protocol=self.protocol).__module__, func.__module__)

def test_function_qualname(self):
def func(x):
Expand Down Expand Up @@ -764,5 +776,10 @@ def test_function_pickle_compat_0_4_1(self):
self.assertEquals(42, cloudpickle.loads(pickled)(42))


class Protocol2CloudPickleTest(CloudPickleTest):

protocol = 2


if __name__ == '__main__':
unittest.main()
Loading