Skip to content

gh-87533: Expand pickle importing to support non-package C-modules #119152

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 29 additions & 12 deletions Lib/pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,25 @@ def whichmodule(obj, name):
pass
return '__main__'

def import_module_from_string(module_name):
"""Import a module from a string.

The last module in the dot-delimited module name is returned. The last
module need not be a package (see bpo-43367).

>>> import_module_from_string('pickle')
<module 'pickle' from '...'>
>>> import_module_from_string('collections.abc)
<module 'collections.abc' from '...'>
"""
parent_module_name, *fromlist = module_name.rsplit(".", 1)
parent_module = __import__(parent_module_name, fromlist=fromlist, level=0)
if not fromlist:
return parent_module
assert len(fromlist) == 1
return (sys.modules[module_name] if module_name in sys.modules
else getattr(parent_module, fromlist[0]))

def encode_long(x):
r"""Encode a long to a two's complement little-endian binary string.
Note that 0 is a special case, returning an empty string, to save a
Expand All @@ -364,7 +383,6 @@ def encode_long(x):
b'\x80'
>>> encode_long(127)
b'\x7f'
>>>
"""
if x == 0:
return b''
Expand Down Expand Up @@ -1049,7 +1067,6 @@ def save_frozenset(self, obj):

def save_global(self, obj, name=None):
write = self.write
memo = self.memo

if name is None:
name = getattr(obj, '__qualname__', None)
Expand All @@ -1058,8 +1075,7 @@ def save_global(self, obj, name=None):

module_name = whichmodule(obj, name)
try:
__import__(module_name, level=0)
module = sys.modules[module_name]
module = import_module_from_string(module_name)
obj2, parent = _getattribute(module, name)
except (ImportError, KeyError, AttributeError):
raise PicklingError(
Expand Down Expand Up @@ -1565,17 +1581,18 @@ def get_extension(self, code):

def find_class(self, module, name):
# Subclasses may override this.
sys.audit('pickle.find_class', module, name)
module_name = module
sys.audit('pickle.find_class', module_name, name)
if self.proto < 3 and self.fix_imports:
if (module, name) in _compat_pickle.NAME_MAPPING:
module, name = _compat_pickle.NAME_MAPPING[(module, name)]
elif module in _compat_pickle.IMPORT_MAPPING:
module = _compat_pickle.IMPORT_MAPPING[module]
__import__(module, level=0)
if (module_name, name) in _compat_pickle.NAME_MAPPING:
module_name, name = _compat_pickle.NAME_MAPPING[(module_name, name)]
elif module_name in _compat_pickle.IMPORT_MAPPING:
module_name = _compat_pickle.IMPORT_MAPPING[module_name]
module = import_module_from_string(module_name)
if self.proto >= 4:
return _getattribute(sys.modules[module], name)[0]
return _getattribute(module, name)[0]
else:
return getattr(sys.modules[module], name)
return getattr(module, name)

def load_reduce(self):
stack = self.stack
Expand Down
11 changes: 11 additions & 0 deletions Lib/test/pickletester.py
Original file line number Diff line number Diff line change
Expand Up @@ -2845,6 +2845,15 @@ class Subclass(tuple):
class Nested(str):
pass

# simulate a module created with PyModule_Create containing a function
global c_module
c_module = types.ModuleType("c_module")
def c_function():
return None
c_function.__qualname__ = c_function.__name__ = "c_function"
c_function.__module__ = f"{__name__}.{c_module.__name__}"
c_module.c_function = c_function

c_methods = (
# bound built-in method
("abcd".index, ("c",)),
Expand All @@ -2867,6 +2876,8 @@ class Nested(str):
(Subclass.count, (Subclass([1,2,2]), 2)),
(Subclass.Nested("sweet").count, ("e",)),
(Subclass.Nested.count, (Subclass.Nested("sweet"), "e")),
# bpo-43367: pickling C-module attributes
(c_module.c_function, ()),
)
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
for method, args in c_methods:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Extend :mod:`pickle` import behavior to support loading :c:func:`PyModule_Create`-generated ``module`` attributes.
81 changes: 67 additions & 14 deletions Modules/_pickle.c
Original file line number Diff line number Diff line change
Expand Up @@ -3584,6 +3584,69 @@ fix_imports(PickleState *st, PyObject **module_name, PyObject **global_name)
return 0;
}

PyObject*
import_module_from_string(PickleState *st, PyObject *module_name)
{
PyObject *split_module_name = PyUnicode_RSplit(module_name,
PyUnicode_FromString("."),
1);
if (split_module_name == NULL) {
PyErr_Format(PyExc_RuntimeError, "Failed to split module name %R",
module_name);
return NULL;
}

PyObject *parent_module_name = PySequence_Fast_GET_ITEM(split_module_name,
0);
if (parent_module_name == NULL) {
PyErr_Format(PyExc_RuntimeError,
"Failed to get parent module name from %R",
split_module_name);
return NULL;
}

PyObject *fromlist = PySequence_GetSlice(split_module_name, 1,
PySequence_Fast_GET_SIZE(split_module_name));
if (fromlist == NULL) {
PyErr_Format(PyExc_RuntimeError, "Failed to get fromlist from %R",
split_module_name);
return NULL;
}

PyObject *parent_module = PyImport_ImportModuleLevelObject(parent_module_name,
NULL, NULL,
fromlist, 0);
if (parent_module == NULL) {
PyErr_Format(st->PicklingError, "Import of module %R failed",
parent_module_name);
return NULL;
}

Py_ssize_t fromlist_size = PySequence_Fast_GET_SIZE(fromlist);
if (fromlist_size == 0) {
return parent_module;
}
assert(fromlist_size == 1);

PyObject *module = PyImport_GetModule(module_name);
if (module == NULL) {
PyObject *child_module_name = PySequence_Fast_GET_ITEM(fromlist, 0);
if (child_module_name == NULL) {
PyErr_Format(PyExc_RuntimeError, "Failed to get item from %R",
fromlist);
return NULL;
}
module = PyObject_GetAttr(parent_module, child_module_name);
if (module == NULL) {
PyErr_Format(st->PicklingError, "Attribute lookup %R on %R failed",
child_module_name, parent_module_name);
return NULL;
}
}
return module;
}


static int
save_global(PickleState *st, PicklerObject *self, PyObject *obj,
PyObject *name)
Expand Down Expand Up @@ -3619,21 +3682,11 @@ save_global(PickleState *st, PicklerObject *self, PyObject *obj,
if (module_name == NULL)
goto error;

/* XXX: Change to use the import C API directly with level=0 to disallow
relative imports.

XXX: PyImport_ImportModuleLevel could be used. However, this bypasses
builtins.__import__. Therefore, _pickle, unlike pickle.py, will ignore
custom import functions (IMHO, this would be a nice security
feature). The import C API would need to be extended to support the
extra parameters of __import__ to fix that. */
module = PyImport_Import(module_name);
module = import_module_from_string(st, module_name);
if (module == NULL) {
PyErr_Format(st->PicklingError,
"Can't pickle %R: import of module %R failed",
obj, module_name);
goto error;
}

lastname = Py_NewRef(PyList_GET_ITEM(dotted_path,
PyList_GET_SIZE(dotted_path) - 1));
cls = get_deep_attribute(module, dotted_path, &parent);
Expand Down Expand Up @@ -6980,10 +7033,10 @@ _pickle_Unpickler_find_class_impl(UnpicklerObject *self, PyTypeObject *cls,
/* Try to map the old names used in Python 2.x to the new ones used in
Python 3.x. We do this only with old pickle protocols and when the
user has not disabled the feature. */
PickleState *st = _Pickle_GetStateByClass(cls);
if (self->proto < 3 && self->fix_imports) {
PyObject *key;
PyObject *item;
PickleState *st = _Pickle_GetStateByClass(cls);

/* Check if the global (i.e., a function or a class) was renamed
or moved to another module. */
Expand Down Expand Up @@ -7036,7 +7089,7 @@ _pickle_Unpickler_find_class_impl(UnpicklerObject *self, PyTypeObject *cls,
* we don't use PyImport_GetModule here, because it can return partially-
* initialised modules, which then cause the getattribute to fail.
*/
module = PyImport_Import(module_name);
module = import_module_from_string(st, module_name);
if (module == NULL) {
return NULL;
}
Expand Down