diff --git a/CHANGES.md b/CHANGES.md index c4efa23dd..6fb058b8a 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -11,6 +11,10 @@ and expand the support for pickling `TypeVar` instances (dynamic or non-dynamic) to Python 3.5-3.6 ([PR #350](https://github.com/cloudpipe/cloudpickle/pull/350)) +- Add support for pickling dynamic classes subclassing `typing.Generic` + instances on Python 3.7+ + ([PR #351](https://github.com/cloudpipe/cloudpickle/pull/351)) + 1.3.0 ===== diff --git a/cloudpickle/cloudpickle.py b/cloudpickle/cloudpickle.py index 5cde86cfa..537a00cc5 100644 --- a/cloudpickle/cloudpickle.py +++ b/cloudpickle/cloudpickle.py @@ -446,7 +446,7 @@ def dump(self, obj): raise def save_typevar(self, obj): - self.save_reduce(*_typevar_reduce(obj)) + self.save_reduce(*_typevar_reduce(obj), obj=obj) dispatch[typing.TypeVar] = save_typevar @@ -645,7 +645,7 @@ def save_dynamic_class(self, obj): # "Regular" class definition: tp = type(obj) self.save_reduce(_make_skeleton_class, - (tp, obj.__name__, obj.__bases__, type_kwargs, + (tp, obj.__name__, _get_bases(obj), type_kwargs, _ensure_tracking(obj), None), obj=obj) @@ -1163,7 +1163,10 @@ class id will also reuse this class definition. The "extra" variable is meant to be a dict (or None) that can be used for forward compatibility shall the need arise. """ - skeleton_class = type_constructor(name, bases, type_kwargs) + skeleton_class = types.new_class( + name, bases, {'metaclass': type_constructor}, + lambda ns: ns.update(type_kwargs) + ) return _lookup_class_or_track(class_tracker_id, skeleton_class) @@ -1268,3 +1271,13 @@ def _typevar_reduce(obj): if module_and_name is None: return (_make_typevar, _decompose_typevar(obj)) return (getattr, module_and_name) + + +def _get_bases(typ): + if hasattr(typ, '__orig_bases__'): + # For generic types (see PEP 560) + bases_attr = '__orig_bases__' + else: + # For regular class objects + bases_attr = '__bases__' + return getattr(typ, bases_attr) diff --git a/cloudpickle/cloudpickle_fast.py b/cloudpickle/cloudpickle_fast.py index 47e70de94..49453d5b1 100644 --- a/cloudpickle/cloudpickle_fast.py +++ b/cloudpickle/cloudpickle_fast.py @@ -28,7 +28,7 @@ _is_dynamic, _extract_code_globals, _BUILTIN_TYPE_NAMES, DEFAULT_PROTOCOL, _find_imported_submodules, _get_cell_contents, _is_importable_by_name, _builtin_type, Enum, _ensure_tracking, _make_skeleton_class, _make_skeleton_enum, - _extract_class_dict, dynamic_subimport, subimport, _typevar_reduce, + _extract_class_dict, dynamic_subimport, subimport, _typevar_reduce, _get_bases, ) load, loads = _pickle.load, _pickle.loads @@ -76,7 +76,7 @@ def _class_getnewargs(obj): if isinstance(__dict__, property): type_kwargs['__dict__'] = __dict__ - return (type(obj), obj.__name__, obj.__bases__, type_kwargs, + return (type(obj), obj.__name__, _get_bases(obj), type_kwargs, _ensure_tracking(obj), None) diff --git a/tests/cloudpickle_test.py b/tests/cloudpickle_test.py index aa64f5740..eab45763b 100644 --- a/tests/cloudpickle_test.py +++ b/tests/cloudpickle_test.py @@ -50,6 +50,7 @@ from .testutils import subprocess_pickle_echo from .testutils import assert_run_python_script +from .testutils import subprocess_worker _TEST_GLOBAL_VARIABLE = "default_value" @@ -2121,6 +2122,12 @@ def test_pickle_dynamic_typevar(self): for attr in attr_list: assert getattr(T, attr) == getattr(depickled_T, attr) + def test_pickle_dynamic_typevar_memoization(self): + T = typing.TypeVar('T') + depickled_T1, depickled_T2 = pickle_depickle((T, T), + protocol=self.protocol) + assert depickled_T1 is depickled_T2 + def test_pickle_importable_typevar(self): from .mypkg import T T1 = pickle_depickle(T, protocol=self.protocol) @@ -2130,6 +2137,61 @@ def test_pickle_importable_typevar(self): from typing import AnyStr assert AnyStr is pickle_depickle(AnyStr, protocol=self.protocol) + @unittest.skipIf(sys.version_info < (3, 7), + "Pickling generics not supported below py37") + def test_generic_type(self): + T = typing.TypeVar('T') + + class C(typing.Generic[T]): + pass + + assert pickle_depickle(C, protocol=self.protocol) is C + assert pickle_depickle(C[int], protocol=self.protocol) is C[int] + + with subprocess_worker(protocol=self.protocol) as worker: + + def check_generic(generic, origin, type_value): + assert generic.__origin__ is origin + assert len(generic.__args__) == 1 + assert generic.__args__[0] is type_value + + assert len(origin.__orig_bases__) == 1 + ob = origin.__orig_bases__[0] + assert ob.__origin__ is typing.Generic + assert len(ob.__parameters__) == 1 + + return "ok" + + assert check_generic(C[int], C, int) == "ok" + assert worker.run(check_generic, C[int], C, int) == "ok" + + @unittest.skipIf(sys.version_info < (3, 7), + "Pickling type hints not supported below py37") + def test_locally_defined_class_with_type_hints(self): + with subprocess_worker(protocol=self.protocol) as worker: + for type_ in _all_types_to_test(): + # The type annotation syntax causes a SyntaxError on Python 3.5 + code = textwrap.dedent("""\ + class MyClass: + attribute: type_ + + def method(self, arg: type_) -> type_: + return arg + """) + ns = {"type_": type_} + exec(code, ns) + MyClass = ns["MyClass"] + + def check_annotations(obj, expected_type): + assert obj.__annotations__["attribute"] is expected_type + assert obj.method.__annotations__["arg"] is expected_type + assert obj.method.__annotations__["return"] is expected_type + return "ok" + + obj = MyClass() + assert check_annotations(obj, type_) == "ok" + assert worker.run(check_annotations, obj, type_) == "ok" + class Protocol2CloudPickleTest(CloudPickleTest): @@ -2161,5 +2223,28 @@ def test_lookup_module_and_qualname_stdlib_typevar(): assert name == 'AnyStr' +def _all_types_to_test(): + T = typing.TypeVar('T') + + class C(typing.Generic[T]): + pass + + return [ + C, C[int], + T, typing.Any, typing.NoReturn, typing.Optional, + typing.Generic, typing.Union, typing.ClassVar, + typing.Optional[int], + typing.Generic[T], + typing.Callable[[int], typing.Any], + typing.Callable[..., typing.Any], + typing.Callable[[], typing.Any], + typing.Tuple[int, ...], + typing.Tuple[int, C[int]], + typing.ClassVar[C[int]], + typing.List[int], + typing.Dict[int, str], + ] + + if __name__ == '__main__': unittest.main()