From 290ff2ccfa20c6cd6578744cf71f06fe2272c5ed Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 4 Feb 2019 16:10:20 -0800 Subject: [PATCH 01/10] fix pickling dataclasses --- cloudpickle/cloudpickle.py | 7 +++++++ tests/cloudpickle_test.py | 9 +++++++++ 2 files changed, 16 insertions(+) diff --git a/cloudpickle/cloudpickle.py b/cloudpickle/cloudpickle.py index 4530c137f..42b26df24 100644 --- a/cloudpickle/cloudpickle.py +++ b/cloudpickle/cloudpickle.py @@ -266,6 +266,11 @@ def _walk_global_ops(code): yield op, instr.arg +def _normalize_dataclass_dict(clsdict): + dataclass_fields = clsdict.get('__dataclass_fields__', []) + for key in dataclass_fields: + dataclass_fields[key].metadata = dict(dataclass_fields[key].metadata) + class CloudPickler(Pickler): dispatch = Pickler.dispatch.copy() @@ -543,6 +548,8 @@ def save_dynamic_class(self, obj): tp = type(obj) self.save_reduce(tp, (obj.__name__, obj.__bases__, type_kwargs), obj=obj) + _normalize_dataclass_dict(clsdict) + # Now save the rest of obj's __dict__. Any references to obj # encountered while saving will point to the skeleton class. save(clsdict) diff --git a/tests/cloudpickle_test.py b/tests/cloudpickle_test.py index 6e922f10f..c17339fda 100644 --- a/tests/cloudpickle_test.py +++ b/tests/cloudpickle_test.py @@ -1344,6 +1344,15 @@ def __init__(self): with pytest.raises(AttributeError): obj.non_registered_attribute = 1 + def test_dataclass(self): + from dataclasses import dataclass + + @dataclass + class DataClass: + field: int + + pickle_depickle(DataClass, protocol=self.protocol) + class Protocol2CloudPickleTest(CloudPickleTest): From f31a9355143efdcb1f943861cc73262f1a3d6965 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 4 Feb 2019 16:23:03 -0800 Subject: [PATCH 02/10] fix --- tests/cloudpickle_test.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/cloudpickle_test.py b/tests/cloudpickle_test.py index c17339fda..435bb0dda 100644 --- a/tests/cloudpickle_test.py +++ b/tests/cloudpickle_test.py @@ -1344,12 +1344,12 @@ def __init__(self): with pytest.raises(AttributeError): obj.non_registered_attribute = 1 + @pytest.mark.skipif(sys.version_info < (3, 7), + reason="dataclasses not implemented before 3.7") def test_dataclass(self): - from dataclasses import dataclass + from dataclasses import make_dataclass - @dataclass - class DataClass: - field: int + DataClass = make_dataclass('DataClass', [('x', int)]) pickle_depickle(DataClass, protocol=self.protocol) From 436c06f1836bd918f96110ed6e8aa1aaf385d035 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 4 Feb 2019 16:35:15 -0800 Subject: [PATCH 03/10] add doc --- cloudpickle/cloudpickle.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/cloudpickle/cloudpickle.py b/cloudpickle/cloudpickle.py index 42b26df24..22a82fe59 100644 --- a/cloudpickle/cloudpickle.py +++ b/cloudpickle/cloudpickle.py @@ -267,6 +267,10 @@ def _walk_global_ops(code): def _normalize_dataclass_dict(clsdict): + """ + Replace mappingproxy fields (non pickleable) in dataclass + __dict__ with dict. + """ dataclass_fields = clsdict.get('__dataclass_fields__', []) for key in dataclass_fields: dataclass_fields[key].metadata = dict(dataclass_fields[key].metadata) From eb4cbbb1e965dbe2272a274d9e4d25c13e94808d Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 4 Feb 2019 16:36:09 -0800 Subject: [PATCH 04/10] fix --- cloudpickle/cloudpickle.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cloudpickle/cloudpickle.py b/cloudpickle/cloudpickle.py index 22a82fe59..25621ab8b 100644 --- a/cloudpickle/cloudpickle.py +++ b/cloudpickle/cloudpickle.py @@ -271,7 +271,7 @@ def _normalize_dataclass_dict(clsdict): Replace mappingproxy fields (non pickleable) in dataclass __dict__ with dict. """ - dataclass_fields = clsdict.get('__dataclass_fields__', []) + dataclass_fields = clsdict.get('__dataclass_fields__', {}) for key in dataclass_fields: dataclass_fields[key].metadata = dict(dataclass_fields[key].metadata) From 288fd09c48c1879646efce036c5ac97e512209d9 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Tue, 5 Feb 2019 13:57:18 +0100 Subject: [PATCH 05/10] Make the test introspect the presence of the dataclasses module instead of relying on Python version --- tests/cloudpickle_test.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/tests/cloudpickle_test.py b/tests/cloudpickle_test.py index 435bb0dda..f3a03cbe8 100644 --- a/tests/cloudpickle_test.py +++ b/tests/cloudpickle_test.py @@ -1344,13 +1344,10 @@ def __init__(self): with pytest.raises(AttributeError): obj.non_registered_attribute = 1 - @pytest.mark.skipif(sys.version_info < (3, 7), - reason="dataclasses not implemented before 3.7") def test_dataclass(self): - from dataclasses import make_dataclass - - DataClass = make_dataclass('DataClass', [('x', int)]) + dataclasses = pytest.importorskip("dataclasses") + DataClass = dataclasses.make_dataclass('DataClass', [('x', int)]) pickle_depickle(DataClass, protocol=self.protocol) From 90befb48af0e9fc868173ec64e57c7c30ef4f184 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Tue, 5 Feb 2019 13:58:04 +0100 Subject: [PATCH 06/10] Use a reducer for instances of MappingProxyTypes --- cloudpickle/cloudpickle.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/cloudpickle/cloudpickle.py b/cloudpickle/cloudpickle.py index 25621ab8b..ec84ce313 100644 --- a/cloudpickle/cloudpickle.py +++ b/cloudpickle/cloudpickle.py @@ -266,15 +266,6 @@ def _walk_global_ops(code): yield op, instr.arg -def _normalize_dataclass_dict(clsdict): - """ - Replace mappingproxy fields (non pickleable) in dataclass - __dict__ with dict. - """ - dataclass_fields = clsdict.get('__dataclass_fields__', {}) - for key in dataclass_fields: - dataclass_fields[key].metadata = dict(dataclass_fields[key].metadata) - class CloudPickler(Pickler): dispatch = Pickler.dispatch.copy() @@ -552,8 +543,6 @@ def save_dynamic_class(self, obj): tp = type(obj) self.save_reduce(tp, (obj.__name__, obj.__bases__, type_kwargs), obj=obj) - _normalize_dataclass_dict(clsdict) - # Now save the rest of obj's __dict__. Any references to obj # encountered while saving will point to the skeleton class. save(clsdict) @@ -915,6 +904,12 @@ def save_root_logger(self, obj): dispatch[logging.RootLogger] = save_root_logger + if hasattr(types, "MappingProxyType"): + def _save_mappingproxy(self, obj): + self.save_reduce(types.MappingProxyType, (dict(obj),), obj=obj) + + dispatch[types.MappingProxyType] = _save_mappingproxy + """Special functions for Add-on libraries""" def inject_addons(self): """Plug in system. Register additional pickling functions if modules already loaded""" From e10b58b54638a7bef70ad3a0ccc06ed080c1476a Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Tue, 5 Feb 2019 14:36:21 +0100 Subject: [PATCH 07/10] Improve tests --- tests/cloudpickle_test.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tests/cloudpickle_test.py b/tests/cloudpickle_test.py index f3a03cbe8..08597e757 100644 --- a/tests/cloudpickle_test.py +++ b/tests/cloudpickle_test.py @@ -1344,11 +1344,20 @@ def __init__(self): with pytest.raises(AttributeError): obj.non_registered_attribute = 1 + @unittest.skipIf(not hasattr(types, "MappingProxyType"), + "Old versions of Python do not have this type.") + def test_mappingproxy(self): + mp = types.MappingProxyType({"some_key": "some value"}) + assert mp == pickle_depickle(mp, protocol=self.protocol) + def test_dataclass(self): dataclasses = pytest.importorskip("dataclasses") DataClass = dataclasses.make_dataclass('DataClass', [('x', int)]) + data = DataClass(x=42) + pickle_depickle(DataClass, protocol=self.protocol) + assert data.x == pickle_depickle(data).x == 42 class Protocol2CloudPickleTest(CloudPickleTest): From eb1cc6c4e708271835c380e759b4b77771939c82 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Tue, 5 Feb 2019 14:36:37 +0100 Subject: [PATCH 08/10] Add an entry to the changelog --- CHANGES.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/CHANGES.md b/CHANGES.md index 1e6252458..76d380915 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,3 +1,10 @@ +0.8.0 +===== + +- Add support for pickling interactively defined dataclasses. + ([issue #245](https://github.com/cloudpipe/cloudpickle/pull/245)) + + 0.7.0 ===== From 236ad9482eef87b387af27ce131dd06da61ca30d Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Tue, 5 Feb 2019 14:58:29 +0100 Subject: [PATCH 09/10] Fix test to test with different protocol versions --- tests/cloudpickle_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/cloudpickle_test.py b/tests/cloudpickle_test.py index 08597e757..6091fb8bf 100644 --- a/tests/cloudpickle_test.py +++ b/tests/cloudpickle_test.py @@ -1357,7 +1357,7 @@ def test_dataclass(self): data = DataClass(x=42) pickle_depickle(DataClass, protocol=self.protocol) - assert data.x == pickle_depickle(data).x == 42 + assert data.x == pickle_depickle(data, protocol=self.protocol).x == 42 class Protocol2CloudPickleTest(CloudPickleTest): From 9ef232ca9fc7c0d15d286a909fbef315dfde1c84 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Wed, 6 Feb 2019 11:19:44 +0100 Subject: [PATCH 10/10] Locally disable branch-wise coverage measurements when conditional branches are system dependent --- cloudpickle/cloudpickle.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/cloudpickle/cloudpickle.py b/cloudpickle/cloudpickle.py index ec84ce313..f94e4c109 100644 --- a/cloudpickle/cloudpickle.py +++ b/cloudpickle/cloudpickle.py @@ -63,7 +63,7 @@ DEFAULT_PROTOCOL = pickle.HIGHEST_PROTOCOL -if sys.version < '3': +if sys.version_info[0] < 3: # pragma: no branch from pickle import Pickler try: from cStringIO import StringIO @@ -128,7 +128,7 @@ def inner(value): # NOTE: we are marking the cell variable as a free variable intentionally # so that we simulate an inner function instead of the outer function. This # is what gives us the ``nonlocal`` behavior in a Python 2 compatible way. - if not PY3: + if not PY3: # pragma: no branch return types.CodeType( co.co_argcount, co.co_nlocals, @@ -229,14 +229,14 @@ def _factory(): } -if sys.version_info < (3, 4): +if sys.version_info < (3, 4): # pragma: no branch def _walk_global_ops(code): """ Yield (opcode, argument number) tuples for all global-referencing instructions in *code*. """ code = getattr(code, 'co_code', b'') - if not PY3: + if not PY3: # pragma: no branch code = map(ord, code) n = len(code) @@ -293,7 +293,7 @@ def save_memoryview(self, obj): dispatch[memoryview] = save_memoryview - if not PY3: + if not PY3: # pragma: no branch def save_buffer(self, obj): self.save(str(obj)) @@ -315,7 +315,7 @@ def save_codeobject(self, obj): """ Save a code object """ - if PY3: + if PY3: # pragma: no branch args = ( obj.co_argcount, obj.co_kwonlyargcount, obj.co_nlocals, obj.co_stacksize, obj.co_flags, obj.co_code, obj.co_consts, obj.co_names, obj.co_varnames, @@ -393,7 +393,7 @@ def save_function(self, obj, name=None): # So we pickle them here using save_reduce; have to do it differently # for different python versions. if not hasattr(obj, '__code__'): - if PY3: + if PY3: # pragma: no branch rv = obj.__reduce_ex__(self.proto) else: if hasattr(obj, '__self__'): @@ -730,7 +730,7 @@ def save_instancemethod(self, obj): if obj.__self__ is None: self.save_reduce(getattr, (obj.im_class, obj.__name__)) else: - if PY3: + if PY3: # pragma: no branch self.save_reduce(types.MethodType, (obj.__func__, obj.__self__), obj=obj) else: self.save_reduce(types.MethodType, (obj.__func__, obj.__self__, obj.__self__.__class__), @@ -783,7 +783,7 @@ def save_inst(self, obj): save(stuff) write(pickle.BUILD) - if not PY3: + if not PY3: # pragma: no branch dispatch[types.InstanceType] = save_inst def save_property(self, obj): @@ -883,7 +883,7 @@ def save_not_implemented(self, obj): try: # Python 2 dispatch[file] = save_file - except NameError: # Python 3 + except NameError: # Python 3 # pragma: no branch dispatch[io.TextIOWrapper] = save_file dispatch[type(Ellipsis)] = save_ellipsis @@ -904,11 +904,11 @@ def save_root_logger(self, obj): dispatch[logging.RootLogger] = save_root_logger - if hasattr(types, "MappingProxyType"): - def _save_mappingproxy(self, obj): + if hasattr(types, "MappingProxyType"): # pragma: no branch + def save_mappingproxy(self, obj): self.save_reduce(types.MappingProxyType, (dict(obj),), obj=obj) - dispatch[types.MappingProxyType] = _save_mappingproxy + dispatch[types.MappingProxyType] = save_mappingproxy """Special functions for Add-on libraries""" def inject_addons(self): @@ -1219,7 +1219,7 @@ def _getobject(modname, attribute): """ Use copy_reg to extend global pickle definitions """ -if sys.version_info < (3, 4): +if sys.version_info < (3, 4): # pragma: no branch method_descriptor = type(str.upper) def _reduce_method_descriptor(obj):