Skip to content

Commit

Permalink
Add custom logic for pickling dynamic imports.
Browse files Browse the repository at this point in the history
Add test cases, special case Ellipsis and NotImplemented.
Use custom logic in lieu of imp.find_module to properly follow subimports. For example sklearn.tree was spuriously treated as a dynamic module.
  • Loading branch information
rodrigofarnhamsc committed Feb 1, 2016
1 parent e47f29f commit 340d175
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 2 deletions.
48 changes: 46 additions & 2 deletions cloudpickle/cloudpickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@

import operator
import io
import imp
import pickle
import struct
import sys
Expand Down Expand Up @@ -134,8 +135,19 @@ def save_module(self, obj):
"""
Save a module as an import
"""
mod_name = obj.__name__
# If module is successfully found then it is not a dynamically created module
try:
_find_module(mod_name)
is_dynamic = False
except ImportError:
is_dynamic = True

self.modules.add(obj)
self.save_reduce(subimport, (obj.__name__,), obj=obj)
if is_dynamic:
self.save_reduce(dynamic_subimport, (obj.__name__, vars(obj)), obj=obj)
else:
self.save_reduce(subimport, (obj.__name__,), obj=obj)
dispatch[types.ModuleType] = save_module

def save_codeobject(self, obj):
Expand Down Expand Up @@ -313,7 +325,7 @@ def extract_func_data(self, func):
return (code, f_globals, defaults, closure, dct, base_globals)

def save_builtin_function(self, obj):
if obj.__module__ is "__builtin__":
if obj.__module__ == "__builtin__":
return self.save_global(obj)
return self.save_function(obj)
dispatch[types.BuiltinFunctionType] = save_builtin_function
Expand Down Expand Up @@ -584,11 +596,20 @@ def save_file(self, obj):
self.save(retval)
self.memoize(obj)

def save_ellipsis(self, obj):
self.save_reduce(_gen_ellipsis, ())

def save_not_implemented(self, obj):
self.save_reduce(_gen_not_implemented, ())

if PY3:
dispatch[io.TextIOWrapper] = save_file
else:
dispatch[file] = save_file

dispatch[type(Ellipsis)] = save_ellipsis
dispatch[type(NotImplemented)] = save_not_implemented

"""Special functions for Add-on libraries"""
def inject_addons(self):
"""Plug in system. Register additional pickling functions if modules already loaded"""
Expand Down Expand Up @@ -620,6 +641,12 @@ def subimport(name):
return sys.modules[name]


def dynamic_subimport(name, vars):
mod = imp.new_module(name)
mod.__dict__.update(vars)
sys.modules[name] = mod
return mod

# restores function attributes
def _restore_attr(obj, attr):
for key, val in attr.items():
Expand Down Expand Up @@ -663,6 +690,11 @@ def _genpartial(func, args, kwds):
kwds = {}
return partial(func, *args, **kwds)

def _gen_ellipsis():
return Ellipsis

def _gen_not_implemented():
return NotImplemented

def _fill_function(func, globals, defaults, dict):
""" Fills in the rest of function data into the skeleton function object
Expand Down Expand Up @@ -698,6 +730,18 @@ def _make_skel_func(code, closures, base_globals = None):
None, None, closure)


def _find_module(mod_name):
"""
Iterate over each part instead of calling imp.find_module directly.
This function is able to find submodules (e.g. sickit.tree)
"""
path = None
for part in mod_name.split('.'):
if path is not None:
path = [path]
file, path, description = imp.find_module(part, path)
return file, path, description

"""Constructors for 3rd party libraries
Note: These can never be renamed due to client compatibility issues"""

Expand Down
26 changes: 26 additions & 0 deletions tests/cloudpickle_test.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from __future__ import division
import imp
import unittest
import pytest
import pickle
import sys
import functools
import platform
import textwrap

try:
# try importing numpy and scipy. These are not hard dependencies and
Expand Down Expand Up @@ -252,6 +254,30 @@ def f(self, x):
self.assertEqual(g.im_class.__name__, F.f.im_class.__name__)
# self.assertEqual(g(F(), 1), 2) # still fails

def test_module(self):
self.assertEqual(pickle, pickle_depickle(pickle))

def test_dynamic_module(self):
mod = imp.new_module('mod')
code = '''
x = 1
def f(y):
return x + y
'''
exec(textwrap.dedent(code), mod.__dict__)
mod2 = pickle_depickle(mod)
self.assertEqual(mod.x, mod2.x)
self.assertEqual(mod.f(5), mod2.f(5))

# Test dynamic modules when imported back are singletons
mod1, mod2 = pickle_depickle([mod, mod])
self.assertEqual(id(mod1), id(mod2))

def test_Ellipsis(self):
self.assertEqual(Ellipsis, pickle_depickle(Ellipsis))

def test_NotImplemented(self):
self.assertEqual(NotImplemented, pickle_depickle(NotImplemented))

if __name__ == '__main__':
unittest.main()

0 comments on commit 340d175

Please sign in to comment.