Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Import submodules accessed by pickled functions #80

Merged
merged 8 commits into from
Feb 24, 2017
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion cloudpickle/cloudpickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def save_function(self, obj, name=None):
# a builtin_function_or_method which comes in as an attribute of some
# object (e.g., object.__new__, itertools.chain.from_iterable) will end
# up with modname "__main__" and so end up here. But these functions
# have no __code__ attribute in CPython, so the handling for
# have no __code__ attribute in CPython, so the handling for
# user-defined functions below will fail.
# So we pickle them here using save_reduce; have to do it differently
# for different python versions.
Expand Down Expand Up @@ -282,6 +282,21 @@ def save_function(self, obj, name=None):
self.memoize(obj)
dispatch[types.FunctionType] = save_function

def _save_subimports(self, code, top_level_dependencies):
"""Ensure de-pickler imports any package sub-modules needed by the function"""
# check if any known dependency is an imported package
for x in top_level_dependencies:
if isinstance(x, types.ModuleType) and x.__package__:
# check if the package has any currently loaded sub-imports
prefix = x.__name__ + '.'
for name, module in sys.modules.items():
if name.startswith(prefix):
# check whether the function can address the sub-module
tokens = set(name[len(prefix):].split('.'))
if not tokens - set(code.co_names):
self.save(module) # ensure the unpickler executes import of this submodule
self.write(pickle.POP) # then discard the reference to it
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cosmetic: could you please move the comments to be on their own lines (before the matching statements) so as to avoid long (80+ columns lines).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note, this file does not adopt that convention elsewhere.


def save_function_tuple(self, func):
""" Pickles an actual func object.

Expand All @@ -307,6 +322,8 @@ def save_function_tuple(self, func):
save(_fill_function) # skeleton function updater
write(pickle.MARK) # beginning of tuple that _fill_function expects

self._save_subimports(code, set(f_globals.values()) | set(closure))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line creates a regression: #86


# create a skeleton function object and memoize it
save(_make_skel_func)
save((code, closure, base_globals))
Expand Down
66 changes: 66 additions & 0 deletions tests/cloudpickle_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import itertools
import platform
import textwrap
import base64
import subprocess

try:
# try importing numpy and scipy. These are not hard dependencies and
Expand Down Expand Up @@ -360,6 +362,70 @@ def f():
self.assertTrue(f2 is f3)
self.assertEqual(f2(), res)

def test_submodule(self):
# Function that refers (by attribute) to a sub-module of a package.

# Choose any module NOT imported by __init__ of its parent package
# examples in standard library include:
# - http.cookies, unittest.mock, curses.textpad, xml.etree.ElementTree

global xml # imitate performing this import at top of file
import xml.etree.ElementTree
def example():
x = xml.etree.ElementTree.Comment # potential AttributeError

s = cloudpickle.dumps(example)

# refresh the environment, i.e., unimport the dependency
del xml
for item in list(sys.modules):
if item.split('.')[0] == 'xml':
del sys.modules[item]

# deserialise
f = pickle.loads(s)
f() # perform test for error
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume that your fix should also handle a case such as follows:

global etree
import xml.etree.ElementTree as etree
def example():
    x = etree.Comment
...

Maybe it would still be worth adding a test to make sure that this import pattern is also supported.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note, such a test would have passed even prior to this pull request.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

alright but thanks for having added the test as a non-regression test anyway :)


def test_submodule_closure(self):
# Same as test_submodule except the package is not a global
def scope():
import xml.etree.ElementTree
def example():
x = xml.etree.ElementTree.Comment # potential AttributeError
return example
example = scope()

s = cloudpickle.dumps(example)

# refresh the environment (unimport dependency)
for item in list(sys.modules):
if item.split('.')[0] == 'xml':
del sys.modules[item]

f = cloudpickle.loads(s)
f() # test

def test_multiprocess(self):
# running a function pickled by another process (a la dask.distributed)
def scope():
import curses.textpad
def example():
x = xml.etree.ElementTree.Comment
x = curses.textpad.Textbox
return example
global xml
import xml.etree.ElementTree
example = scope()

s = cloudpickle.dumps(example)

# choose "subprocess" rather than "multiprocessing" because the latter
# library uses fork to preserve the parent environment.
command = ("import pickle, base64; "
"pickle.loads(base64.b32decode('" +
base64.b32encode(s).decode('ascii') +
"'))()")
assert not subprocess.call([sys.executable, '-c', command])

if __name__ == '__main__':
unittest.main()