diff --git a/python/pyspark/cloudpickle.py b/python/pyspark/cloudpickle.py index 389bee7eee6e9..dc6eef32ce927 100644 --- a/python/pyspark/cloudpickle.py +++ b/python/pyspark/cloudpickle.py @@ -9,10 +9,10 @@ It does not include an unpickler, as standard python unpickling suffices. This module was extracted from the `cloud` package, developed by `PiCloud, Inc. -`_. +`_. Copyright (c) 2012, Regents of the University of California. -Copyright (c) 2009 `PiCloud, Inc. `_. +Copyright (c) 2009 `PiCloud, Inc. `_. All rights reserved. Redistribution and use in source and binary forms, with or without @@ -42,21 +42,21 @@ """ from __future__ import print_function -import operator -import opcode -import os +import dis +from functools import partial +import imp import io +import itertools +import logging +import opcode +import operator import pickle import struct import sys -import types -from functools import partial -import itertools -import dis import traceback +import types import weakref -from pyspark.util import _exception_message if sys.version < '3': from pickle import Pickler @@ -71,6 +71,92 @@ from io import BytesIO as StringIO PY3 = True + +def _make_cell_set_template_code(): + """Get the Python compiler to emit LOAD_FAST(arg); STORE_DEREF + + Notes + ----- + In Python 3, we could use an easier function: + + .. code-block:: python + + def f(): + cell = None + + def _stub(value): + nonlocal cell + cell = value + + return _stub + + _cell_set_template_code = f() + + This function is _only_ a LOAD_FAST(arg); STORE_DEREF, but that is + invalid syntax on Python 2. If we use this function we also don't need + to do the weird freevars/cellvars swap below + """ + def inner(value): + lambda: cell # make ``cell`` a closure so that we get a STORE_DEREF + cell = value + + co = inner.__code__ + + # NOTE: we are marking the cell variable as a free variable intentionally + # so that we simulate an inner function instead of the outer function. This + # is what gives us the ``nonlocal`` behavior in a Python 2 compatible way. + if not PY3: + return types.CodeType( + co.co_argcount, + co.co_nlocals, + co.co_stacksize, + co.co_flags, + co.co_code, + co.co_consts, + co.co_names, + co.co_varnames, + co.co_filename, + co.co_name, + co.co_firstlineno, + co.co_lnotab, + co.co_cellvars, # this is the trickery + (), + ) + else: + return types.CodeType( + co.co_argcount, + co.co_kwonlyargcount, + co.co_nlocals, + co.co_stacksize, + co.co_flags, + co.co_code, + co.co_consts, + co.co_names, + co.co_varnames, + co.co_filename, + co.co_name, + co.co_firstlineno, + co.co_lnotab, + co.co_cellvars, # this is the trickery + (), + ) + + +_cell_set_template_code = _make_cell_set_template_code() + + +def cell_set(cell, value): + """Set the value of a closure cell. + """ + return types.FunctionType( + _cell_set_template_code, + {}, + '_cell_set_inner', + (), + (cell,), + )(value) + + #relevant opcodes STORE_GLOBAL = opcode.opmap['STORE_GLOBAL'] DELETE_GLOBAL = opcode.opmap['DELETE_GLOBAL'] @@ -150,16 +236,6 @@ def dump(self, obj): if 'recursion' in e.args[0]: msg = """Could not pickle object as excessively deep recursion required.""" raise pickle.PicklingError(msg) - except pickle.PickleError: - raise - except Exception as e: - emsg = _exception_message(e) - if "'i' format requires" in emsg: - msg = "Object too large to serialize: %s" % emsg - else: - msg = "Could not serialize object: %s: %s" % (e.__class__.__name__, emsg) - print_exec(sys.stderr) - raise pickle.PicklingError(msg) def save_memoryview(self, obj): """Fallback to save_string""" @@ -186,8 +262,22 @@ def save_module(self, obj): """ Save a module as an import """ + mod_name = obj.__name__ + # If module is successfully found then it is not a dynamically created module + if hasattr(obj, '__file__'): + is_dynamic = False + else: + try: + _find_module(mod_name) + is_dynamic = False + except ImportError: + is_dynamic = True + self.modules.add(obj) - self.save_reduce(subimport, (obj.__name__,), obj=obj) + if is_dynamic: + self.save_reduce(dynamic_subimport, (obj.__name__, vars(obj)), obj=obj) + else: + self.save_reduce(subimport, (obj.__name__,), obj=obj) dispatch[types.ModuleType] = save_module def save_codeobject(self, obj): @@ -220,12 +310,7 @@ def save_function(self, obj, name=None): if name is None: name = obj.__name__ - try: - # whichmodule() could fail, see - # https://bitbucket.org/gutworth/six/issues/63/importing-six-breaks-pickling - modname = pickle.whichmodule(obj, name) - except Exception: - modname = None + modname = pickle.whichmodule(obj, name) # print('which gives %s %s %s' % (modname, obj, name)) try: themodule = sys.modules[modname] @@ -241,11 +326,32 @@ def save_function(self, obj, name=None): if getattr(themodule, name, None) is obj: return self.save_global(obj, name) + # a builtin_function_or_method which comes in as an attribute of some + # object (e.g., object.__new__, itertools.chain.from_iterable) will end + # up with modname "__main__" and so end up here. But these functions + # have no __code__ attribute in CPython, so the handling for + # user-defined functions below will fail. + # So we pickle them here using save_reduce; have to do it differently + # for different python versions. + if not hasattr(obj, '__code__'): + if PY3: + if sys.version_info < (3, 4): + raise pickle.PicklingError("Can't pickle %r" % obj) + else: + rv = obj.__reduce_ex__(self.proto) + else: + if hasattr(obj, '__self__'): + rv = (getattr, (obj.__self__, name)) + else: + raise pickle.PicklingError("Can't pickle %r" % obj) + return Pickler.save_reduce(self, obj=obj, *rv) + # if func is lambda, def'ed at prompt, is in main, or is nested, then # we'll pickle the actual function object rather than simply saving a # reference (as is done in default pickler), via save_function_tuple. - if islambda(obj) or obj.__code__.co_filename == '' or themodule is None: - #print("save global", islambda(obj), obj.__code__.co_filename, modname, themodule) + if (islambda(obj) + or getattr(obj.__code__, 'co_filename', None) == '' + or themodule is None): self.save_function_tuple(obj) return else: @@ -267,6 +373,26 @@ def save_function(self, obj, name=None): self.memoize(obj) dispatch[types.FunctionType] = save_function + def _save_subimports(self, code, top_level_dependencies): + """ + Ensure de-pickler imports any package child-modules that + are needed by the function + """ + # check if any known dependency is an imported package + for x in top_level_dependencies: + if isinstance(x, types.ModuleType) and x.__package__: + # check if the package has any currently loaded sub-imports + prefix = x.__name__ + '.' + for name, module in sys.modules.items(): + if name.startswith(prefix): + # check whether the function can address the sub-module + tokens = set(name[len(prefix):].split('.')) + if not tokens - set(code.co_names): + # ensure unpickler executes this import + self.save(module) + # then discards the reference to it + self.write(pickle.POP) + def save_function_tuple(self, func): """ Pickles an actual func object. @@ -279,17 +405,31 @@ def save_function_tuple(self, func): safe, since this won't contain a ref to the func), and memoize it as soon as it's created. The other stuff can then be filled in later. """ + if is_tornado_coroutine(func): + self.save_reduce(_rebuild_tornado_coroutine, (func.__wrapped__,), + obj=func) + return + save = self.save write = self.write - code, f_globals, defaults, closure, dct, base_globals = self.extract_func_data(func) + code, f_globals, defaults, closure_values, dct, base_globals = self.extract_func_data(func) save(_fill_function) # skeleton function updater write(pickle.MARK) # beginning of tuple that _fill_function expects + self._save_subimports( + code, + itertools.chain(f_globals.values(), closure_values or ()), + ) + # create a skeleton function object and memoize it save(_make_skel_func) - save((code, closure, base_globals)) + save(( + code, + len(closure_values) if closure_values is not None else -1, + base_globals, + )) write(pickle.REDUCE) self.memoize(func) @@ -297,7 +437,7 @@ def save_function_tuple(self, func): save(f_globals) save(defaults) save(dct) - save(func.__module__) + save(closure_values) write(pickle.TUPLE) write(pickle.REDUCE) # applies _fill_function on the tuple @@ -335,7 +475,7 @@ def extract_code_globals(cls, co): def extract_func_data(self, func): """ Turn the function into a tuple of data necessary to recreate it: - code, globals, defaults, closure, dict + code, globals, defaults, closure_values, dict """ code = func.__code__ @@ -352,7 +492,11 @@ def extract_func_data(self, func): defaults = func.__defaults__ # process closure - closure = [c.cell_contents for c in func.__closure__] if func.__closure__ else [] + closure = ( + list(map(_get_cell_contents, func.__closure__)) + if func.__closure__ is not None + else None + ) # save the dict dct = func.__dict__ @@ -363,7 +507,7 @@ def extract_func_data(self, func): return (code, f_globals, defaults, closure, dct, base_globals) def save_builtin_function(self, obj): - if obj.__module__ is "__builtin__": + if obj.__module__ == "__builtin__": return self.save_global(obj) return self.save_function(obj) dispatch[types.BuiltinFunctionType] = save_builtin_function @@ -378,12 +522,7 @@ def save_global(self, obj, name=None, pack=struct.pack): modname = getattr(obj, "__module__", None) if modname is None: - try: - # whichmodule() could fail, see - # https://bitbucket.org/gutworth/six/issues/63/importing-six-breaks-pickling - modname = pickle.whichmodule(obj, name) - except Exception: - modname = '__main__' + modname = pickle.whichmodule(obj, name) if modname == '__main__': themodule = None @@ -408,31 +547,7 @@ def save_global(self, obj, name=None, pack=struct.pack): if new_override: d['__new__'] = obj.__new__ - # workaround for namedtuple (hijacked by PySpark) - if getattr(obj, '_is_namedtuple_', False): - self.save_reduce(_load_namedtuple, (obj.__name__, obj._fields)) - return - - self.save(_load_class) - self.save_reduce(typ, (obj.__name__, obj.__bases__, {"__doc__": obj.__doc__}), obj=obj) - d.pop('__doc__', None) - # handle property and staticmethod - dd = {} - for k, v in d.items(): - if isinstance(v, property): - k = ('property', k) - v = (v.fget, v.fset, v.fdel, v.__doc__) - elif isinstance(v, staticmethod) and hasattr(v, '__func__'): - k = ('staticmethod', k) - v = v.__func__ - elif isinstance(v, classmethod) and hasattr(v, '__func__'): - k = ('classmethod', k) - v = v.__func__ - dd[k] = v - self.save(dd) - self.write(pickle.TUPLE2) - self.write(pickle.REDUCE) - + self.save_reduce(typ, (obj.__name__, obj.__bases__, d), obj=obj) else: raise pickle.PicklingError("Can't pickle %r" % obj) @@ -441,11 +556,14 @@ def save_global(self, obj, name=None, pack=struct.pack): def save_instancemethod(self, obj): # Memoization rarely is ever useful due to python bounding - if PY3: - self.save_reduce(types.MethodType, (obj.__func__, obj.__self__), obj=obj) + if obj.__self__ is None: + self.save_reduce(getattr, (obj.im_class, obj.__name__)) else: - self.save_reduce(types.MethodType, (obj.__func__, obj.__self__, obj.__self__.__class__), - obj=obj) + if PY3: + self.save_reduce(types.MethodType, (obj.__func__, obj.__self__), obj=obj) + else: + self.save_reduce(types.MethodType, (obj.__func__, obj.__self__, obj.__self__.__class__), + obj=obj) dispatch[types.MethodType] = save_instancemethod def save_inst(self, obj): @@ -453,6 +571,12 @@ def save_inst(self, obj): Supports __transient__""" cls = obj.__class__ + # Try the dispatch table (pickle module doesn't do it) + f = self.dispatch.get(cls) + if f: + f(self, obj) # Call unbound method with explicit self + return + memo = self.memo write = self.write save = self.save @@ -503,6 +627,17 @@ def save_property(self, obj): self.save_reduce(property, (obj.fget, obj.fset, obj.fdel, obj.__doc__), obj=obj) dispatch[property] = save_property + def save_classmethod(self, obj): + try: + orig_func = obj.__func__ + except AttributeError: # Python 2.6 + orig_func = obj.__get__(None, object) + if isinstance(obj, classmethod): + orig_func = orig_func.__func__ # Unbind + self.save_reduce(type(obj), (orig_func,), obj=obj) + dispatch[classmethod] = save_classmethod + dispatch[staticmethod] = save_classmethod + def save_itemgetter(self, obj): """itemgetter serializer (needed for namedtuple support)""" class Dummy: @@ -623,72 +758,75 @@ def save_file(self, obj): return self.save_reduce(getattr, (sys,'stderr'), obj=obj) if obj is sys.stdin: raise pickle.PicklingError("Cannot pickle standard input") - if hasattr(obj, 'isatty') and obj.isatty(): + if obj.closed: + raise pickle.PicklingError("Cannot pickle closed files") + if hasattr(obj, 'isatty') and obj.isatty(): raise pickle.PicklingError("Cannot pickle files that map to tty objects") - if 'r' not in obj.mode: - raise pickle.PicklingError("Cannot pickle files that are not opened for reading") + if 'r' not in obj.mode and '+' not in obj.mode: + raise pickle.PicklingError("Cannot pickle files that are not opened for reading: %s" % obj.mode) + name = obj.name - try: - fsize = os.stat(name).st_size - except OSError: - raise pickle.PicklingError("Cannot pickle file %s as it cannot be stat" % name) - if obj.closed: - #create an empty closed string io - retval = pystringIO.StringIO("") - retval.close() - elif not fsize: #empty file - retval = pystringIO.StringIO("") - try: - tmpfile = file(name) - tst = tmpfile.read(1) - except IOError: - raise pickle.PicklingError("Cannot pickle file %s as it cannot be read" % name) - tmpfile.close() - if tst != '': - raise pickle.PicklingError("Cannot pickle file %s as it does not appear to map to a physical, real file" % name) - else: - try: - tmpfile = file(name) - contents = tmpfile.read() - tmpfile.close() - except IOError: - raise pickle.PicklingError("Cannot pickle file %s as it cannot be read" % name) - retval = pystringIO.StringIO(contents) + retval = pystringIO.StringIO() + + try: + # Read the whole file curloc = obj.tell() - retval.seek(curloc) + obj.seek(0) + contents = obj.read() + obj.seek(curloc) + except IOError: + raise pickle.PicklingError("Cannot pickle file %s as it cannot be read" % name) + retval.write(contents) + retval.seek(curloc) retval.name = name self.save(retval) self.memoize(obj) + def save_ellipsis(self, obj): + self.save_reduce(_gen_ellipsis, ()) + + def save_not_implemented(self, obj): + self.save_reduce(_gen_not_implemented, ()) + if PY3: dispatch[io.TextIOWrapper] = save_file else: dispatch[file] = save_file - """Special functions for Add-on libraries""" - - def inject_numpy(self): - numpy = sys.modules.get('numpy') - if not numpy or not hasattr(numpy, 'ufunc'): - return - self.dispatch[numpy.ufunc] = self.__class__.save_ufunc - - def save_ufunc(self, obj): - """Hack function for saving numpy ufunc objects""" - name = obj.__name__ - numpy_tst_mods = ['numpy', 'scipy.special'] - for tst_mod_name in numpy_tst_mods: - tst_mod = sys.modules.get(tst_mod_name, None) - if tst_mod and name in tst_mod.__dict__: - return self.save_reduce(_getobject, (tst_mod_name, name)) - raise pickle.PicklingError('cannot save %s. Cannot resolve what module it is defined in' - % str(obj)) + dispatch[type(Ellipsis)] = save_ellipsis + dispatch[type(NotImplemented)] = save_not_implemented + """Special functions for Add-on libraries""" def inject_addons(self): """Plug in system. Register additional pickling functions if modules already loaded""" - self.inject_numpy() + pass + + def save_logger(self, obj): + self.save_reduce(logging.getLogger, (obj.name,), obj=obj) + + dispatch[logging.Logger] = save_logger + + +# Tornado support + +def is_tornado_coroutine(func): + """ + Return whether *func* is a Tornado coroutine function. + Running coroutines are not supported. + """ + if 'tornado.gen' not in sys.modules: + return False + gen = sys.modules['tornado.gen'] + if not hasattr(gen, "is_coroutine_function"): + # Tornado version is too old + return False + return gen.is_coroutine_function(func) + +def _rebuild_tornado_coroutine(func): + from tornado import gen + return gen.coroutine(func) # Shorthands for legacy support @@ -705,6 +843,10 @@ def dumps(obj, protocol=2): return file.getvalue() +# including pickles unloading functions in this namespace +load = pickle.load +loads = pickle.loads + #hack for __import__ not working as desired def subimport(name): @@ -712,6 +854,12 @@ def subimport(name): return sys.modules[name] +def dynamic_subimport(name, vars): + mod = imp.new_module(name) + mod.__dict__.update(vars) + sys.modules[name] = mod + return mod + # restores function attributes def _restore_attr(obj, attr): for key, val in attr.items(): @@ -755,66 +903,102 @@ def _genpartial(func, args, kwds): kwds = {} return partial(func, *args, **kwds) +def _gen_ellipsis(): + return Ellipsis + +def _gen_not_implemented(): + return NotImplemented + -def _fill_function(func, globals, defaults, dict, module): +def _get_cell_contents(cell): + try: + return cell.cell_contents + except ValueError: + # sentinel used by ``_fill_function`` which will leave the cell empty + return _empty_cell_value + + +def instance(cls): + """Create a new instance of a class. + + Parameters + ---------- + cls : type + The class to create an instance of. + + Returns + ------- + instance : cls + A new instance of ``cls``. + """ + return cls() + + +@instance +class _empty_cell_value(object): + """sentinel for empty closures + """ + @classmethod + def __reduce__(cls): + return cls.__name__ + + +def _fill_function(func, globals, defaults, dict, closure_values): """ Fills in the rest of function data into the skeleton function object that were created via _make_skel_func(). - """ + """ func.__globals__.update(globals) func.__defaults__ = defaults func.__dict__ = dict - func.__module__ = module - return func + cells = func.__closure__ + if cells is not None: + for cell, value in zip(cells, closure_values): + if value is not _empty_cell_value: + cell_set(cell, value) + return func -def _make_cell(value): - return (lambda: value).__closure__[0] +def _make_empty_cell(): + if False: + # trick the compiler into creating an empty cell in our lambda + cell = None + raise AssertionError('this route should not be executed') -def _reconstruct_closure(values): - return tuple([_make_cell(v) for v in values]) + return (lambda: cell).__closure__[0] -def _make_skel_func(code, closures, base_globals = None): +def _make_skel_func(code, cell_count, base_globals=None): """ Creates a skeleton function object that contains just the provided code and the correct number of cells in func_closure. All other func attributes (e.g. func_globals) are empty. """ - closure = _reconstruct_closure(closures) if closures else None - if base_globals is None: base_globals = {} base_globals['__builtins__'] = __builtins__ - return types.FunctionType(code, base_globals, - None, None, closure) + closure = ( + tuple(_make_empty_cell() for _ in range(cell_count)) + if cell_count >= 0 else + None + ) + return types.FunctionType(code, base_globals, None, None, closure) -def _load_class(cls, d): - """ - Loads additional properties into class `cls`. - """ - for k, v in d.items(): - if isinstance(k, tuple): - typ, k = k - if typ == 'property': - v = property(*v) - elif typ == 'staticmethod': - v = staticmethod(v) - elif typ == 'classmethod': - v = classmethod(v) - setattr(cls, k, v) - return cls - - -def _load_namedtuple(name, fields): +def _find_module(mod_name): """ - Loads a class generated by namedtuple + Iterate over each part instead of calling imp.find_module directly. + This function is able to find submodules (e.g. sickit.tree) """ - from collections import namedtuple - return namedtuple(name, fields) - + path = None + for part in mod_name.split('.'): + if path is not None: + path = [path] + file, path, description = imp.find_module(part, path) + if file is not None: + file.close() + return path, description """Constructors for 3rd party libraries Note: These can never be renamed due to client compatibility issues""" @@ -822,3 +1006,18 @@ def _load_namedtuple(name, fields): def _getobject(modname, attribute): mod = __import__(modname, fromlist=[attribute]) return mod.__dict__[attribute] + + +""" Use copy_reg to extend global pickle definitions """ + +if sys.version_info < (3, 4): + method_descriptor = type(str.upper) + + def _reduce_method_descriptor(obj): + return (getattr, (obj.__objclass__, obj.__name__)) + + try: + import copy_reg as copyreg + except ImportError: + import copyreg + copyreg.pickle(method_descriptor, _reduce_method_descriptor)