diff --git a/Lib/pickle.py b/Lib/pickle.py index 33c97c8c5efb28..4df4f083358cce 100644 --- a/Lib/pickle.py +++ b/Lib/pickle.py @@ -345,6 +345,25 @@ def whichmodule(obj, name): pass return '__main__' +def import_module_from_string(module_name): + """Import a module from a string. + + The last module in the dot-delimited module name is returned. The last + module need not be a package (see bpo-43367). + + >>> import_module_from_string('pickle') + + >>> import_module_from_string('collections.abc) + + """ + parent_module_name, *fromlist = module_name.rsplit(".", 1) + parent_module = __import__(parent_module_name, fromlist=fromlist, level=0) + if not fromlist: + return parent_module + assert len(fromlist) == 1 + return (sys.modules[module_name] if module_name in sys.modules + else getattr(parent_module, fromlist[0])) + def encode_long(x): r"""Encode a long to a two's complement little-endian binary string. Note that 0 is a special case, returning an empty string, to save a @@ -364,7 +383,6 @@ def encode_long(x): b'\x80' >>> encode_long(127) b'\x7f' - >>> """ if x == 0: return b'' @@ -1049,7 +1067,6 @@ def save_frozenset(self, obj): def save_global(self, obj, name=None): write = self.write - memo = self.memo if name is None: name = getattr(obj, '__qualname__', None) @@ -1058,8 +1075,7 @@ def save_global(self, obj, name=None): module_name = whichmodule(obj, name) try: - __import__(module_name, level=0) - module = sys.modules[module_name] + module = import_module_from_string(module_name) obj2, parent = _getattribute(module, name) except (ImportError, KeyError, AttributeError): raise PicklingError( @@ -1565,17 +1581,18 @@ def get_extension(self, code): def find_class(self, module, name): # Subclasses may override this. - sys.audit('pickle.find_class', module, name) + module_name = module + sys.audit('pickle.find_class', module_name, name) if self.proto < 3 and self.fix_imports: - if (module, name) in _compat_pickle.NAME_MAPPING: - module, name = _compat_pickle.NAME_MAPPING[(module, name)] - elif module in _compat_pickle.IMPORT_MAPPING: - module = _compat_pickle.IMPORT_MAPPING[module] - __import__(module, level=0) + if (module_name, name) in _compat_pickle.NAME_MAPPING: + module_name, name = _compat_pickle.NAME_MAPPING[(module_name, name)] + elif module_name in _compat_pickle.IMPORT_MAPPING: + module_name = _compat_pickle.IMPORT_MAPPING[module_name] + module = import_module_from_string(module_name) if self.proto >= 4: - return _getattribute(sys.modules[module], name)[0] + return _getattribute(module, name)[0] else: - return getattr(sys.modules[module], name) + return getattr(module, name) def load_reduce(self): stack = self.stack diff --git a/Lib/test/pickletester.py b/Lib/test/pickletester.py index 93e7dbbd103934..5e61be34f1e22c 100644 --- a/Lib/test/pickletester.py +++ b/Lib/test/pickletester.py @@ -2845,6 +2845,15 @@ class Subclass(tuple): class Nested(str): pass + # simulate a module created with PyModule_Create containing a function + global c_module + c_module = types.ModuleType("c_module") + def c_function(): + return None + c_function.__qualname__ = c_function.__name__ = "c_function" + c_function.__module__ = f"{__name__}.{c_module.__name__}" + c_module.c_function = c_function + c_methods = ( # bound built-in method ("abcd".index, ("c",)), @@ -2867,6 +2876,8 @@ class Nested(str): (Subclass.count, (Subclass([1,2,2]), 2)), (Subclass.Nested("sweet").count, ("e",)), (Subclass.Nested.count, (Subclass.Nested("sweet"), "e")), + # bpo-43367: pickling C-module attributes + (c_module.c_function, ()), ) for proto in range(pickle.HIGHEST_PROTOCOL + 1): for method, args in c_methods: diff --git a/Misc/NEWS.d/next/Core and Builtins/2024-05-18-15-19-15.gh-issue-87533.GyYvpT.rst b/Misc/NEWS.d/next/Core and Builtins/2024-05-18-15-19-15.gh-issue-87533.GyYvpT.rst new file mode 100644 index 00000000000000..b9ae9ed1cceecc --- /dev/null +++ b/Misc/NEWS.d/next/Core and Builtins/2024-05-18-15-19-15.gh-issue-87533.GyYvpT.rst @@ -0,0 +1 @@ +Extend :mod:`pickle` import behavior to support loading :c:func:`PyModule_Create`-generated ``module`` attributes. diff --git a/Modules/_pickle.c b/Modules/_pickle.c index 754a326822e0f0..e936d3b4f20050 100644 --- a/Modules/_pickle.c +++ b/Modules/_pickle.c @@ -3584,6 +3584,69 @@ fix_imports(PickleState *st, PyObject **module_name, PyObject **global_name) return 0; } +PyObject* +import_module_from_string(PickleState *st, PyObject *module_name) +{ + PyObject *split_module_name = PyUnicode_RSplit(module_name, + PyUnicode_FromString("."), + 1); + if (split_module_name == NULL) { + PyErr_Format(PyExc_RuntimeError, "Failed to split module name %R", + module_name); + return NULL; + } + + PyObject *parent_module_name = PySequence_Fast_GET_ITEM(split_module_name, + 0); + if (parent_module_name == NULL) { + PyErr_Format(PyExc_RuntimeError, + "Failed to get parent module name from %R", + split_module_name); + return NULL; + } + + PyObject *fromlist = PySequence_GetSlice(split_module_name, 1, + PySequence_Fast_GET_SIZE(split_module_name)); + if (fromlist == NULL) { + PyErr_Format(PyExc_RuntimeError, "Failed to get fromlist from %R", + split_module_name); + return NULL; + } + + PyObject *parent_module = PyImport_ImportModuleLevelObject(parent_module_name, + NULL, NULL, + fromlist, 0); + if (parent_module == NULL) { + PyErr_Format(st->PicklingError, "Import of module %R failed", + parent_module_name); + return NULL; + } + + Py_ssize_t fromlist_size = PySequence_Fast_GET_SIZE(fromlist); + if (fromlist_size == 0) { + return parent_module; + } + assert(fromlist_size == 1); + + PyObject *module = PyImport_GetModule(module_name); + if (module == NULL) { + PyObject *child_module_name = PySequence_Fast_GET_ITEM(fromlist, 0); + if (child_module_name == NULL) { + PyErr_Format(PyExc_RuntimeError, "Failed to get item from %R", + fromlist); + return NULL; + } + module = PyObject_GetAttr(parent_module, child_module_name); + if (module == NULL) { + PyErr_Format(st->PicklingError, "Attribute lookup %R on %R failed", + child_module_name, parent_module_name); + return NULL; + } + } + return module; +} + + static int save_global(PickleState *st, PicklerObject *self, PyObject *obj, PyObject *name) @@ -3619,21 +3682,11 @@ save_global(PickleState *st, PicklerObject *self, PyObject *obj, if (module_name == NULL) goto error; - /* XXX: Change to use the import C API directly with level=0 to disallow - relative imports. - - XXX: PyImport_ImportModuleLevel could be used. However, this bypasses - builtins.__import__. Therefore, _pickle, unlike pickle.py, will ignore - custom import functions (IMHO, this would be a nice security - feature). The import C API would need to be extended to support the - extra parameters of __import__ to fix that. */ - module = PyImport_Import(module_name); + module = import_module_from_string(st, module_name); if (module == NULL) { - PyErr_Format(st->PicklingError, - "Can't pickle %R: import of module %R failed", - obj, module_name); goto error; } + lastname = Py_NewRef(PyList_GET_ITEM(dotted_path, PyList_GET_SIZE(dotted_path) - 1)); cls = get_deep_attribute(module, dotted_path, &parent); @@ -6980,10 +7033,10 @@ _pickle_Unpickler_find_class_impl(UnpicklerObject *self, PyTypeObject *cls, /* Try to map the old names used in Python 2.x to the new ones used in Python 3.x. We do this only with old pickle protocols and when the user has not disabled the feature. */ + PickleState *st = _Pickle_GetStateByClass(cls); if (self->proto < 3 && self->fix_imports) { PyObject *key; PyObject *item; - PickleState *st = _Pickle_GetStateByClass(cls); /* Check if the global (i.e., a function or a class) was renamed or moved to another module. */ @@ -7036,7 +7089,7 @@ _pickle_Unpickler_find_class_impl(UnpicklerObject *self, PyTypeObject *cls, * we don't use PyImport_GetModule here, because it can return partially- * initialised modules, which then cause the getattribute to fail. */ - module = PyImport_Import(module_name); + module = import_module_from_string(st, module_name); if (module == NULL) { return NULL; }