Skip to content

Commit

Permalink
session: deal with modules with unpickleable objects
Browse files Browse the repository at this point in the history
  • Loading branch information
leogama committed Jul 19, 2022
1 parent 2fdd31d commit 6b55755
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 29 deletions.
104 changes: 87 additions & 17 deletions dill/_dill.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
#XXX: get types from .objtypes ?
import builtins as __builtin__
from pickle import _Pickler as StockPickler, Unpickler as StockUnpickler
from pickle import BINPUT, DICT, EMPTY_DICT, LONG_BINPUT, MARK, PUT, SETITEM
from struct import pack
from _thread import LockType
from _thread import RLock as RLockType
#from io import IOBase
Expand Down Expand Up @@ -234,6 +236,9 @@ def __reduce_ex__(self, protocol):
#: Pickles the entire file (handle and contents), preserving mode and position.
FILE_FMODE = 2

# Exceptions commonly raised by unpicklable objects.
UNPICKLEABLE_ERRORS = (PicklingError, TypeError, NotImplementedError)

### Shorthands (modified from python2.5/lib/pickle.py)
def copy(obj, *args, **kwds):
"""
Expand Down Expand Up @@ -349,16 +354,18 @@ class Pickler(StockPickler):
def __init__(self, file, *args, **kwds):
settings = Pickler.settings
_byref = kwds.pop('byref', None)
#_strictio = kwds.pop('strictio', None)
_fmode = kwds.pop('fmode', None)
_recurse = kwds.pop('recurse', None)
#_refonfail = kwds.pop('refonfail', None)
#_strictio = kwds.pop('strictio', None)
StockPickler.__init__(self, file, *args, **kwds)
self._main = _main_module
self._diff_cache = {}
self._byref = settings['byref'] if _byref is None else _byref
self._strictio = False #_strictio
self._fmode = settings['fmode'] if _fmode is None else _fmode
self._recurse = settings['recurse'] if _recurse is None else _recurse
self._refonfail = False #settings['dump_module']['refonfail'] if _refonfail is None else _refonfail
self._strictio = False #_strictio
self._postproc = OrderedDict()
self._file = file # for the logger

Expand Down Expand Up @@ -395,7 +402,7 @@ def save_numpy_dtype(pickler, obj):
if NumpyArrayType and ndarraysubclassinstance(obj):
@register(type(obj))
def save_numpy_array(pickler, obj):
logger.trace(pickler, "Nu: (%s, %s)", obj.shape, obj.dtype)
logger.trace(pickler, "Nu: (%s, %s)", obj.shape, obj.dtype, obj=obj)
npdict = getattr(obj, '__dict__', None)
f, args, state = obj.__reduce__()
pickler.save_reduce(_create_array, (f,args,state,npdict), obj=obj)
Expand All @@ -407,9 +414,68 @@ def save_numpy_array(pickler, obj):
raise PicklingError(msg)
logger.trace_setup(self)
StockPickler.dump(self, obj)

dump.__doc__ = StockPickler.dump.__doc__

def save(self, obj, save_persistent_id=True, *, name=None):
"""If self._refonfail is True, try to save object by reference if pickling fails."""
if not self._refonfail:
super().save(obj, save_persistent_id)
return
if self.framer.current_frame:
# protocol >= 4
self.framer.commit_frame()
stream = self.framer.current_frame
else:
stream = self._file
position = stream.tell()
memo_size = len(self.memo)
try:
super().save(obj, save_persistent_id)
except UNPICKLEABLE_ERRORS + (AttributeError,) as error_stack:
# AttributeError may happen in save_global() call for child object.
if (type(error_stack) == AttributeError
and "no attribute '__name__'" not in error_stack.args[0]):
raise
# roll back the stream
stream.seek(position)
stream.truncate()
# roll back memo
for _ in range(len(self.memo) - memo_size):
self.memo.popitem() # LIFO order is guaranteed for since 3.7
try:
self.save_global(obj, name)
except (AttributeError, PicklingError) as error:
if getattr(self, '_trace_stack', None) and id(obj) == self._trace_stack[-1]:
# roll back trace state
self._trace_stack.pop()
self._size_stack.pop()
raise error from error_stack
logger.trace(self, "# X: fallback to save_global: <%s object at %#012x>",
type(obj).__name__, id(obj), obj=obj)

def _save_module_dict(self, obj):
"""
Use object name in the module namespace as a last resource to try to
save it by reference when pickling fails.
Modified from Pickler.save_dict() and Pickler._batch_setitems().
"""
if not self._refonfail:
super().save_dict(obj)
return
if self.bin:
self.write(EMPTY_DICT)
else: # proto 0 -- can't use EMPTY_DICT
self.write(MARK + DICT)
self.memoize(obj)
for k, v in obj.items():
self.save(k)
if hasattr(v, '__name__') or hasattr(v, '__qualname__'):
self.save(v)
else:
self.save(v, name=k)
self.write(SETITEM)

class Unpickler(StockUnpickler):
"""python's Unpickler extended to interpreter sessions and more types"""
from .settings import settings
Expand Down Expand Up @@ -1173,26 +1239,30 @@ def _repr_dict(obj):

@register(dict)
def save_module_dict(pickler, obj):
if is_dill(pickler, child=False) and obj == pickler._main.__dict__ and \
pickler_is_dill = is_dill(pickler, child=False)
if pickler_is_dill and obj == pickler._main.__dict__ and \
not (pickler._session and pickler._first_pass):
logger.trace(pickler, "D1: %s", _repr_dict(obj)) # obj
logger.trace(pickler, "D1: %s", _repr_dict(obj), obj=obj)
pickler.write(bytes('c__builtin__\n__main__\n', 'UTF-8'))
logger.trace(pickler, "# D1")
elif (not is_dill(pickler, child=False)) and (obj == _main_module.__dict__):
logger.trace(pickler, "D3: %s", _repr_dict(obj)) # obj
elif (not pickler_is_dill) and (obj == _main_module.__dict__):
logger.trace(pickler, "D3: %s", _repr_dict(obj), obj=obj)
pickler.write(bytes('c__main__\n__dict__\n', 'UTF-8')) #XXX: works in general?
logger.trace(pickler, "# D3")
elif '__name__' in obj and obj != _main_module.__dict__ \
and type(obj['__name__']) is str \
and obj is getattr(_import_module(obj['__name__'],True), '__dict__', None):
logger.trace(pickler, "D4: %s", _repr_dict(obj)) # obj
logger.trace(pickler, "D4: %s", _repr_dict(obj), obj=obj)
pickler.write(bytes('c%s\n__dict__\n' % obj['__name__'], 'UTF-8'))
logger.trace(pickler, "# D4")
elif pickler_is_dill and pickler._session and pickler._first_pass:
# we only care about session the first pass thru
pickler._first_pass = False
logger.trace(pickler, "D5: %s", _repr_dict(obj), obj=obj)
pickler._save_module_dict(obj)
logger.trace(pickler, "# D5")
else:
logger.trace(pickler, "D2: %s", _repr_dict(obj)) # obj
if is_dill(pickler, child=False) and pickler._session:
# we only care about session the first pass thru
pickler._first_pass = False
logger.trace(pickler, "D2: %s", _repr_dict(obj), obj=obj)
StockPickler.save_dict(pickler, obj)
logger.trace(pickler, "# D2")
return
Expand Down Expand Up @@ -1491,15 +1561,15 @@ def save_cell(pickler, obj):
if MAPPING_PROXY_TRICK:
@register(DictProxyType)
def save_dictproxy(pickler, obj):
logger.trace(pickler, "Mp: %s", _repr_dict(obj)) # obj
logger.trace(pickler, "Mp: %s", _repr_dict(obj), obj=obj)
mapping = obj | _dictproxy_helper_instance
pickler.save_reduce(DictProxyType, (mapping,), obj=obj)
logger.trace(pickler, "# Mp")
return
else:
@register(DictProxyType)
def save_dictproxy(pickler, obj):
logger.trace(pickler, "Mp: %s", _repr_dict(obj)) # obj
logger.trace(pickler, "Mp: %s", _repr_dict(obj), obj=obj)
pickler.save_reduce(DictProxyType, (obj.copy(),), obj=obj)
logger.trace(pickler, "# Mp")
return
Expand Down Expand Up @@ -1575,7 +1645,7 @@ def save_weakproxy(pickler, obj):
logger.trace(pickler, "%s: %s", _t, obj)
except ReferenceError:
_t = "R3"
logger.trace(pickler, "%s: %s", _t, sys.exc_info()[1])
logger.trace(pickler, "%s: %s", _t, sys.exc_info()[1], obj=obj)
#callable = bool(getattr(refobj, '__call__', None))
if type(obj) is CallableProxyType: callable = True
else: callable = False
Expand Down Expand Up @@ -1914,7 +1984,7 @@ def pickles(obj,exact=False,safe=False,**kwds):
"""
if safe: exceptions = (Exception,) # RuntimeError, ValueError
else:
exceptions = (TypeError, AssertionError, NotImplementedError, PicklingError, UnpicklingError)
exceptions = UNPICKLEABLE_ERRORS + (AssertionError, UnpicklingError)
try:
pik = copy(obj, **kwds)
#FIXME: should check types match first, then check content if "exact"
Expand Down
22 changes: 12 additions & 10 deletions dill/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,18 +129,22 @@ def trace_setup(self, pickler):
if not dill._dill.is_dill(pickler, child=False):
return
if self.isEnabledFor(logging.INFO):
pickler._trace_depth = 1
pickler._trace_stack = []
pickler._size_stack = []
else:
pickler._trace_depth = None
def trace(self, pickler, msg, *args, **kwargs):
if not hasattr(pickler, '_trace_depth'):
pickler._trace_stack = None
def trace(self, pickler, msg, *args, obj=None, **kwargs):
if not hasattr(pickler, '_trace_stack'):
logger.info(msg, *args, **kwargs)
return
if pickler._trace_depth is None:
if pickler._trace_stack is None:
return
extra = kwargs.get('extra', {})
pushed_obj = msg.startswith('#')
if not pushed_obj:
if obj is None:
obj = args[-1]
pickler._trace_stack.append(id(obj))
size = None
try:
# Streams are not required to be tellable.
Expand All @@ -159,13 +163,11 @@ def trace(self, pickler, msg, *args, **kwargs):
else:
size -= pickler._size_stack.pop()
extra['size'] = size
if pushed_obj:
pickler._trace_depth -= 1
extra['depth'] = pickler._trace_depth
extra['depth'] = len(pickler._trace_stack)
kwargs['extra'] = extra
self.info(msg, *args, **kwargs)
if not pushed_obj:
pickler._trace_depth += 1
if pushed_obj:
pickler._trace_stack.pop()

class TraceFormatter(logging.Formatter):
"""
Expand Down
22 changes: 21 additions & 1 deletion dill/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,9 +184,10 @@ def dump_module(
filename = str(TEMPDIR/'session.pkl'),
module: Union[ModuleType, str] = None,
refimported: bool = False,
refonfail: bool = False,
**kwds
) -> None:
"""Pickle the current state of :py:mod:`__main__` or another module to a file.
R"""Pickle the current state of :py:mod:`__main__` or another module to a file.
Save the contents of :py:mod:`__main__` (e.g. from an interactive
interpreter session), an imported module, or a module-type object (e.g.
Expand All @@ -202,6 +203,10 @@ def dump_module(
similar but independent from ``dill.settings[`byref`]``, as
``refimported`` refers to virtually all imported objects, while
``byref`` only affects select objects.
refonfail: if `True`, objects that fail to be saved by value will try to
be saved by reference. If it also fails, saving their parent
objects by reference will be attempted recursively. In the worst
case scenario, the module itself may be saved by reference.
**kwds: extra keyword arguments passed to :py:class:`Pickler()`.
Raises:
Expand Down Expand Up @@ -232,6 +237,15 @@ def dump_module(
>>> foo.sin = math.sin
>>> dill.dump_module('foo_session.pkl', module=foo, refimported=True)
- Save the state of a module with unpickleable objects:
>>> import dill
>>> import os
>>> os.altsep = '\\'
>>> dill.dump_module('os_session.pkl', module=os)
PicklingError: ...
>>> dill.dump_module('os_session.pkl', module=os, refonfail=True)
- Restore the state of the saved modules:
>>> import dill
Expand All @@ -244,6 +258,9 @@ def dump_module(
>>> foo = dill.load_module('foo_session.pkl')
>>> [foo.sin(x) for x in foo.values]
[0.8414709848078965, 0.9092974268256817, 0.1411200080598672]
>>> os = dill.load_module('os_session.pkl')
>>> print(os.altsep.join('path'))
p\a\t\h
*Changed in version 0.3.6:* Function ``dump_session()`` was renamed to
``dump_module()``. Parameters ``main`` and ``byref`` were renamed to
Expand All @@ -266,6 +283,8 @@ def dump_module(

from .settings import settings
protocol = settings['protocol']
if refimported is None: refimported = settings['dump_module']['refimported']
if refonfail is None: refonfail = settings['dump_module']['refonfail']
main = module
if main is None:
main = _main_module
Expand All @@ -283,6 +302,7 @@ def dump_module(
pickler._main = main #FIXME: dill.settings are disabled
pickler._byref = False # disable pickling by name reference
pickler._recurse = False # disable pickling recursion for globals
pickler._refonfail = refonfail
pickler._session = True # is best indicator of when pickling a session
pickler._first_pass = True
pickler.dump(main)
Expand Down
4 changes: 4 additions & 0 deletions dill/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@
'fmode' : 0, #HANDLE_FMODE
'recurse' : False,
'ignore' : False,
'dump_module' : {
'refimported': False,
'refonfail' : False,
},
}

del DEFAULT_PROTOCOL
Expand Down
2 changes: 1 addition & 1 deletion dill/tests/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def test_runtime_module():
runtime = ModuleType(modname)
runtime.x = 42

mod = dill._dill._stash_modules(runtime)
mod = dill.session._stash_modules(runtime)
if mod is not runtime:
print("There are objects to save by referenece that shouldn't be:",
mod.__dill_imported, mod.__dill_imported_as, mod.__dill_imported_top_level,
Expand Down

0 comments on commit 6b55755

Please sign in to comment.