Skip to content

Commit 52a2a20

Browse files
author
Matthew Hoffman
committed
gh-87533: Expand pickle importing to support non-package C-modules
There have been recurring issue with PyModule_Create modules in PyTorch; when trying to serialize attributes of these C-modules, pickle fails to import the C-module because it is not a package This is the current issue that brought this to my attention: pytorch/pytorch#126154 The existing hack to this issue has been to insert the C-module into sys.modules in order to enable pickle to find them: https://github.com/pytorch/pytorch/pull/38136/files#diff-d7e90d0f94b43db763b44fba679a5c1b4cabe3668aaf34f2aee07de8e2d1b2faR524-R528 Instead of relying on this hack, we can change `pickle`'s approach to loading, which is currently equivalent to `import package.c_module`; instead, we could do `from package import c_module`, which 1) does not care if `c_module` is a package or not 2) is fully backward compatible with the previous approach and 3) slots in nicely to the `fromlist` parameter of `__import__`, which we are already using to load modules in `pickle`
1 parent 31a28cb commit 52a2a20

File tree

3 files changed

+96
-25
lines changed

3 files changed

+96
-25
lines changed

Lib/pickle.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,24 @@ def whichmodule(obj, name):
345345
pass
346346
return '__main__'
347347

348+
def import_module_from_string(module_name):
349+
"""Import a module from a string.
350+
351+
The last module in the dot-delimited module name is returned. The last
352+
module need not be a package (see bpo-43367).
353+
354+
>>> import_module_from_string('pickle')
355+
<module 'pickle' from '...'>
356+
>>> import_module_from_string('collections.abc)
357+
<module 'collections.abc' from '...'>
358+
"""
359+
module_name, *fromlist = module_name.rsplit(".", 1)
360+
module = __import__(module_name, fromlist=fromlist, level=0)
361+
if fromlist:
362+
assert len(fromlist) == 1
363+
module = getattr(module, fromlist[0])
364+
return module
365+
348366
def encode_long(x):
349367
r"""Encode a long to a two's complement little-endian binary string.
350368
Note that 0 is a special case, returning an empty string, to save a
@@ -364,7 +382,6 @@ def encode_long(x):
364382
b'\x80'
365383
>>> encode_long(127)
366384
b'\x7f'
367-
>>>
368385
"""
369386
if x == 0:
370387
return b''
@@ -1049,7 +1066,6 @@ def save_frozenset(self, obj):
10491066

10501067
def save_global(self, obj, name=None):
10511068
write = self.write
1052-
memo = self.memo
10531069

10541070
if name is None:
10551071
name = getattr(obj, '__qualname__', None)
@@ -1058,8 +1074,7 @@ def save_global(self, obj, name=None):
10581074

10591075
module_name = whichmodule(obj, name)
10601076
try:
1061-
__import__(module_name, level=0)
1062-
module = sys.modules[module_name]
1077+
module = import_module_from_string(module_name)
10631078
obj2, parent = _getattribute(module, name)
10641079
except (ImportError, KeyError, AttributeError):
10651080
raise PicklingError(
@@ -1565,17 +1580,18 @@ def get_extension(self, code):
15651580

15661581
def find_class(self, module, name):
15671582
# Subclasses may override this.
1568-
sys.audit('pickle.find_class', module, name)
1583+
module_name = module
1584+
sys.audit('pickle.find_class', module_name, name)
15691585
if self.proto < 3 and self.fix_imports:
1570-
if (module, name) in _compat_pickle.NAME_MAPPING:
1571-
module, name = _compat_pickle.NAME_MAPPING[(module, name)]
1572-
elif module in _compat_pickle.IMPORT_MAPPING:
1573-
module = _compat_pickle.IMPORT_MAPPING[module]
1574-
__import__(module, level=0)
1586+
if (module_name, name) in _compat_pickle.NAME_MAPPING:
1587+
module_name, name = _compat_pickle.NAME_MAPPING[(module_name, name)]
1588+
elif module_name in _compat_pickle.IMPORT_MAPPING:
1589+
module_name = _compat_pickle.IMPORT_MAPPING[module_name]
1590+
module = import_module_from_string(module_name)
15751591
if self.proto >= 4:
1576-
return _getattribute(sys.modules[module], name)[0]
1592+
return _getattribute(module, name)[0]
15771593
else:
1578-
return getattr(sys.modules[module], name)
1594+
return getattr(module, name)
15791595

15801596
def load_reduce(self):
15811597
stack = self.stack

Lib/test/pickletester.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2845,6 +2845,12 @@ class Subclass(tuple):
28452845
class Nested(str):
28462846
pass
28472847

2848+
c_module = types.ModuleType("c_module")
2849+
def c_function():
2850+
return None
2851+
c_function.__module__ = f"{__name__}.{c_module.__name__}"
2852+
c_module.c_function = c_function
2853+
28482854
c_methods = (
28492855
# bound built-in method
28502856
("abcd".index, ("c",)),
@@ -2867,6 +2873,8 @@ class Nested(str):
28672873
(Subclass.count, (Subclass([1,2,2]), 2)),
28682874
(Subclass.Nested("sweet").count, ("e",)),
28692875
(Subclass.Nested.count, (Subclass.Nested("sweet"), "e")),
2876+
# bpo-43367: pickling C-module attributes
2877+
(c_module.c_function, ()),
28702878
)
28712879
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
28722880
for method, args in c_methods:

Modules/_pickle.c

Lines changed: 60 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3584,6 +3584,63 @@ fix_imports(PickleState *st, PyObject **module_name, PyObject **global_name)
35843584
return 0;
35853585
}
35863586

3587+
PyObject*
3588+
import_module_from_string(PyObject *module_name)
3589+
{
3590+
PyObject *split_module_name = PyUnicode_RSplit(module_name,
3591+
PyUnicode_FromString("."),
3592+
1);
3593+
if (split_module_name == NULL) {
3594+
PyErr_Format(PyExc_RuntimeError,
3595+
"Failed to split module name %R", module_name);
3596+
return NULL;
3597+
}
3598+
3599+
module_name = PySequence_Fast_GET_ITEM(split_module_name, 0);
3600+
if (module_name == NULL) {
3601+
PyErr_Format(PyExc_RuntimeError, "Failed to get module name from %R",
3602+
split_module_name);
3603+
return NULL;
3604+
}
3605+
3606+
PyObject *fromlist = PySequence_GetSlice(split_module_name, 1,
3607+
PySequence_Fast_GET_SIZE(split_module_name));
3608+
if (fromlist == NULL) {
3609+
PyErr_Format(PyExc_RuntimeError, "Failed to get fromlist from %R",
3610+
split_module_name);
3611+
return NULL;
3612+
}
3613+
3614+
PyObject *module = PyImport_ImportModuleLevelObject(module_name, NULL,
3615+
NULL, fromlist, 0);
3616+
if (module == NULL) {
3617+
PyErr_Format(PyExc_ModuleNotFoundError, "Import of module %R failed",
3618+
module_name);
3619+
return NULL;
3620+
}
3621+
3622+
Py_ssize_t fromlist_size = PySequence_Fast_GET_SIZE(fromlist);
3623+
if (fromlist_size > 0) {
3624+
assert(fromlist_size == 1);
3625+
PyObject *submodule_name = PySequence_Fast_GET_ITEM(fromlist, 0);
3626+
if (submodule_name == NULL) {
3627+
PyErr_Format(PyExc_RuntimeError, "Failed to get submodule name from %R",
3628+
fromlist);
3629+
return NULL;
3630+
}
3631+
module = PyObject_GetAttr(module, submodule_name);
3632+
if (module == NULL) {
3633+
PyErr_Format(PyExc_ModuleNotFoundError,
3634+
"Attribute lookup %R on %R failed", submodule_name,
3635+
module_name);
3636+
return NULL;
3637+
}
3638+
}
3639+
3640+
return module;
3641+
}
3642+
3643+
35873644
static int
35883645
save_global(PickleState *st, PicklerObject *self, PyObject *obj,
35893646
PyObject *name)
@@ -3619,21 +3676,11 @@ save_global(PickleState *st, PicklerObject *self, PyObject *obj,
36193676
if (module_name == NULL)
36203677
goto error;
36213678

3622-
/* XXX: Change to use the import C API directly with level=0 to disallow
3623-
relative imports.
3624-
3625-
XXX: PyImport_ImportModuleLevel could be used. However, this bypasses
3626-
builtins.__import__. Therefore, _pickle, unlike pickle.py, will ignore
3627-
custom import functions (IMHO, this would be a nice security
3628-
feature). The import C API would need to be extended to support the
3629-
extra parameters of __import__ to fix that. */
3630-
module = PyImport_Import(module_name);
3679+
module = import_module_from_string(module_name);
36313680
if (module == NULL) {
3632-
PyErr_Format(st->PicklingError,
3633-
"Can't pickle %R: import of module %R failed",
3634-
obj, module_name);
36353681
goto error;
36363682
}
3683+
36373684
lastname = Py_NewRef(PyList_GET_ITEM(dotted_path,
36383685
PyList_GET_SIZE(dotted_path) - 1));
36393686
cls = get_deep_attribute(module, dotted_path, &parent);
@@ -7036,7 +7083,7 @@ _pickle_Unpickler_find_class_impl(UnpicklerObject *self, PyTypeObject *cls,
70367083
* we don't use PyImport_GetModule here, because it can return partially-
70377084
* initialised modules, which then cause the getattribute to fail.
70387085
*/
7039-
module = PyImport_Import(module_name);
7086+
module = import_module_from_string(module_name);
70407087
if (module == NULL) {
70417088
return NULL;
70427089
}

0 commit comments

Comments
 (0)