-
Notifications
You must be signed in to change notification settings - Fork 167
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Import submodules accessed by pickled functions #80
Changes from 5 commits
b31d000
0657c6d
0922273
11709cb
6854eaa
ee40673
3319f50
4b9bfbc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -238,7 +238,7 @@ def save_function(self, obj, name=None): | |
# a builtin_function_or_method which comes in as an attribute of some | ||
# object (e.g., object.__new__, itertools.chain.from_iterable) will end | ||
# up with modname "__main__" and so end up here. But these functions | ||
# have no __code__ attribute in CPython, so the handling for | ||
# have no __code__ attribute in CPython, so the handling for | ||
# user-defined functions below will fail. | ||
# So we pickle them here using save_reduce; have to do it differently | ||
# for different python versions. | ||
|
@@ -282,6 +282,21 @@ def save_function(self, obj, name=None): | |
self.memoize(obj) | ||
dispatch[types.FunctionType] = save_function | ||
|
||
def _save_subimports(self, code, top_level_dependencies): | ||
"""Ensure de-pickler imports any package sub-modules needed by the function""" | ||
# check if any known dependency is an imported package | ||
for x in top_level_dependencies: | ||
if isinstance(x, types.ModuleType) and x.__package__: | ||
# check if the package has any currently loaded sub-imports | ||
prefix = x.__name__ + '.' | ||
for name, module in sys.modules.items(): | ||
if name.startswith(prefix): | ||
# check whether the function can address the sub-module | ||
tokens = set(name[len(prefix):].split('.')) | ||
if not tokens - set(code.co_names): | ||
self.save(module) # ensure the unpickler executes import of this submodule | ||
self.write(pickle.POP) # then discard the reference to it | ||
|
||
def save_function_tuple(self, func): | ||
""" Pickles an actual func object. | ||
|
||
|
@@ -307,6 +322,8 @@ def save_function_tuple(self, func): | |
save(_fill_function) # skeleton function updater | ||
write(pickle.MARK) # beginning of tuple that _fill_function expects | ||
|
||
self._save_subimports(code, set(f_globals.values()) | set(closure)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This line creates a regression: #86 |
||
|
||
# create a skeleton function object and memoize it | ||
save(_make_skel_func) | ||
save((code, closure, base_globals)) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,6 +9,8 @@ | |
import itertools | ||
import platform | ||
import textwrap | ||
import base64 | ||
import subprocess | ||
|
||
try: | ||
# try importing numpy and scipy. These are not hard dependencies and | ||
|
@@ -360,6 +362,70 @@ def f(): | |
self.assertTrue(f2 is f3) | ||
self.assertEqual(f2(), res) | ||
|
||
def test_submodule(self): | ||
# Function that refers (by attribute) to a sub-module of a package. | ||
|
||
# Choose any module NOT imported by __init__ of its parent package | ||
# examples in standard library include: | ||
# - http.cookies, unittest.mock, curses.textpad, xml.etree.ElementTree | ||
|
||
global xml # imitate performing this import at top of file | ||
import xml.etree.ElementTree | ||
def example(): | ||
x = xml.etree.ElementTree.Comment # potential AttributeError | ||
|
||
s = cloudpickle.dumps(example) | ||
|
||
# refresh the environment, i.e., unimport the dependency | ||
del xml | ||
for item in list(sys.modules): | ||
if item.split('.')[0] == 'xml': | ||
del sys.modules[item] | ||
|
||
# deserialise | ||
f = pickle.loads(s) | ||
f() # perform test for error | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I assume that your fix should also handle a case such as follows: global etree
import xml.etree.ElementTree as etree
def example():
x = etree.Comment
... Maybe it would still be worth adding a test to make sure that this import pattern is also supported. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note, such a test would have passed even prior to this pull request. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. alright but thanks for having added the test as a non-regression test anyway :) |
||
|
||
def test_submodule_closure(self): | ||
# Same as test_submodule except the package is not a global | ||
def scope(): | ||
import xml.etree.ElementTree | ||
def example(): | ||
x = xml.etree.ElementTree.Comment # potential AttributeError | ||
return example | ||
example = scope() | ||
|
||
s = cloudpickle.dumps(example) | ||
|
||
# refresh the environment (unimport dependency) | ||
for item in list(sys.modules): | ||
if item.split('.')[0] == 'xml': | ||
del sys.modules[item] | ||
|
||
f = cloudpickle.loads(s) | ||
f() # test | ||
|
||
def test_multiprocess(self): | ||
# running a function pickled by another process (a la dask.distributed) | ||
def scope(): | ||
import curses.textpad | ||
def example(): | ||
x = xml.etree.ElementTree.Comment | ||
x = curses.textpad.Textbox | ||
return example | ||
global xml | ||
import xml.etree.ElementTree | ||
example = scope() | ||
|
||
s = cloudpickle.dumps(example) | ||
|
||
# choose "subprocess" rather than "multiprocessing" because the latter | ||
# library uses fork to preserve the parent environment. | ||
command = ("import pickle, base64; " | ||
"pickle.loads(base64.b32decode('" + | ||
base64.b32encode(s).decode('ascii') + | ||
"'))()") | ||
assert not subprocess.call([sys.executable, '-c', command]) | ||
|
||
if __name__ == '__main__': | ||
unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cosmetic: could you please move the comments to be on their own lines (before the matching statements) so as to avoid long (80+ columns lines).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note, this file does not adopt that convention elsewhere.