From 52a2a207dfb4fc16de76ae73ee1b5544f701bb87 Mon Sep 17 00:00:00 2001 From: Matthew Hoffman Date: Sat, 18 May 2024 12:17:49 -0700 Subject: [PATCH 1/9] gh-87533: Expand pickle importing to support non-package C-modules There have been recurring issue with PyModule_Create modules in PyTorch; when trying to serialize attributes of these C-modules, pickle fails to import the C-module because it is not a package This is the current issue that brought this to my attention: https://github.com/pytorch/pytorch/issues/126154 The existing hack to this issue has been to insert the C-module into sys.modules in order to enable pickle to find them: https://github.com/pytorch/pytorch/pull/38136/files#diff-d7e90d0f94b43db763b44fba679a5c1b4cabe3668aaf34f2aee07de8e2d1b2faR524-R528 Instead of relying on this hack, we can change `pickle`'s approach to loading, which is currently equivalent to `import package.c_module`; instead, we could do `from package import c_module`, which 1) does not care if `c_module` is a package or not 2) is fully backward compatible with the previous approach and 3) slots in nicely to the `fromlist` parameter of `__import__`, which we are already using to load modules in `pickle` --- Lib/pickle.py | 40 +++++++++++++++------- Lib/test/pickletester.py | 8 +++++ Modules/_pickle.c | 73 +++++++++++++++++++++++++++++++++------- 3 files changed, 96 insertions(+), 25 deletions(-) diff --git a/Lib/pickle.py b/Lib/pickle.py index 33c97c8c5efb28..65d9f9c2fc490f 100644 --- a/Lib/pickle.py +++ b/Lib/pickle.py @@ -345,6 +345,24 @@ def whichmodule(obj, name): pass return '__main__' +def import_module_from_string(module_name): + """Import a module from a string. + + The last module in the dot-delimited module name is returned. The last + module need not be a package (see bpo-43367). + + >>> import_module_from_string('pickle') + + >>> import_module_from_string('collections.abc) + + """ + module_name, *fromlist = module_name.rsplit(".", 1) + module = __import__(module_name, fromlist=fromlist, level=0) + if fromlist: + assert len(fromlist) == 1 + module = getattr(module, fromlist[0]) + return module + def encode_long(x): r"""Encode a long to a two's complement little-endian binary string. Note that 0 is a special case, returning an empty string, to save a @@ -364,7 +382,6 @@ def encode_long(x): b'\x80' >>> encode_long(127) b'\x7f' - >>> """ if x == 0: return b'' @@ -1049,7 +1066,6 @@ def save_frozenset(self, obj): def save_global(self, obj, name=None): write = self.write - memo = self.memo if name is None: name = getattr(obj, '__qualname__', None) @@ -1058,8 +1074,7 @@ def save_global(self, obj, name=None): module_name = whichmodule(obj, name) try: - __import__(module_name, level=0) - module = sys.modules[module_name] + module = import_module_from_string(module_name) obj2, parent = _getattribute(module, name) except (ImportError, KeyError, AttributeError): raise PicklingError( @@ -1565,17 +1580,18 @@ def get_extension(self, code): def find_class(self, module, name): # Subclasses may override this. - sys.audit('pickle.find_class', module, name) + module_name = module + sys.audit('pickle.find_class', module_name, name) if self.proto < 3 and self.fix_imports: - if (module, name) in _compat_pickle.NAME_MAPPING: - module, name = _compat_pickle.NAME_MAPPING[(module, name)] - elif module in _compat_pickle.IMPORT_MAPPING: - module = _compat_pickle.IMPORT_MAPPING[module] - __import__(module, level=0) + if (module_name, name) in _compat_pickle.NAME_MAPPING: + module_name, name = _compat_pickle.NAME_MAPPING[(module_name, name)] + elif module_name in _compat_pickle.IMPORT_MAPPING: + module_name = _compat_pickle.IMPORT_MAPPING[module_name] + module = import_module_from_string(module_name) if self.proto >= 4: - return _getattribute(sys.modules[module], name)[0] + return _getattribute(module, name)[0] else: - return getattr(sys.modules[module], name) + return getattr(module, name) def load_reduce(self): stack = self.stack diff --git a/Lib/test/pickletester.py b/Lib/test/pickletester.py index 93e7dbbd103934..1c1c6b660c43b5 100644 --- a/Lib/test/pickletester.py +++ b/Lib/test/pickletester.py @@ -2845,6 +2845,12 @@ class Subclass(tuple): class Nested(str): pass + c_module = types.ModuleType("c_module") + def c_function(): + return None + c_function.__module__ = f"{__name__}.{c_module.__name__}" + c_module.c_function = c_function + c_methods = ( # bound built-in method ("abcd".index, ("c",)), @@ -2867,6 +2873,8 @@ class Nested(str): (Subclass.count, (Subclass([1,2,2]), 2)), (Subclass.Nested("sweet").count, ("e",)), (Subclass.Nested.count, (Subclass.Nested("sweet"), "e")), + # bpo-43367: pickling C-module attributes + (c_module.c_function, ()), ) for proto in range(pickle.HIGHEST_PROTOCOL + 1): for method, args in c_methods: diff --git a/Modules/_pickle.c b/Modules/_pickle.c index 754a326822e0f0..baf599f303ff09 100644 --- a/Modules/_pickle.c +++ b/Modules/_pickle.c @@ -3584,6 +3584,63 @@ fix_imports(PickleState *st, PyObject **module_name, PyObject **global_name) return 0; } +PyObject* +import_module_from_string(PyObject *module_name) +{ + PyObject *split_module_name = PyUnicode_RSplit(module_name, + PyUnicode_FromString("."), + 1); + if (split_module_name == NULL) { + PyErr_Format(PyExc_RuntimeError, + "Failed to split module name %R", module_name); + return NULL; + } + + module_name = PySequence_Fast_GET_ITEM(split_module_name, 0); + if (module_name == NULL) { + PyErr_Format(PyExc_RuntimeError, "Failed to get module name from %R", + split_module_name); + return NULL; + } + + PyObject *fromlist = PySequence_GetSlice(split_module_name, 1, + PySequence_Fast_GET_SIZE(split_module_name)); + if (fromlist == NULL) { + PyErr_Format(PyExc_RuntimeError, "Failed to get fromlist from %R", + split_module_name); + return NULL; + } + + PyObject *module = PyImport_ImportModuleLevelObject(module_name, NULL, + NULL, fromlist, 0); + if (module == NULL) { + PyErr_Format(PyExc_ModuleNotFoundError, "Import of module %R failed", + module_name); + return NULL; + } + + Py_ssize_t fromlist_size = PySequence_Fast_GET_SIZE(fromlist); + if (fromlist_size > 0) { + assert(fromlist_size == 1); + PyObject *submodule_name = PySequence_Fast_GET_ITEM(fromlist, 0); + if (submodule_name == NULL) { + PyErr_Format(PyExc_RuntimeError, "Failed to get submodule name from %R", + fromlist); + return NULL; + } + module = PyObject_GetAttr(module, submodule_name); + if (module == NULL) { + PyErr_Format(PyExc_ModuleNotFoundError, + "Attribute lookup %R on %R failed", submodule_name, + module_name); + return NULL; + } + } + + return module; +} + + static int save_global(PickleState *st, PicklerObject *self, PyObject *obj, PyObject *name) @@ -3619,21 +3676,11 @@ save_global(PickleState *st, PicklerObject *self, PyObject *obj, if (module_name == NULL) goto error; - /* XXX: Change to use the import C API directly with level=0 to disallow - relative imports. - - XXX: PyImport_ImportModuleLevel could be used. However, this bypasses - builtins.__import__. Therefore, _pickle, unlike pickle.py, will ignore - custom import functions (IMHO, this would be a nice security - feature). The import C API would need to be extended to support the - extra parameters of __import__ to fix that. */ - module = PyImport_Import(module_name); + module = import_module_from_string(module_name); if (module == NULL) { - PyErr_Format(st->PicklingError, - "Can't pickle %R: import of module %R failed", - obj, module_name); goto error; } + lastname = Py_NewRef(PyList_GET_ITEM(dotted_path, PyList_GET_SIZE(dotted_path) - 1)); cls = get_deep_attribute(module, dotted_path, &parent); @@ -7036,7 +7083,7 @@ _pickle_Unpickler_find_class_impl(UnpicklerObject *self, PyTypeObject *cls, * we don't use PyImport_GetModule here, because it can return partially- * initialised modules, which then cause the getattribute to fail. */ - module = PyImport_Import(module_name); + module = import_module_from_string(module_name); if (module == NULL) { return NULL; } From 18319074e4359f51d6c9e0833ddcdbd9f6b5b381 Mon Sep 17 00:00:00 2001 From: Matthew Hoffman Date: Sat, 18 May 2024 13:02:01 -0700 Subject: [PATCH 2/9] Update error type --- Lib/test/test_dataclasses/__init__.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/Lib/test/test_dataclasses/__init__.py b/Lib/test/test_dataclasses/__init__.py index ea49596eaa4d96..1d51e69b7a5a23 100644 --- a/Lib/test/test_dataclasses/__init__.py +++ b/Lib/test/test_dataclasses/__init__.py @@ -4100,12 +4100,16 @@ def test_pickle_support(self): ) def test_cannot_be_pickled(self): - for klass in [WrongNameMakeDataclass, WrongModuleMakeDataclass]: + klass_error_type_pairs = [ + (WrongNameMakeDataclass, pickle.PicklingError), + (WrongModuleMakeDataclass, ModuleNotFoundError), + ] + for klass, error_type in klass_error_type_pairs: for proto in range(pickle.HIGHEST_PROTOCOL + 1): with self.subTest(proto=proto): - with self.assertRaises(pickle.PickleError): + with self.assertRaises(error_type): pickle.dumps(klass, proto) - with self.assertRaises(pickle.PickleError): + with self.assertRaises(error_type): pickle.dumps(klass(1), proto) def test_invalid_type_specification(self): From e8adfa76f98d3b34708024ac7393ed54a57b8a1a Mon Sep 17 00:00:00 2001 From: Matthew Hoffman Date: Sat, 18 May 2024 14:02:22 -0700 Subject: [PATCH 3/9] Raise PickleError to maintain backward compatibility --- Lib/test/pickletester.py | 3 +++ Lib/test/test_dataclasses/__init__.py | 10 +++------- Modules/_pickle.c | 25 ++++++++++++------------- 3 files changed, 18 insertions(+), 20 deletions(-) diff --git a/Lib/test/pickletester.py b/Lib/test/pickletester.py index 1c1c6b660c43b5..5e61be34f1e22c 100644 --- a/Lib/test/pickletester.py +++ b/Lib/test/pickletester.py @@ -2845,9 +2845,12 @@ class Subclass(tuple): class Nested(str): pass + # simulate a module created with PyModule_Create containing a function + global c_module c_module = types.ModuleType("c_module") def c_function(): return None + c_function.__qualname__ = c_function.__name__ = "c_function" c_function.__module__ = f"{__name__}.{c_module.__name__}" c_module.c_function = c_function diff --git a/Lib/test/test_dataclasses/__init__.py b/Lib/test/test_dataclasses/__init__.py index 1d51e69b7a5a23..ea49596eaa4d96 100644 --- a/Lib/test/test_dataclasses/__init__.py +++ b/Lib/test/test_dataclasses/__init__.py @@ -4100,16 +4100,12 @@ def test_pickle_support(self): ) def test_cannot_be_pickled(self): - klass_error_type_pairs = [ - (WrongNameMakeDataclass, pickle.PicklingError), - (WrongModuleMakeDataclass, ModuleNotFoundError), - ] - for klass, error_type in klass_error_type_pairs: + for klass in [WrongNameMakeDataclass, WrongModuleMakeDataclass]: for proto in range(pickle.HIGHEST_PROTOCOL + 1): with self.subTest(proto=proto): - with self.assertRaises(error_type): + with self.assertRaises(pickle.PickleError): pickle.dumps(klass, proto) - with self.assertRaises(error_type): + with self.assertRaises(pickle.PickleError): pickle.dumps(klass(1), proto) def test_invalid_type_specification(self): diff --git a/Modules/_pickle.c b/Modules/_pickle.c index baf599f303ff09..bb0c4d1aa1793c 100644 --- a/Modules/_pickle.c +++ b/Modules/_pickle.c @@ -3585,20 +3585,20 @@ fix_imports(PickleState *st, PyObject **module_name, PyObject **global_name) } PyObject* -import_module_from_string(PyObject *module_name) +import_module_from_string(PickleState *st, PyObject *module_name) { PyObject *split_module_name = PyUnicode_RSplit(module_name, PyUnicode_FromString("."), 1); if (split_module_name == NULL) { - PyErr_Format(PyExc_RuntimeError, + PyErr_Format(st->PicklingError, "Failed to split module name %R", module_name); return NULL; } module_name = PySequence_Fast_GET_ITEM(split_module_name, 0); if (module_name == NULL) { - PyErr_Format(PyExc_RuntimeError, "Failed to get module name from %R", + PyErr_Format(st->PicklingError, "Failed to get module name from %R", split_module_name); return NULL; } @@ -3606,7 +3606,7 @@ import_module_from_string(PyObject *module_name) PyObject *fromlist = PySequence_GetSlice(split_module_name, 1, PySequence_Fast_GET_SIZE(split_module_name)); if (fromlist == NULL) { - PyErr_Format(PyExc_RuntimeError, "Failed to get fromlist from %R", + PyErr_Format(st->PicklingError, "Failed to get fromlist from %R", split_module_name); return NULL; } @@ -3614,7 +3614,7 @@ import_module_from_string(PyObject *module_name) PyObject *module = PyImport_ImportModuleLevelObject(module_name, NULL, NULL, fromlist, 0); if (module == NULL) { - PyErr_Format(PyExc_ModuleNotFoundError, "Import of module %R failed", + PyErr_Format(st->PicklingError, "Import of module %R failed", module_name); return NULL; } @@ -3624,15 +3624,14 @@ import_module_from_string(PyObject *module_name) assert(fromlist_size == 1); PyObject *submodule_name = PySequence_Fast_GET_ITEM(fromlist, 0); if (submodule_name == NULL) { - PyErr_Format(PyExc_RuntimeError, "Failed to get submodule name from %R", - fromlist); + PyErr_Format(st->PicklingError, + "Failed to get submodule name from %R", fromlist); return NULL; } module = PyObject_GetAttr(module, submodule_name); if (module == NULL) { - PyErr_Format(PyExc_ModuleNotFoundError, - "Attribute lookup %R on %R failed", submodule_name, - module_name); + PyErr_Format(st->PicklingError, "Attribute lookup %R on %R failed", + submodule_name, module_name); return NULL; } } @@ -3676,7 +3675,7 @@ save_global(PickleState *st, PicklerObject *self, PyObject *obj, if (module_name == NULL) goto error; - module = import_module_from_string(module_name); + module = import_module_from_string(st, module_name); if (module == NULL) { goto error; } @@ -7027,10 +7026,10 @@ _pickle_Unpickler_find_class_impl(UnpicklerObject *self, PyTypeObject *cls, /* Try to map the old names used in Python 2.x to the new ones used in Python 3.x. We do this only with old pickle protocols and when the user has not disabled the feature. */ + PickleState *st = _Pickle_GetStateByClass(cls); if (self->proto < 3 && self->fix_imports) { PyObject *key; PyObject *item; - PickleState *st = _Pickle_GetStateByClass(cls); /* Check if the global (i.e., a function or a class) was renamed or moved to another module. */ @@ -7083,7 +7082,7 @@ _pickle_Unpickler_find_class_impl(UnpicklerObject *self, PyTypeObject *cls, * we don't use PyImport_GetModule here, because it can return partially- * initialised modules, which then cause the getattribute to fail. */ - module = import_module_from_string(module_name); + module = import_module_from_string(st, module_name); if (module == NULL) { return NULL; } From 0a7806921cdbf0d2db48a76b90f9496eb79c3394 Mon Sep 17 00:00:00 2001 From: Matthew Hoffman Date: Sat, 18 May 2024 15:29:29 -0700 Subject: [PATCH 4/9] Add news blurb --- .../2024-05-18-15-19-15.gh-issue-87533.GyYvpT.rst | 1 + 1 file changed, 1 insertion(+) create mode 100644 Misc/NEWS.d/next/Core and Builtins/2024-05-18-15-19-15.gh-issue-87533.GyYvpT.rst diff --git a/Misc/NEWS.d/next/Core and Builtins/2024-05-18-15-19-15.gh-issue-87533.GyYvpT.rst b/Misc/NEWS.d/next/Core and Builtins/2024-05-18-15-19-15.gh-issue-87533.GyYvpT.rst new file mode 100644 index 00000000000000..b9ae9ed1cceecc --- /dev/null +++ b/Misc/NEWS.d/next/Core and Builtins/2024-05-18-15-19-15.gh-issue-87533.GyYvpT.rst @@ -0,0 +1 @@ +Extend :mod:`pickle` import behavior to support loading :c:func:`PyModule_Create`-generated ``module`` attributes. From 94ec9a84ecc255bd188952523f59c119c393a68f Mon Sep 17 00:00:00 2001 From: Matthew Hoffman Date: Sat, 18 May 2024 22:38:53 -0700 Subject: [PATCH 5/9] Fix datetime pickle tests Also remove incorrect notes and unnecessary steps --- Lib/test/test_datetime.py | 46 +++++++++++++++++++-------------------- 1 file changed, 22 insertions(+), 24 deletions(-) diff --git a/Lib/test/test_datetime.py b/Lib/test/test_datetime.py index 3859733a4fe65b..9d249b56715481 100644 --- a/Lib/test/test_datetime.py +++ b/Lib/test/test_datetime.py @@ -1,33 +1,27 @@ -import unittest import sys +import unittest from test.support.import_helper import import_fresh_module - TESTS = 'test.datetimetester' + def load_tests(loader, tests, pattern): - try: - pure_tests = import_fresh_module(TESTS, - fresh=['datetime', '_pydatetime', '_strptime'], - blocked=['_datetime']) - fast_tests = import_fresh_module(TESTS, - fresh=['datetime', '_strptime'], - blocked=['_pydatetime']) - finally: - # XXX: import_fresh_module() is supposed to leave sys.module cache untouched, - # XXX: but it does not, so we have to cleanup ourselves. - for modname in ['datetime', '_datetime', '_strptime']: - sys.modules.pop(modname, None) - - test_modules = [pure_tests, fast_tests] - test_suffixes = ["_Pure", "_Fast"] - # XXX(gb) First run all the _Pure tests, then all the _Fast tests. You might - # not believe this, but in spite of all the sys.modules trickery running a _Pure - # test last will leave a mix of pure and native datetime stuff lying around. - for module, suffix in zip(test_modules, test_suffixes): + python_impl_tests = import_fresh_module(TESTS, + fresh=['datetime', '_pydatetime', '_strptime'], + blocked=['_datetime']) + c_impl_tests = import_fresh_module(TESTS, fresh=['datetime', '_strptime'], + blocked=['_pydatetime']) + + test_module_map = { + '_Pure': python_impl_tests, + '_Fast': c_impl_tests, + } + + for suffix, module in test_module_map.items(): test_classes = [] - for name, cls in module.__dict__.items(): + + for cls in module.__dict__.values(): if not isinstance(cls, type): continue if issubclass(cls, unittest.TestCase): @@ -35,25 +29,29 @@ def load_tests(loader, tests, pattern): elif issubclass(cls, unittest.TestSuite): suit = cls() test_classes.extend(type(test) for test in suit) - test_classes = sorted(set(test_classes), key=lambda cls: cls.__qualname__) + for cls in test_classes: cls.__name__ += suffix cls.__qualname__ += suffix + @classmethod def setUpClass(cls_, module=module): cls_._save_sys_modules = sys.modules.copy() - sys.modules[TESTS] = module + sys.modules['test'].datetimetester = module # for pickle tests sys.modules['datetime'] = module.datetime_module if hasattr(module, '_pydatetime'): sys.modules['_pydatetime'] = module._pydatetime sys.modules['_strptime'] = module._strptime + @classmethod def tearDownClass(cls_): sys.modules.clear() sys.modules.update(cls_._save_sys_modules) + cls.setUpClass = setUpClass cls.tearDownClass = tearDownClass tests.addTests(loader.loadTestsFromTestCase(cls)) + return tests From 4cde8026d5792139141ea0fded3bd696ba318ea0 Mon Sep 17 00:00:00 2001 From: Matthew Hoffman Date: Sun, 19 May 2024 01:16:01 -0700 Subject: [PATCH 6/9] Use sys.modules before attempting to get the child module from the parent Otherwise you might set a value for the module in sys.modules, but have it go unused and instead be accessed through the parent --- Lib/pickle.py | 13 +++++----- Lib/test/test_datetime.py | 2 +- Modules/_pickle.c | 50 ++++++++++++++++++++++----------------- 3 files changed, 36 insertions(+), 29 deletions(-) diff --git a/Lib/pickle.py b/Lib/pickle.py index 65d9f9c2fc490f..4df4f083358cce 100644 --- a/Lib/pickle.py +++ b/Lib/pickle.py @@ -356,12 +356,13 @@ def import_module_from_string(module_name): >>> import_module_from_string('collections.abc) """ - module_name, *fromlist = module_name.rsplit(".", 1) - module = __import__(module_name, fromlist=fromlist, level=0) - if fromlist: - assert len(fromlist) == 1 - module = getattr(module, fromlist[0]) - return module + parent_module_name, *fromlist = module_name.rsplit(".", 1) + parent_module = __import__(parent_module_name, fromlist=fromlist, level=0) + if not fromlist: + return parent_module + assert len(fromlist) == 1 + return (sys.modules[module_name] if module_name in sys.modules + else getattr(parent_module, fromlist[0])) def encode_long(x): r"""Encode a long to a two's complement little-endian binary string. diff --git a/Lib/test/test_datetime.py b/Lib/test/test_datetime.py index 9d249b56715481..3748bcc0d9cee1 100644 --- a/Lib/test/test_datetime.py +++ b/Lib/test/test_datetime.py @@ -37,7 +37,7 @@ def load_tests(loader, tests, pattern): @classmethod def setUpClass(cls_, module=module): cls_._save_sys_modules = sys.modules.copy() - sys.modules['test'].datetimetester = module # for pickle tests + sys.modules[TESTS] = module # for pickle tests sys.modules['datetime'] = module.datetime_module if hasattr(module, '_pydatetime'): sys.modules['_pydatetime'] = module._pydatetime diff --git a/Modules/_pickle.c b/Modules/_pickle.c index bb0c4d1aa1793c..9b1351eb357d74 100644 --- a/Modules/_pickle.c +++ b/Modules/_pickle.c @@ -1923,10 +1923,9 @@ whichmodule(PyObject *global, PyObject *dotted_path) assert(module_name == NULL); /* Fallback on walking sys.modules */ - PyThreadState *tstate = _PyThreadState_GET(); - modules = _PySys_GetAttr(tstate, &_Py_ID(modules)); + modules = PyImport_GetModuleDict(); if (modules == NULL) { - PyErr_SetString(PyExc_RuntimeError, "unable to get sys.modules"); + PyErr_SetString(PyExc_RuntimeError, "Unable to get sys.modules"); return NULL; } if (PyDict_CheckExact(modules)) { @@ -3591,14 +3590,16 @@ import_module_from_string(PickleState *st, PyObject *module_name) PyUnicode_FromString("."), 1); if (split_module_name == NULL) { - PyErr_Format(st->PicklingError, - "Failed to split module name %R", module_name); + PyErr_Format(PyExc_RuntimeError, "Failed to split module name %R", + module_name); return NULL; } - module_name = PySequence_Fast_GET_ITEM(split_module_name, 0); - if (module_name == NULL) { - PyErr_Format(st->PicklingError, "Failed to get module name from %R", + PyObject *parent_module_name = PySequence_Fast_GET_ITEM(split_module_name, + 0); + if (parent_module_name == NULL) { + PyErr_Format(PyExc_RuntimeError, + "Failed to get parent module name from %R", split_module_name); return NULL; } @@ -3606,36 +3607,41 @@ import_module_from_string(PickleState *st, PyObject *module_name) PyObject *fromlist = PySequence_GetSlice(split_module_name, 1, PySequence_Fast_GET_SIZE(split_module_name)); if (fromlist == NULL) { - PyErr_Format(st->PicklingError, "Failed to get fromlist from %R", + PyErr_Format(PyExc_RuntimeError, "Failed to get fromlist from %R", split_module_name); return NULL; } - PyObject *module = PyImport_ImportModuleLevelObject(module_name, NULL, - NULL, fromlist, 0); - if (module == NULL) { + PyObject *parent_module = PyImport_ImportModuleLevelObject(parent_module_name, + NULL, NULL, + fromlist, 0); + if (parent_module == NULL) { PyErr_Format(st->PicklingError, "Import of module %R failed", - module_name); + parent_module_name); return NULL; } Py_ssize_t fromlist_size = PySequence_Fast_GET_SIZE(fromlist); - if (fromlist_size > 0) { - assert(fromlist_size == 1); - PyObject *submodule_name = PySequence_Fast_GET_ITEM(fromlist, 0); - if (submodule_name == NULL) { - PyErr_Format(st->PicklingError, - "Failed to get submodule name from %R", fromlist); + if (fromlist_size == 0) { + return parent_module; + } + assert(fromlist_size == 1); + + PyObject *module = PyImport_GetModule(module_name); + if (module == NULL) { + PyObject *child_module_name = PySequence_Fast_GET_ITEM(fromlist, 0); + if (child_module_name == NULL) { + PyErr_Format(PyExc_RuntimeError, "Failed to get item from %R", + fromlist); return NULL; } - module = PyObject_GetAttr(module, submodule_name); + module = PyObject_GetAttr(parent_module, child_module_name); if (module == NULL) { PyErr_Format(st->PicklingError, "Attribute lookup %R on %R failed", - submodule_name, module_name); + child_module_name, parent_module_name); return NULL; } } - return module; } From 556d62d49ec391dca02121954dae74ae5bf5e7f5 Mon Sep 17 00:00:00 2001 From: Matthew Hoffman Date: Sun, 19 May 2024 01:23:54 -0700 Subject: [PATCH 7/9] make regen-all --- Include/internal/pycore_global_objects_fini_generated.h | 1 - Include/internal/pycore_global_strings.h | 1 - Include/internal/pycore_runtime_init_generated.h | 1 - Include/internal/pycore_unicodeobject_generated.h | 3 --- 4 files changed, 6 deletions(-) diff --git a/Include/internal/pycore_global_objects_fini_generated.h b/Include/internal/pycore_global_objects_fini_generated.h index ca7355b2b61aa7..3651aca48cd844 100644 --- a/Include/internal/pycore_global_objects_fini_generated.h +++ b/Include/internal/pycore_global_objects_fini_generated.h @@ -1084,7 +1084,6 @@ _PyStaticObjects_CheckRefcnt(PyInterpreterState *interp) { _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(mode)); _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(module)); _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(module_globals)); - _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(modules)); _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(month)); _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(mro)); _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(msg)); diff --git a/Include/internal/pycore_global_strings.h b/Include/internal/pycore_global_strings.h index fbb25285f0f282..4481f69d2af90f 100644 --- a/Include/internal/pycore_global_strings.h +++ b/Include/internal/pycore_global_strings.h @@ -573,7 +573,6 @@ struct _Py_global_strings { STRUCT_FOR_ID(mode) STRUCT_FOR_ID(module) STRUCT_FOR_ID(module_globals) - STRUCT_FOR_ID(modules) STRUCT_FOR_ID(month) STRUCT_FOR_ID(mro) STRUCT_FOR_ID(msg) diff --git a/Include/internal/pycore_runtime_init_generated.h b/Include/internal/pycore_runtime_init_generated.h index 508da40c53422d..2279a333e39e9b 100644 --- a/Include/internal/pycore_runtime_init_generated.h +++ b/Include/internal/pycore_runtime_init_generated.h @@ -1082,7 +1082,6 @@ extern "C" { INIT_ID(mode), \ INIT_ID(module), \ INIT_ID(module_globals), \ - INIT_ID(modules), \ INIT_ID(month), \ INIT_ID(mro), \ INIT_ID(msg), \ diff --git a/Include/internal/pycore_unicodeobject_generated.h b/Include/internal/pycore_unicodeobject_generated.h index cc2fc15ac5cabf..2651a2212aa6cf 100644 --- a/Include/internal/pycore_unicodeobject_generated.h +++ b/Include/internal/pycore_unicodeobject_generated.h @@ -1560,9 +1560,6 @@ _PyUnicode_InitStaticStrings(PyInterpreterState *interp) { string = &_Py_ID(module_globals); assert(_PyUnicode_CheckConsistency(string, 1)); _PyUnicode_InternInPlace(interp, &string); - string = &_Py_ID(modules); - assert(_PyUnicode_CheckConsistency(string, 1)); - _PyUnicode_InternInPlace(interp, &string); string = &_Py_ID(month); assert(_PyUnicode_CheckConsistency(string, 1)); _PyUnicode_InternInPlace(interp, &string); From fafddb7d6ac93639eabb9792c10ce71565e23bc3 Mon Sep 17 00:00:00 2001 From: Matthew Hoffman Date: Sun, 19 May 2024 01:34:36 -0700 Subject: [PATCH 8/9] Undo datetime test changes --- Lib/test/test_datetime.py | 46 ++++++++++++++++++++------------------- 1 file changed, 24 insertions(+), 22 deletions(-) diff --git a/Lib/test/test_datetime.py b/Lib/test/test_datetime.py index 3748bcc0d9cee1..3859733a4fe65b 100644 --- a/Lib/test/test_datetime.py +++ b/Lib/test/test_datetime.py @@ -1,27 +1,33 @@ -import sys import unittest +import sys from test.support.import_helper import import_fresh_module -TESTS = 'test.datetimetester' +TESTS = 'test.datetimetester' def load_tests(loader, tests, pattern): - python_impl_tests = import_fresh_module(TESTS, - fresh=['datetime', '_pydatetime', '_strptime'], - blocked=['_datetime']) - c_impl_tests = import_fresh_module(TESTS, fresh=['datetime', '_strptime'], - blocked=['_pydatetime']) - - test_module_map = { - '_Pure': python_impl_tests, - '_Fast': c_impl_tests, - } - - for suffix, module in test_module_map.items(): + try: + pure_tests = import_fresh_module(TESTS, + fresh=['datetime', '_pydatetime', '_strptime'], + blocked=['_datetime']) + fast_tests = import_fresh_module(TESTS, + fresh=['datetime', '_strptime'], + blocked=['_pydatetime']) + finally: + # XXX: import_fresh_module() is supposed to leave sys.module cache untouched, + # XXX: but it does not, so we have to cleanup ourselves. + for modname in ['datetime', '_datetime', '_strptime']: + sys.modules.pop(modname, None) + + test_modules = [pure_tests, fast_tests] + test_suffixes = ["_Pure", "_Fast"] + # XXX(gb) First run all the _Pure tests, then all the _Fast tests. You might + # not believe this, but in spite of all the sys.modules trickery running a _Pure + # test last will leave a mix of pure and native datetime stuff lying around. + for module, suffix in zip(test_modules, test_suffixes): test_classes = [] - - for cls in module.__dict__.values(): + for name, cls in module.__dict__.items(): if not isinstance(cls, type): continue if issubclass(cls, unittest.TestCase): @@ -29,29 +35,25 @@ def load_tests(loader, tests, pattern): elif issubclass(cls, unittest.TestSuite): suit = cls() test_classes.extend(type(test) for test in suit) - + test_classes = sorted(set(test_classes), key=lambda cls: cls.__qualname__) for cls in test_classes: cls.__name__ += suffix cls.__qualname__ += suffix - @classmethod def setUpClass(cls_, module=module): cls_._save_sys_modules = sys.modules.copy() - sys.modules[TESTS] = module # for pickle tests + sys.modules[TESTS] = module sys.modules['datetime'] = module.datetime_module if hasattr(module, '_pydatetime'): sys.modules['_pydatetime'] = module._pydatetime sys.modules['_strptime'] = module._strptime - @classmethod def tearDownClass(cls_): sys.modules.clear() sys.modules.update(cls_._save_sys_modules) - cls.setUpClass = setUpClass cls.tearDownClass = tearDownClass tests.addTests(loader.loadTestsFromTestCase(cls)) - return tests From da5aff278a362222bb82945b6a98859133216f92 Mon Sep 17 00:00:00 2001 From: Matthew Hoffman Date: Sun, 19 May 2024 02:01:30 -0700 Subject: [PATCH 9/9] Undo whichmodule changes --- Include/internal/pycore_global_objects_fini_generated.h | 1 + Include/internal/pycore_global_strings.h | 1 + Include/internal/pycore_runtime_init_generated.h | 1 + Include/internal/pycore_unicodeobject_generated.h | 3 +++ Modules/_pickle.c | 5 +++-- 5 files changed, 9 insertions(+), 2 deletions(-) diff --git a/Include/internal/pycore_global_objects_fini_generated.h b/Include/internal/pycore_global_objects_fini_generated.h index 3651aca48cd844..ca7355b2b61aa7 100644 --- a/Include/internal/pycore_global_objects_fini_generated.h +++ b/Include/internal/pycore_global_objects_fini_generated.h @@ -1084,6 +1084,7 @@ _PyStaticObjects_CheckRefcnt(PyInterpreterState *interp) { _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(mode)); _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(module)); _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(module_globals)); + _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(modules)); _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(month)); _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(mro)); _PyStaticObject_CheckRefcnt((PyObject *)&_Py_ID(msg)); diff --git a/Include/internal/pycore_global_strings.h b/Include/internal/pycore_global_strings.h index 4481f69d2af90f..fbb25285f0f282 100644 --- a/Include/internal/pycore_global_strings.h +++ b/Include/internal/pycore_global_strings.h @@ -573,6 +573,7 @@ struct _Py_global_strings { STRUCT_FOR_ID(mode) STRUCT_FOR_ID(module) STRUCT_FOR_ID(module_globals) + STRUCT_FOR_ID(modules) STRUCT_FOR_ID(month) STRUCT_FOR_ID(mro) STRUCT_FOR_ID(msg) diff --git a/Include/internal/pycore_runtime_init_generated.h b/Include/internal/pycore_runtime_init_generated.h index 2279a333e39e9b..508da40c53422d 100644 --- a/Include/internal/pycore_runtime_init_generated.h +++ b/Include/internal/pycore_runtime_init_generated.h @@ -1082,6 +1082,7 @@ extern "C" { INIT_ID(mode), \ INIT_ID(module), \ INIT_ID(module_globals), \ + INIT_ID(modules), \ INIT_ID(month), \ INIT_ID(mro), \ INIT_ID(msg), \ diff --git a/Include/internal/pycore_unicodeobject_generated.h b/Include/internal/pycore_unicodeobject_generated.h index 2651a2212aa6cf..cc2fc15ac5cabf 100644 --- a/Include/internal/pycore_unicodeobject_generated.h +++ b/Include/internal/pycore_unicodeobject_generated.h @@ -1560,6 +1560,9 @@ _PyUnicode_InitStaticStrings(PyInterpreterState *interp) { string = &_Py_ID(module_globals); assert(_PyUnicode_CheckConsistency(string, 1)); _PyUnicode_InternInPlace(interp, &string); + string = &_Py_ID(modules); + assert(_PyUnicode_CheckConsistency(string, 1)); + _PyUnicode_InternInPlace(interp, &string); string = &_Py_ID(month); assert(_PyUnicode_CheckConsistency(string, 1)); _PyUnicode_InternInPlace(interp, &string); diff --git a/Modules/_pickle.c b/Modules/_pickle.c index 9b1351eb357d74..e936d3b4f20050 100644 --- a/Modules/_pickle.c +++ b/Modules/_pickle.c @@ -1923,9 +1923,10 @@ whichmodule(PyObject *global, PyObject *dotted_path) assert(module_name == NULL); /* Fallback on walking sys.modules */ - modules = PyImport_GetModuleDict(); + PyThreadState *tstate = _PyThreadState_GET(); + modules = _PySys_GetAttr(tstate, &_Py_ID(modules)); if (modules == NULL) { - PyErr_SetString(PyExc_RuntimeError, "Unable to get sys.modules"); + PyErr_SetString(PyExc_RuntimeError, "unable to get sys.modules"); return NULL; } if (PyDict_CheckExact(modules)) {